179 lines
4.3 KiB
Go
179 lines
4.3 KiB
Go
package listeners
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
"github.com/mochi-co/mqtt/server/listeners/auth"
|
|
"github.com/mochi-co/mqtt/server/system"
|
|
)
|
|
|
|
var (
|
|
// ErrInvalidMessage indicates that a message payload was not valid.
|
|
ErrInvalidMessage = errors.New("message type not binary")
|
|
|
|
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
|
|
// websocket compliant connection.
|
|
wsUpgrader = &websocket.Upgrader{
|
|
Subprotocols: []string{"mqtt"},
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
)
|
|
|
|
// Websocket is a listener for establishing websocket connections.
|
|
type Websocket struct {
|
|
sync.RWMutex
|
|
id string // the internal id of the listener.
|
|
address string // the network address to bind to.
|
|
config *Config // configuration values for the listener.
|
|
listen *http.Server // an http server for serving websocket connections.
|
|
establish EstablishFunc // the server's establish connection handler.
|
|
end uint32 // ensure the close methods are only called once.
|
|
}
|
|
|
|
// wsConn is a websocket connection which satisfies the net.Conn interface.
|
|
// Inspired by
|
|
type wsConn struct {
|
|
net.Conn
|
|
c *websocket.Conn
|
|
}
|
|
|
|
// Read reads the next span of bytes from the websocket connection and returns
|
|
// the number of bytes read.
|
|
func (ws *wsConn) Read(p []byte) (n int, err error) {
|
|
op, r, err := ws.c.NextReader()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if op != websocket.BinaryMessage {
|
|
err = ErrInvalidMessage
|
|
return
|
|
}
|
|
|
|
return r.Read(p)
|
|
}
|
|
|
|
// Write writes bytes to the websocket connection.
|
|
func (ws *wsConn) Write(p []byte) (n int, err error) {
|
|
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
return len(p), nil
|
|
}
|
|
|
|
// Close signals the underlying websocket conn to close.
|
|
func (ws *wsConn) Close() error {
|
|
return ws.Conn.Close()
|
|
}
|
|
|
|
// NewWebsocket initialises and returns a new Websocket listener, listening on an address.
|
|
func NewWebsocket(id, address string) *Websocket {
|
|
return &Websocket{
|
|
id: id,
|
|
address: address,
|
|
config: &Config{
|
|
Auth: new(auth.Allow),
|
|
TLS: new(TLS),
|
|
},
|
|
}
|
|
}
|
|
|
|
// SetConfig sets the configuration values for the listener config.
|
|
func (l *Websocket) SetConfig(config *Config) {
|
|
l.Lock()
|
|
if config != nil {
|
|
l.config = config
|
|
|
|
// If a config has been passed without an auth controller,
|
|
// it may be a mistake, so disallow all traffic.
|
|
if l.config.Auth == nil {
|
|
l.config.Auth = new(auth.Disallow)
|
|
}
|
|
}
|
|
|
|
l.Unlock()
|
|
}
|
|
|
|
// ID returns the id of the listener.
|
|
func (l *Websocket) ID() string {
|
|
l.RLock()
|
|
id := l.id
|
|
l.RUnlock()
|
|
return id
|
|
}
|
|
|
|
// Listen starts listening on the listener's network address.
|
|
func (l *Websocket) Listen(s *system.Info) error {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/", l.handler)
|
|
l.listen = &http.Server{
|
|
Addr: l.address,
|
|
Handler: mux,
|
|
}
|
|
|
|
// The following logic is deprecated in favour of passing through the tls.Config
|
|
// value directly, however it remains in order to provide backwards compatibility.
|
|
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
|
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
|
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
l.listen.TLSConfig = &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
}
|
|
} else {
|
|
l.listen.TLSConfig = l.config.TLSConfig
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
|
|
c, err := wsUpgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer c.Close()
|
|
|
|
l.establish(l.id, &wsConn{c.UnderlyingConn(), c}, l.config.Auth)
|
|
}
|
|
|
|
// Serve starts waiting for new Websocket connections, and calls the connection
|
|
// establishment callback for any received.
|
|
func (l *Websocket) Serve(establish EstablishFunc) {
|
|
l.establish = establish
|
|
|
|
if l.listen.TLSConfig != nil {
|
|
l.listen.ListenAndServeTLS("", "")
|
|
} else {
|
|
l.listen.ListenAndServe()
|
|
}
|
|
}
|
|
|
|
// Close closes the listener and any client connections.
|
|
func (l *Websocket) Close(closeClients CloseFunc) {
|
|
l.Lock()
|
|
defer l.Unlock()
|
|
|
|
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
l.listen.Shutdown(ctx)
|
|
}
|
|
|
|
closeClients(l.id)
|
|
}
|