clone
This commit is contained in:
12
server/listeners/auth/auth.go
Normal file
12
server/listeners/auth/auth.go
Normal file
@ -0,0 +1,12 @@
|
||||
package auth
|
||||
|
||||
// Controller is an interface for authentication controllers.
|
||||
type Controller interface {
|
||||
|
||||
// Authenticate authenticates a user on CONNECT and returns true if a user is
|
||||
// allowed to join the server.
|
||||
Authenticate(user, password []byte) bool
|
||||
|
||||
// ACL returns true if a user has read or write access to a given topic.
|
||||
ACL(user []byte, topic string, write bool) bool
|
||||
}
|
31
server/listeners/auth/defaults.go
Normal file
31
server/listeners/auth/defaults.go
Normal file
@ -0,0 +1,31 @@
|
||||
package auth
|
||||
|
||||
// Allow is an auth controller which allows access to all connections and topics.
|
||||
type Allow struct{}
|
||||
|
||||
// Authenticate returns true if a username and password are acceptable. Allow always
|
||||
// returns true.
|
||||
func (a *Allow) Authenticate(user, password []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ACL returns true if a user has access permissions to read or write on a topic.
|
||||
// Allow always returns true.
|
||||
func (a *Allow) ACL(user []byte, topic string, write bool) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Disallow is an auth controller which disallows access to all connections and topics.
|
||||
type Disallow struct{}
|
||||
|
||||
// Authenticate returns true if a username and password are acceptable. Disallow always
|
||||
// returns false.
|
||||
func (d *Disallow) Authenticate(user, password []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ACL returns true if a user has access permissions to read or write on a topic.
|
||||
// Disallow always returns false.
|
||||
func (d *Disallow) ACL(user []byte, topic string, write bool) bool {
|
||||
return false
|
||||
}
|
55
server/listeners/auth/defaults_test.go
Normal file
55
server/listeners/auth/defaults_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAllowAuth(t *testing.T) {
|
||||
ac := new(Allow)
|
||||
require.Equal(t, true, ac.Authenticate([]byte("user"), []byte("pass")))
|
||||
}
|
||||
|
||||
func BenchmarkAllowAuth(b *testing.B) {
|
||||
ac := new(Allow)
|
||||
for n := 0; n < b.N; n++ {
|
||||
ac.Authenticate([]byte("user"), []byte("pass"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllowACL(t *testing.T) {
|
||||
ac := new(Allow)
|
||||
require.Equal(t, true, ac.ACL([]byte("user"), "topic", true))
|
||||
}
|
||||
|
||||
func BenchmarkAllowACL(b *testing.B) {
|
||||
ac := new(Allow)
|
||||
for n := 0; n < b.N; n++ {
|
||||
ac.ACL([]byte("user"), "pass", true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisallowAuth(t *testing.T) {
|
||||
ac := new(Disallow)
|
||||
require.Equal(t, false, ac.Authenticate([]byte("user"), []byte("pass")))
|
||||
}
|
||||
|
||||
func BenchmarkDisallowAuth(b *testing.B) {
|
||||
ac := new(Disallow)
|
||||
for n := 0; n < b.N; n++ {
|
||||
ac.Authenticate([]byte("user"), []byte("pass"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisallowACL(t *testing.T) {
|
||||
ac := new(Disallow)
|
||||
require.Equal(t, false, ac.ACL([]byte("user"), "topic", true))
|
||||
}
|
||||
|
||||
func BenchmarkDisallowACL(b *testing.B) {
|
||||
ac := new(Disallow)
|
||||
for n := 0; n < b.N; n++ {
|
||||
ac.ACL([]byte("user"), "pass", true)
|
||||
}
|
||||
}
|
124
server/listeners/http_sysinfo.go
Normal file
124
server/listeners/http_sysinfo.go
Normal file
@ -0,0 +1,124 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/mochi-co/mqtt/server/system"
|
||||
)
|
||||
|
||||
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
|
||||
type HTTPStats struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
address string // the network address to bind to.
|
||||
config *Config // configuration values for the listener.
|
||||
system *system.Info // pointers to the server data.
|
||||
listen *http.Server // the http server.
|
||||
end uint32 // ensure the close methods are only called once.}
|
||||
}
|
||||
|
||||
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
|
||||
func NewHTTPStats(id, address string) *HTTPStats {
|
||||
return &HTTPStats{
|
||||
id: id,
|
||||
address: address,
|
||||
config: &Config{
|
||||
Auth: new(auth.Allow),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig sets the configuration values for the listener config.
|
||||
func (l *HTTPStats) SetConfig(config *Config) {
|
||||
l.Lock()
|
||||
if config != nil {
|
||||
l.config = config
|
||||
|
||||
// If a config has been passed without an auth controller,
|
||||
// it may be a mistake, so disallow all traffic.
|
||||
if l.config.Auth == nil {
|
||||
l.config.Auth = new(auth.Disallow)
|
||||
}
|
||||
}
|
||||
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *HTTPStats) ID() string {
|
||||
l.RLock()
|
||||
id := l.id
|
||||
l.RUnlock()
|
||||
return id
|
||||
}
|
||||
|
||||
// Listen starts listening on the listener's network address.
|
||||
func (l *HTTPStats) Listen(s *system.Info) error {
|
||||
l.system = s
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.jsonHandler)
|
||||
l.listen = &http.Server{
|
||||
Addr: l.address,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.listen.TLSConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
} else {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts listening for new connections and serving responses.
|
||||
func (l *HTTPStats) Serve(establish EstablishFunc) {
|
||||
if l.listen.TLSConfig != nil {
|
||||
l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
l.listen.ListenAndServe()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *HTTPStats) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
|
||||
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
|
||||
info, err := json.MarshalIndent(l.system, "", "\t")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Write(info)
|
||||
}
|
198
server/listeners/http_sysinfo_test.go
Normal file
198
server/listeners/http_sysinfo_test.go
Normal file
@ -0,0 +1,198 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/mochi-co/mqtt/server/system"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewHTTPStats(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testPort, l.address)
|
||||
}
|
||||
|
||||
func BenchmarkNewHTTPStats(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
NewHTTPStats("t1", testPort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPStatsSetConfig(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
})
|
||||
require.NotNil(t, l.config)
|
||||
require.NotNil(t, l.config.Auth)
|
||||
require.Equal(t, new(auth.Allow), l.config.Auth)
|
||||
|
||||
// Switch to disallow on bad config set.
|
||||
l.SetConfig(new(Config))
|
||||
require.NotNil(t, l.config)
|
||||
require.NotNil(t, l.config.Auth)
|
||||
require.Equal(t, new(auth.Disallow), l.config.Auth)
|
||||
}
|
||||
|
||||
func BenchmarkHTTPStatsSetConfig(b *testing.B) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.SetConfig(new(Config))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPStatsID(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func BenchmarkHTTPStatsID(b *testing.B) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.ID()
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPStatsListen(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
err := l.Listen(new(system.Info))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, l.system)
|
||||
require.NotNil(t, l.listen)
|
||||
require.Equal(t, testPort, l.listen.Addr)
|
||||
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestHTTPStatsListenTLSConfig(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(new(system.Info))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen.TLSConfig)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestHTTPStatsListenTLS(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
})
|
||||
err := l.Listen(new(system.Info))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen.TLSConfig)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestHTTPStatsListenTLSInvalid(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &TLS{
|
||||
Certificate: []byte("abcde"),
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
})
|
||||
err := l.Listen(new(system.Info))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHTTPStatsServeAndClose(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
err := l.Listen(&system.Info{
|
||||
Version: "test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
resp, err := http.Get("http://localhost" + testPort)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
v := new(system.Info)
|
||||
err = json.Unmarshal(body, v)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", v.Version)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
|
||||
_, err = http.Get("http://localhost" + testPort)
|
||||
require.Error(t, err)
|
||||
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestHTTPStatsServeTLSAndClose(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
})
|
||||
err := l.Listen(&system.Info{
|
||||
Version: "test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
time.Sleep(time.Millisecond)
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
||||
func TestHTTPStatsJSONHandler(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
err := l.Listen(&system.Info{
|
||||
Version: "test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
l.jsonHandler(w, nil)
|
||||
resp := w.Result()
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
|
||||
v := new(system.Info)
|
||||
err = json.Unmarshal(body, v)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", v.Version)
|
||||
}
|
152
server/listeners/listeners.go
Normal file
152
server/listeners/listeners.go
Normal file
@ -0,0 +1,152 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/mochi-co/mqtt/server/system"
|
||||
)
|
||||
|
||||
// Config contains configuration values for a listener.
|
||||
type Config struct {
|
||||
// Auth controller containing auth and ACL logic for
|
||||
// allowing or denying access to the server and topics.
|
||||
Auth auth.Controller
|
||||
|
||||
// TLS certficates and settings for the connection.
|
||||
//
|
||||
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
|
||||
// Please use TLSConfig instead.
|
||||
TLS *TLS
|
||||
|
||||
// TLSConfig is a tls.Config configuration to be used with the listener.
|
||||
// See examples folder for basic and mutual-tls use.
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// TLS contains the TLS certificates and settings for the listener connection.
|
||||
//
|
||||
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
|
||||
// Please use TLSConfig instead.
|
||||
type TLS struct {
|
||||
Certificate []byte // the body of a public certificate.
|
||||
PrivateKey []byte // the body of a private key.
|
||||
}
|
||||
|
||||
// EstablishFunc is a callback function for establishing new clients.
|
||||
type EstablishFunc func(id string, c net.Conn, ac auth.Controller) error
|
||||
|
||||
// CloseFunc is a callback function for closing all listener clients.
|
||||
type CloseFunc func(id string)
|
||||
|
||||
// Listener is an interface for network listeners. A network listener listens
|
||||
// for incoming client connections and adds them to the server.
|
||||
type Listener interface {
|
||||
SetConfig(*Config) // set the listener config.
|
||||
Listen(s *system.Info) error // open the network address.
|
||||
Serve(EstablishFunc) // starting actively listening for new connections.
|
||||
ID() string // return the id of the listener.
|
||||
Close(CloseFunc) // stop and close the listener.
|
||||
}
|
||||
|
||||
// Listeners contains the network listeners for the broker.
|
||||
type Listeners struct {
|
||||
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
|
||||
internal map[string]Listener // a map of active listeners.
|
||||
system *system.Info // pointers to system info.
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// New returns a new instance of Listeners.
|
||||
func New(s *system.Info) *Listeners {
|
||||
return &Listeners{
|
||||
internal: map[string]Listener{},
|
||||
system: s,
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new listener to the listeners map, keyed on id.
|
||||
func (l *Listeners) Add(val Listener) {
|
||||
l.Lock()
|
||||
l.internal[val.ID()] = val
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
// Get returns the value of a listener if it exists.
|
||||
func (l *Listeners) Get(id string) (Listener, bool) {
|
||||
l.RLock()
|
||||
val, ok := l.internal[id]
|
||||
l.RUnlock()
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the length of the listeners map.
|
||||
func (l *Listeners) Len() int {
|
||||
l.RLock()
|
||||
val := len(l.internal)
|
||||
l.RUnlock()
|
||||
return val
|
||||
}
|
||||
|
||||
// Delete removes a listener from the internal map.
|
||||
func (l *Listeners) Delete(id string) {
|
||||
l.Lock()
|
||||
delete(l.internal, id)
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
// Serve starts a listener serving from the internal map.
|
||||
func (l *Listeners) Serve(id string, establisher EstablishFunc) {
|
||||
l.RLock()
|
||||
listener := l.internal[id]
|
||||
l.RUnlock()
|
||||
|
||||
go func(e EstablishFunc) {
|
||||
defer l.wg.Done()
|
||||
l.wg.Add(1)
|
||||
listener.Serve(e)
|
||||
}(establisher)
|
||||
}
|
||||
|
||||
// ServeAll starts all listeners serving from the internal map.
|
||||
func (l *Listeners) ServeAll(establisher EstablishFunc) {
|
||||
l.RLock()
|
||||
i := 0
|
||||
ids := make([]string, len(l.internal))
|
||||
for id := range l.internal {
|
||||
ids[i] = id
|
||||
i++
|
||||
}
|
||||
l.RUnlock()
|
||||
|
||||
for _, id := range ids {
|
||||
l.Serve(id, establisher)
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops a listener from the internal map.
|
||||
func (l *Listeners) Close(id string, closer CloseFunc) {
|
||||
l.RLock()
|
||||
listener := l.internal[id]
|
||||
l.RUnlock()
|
||||
listener.Close(closer)
|
||||
}
|
||||
|
||||
// CloseAll iterates and closes all registered listeners.
|
||||
func (l *Listeners) CloseAll(closer CloseFunc) {
|
||||
l.RLock()
|
||||
i := 0
|
||||
ids := make([]string, len(l.internal))
|
||||
for id := range l.internal {
|
||||
ids[i] = id
|
||||
i++
|
||||
}
|
||||
l.RUnlock()
|
||||
|
||||
for _, id := range ids {
|
||||
l.Close(id, closer)
|
||||
}
|
||||
l.wg.Wait()
|
||||
}
|
245
server/listeners/listeners_test.go
Normal file
245
server/listeners/listeners_test.go
Normal file
@ -0,0 +1,245 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
|
||||
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
|
||||
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
|
||||
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
|
||||
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
|
||||
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
|
||||
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
|
||||
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
|
||||
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
|
||||
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
|
||||
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
|
||||
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
|
||||
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
|
||||
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
|
||||
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
|
||||
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
|
||||
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
|
||||
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
|
||||
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
|
||||
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
|
||||
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
|
||||
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
|
||||
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
|
||||
tlsConfigBasic *tls.Config
|
||||
)
|
||||
|
||||
func init() {
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfigBasic = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
l := New(nil)
|
||||
require.NotNil(t, l.internal)
|
||||
}
|
||||
|
||||
func BenchmarkNewListeners(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
New(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddListener(t *testing.T) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
}
|
||||
|
||||
func BenchmarkAddListener(b *testing.B) {
|
||||
l := New(nil)
|
||||
mocked := NewMockListener("t1", ":1882")
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.Add(mocked)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetListener(t *testing.T) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
l.Add(NewMockListener("t2", ":1882"))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
require.Contains(t, l.internal, "t2")
|
||||
|
||||
g, ok := l.Get("t1")
|
||||
require.Equal(t, true, ok)
|
||||
require.Equal(t, g.ID(), "t1")
|
||||
}
|
||||
|
||||
func BenchmarkGetListener(b *testing.B) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.Get("t1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLenListener(t *testing.T) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
l.Add(NewMockListener("t2", ":1882"))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
require.Contains(t, l.internal, "t2")
|
||||
require.Equal(t, 2, l.Len())
|
||||
}
|
||||
|
||||
func BenchmarkLenListener(b *testing.B) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.Len()
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteListener(t *testing.T) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
|
||||
l.Delete("t1")
|
||||
_, ok := l.Get("t1")
|
||||
require.Equal(t, false, ok)
|
||||
require.Nil(t, l.internal["t1"])
|
||||
}
|
||||
|
||||
func BenchmarkDeleteListener(b *testing.B) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.Delete("t1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeListener(t *testing.T) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
l.Serve("t1", MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing())
|
||||
|
||||
l.Close("t1", MockCloser)
|
||||
require.Equal(t, false, l.internal["t1"].(*MockListener).IsServing())
|
||||
}
|
||||
|
||||
func BenchmarkServeListener(b *testing.B) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.Serve("t1", MockEstablisher)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeAllListeners(t *testing.T) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
l.Add(NewMockListener("t2", ":1882"))
|
||||
l.Add(NewMockListener("t3", ":1882"))
|
||||
l.ServeAll(MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.Equal(t, true, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.Equal(t, true, l.internal["t3"].(*MockListener).IsServing())
|
||||
|
||||
l.Close("t1", MockCloser)
|
||||
l.Close("t2", MockCloser)
|
||||
l.Close("t3", MockCloser)
|
||||
|
||||
require.Equal(t, false, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.Equal(t, false, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.Equal(t, false, l.internal["t3"].(*MockListener).IsServing())
|
||||
}
|
||||
|
||||
func BenchmarkServeAllListeners(b *testing.B) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
l.Add(NewMockListener("t2", ":1883"))
|
||||
l.Add(NewMockListener("t3", ":1884"))
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.ServeAll(MockEstablisher)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseListener(t *testing.T) {
|
||||
l := New(nil)
|
||||
mocked := NewMockListener("t1", ":1882")
|
||||
l.Add(mocked)
|
||||
l.Serve("t1", MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
var closed bool
|
||||
l.Close("t1", func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
||||
func BenchmarkCloseListener(b *testing.B) {
|
||||
l := New(nil)
|
||||
mocked := NewMockListener("t1", ":1882")
|
||||
l.Add(mocked)
|
||||
l.Serve("t1", MockEstablisher)
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.internal["t1"].(*MockListener).done = make(chan bool)
|
||||
l.Close("t1", MockCloser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseAllListeners(t *testing.T) {
|
||||
l := New(nil)
|
||||
l.Add(NewMockListener("t1", ":1882"))
|
||||
l.Add(NewMockListener("t2", ":1882"))
|
||||
l.Add(NewMockListener("t3", ":1882"))
|
||||
l.ServeAll(MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.Equal(t, true, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.Equal(t, true, l.internal["t3"].(*MockListener).IsServing())
|
||||
|
||||
closed := make(map[string]bool)
|
||||
l.CloseAll(func(id string) {
|
||||
closed[id] = true
|
||||
})
|
||||
require.Contains(t, closed, "t1")
|
||||
require.Contains(t, closed, "t2")
|
||||
require.Contains(t, closed, "t3")
|
||||
require.Equal(t, true, closed["t1"])
|
||||
require.Equal(t, true, closed["t2"])
|
||||
require.Equal(t, true, closed["t3"])
|
||||
}
|
||||
|
||||
func BenchmarkCloseAllListeners(b *testing.B) {
|
||||
l := New(nil)
|
||||
mocked := NewMockListener("t1", ":1882")
|
||||
l.Add(mocked)
|
||||
l.Serve("t1", MockEstablisher)
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.internal["t1"].(*MockListener).done = make(chan bool)
|
||||
l.Close("t1", MockCloser)
|
||||
}
|
||||
}
|
100
server/listeners/mock.go
Normal file
100
server/listeners/mock.go
Normal file
@ -0,0 +1,100 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/mochi-co/mqtt/server/system"
|
||||
)
|
||||
|
||||
// MockCloser is a function signature which can be used in testing.
|
||||
func MockCloser(id string) {}
|
||||
|
||||
// MockEstablisher is a function signature which can be used in testing.
|
||||
func MockEstablisher(id string, c net.Conn, ac auth.Controller) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockListener is a mock listener for establishing client connections.
|
||||
type MockListener struct {
|
||||
sync.RWMutex
|
||||
id string // the id of the listener.
|
||||
address string // the network address the listener binds to.
|
||||
Config *Config // configuration for the listener.
|
||||
done chan bool // indicate the listener is done.
|
||||
Serving bool // indicate the listener is serving.
|
||||
Listening bool // indiciate the listener is listening.
|
||||
ErrListen bool // throw an error on listen.
|
||||
}
|
||||
|
||||
// NewMockListener returns a new instance of MockListener
|
||||
func NewMockListener(id, address string) *MockListener {
|
||||
return &MockListener{
|
||||
id: id,
|
||||
address: address,
|
||||
done: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Serve serves the mock listener.
|
||||
func (l *MockListener) Serve(establisher EstablishFunc) {
|
||||
l.Lock()
|
||||
l.Serving = true
|
||||
l.Unlock()
|
||||
for range l.done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Listen begins listening for incoming traffic.
|
||||
func (l *MockListener) Listen(s *system.Info) error {
|
||||
if l.ErrListen {
|
||||
return fmt.Errorf("listen failure")
|
||||
}
|
||||
|
||||
l.Lock()
|
||||
l.Listening = true
|
||||
l.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetConfig sets the configuration values of the mock listener.
|
||||
func (l *MockListener) SetConfig(config *Config) {
|
||||
l.Lock()
|
||||
l.Config = config
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
// ID returns the id of the mock listener.
|
||||
func (l *MockListener) ID() string {
|
||||
l.RLock()
|
||||
id := l.id
|
||||
l.RUnlock()
|
||||
return id
|
||||
}
|
||||
|
||||
// Close closes the mock listener.
|
||||
func (l *MockListener) Close(closer CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.Serving = false
|
||||
closer(l.id)
|
||||
close(l.done)
|
||||
}
|
||||
|
||||
// IsServing indicates whether the mock listener is serving.
|
||||
func (l *MockListener) IsServing() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.Serving
|
||||
}
|
||||
|
||||
// IsListening indicates whether the mock listener is listening.
|
||||
func (l *MockListener) IsListening() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.Listening
|
||||
}
|
89
server/listeners/mock_test.go
Normal file
89
server/listeners/mock_test.go
Normal file
@ -0,0 +1,89 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
)
|
||||
|
||||
func TestMockEstablisher(t *testing.T) {
|
||||
_, w := net.Pipe()
|
||||
err := MockEstablisher("t1", w, new(auth.Allow))
|
||||
require.NoError(t, err)
|
||||
w.Close()
|
||||
}
|
||||
|
||||
func TestNewMockListener(t *testing.T) {
|
||||
mocked := NewMockListener("t1", ":1882")
|
||||
require.Equal(t, "t1", mocked.id)
|
||||
require.Equal(t, ":1882", mocked.address)
|
||||
}
|
||||
|
||||
func TestNewMockListenerListen(t *testing.T) {
|
||||
mocked := NewMockListener("t1", ":1882")
|
||||
require.Equal(t, "t1", mocked.id)
|
||||
require.Equal(t, ":1882", mocked.address)
|
||||
|
||||
require.Equal(t, false, mocked.IsListening())
|
||||
err := mocked.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, mocked.IsListening())
|
||||
}
|
||||
func TestNewMockListenerListenFailure(t *testing.T) {
|
||||
mocked := NewMockListener("t1", ":1882")
|
||||
mocked.ErrListen = true
|
||||
err := mocked.Listen(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMockListenerServe(t *testing.T) {
|
||||
mocked := NewMockListener("t1", ":1882")
|
||||
require.Equal(t, false, mocked.IsServing())
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
mocked.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond) // easy non-channel wait for start of serving
|
||||
require.Equal(t, true, mocked.IsServing())
|
||||
|
||||
var closed bool
|
||||
mocked.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
|
||||
mocked.Listen(nil)
|
||||
}
|
||||
|
||||
func TestMockListenerSetConfig(t *testing.T) {
|
||||
mocked := NewMockListener("t1", ":1883")
|
||||
mocked.SetConfig(new(Config))
|
||||
require.NotNil(t, mocked.Config)
|
||||
}
|
||||
|
||||
func TestMockListenerClose(t *testing.T) {
|
||||
mocked := NewMockListener("t1", ":1882")
|
||||
var closed bool
|
||||
mocked.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
||||
func TestNewMockListenerIsListening(t *testing.T) {
|
||||
mocked := NewMockListener("t1", ":1882")
|
||||
require.Equal(t, false, mocked.IsListening())
|
||||
}
|
||||
|
||||
func TestNewMockListenerIsServing(t *testing.T) {
|
||||
mocked := NewMockListener("t1", ":1882")
|
||||
require.Equal(t, false, mocked.IsServing())
|
||||
}
|
127
server/listeners/tcp.go
Normal file
127
server/listeners/tcp.go
Normal file
@ -0,0 +1,127 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/mochi-co/mqtt/server/system"
|
||||
)
|
||||
|
||||
// TCP is a listener for establishing client connections on basic TCP protocol.
|
||||
type TCP struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
protocol string // the TCP protocol to use.
|
||||
address string // the network address to bind to.
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
config *Config // configuration values for the listener.
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// NewTCP initialises and returns a new TCP listener, listening on an address.
|
||||
func NewTCP(id, address string) *TCP {
|
||||
return &TCP{
|
||||
id: id,
|
||||
protocol: "tcp",
|
||||
address: address,
|
||||
config: &Config{ // default configuration.
|
||||
Auth: new(auth.Allow),
|
||||
TLS: new(TLS),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig sets the configuration values for the listener config.
|
||||
func (l *TCP) SetConfig(config *Config) {
|
||||
l.Lock()
|
||||
if config != nil {
|
||||
l.config = config
|
||||
|
||||
// If a config has been passed without an auth controller,
|
||||
// it may be a mistake, so disallow all traffic.
|
||||
if l.config.Auth == nil {
|
||||
l.config.Auth = new(auth.Disallow)
|
||||
}
|
||||
}
|
||||
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *TCP) ID() string {
|
||||
l.RLock()
|
||||
id := l.id
|
||||
l.RUnlock()
|
||||
return id
|
||||
}
|
||||
|
||||
// Listen starts listening on the listener's network address.
|
||||
func (l *TCP) Listen(s *system.Info) error {
|
||||
var err error
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
} else if l.config.TLSConfig != nil {
|
||||
l.listen, err = tls.Listen(l.protocol, l.address, l.config.TLSConfig)
|
||||
} else {
|
||||
l.listen, err = net.Listen(l.protocol, l.address)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts waiting for new TCP connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *TCP) Serve(establish EstablishFunc) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listen.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
_ = establish(l.id, conn, l.config.Auth)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *TCP) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listen != nil {
|
||||
err := l.listen.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
228
server/listeners/tcp_test.go
Normal file
228
server/listeners/tcp_test.go
Normal file
@ -0,0 +1,228 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
testPort = ":22222"
|
||||
)
|
||||
|
||||
func TestNewTCP(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testPort, l.address)
|
||||
}
|
||||
|
||||
func BenchmarkNewTCP(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
NewTCP("t1", testPort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPSetConfig(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
})
|
||||
require.NotNil(t, l.config)
|
||||
require.NotNil(t, l.config.Auth)
|
||||
require.Equal(t, new(auth.Allow), l.config.Auth)
|
||||
|
||||
// Switch to disallow on bad config set.
|
||||
l.SetConfig(new(Config))
|
||||
require.NotNil(t, l.config)
|
||||
require.NotNil(t, l.config.Auth)
|
||||
require.Equal(t, new(auth.Disallow), l.config.Auth)
|
||||
}
|
||||
|
||||
func BenchmarkTCPSetConfig(b *testing.B) {
|
||||
l := NewTCP("t1", testPort)
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.SetConfig(new(Config))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPID(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func BenchmarkTCPID(b *testing.B) {
|
||||
l := NewTCP("t1", testPort)
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.ID()
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPListen(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewTCP("t2", testPort)
|
||||
err = l2.Listen(nil)
|
||||
require.Error(t, err)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestTCPListenTLSConfig(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestTCPListenTLS(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestTCPListenTLSInvalid(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &TLS{
|
||||
Certificate: []byte("abcde"),
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTCPServeAndClose(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
time.Sleep(time.Millisecond)
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestTCPServeTLSAndClose(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
time.Sleep(time.Millisecond)
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestTCPCloseError(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
l.listen.Close()
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestTCPServeEnd(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
l.Close(MockCloser)
|
||||
l.Serve(func(id string, c net.Conn, ac auth.Controller) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPEstablishThenError(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn, ac auth.Controller) error {
|
||||
established <- true
|
||||
return errors.New("testing") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial(l.protocol, l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestTCPEstablishButEnding(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
l.end = 1
|
||||
|
||||
o := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn, ac auth.Controller) error {
|
||||
return nil
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
net.Dial(l.protocol, l.listen.Addr().String())
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
|
||||
}
|
178
server/listeners/websocket.go
Normal file
178
server/listeners/websocket.go
Normal file
@ -0,0 +1,178 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/mochi-co/mqtt/server/system"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidMessage indicates that a message payload was not valid.
|
||||
ErrInvalidMessage = errors.New("message type not binary")
|
||||
|
||||
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
|
||||
// websocket compliant connection.
|
||||
wsUpgrader = &websocket.Upgrader{
|
||||
Subprotocols: []string{"mqtt"},
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
)
|
||||
|
||||
// Websocket is a listener for establishing websocket connections.
|
||||
type Websocket struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
address string // the network address to bind to.
|
||||
config *Config // configuration values for the listener.
|
||||
listen *http.Server // an http server for serving websocket connections.
|
||||
establish EstablishFunc // the server's establish connection handler.
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// wsConn is a websocket connection which satisfies the net.Conn interface.
|
||||
// Inspired by
|
||||
type wsConn struct {
|
||||
net.Conn
|
||||
c *websocket.Conn
|
||||
}
|
||||
|
||||
// Read reads the next span of bytes from the websocket connection and returns
|
||||
// the number of bytes read.
|
||||
func (ws *wsConn) Read(p []byte) (n int, err error) {
|
||||
op, r, err := ws.c.NextReader()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if op != websocket.BinaryMessage {
|
||||
err = ErrInvalidMessage
|
||||
return
|
||||
}
|
||||
|
||||
return r.Read(p)
|
||||
}
|
||||
|
||||
// Write writes bytes to the websocket connection.
|
||||
func (ws *wsConn) Write(p []byte) (n int, err error) {
|
||||
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Close signals the underlying websocket conn to close.
|
||||
func (ws *wsConn) Close() error {
|
||||
return ws.Conn.Close()
|
||||
}
|
||||
|
||||
// NewWebsocket initialises and returns a new Websocket listener, listening on an address.
|
||||
func NewWebsocket(id, address string) *Websocket {
|
||||
return &Websocket{
|
||||
id: id,
|
||||
address: address,
|
||||
config: &Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: new(TLS),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig sets the configuration values for the listener config.
|
||||
func (l *Websocket) SetConfig(config *Config) {
|
||||
l.Lock()
|
||||
if config != nil {
|
||||
l.config = config
|
||||
|
||||
// If a config has been passed without an auth controller,
|
||||
// it may be a mistake, so disallow all traffic.
|
||||
if l.config.Auth == nil {
|
||||
l.config.Auth = new(auth.Disallow)
|
||||
}
|
||||
}
|
||||
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *Websocket) ID() string {
|
||||
l.RLock()
|
||||
id := l.id
|
||||
l.RUnlock()
|
||||
return id
|
||||
}
|
||||
|
||||
// Listen starts listening on the listener's network address.
|
||||
func (l *Websocket) Listen(s *system.Info) error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.handler)
|
||||
l.listen = &http.Server{
|
||||
Addr: l.address,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.listen.TLSConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
} else {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := wsUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
l.establish(l.id, &wsConn{c.UnderlyingConn(), c}, l.config.Auth)
|
||||
}
|
||||
|
||||
// Serve starts waiting for new Websocket connections, and calls the connection
|
||||
// establishment callback for any received.
|
||||
func (l *Websocket) Serve(establish EstablishFunc) {
|
||||
l.establish = establish
|
||||
|
||||
if l.listen.TLSConfig != nil {
|
||||
l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
l.listen.ListenAndServe()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *Websocket) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
}
|
181
server/listeners/websocket_test.go
Normal file
181
server/listeners/websocket_test.go
Normal file
@ -0,0 +1,181 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWsConnClose(t *testing.T) {
|
||||
r, _ := net.Pipe()
|
||||
ws := &wsConn{r, new(websocket.Conn)}
|
||||
err := ws.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNewWebsocket(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testPort, l.address)
|
||||
}
|
||||
|
||||
func BenchmarkNewWebsocket(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
NewWebsocket("t1", testPort)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketSetConfig(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
})
|
||||
require.NotNil(t, l.config)
|
||||
require.NotNil(t, l.config.Auth)
|
||||
require.Equal(t, new(auth.Allow), l.config.Auth)
|
||||
|
||||
// Switch to disallow on bad config set.
|
||||
l.SetConfig(new(Config))
|
||||
require.NotNil(t, l.config)
|
||||
require.NotNil(t, l.config.Auth)
|
||||
require.Equal(t, new(auth.Disallow), l.config.Auth)
|
||||
}
|
||||
|
||||
func BenchmarkWebsocketSetConfig(b *testing.B) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.SetConfig(new(Config))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketID(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func BenchmarkWebsocketID(b *testing.B) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
for n := 0; n < b.N; n++ {
|
||||
l.ID()
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketListen(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
require.Nil(t, l.listen)
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen)
|
||||
}
|
||||
|
||||
func TestWebsocketListenTLSConfig(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen.TLSConfig)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestWebsocketListenTLS(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen.TLSConfig)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestWebsocketListenTLSInvalid(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &TLS{
|
||||
Certificate: []byte("abcde"),
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWebsocketServeAndClose(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.Listen(nil)
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestWebsocketServeTLSAndClose(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
||||
func TestWebsocketUpgrade(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.Listen(nil)
|
||||
e := make(chan bool)
|
||||
l.establish = func(id string, c net.Conn, ac auth.Controller) error {
|
||||
e <- true
|
||||
return nil
|
||||
}
|
||||
s := httptest.NewServer(http.HandlerFunc(l.handler))
|
||||
u := "ws" + strings.TrimPrefix(s.URL, "http")
|
||||
ws, _, err := websocket.DefaultDialer.Dial(u, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, <-e)
|
||||
|
||||
s.Close()
|
||||
ws.Close()
|
||||
|
||||
}
|
Reference in New Issue
Block a user