mqtt/server/listeners/websocket.go

179 lines
4.3 KiB
Go
Raw Normal View History

2022-08-15 23:06:20 +03:00
package listeners
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
var (
// ErrInvalidMessage indicates that a message payload was not valid.
ErrInvalidMessage = errors.New("message type not binary")
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
// websocket compliant connection.
wsUpgrader = &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
CheckOrigin: func(r *http.Request) bool { return true },
}
)
// Websocket is a listener for establishing websocket connections.
type Websocket struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
config *Config // configuration values for the listener.
listen *http.Server // an http server for serving websocket connections.
establish EstablishFunc // the server's establish connection handler.
end uint32 // ensure the close methods are only called once.
}
// wsConn is a websocket connection which satisfies the net.Conn interface.
// Inspired by
type wsConn struct {
net.Conn
c *websocket.Conn
}
// Read reads the next span of bytes from the websocket connection and returns
// the number of bytes read.
func (ws *wsConn) Read(p []byte) (n int, err error) {
op, r, err := ws.c.NextReader()
if err != nil {
return
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return
}
return r.Read(p)
}
// Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (n int, err error) {
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return
}
return len(p), nil
}
// Close signals the underlying websocket conn to close.
func (ws *wsConn) Close() error {
return ws.Conn.Close()
}
// NewWebsocket initialises and returns a new Websocket listener, listening on an address.
func NewWebsocket(id, address string) *Websocket {
return &Websocket{
id: id,
address: address,
config: &Config{
Auth: new(auth.Allow),
TLS: new(TLS),
},
}
}
// SetConfig sets the configuration values for the listener config.
func (l *Websocket) SetConfig(config *Config) {
l.Lock()
if config != nil {
l.config = config
// If a config has been passed without an auth controller,
// it may be a mistake, so disallow all traffic.
if l.config.Auth == nil {
l.config.Auth = new(auth.Disallow)
}
}
l.Unlock()
}
// ID returns the id of the listener.
func (l *Websocket) ID() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
// Listen starts listening on the listener's network address.
func (l *Websocket) Listen(s *system.Info) error {
mux := http.NewServeMux()
mux.HandleFunc("/", l.handler)
l.listen = &http.Server{
Addr: l.address,
Handler: mux,
}
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
} else {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil
}
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
c, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
l.establish(l.id, &wsConn{c.UnderlyingConn(), c}, l.config.Auth)
}
// Serve starts waiting for new Websocket connections, and calls the connection
// establishment callback for any received.
func (l *Websocket) Serve(establish EstablishFunc) {
l.establish = establish
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *Websocket) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}