This commit is contained in:
Кобелев Андрей Андреевич
2022-08-15 23:06:20 +03:00
commit 4ea1a3802b
532 changed files with 211522 additions and 0 deletions

View File

@ -0,0 +1,12 @@
package auth
// Controller is an interface for authentication controllers.
type Controller interface {
// Authenticate authenticates a user on CONNECT and returns true if a user is
// allowed to join the server.
Authenticate(user, password []byte) bool
// ACL returns true if a user has read or write access to a given topic.
ACL(user []byte, topic string, write bool) bool
}

View File

@ -0,0 +1,31 @@
package auth
// Allow is an auth controller which allows access to all connections and topics.
type Allow struct{}
// Authenticate returns true if a username and password are acceptable. Allow always
// returns true.
func (a *Allow) Authenticate(user, password []byte) bool {
return true
}
// ACL returns true if a user has access permissions to read or write on a topic.
// Allow always returns true.
func (a *Allow) ACL(user []byte, topic string, write bool) bool {
return true
}
// Disallow is an auth controller which disallows access to all connections and topics.
type Disallow struct{}
// Authenticate returns true if a username and password are acceptable. Disallow always
// returns false.
func (d *Disallow) Authenticate(user, password []byte) bool {
return false
}
// ACL returns true if a user has access permissions to read or write on a topic.
// Disallow always returns false.
func (d *Disallow) ACL(user []byte, topic string, write bool) bool {
return false
}

View File

@ -0,0 +1,55 @@
package auth
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAllowAuth(t *testing.T) {
ac := new(Allow)
require.Equal(t, true, ac.Authenticate([]byte("user"), []byte("pass")))
}
func BenchmarkAllowAuth(b *testing.B) {
ac := new(Allow)
for n := 0; n < b.N; n++ {
ac.Authenticate([]byte("user"), []byte("pass"))
}
}
func TestAllowACL(t *testing.T) {
ac := new(Allow)
require.Equal(t, true, ac.ACL([]byte("user"), "topic", true))
}
func BenchmarkAllowACL(b *testing.B) {
ac := new(Allow)
for n := 0; n < b.N; n++ {
ac.ACL([]byte("user"), "pass", true)
}
}
func TestDisallowAuth(t *testing.T) {
ac := new(Disallow)
require.Equal(t, false, ac.Authenticate([]byte("user"), []byte("pass")))
}
func BenchmarkDisallowAuth(b *testing.B) {
ac := new(Disallow)
for n := 0; n < b.N; n++ {
ac.Authenticate([]byte("user"), []byte("pass"))
}
}
func TestDisallowACL(t *testing.T) {
ac := new(Disallow)
require.Equal(t, false, ac.ACL([]byte("user"), "topic", true))
}
func BenchmarkDisallowACL(b *testing.B) {
ac := new(Disallow)
for n := 0; n < b.N; n++ {
ac.ACL([]byte("user"), "pass", true)
}
}

View File

@ -0,0 +1,124 @@
package listeners
import (
"context"
"crypto/tls"
"encoding/json"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
type HTTPStats struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
config *Config // configuration values for the listener.
system *system.Info // pointers to the server data.
listen *http.Server // the http server.
end uint32 // ensure the close methods are only called once.}
}
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
func NewHTTPStats(id, address string) *HTTPStats {
return &HTTPStats{
id: id,
address: address,
config: &Config{
Auth: new(auth.Allow),
},
}
}
// SetConfig sets the configuration values for the listener config.
func (l *HTTPStats) SetConfig(config *Config) {
l.Lock()
if config != nil {
l.config = config
// If a config has been passed without an auth controller,
// it may be a mistake, so disallow all traffic.
if l.config.Auth == nil {
l.config.Auth = new(auth.Disallow)
}
}
l.Unlock()
}
// ID returns the id of the listener.
func (l *HTTPStats) ID() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
// Listen starts listening on the listener's network address.
func (l *HTTPStats) Listen(s *system.Info) error {
l.system = s
mux := http.NewServeMux()
mux.HandleFunc("/", l.jsonHandler)
l.listen = &http.Server{
Addr: l.address,
Handler: mux,
}
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
} else {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil
}
// Serve starts listening for new connections and serving responses.
func (l *HTTPStats) Serve(establish EstablishFunc) {
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *HTTPStats) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
info, err := json.MarshalIndent(l.system, "", "\t")
if err != nil {
io.WriteString(w, err.Error())
return
}
w.Write(info)
}

View File

@ -0,0 +1,198 @@
package listeners
import (
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
"github.com/stretchr/testify/require"
)
func TestNewHTTPStats(t *testing.T) {
l := NewHTTPStats("t1", testPort)
require.Equal(t, "t1", l.id)
require.Equal(t, testPort, l.address)
}
func BenchmarkNewHTTPStats(b *testing.B) {
for n := 0; n < b.N; n++ {
NewHTTPStats("t1", testPort)
}
}
func TestHTTPStatsSetConfig(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
})
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Allow), l.config.Auth)
// Switch to disallow on bad config set.
l.SetConfig(new(Config))
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Disallow), l.config.Auth)
}
func BenchmarkHTTPStatsSetConfig(b *testing.B) {
l := NewHTTPStats("t1", testPort)
for n := 0; n < b.N; n++ {
l.SetConfig(new(Config))
}
}
func TestHTTPStatsID(t *testing.T) {
l := NewHTTPStats("t1", testPort)
require.Equal(t, "t1", l.ID())
}
func BenchmarkHTTPStatsID(b *testing.B) {
l := NewHTTPStats("t1", testPort)
for n := 0; n < b.N; n++ {
l.ID()
}
}
func TestHTTPStatsListen(t *testing.T) {
l := NewHTTPStats("t1", testPort)
err := l.Listen(new(system.Info))
require.NoError(t, err)
require.NotNil(t, l.system)
require.NotNil(t, l.listen)
require.Equal(t, testPort, l.listen.Addr)
l.listen.Close()
}
func TestHTTPStatsListenTLSConfig(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(new(system.Info))
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestHTTPStatsListenTLS(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(new(system.Info))
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestHTTPStatsListenTLSInvalid(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: []byte("abcde"),
PrivateKey: testPrivateKey,
},
})
err := l.Listen(new(system.Info))
require.Error(t, err)
}
func TestHTTPStatsServeAndClose(t *testing.T) {
l := NewHTTPStats("t1", testPort)
err := l.Listen(&system.Info{
Version: "test",
})
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
resp, err := http.Get("http://localhost" + testPort)
require.NoError(t, err)
require.NotNil(t, resp)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
v := new(system.Info)
err = json.Unmarshal(body, v)
require.NoError(t, err)
require.Equal(t, "test", v.Version)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
_, err = http.Get("http://localhost" + testPort)
require.Error(t, err)
<-o
}
func TestHTTPStatsServeTLSAndClose(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(&system.Info{
Version: "test",
})
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func TestHTTPStatsJSONHandler(t *testing.T) {
l := NewHTTPStats("t1", testPort)
err := l.Listen(&system.Info{
Version: "test",
})
require.NoError(t, err)
w := httptest.NewRecorder()
l.jsonHandler(w, nil)
resp := w.Result()
body, _ := ioutil.ReadAll(resp.Body)
v := new(system.Info)
err = json.Unmarshal(body, v)
require.NoError(t, err)
require.Equal(t, "test", v.Version)
}

View File

@ -0,0 +1,152 @@
package listeners
import (
"crypto/tls"
"net"
"sync"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
// Config contains configuration values for a listener.
type Config struct {
// Auth controller containing auth and ACL logic for
// allowing or denying access to the server and topics.
Auth auth.Controller
// TLS certficates and settings for the connection.
//
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
// Please use TLSConfig instead.
TLS *TLS
// TLSConfig is a tls.Config configuration to be used with the listener.
// See examples folder for basic and mutual-tls use.
TLSConfig *tls.Config
}
// TLS contains the TLS certificates and settings for the listener connection.
//
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
// Please use TLSConfig instead.
type TLS struct {
Certificate []byte // the body of a public certificate.
PrivateKey []byte // the body of a private key.
}
// EstablishFunc is a callback function for establishing new clients.
type EstablishFunc func(id string, c net.Conn, ac auth.Controller) error
// CloseFunc is a callback function for closing all listener clients.
type CloseFunc func(id string)
// Listener is an interface for network listeners. A network listener listens
// for incoming client connections and adds them to the server.
type Listener interface {
SetConfig(*Config) // set the listener config.
Listen(s *system.Info) error // open the network address.
Serve(EstablishFunc) // starting actively listening for new connections.
ID() string // return the id of the listener.
Close(CloseFunc) // stop and close the listener.
}
// Listeners contains the network listeners for the broker.
type Listeners struct {
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
internal map[string]Listener // a map of active listeners.
system *system.Info // pointers to system info.
sync.RWMutex
}
// New returns a new instance of Listeners.
func New(s *system.Info) *Listeners {
return &Listeners{
internal: map[string]Listener{},
system: s,
}
}
// Add adds a new listener to the listeners map, keyed on id.
func (l *Listeners) Add(val Listener) {
l.Lock()
l.internal[val.ID()] = val
l.Unlock()
}
// Get returns the value of a listener if it exists.
func (l *Listeners) Get(id string) (Listener, bool) {
l.RLock()
val, ok := l.internal[id]
l.RUnlock()
return val, ok
}
// Len returns the length of the listeners map.
func (l *Listeners) Len() int {
l.RLock()
val := len(l.internal)
l.RUnlock()
return val
}
// Delete removes a listener from the internal map.
func (l *Listeners) Delete(id string) {
l.Lock()
delete(l.internal, id)
l.Unlock()
}
// Serve starts a listener serving from the internal map.
func (l *Listeners) Serve(id string, establisher EstablishFunc) {
l.RLock()
listener := l.internal[id]
l.RUnlock()
go func(e EstablishFunc) {
defer l.wg.Done()
l.wg.Add(1)
listener.Serve(e)
}(establisher)
}
// ServeAll starts all listeners serving from the internal map.
func (l *Listeners) ServeAll(establisher EstablishFunc) {
l.RLock()
i := 0
ids := make([]string, len(l.internal))
for id := range l.internal {
ids[i] = id
i++
}
l.RUnlock()
for _, id := range ids {
l.Serve(id, establisher)
}
}
// Close stops a listener from the internal map.
func (l *Listeners) Close(id string, closer CloseFunc) {
l.RLock()
listener := l.internal[id]
l.RUnlock()
listener.Close(closer)
}
// CloseAll iterates and closes all registered listeners.
func (l *Listeners) CloseAll(closer CloseFunc) {
l.RLock()
i := 0
ids := make([]string, len(l.internal))
for id := range l.internal {
ids[i] = id
i++
}
l.RUnlock()
for _, id := range ids {
l.Close(id, closer)
}
l.wg.Wait()
}

View File

@ -0,0 +1,245 @@
package listeners
import (
"crypto/tls"
"log"
"testing"
"time"
"github.com/stretchr/testify/require"
)
var (
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
-----END CERTIFICATE-----`)
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
tlsConfigBasic *tls.Config
)
func init() {
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfigBasic = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
func TestNew(t *testing.T) {
l := New(nil)
require.NotNil(t, l.internal)
}
func BenchmarkNewListeners(b *testing.B) {
for n := 0; n < b.N; n++ {
New(nil)
}
}
func TestAddListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
require.Contains(t, l.internal, "t1")
}
func BenchmarkAddListener(b *testing.B) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
for n := 0; n < b.N; n++ {
l.Add(mocked)
}
}
func TestGetListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
g, ok := l.Get("t1")
require.Equal(t, true, ok)
require.Equal(t, g.ID(), "t1")
}
func BenchmarkGetListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Get("t1")
}
}
func TestLenListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
require.Equal(t, 2, l.Len())
}
func BenchmarkLenListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Len()
}
}
func TestDeleteListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
require.Contains(t, l.internal, "t1")
l.Delete("t1")
_, ok := l.Get("t1")
require.Equal(t, false, ok)
require.Nil(t, l.internal["t1"])
}
func BenchmarkDeleteListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Delete("t1")
}
}
func TestServeListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing())
l.Close("t1", MockCloser)
require.Equal(t, false, l.internal["t1"].(*MockListener).IsServing())
}
func BenchmarkServeListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Serve("t1", MockEstablisher)
}
}
func TestServeAllListeners(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
l.Add(NewMockListener("t3", ":1882"))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing())
require.Equal(t, true, l.internal["t2"].(*MockListener).IsServing())
require.Equal(t, true, l.internal["t3"].(*MockListener).IsServing())
l.Close("t1", MockCloser)
l.Close("t2", MockCloser)
l.Close("t3", MockCloser)
require.Equal(t, false, l.internal["t1"].(*MockListener).IsServing())
require.Equal(t, false, l.internal["t2"].(*MockListener).IsServing())
require.Equal(t, false, l.internal["t3"].(*MockListener).IsServing())
}
func BenchmarkServeAllListeners(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1883"))
l.Add(NewMockListener("t3", ":1884"))
for n := 0; n < b.N; n++ {
l.ServeAll(MockEstablisher)
}
}
func TestCloseListener(t *testing.T) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
l.Add(mocked)
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
var closed bool
l.Close("t1", func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func BenchmarkCloseListener(b *testing.B) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
l.Add(mocked)
l.Serve("t1", MockEstablisher)
for n := 0; n < b.N; n++ {
l.internal["t1"].(*MockListener).done = make(chan bool)
l.Close("t1", MockCloser)
}
}
func TestCloseAllListeners(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
l.Add(NewMockListener("t3", ":1882"))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing())
require.Equal(t, true, l.internal["t2"].(*MockListener).IsServing())
require.Equal(t, true, l.internal["t3"].(*MockListener).IsServing())
closed := make(map[string]bool)
l.CloseAll(func(id string) {
closed[id] = true
})
require.Contains(t, closed, "t1")
require.Contains(t, closed, "t2")
require.Contains(t, closed, "t3")
require.Equal(t, true, closed["t1"])
require.Equal(t, true, closed["t2"])
require.Equal(t, true, closed["t3"])
}
func BenchmarkCloseAllListeners(b *testing.B) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
l.Add(mocked)
l.Serve("t1", MockEstablisher)
for n := 0; n < b.N; n++ {
l.internal["t1"].(*MockListener).done = make(chan bool)
l.Close("t1", MockCloser)
}
}

100
server/listeners/mock.go Normal file
View File

@ -0,0 +1,100 @@
package listeners
import (
"fmt"
"net"
"sync"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
// MockCloser is a function signature which can be used in testing.
func MockCloser(id string) {}
// MockEstablisher is a function signature which can be used in testing.
func MockEstablisher(id string, c net.Conn, ac auth.Controller) error {
return nil
}
// MockListener is a mock listener for establishing client connections.
type MockListener struct {
sync.RWMutex
id string // the id of the listener.
address string // the network address the listener binds to.
Config *Config // configuration for the listener.
done chan bool // indicate the listener is done.
Serving bool // indicate the listener is serving.
Listening bool // indiciate the listener is listening.
ErrListen bool // throw an error on listen.
}
// NewMockListener returns a new instance of MockListener
func NewMockListener(id, address string) *MockListener {
return &MockListener{
id: id,
address: address,
done: make(chan bool),
}
}
// Serve serves the mock listener.
func (l *MockListener) Serve(establisher EstablishFunc) {
l.Lock()
l.Serving = true
l.Unlock()
for range l.done {
return
}
}
// Listen begins listening for incoming traffic.
func (l *MockListener) Listen(s *system.Info) error {
if l.ErrListen {
return fmt.Errorf("listen failure")
}
l.Lock()
l.Listening = true
l.Unlock()
return nil
}
// SetConfig sets the configuration values of the mock listener.
func (l *MockListener) SetConfig(config *Config) {
l.Lock()
l.Config = config
l.Unlock()
}
// ID returns the id of the mock listener.
func (l *MockListener) ID() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
// Close closes the mock listener.
func (l *MockListener) Close(closer CloseFunc) {
l.Lock()
defer l.Unlock()
l.Serving = false
closer(l.id)
close(l.done)
}
// IsServing indicates whether the mock listener is serving.
func (l *MockListener) IsServing() bool {
l.Lock()
defer l.Unlock()
return l.Serving
}
// IsListening indicates whether the mock listener is listening.
func (l *MockListener) IsListening() bool {
l.Lock()
defer l.Unlock()
return l.Listening
}

View File

@ -0,0 +1,89 @@
package listeners
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/server/listeners/auth"
)
func TestMockEstablisher(t *testing.T) {
_, w := net.Pipe()
err := MockEstablisher("t1", w, new(auth.Allow))
require.NoError(t, err)
w.Close()
}
func TestNewMockListener(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, "t1", mocked.id)
require.Equal(t, ":1882", mocked.address)
}
func TestNewMockListenerListen(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, "t1", mocked.id)
require.Equal(t, ":1882", mocked.address)
require.Equal(t, false, mocked.IsListening())
err := mocked.Listen(nil)
require.NoError(t, err)
require.Equal(t, true, mocked.IsListening())
}
func TestNewMockListenerListenFailure(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
mocked.ErrListen = true
err := mocked.Listen(nil)
require.Error(t, err)
}
func TestMockListenerServe(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, false, mocked.IsServing())
o := make(chan bool)
go func(o chan bool) {
mocked.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond) // easy non-channel wait for start of serving
require.Equal(t, true, mocked.IsServing())
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
mocked.Listen(nil)
}
func TestMockListenerSetConfig(t *testing.T) {
mocked := NewMockListener("t1", ":1883")
mocked.SetConfig(new(Config))
require.NotNil(t, mocked.Config)
}
func TestMockListenerClose(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func TestNewMockListenerIsListening(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, false, mocked.IsListening())
}
func TestNewMockListenerIsServing(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, false, mocked.IsServing())
}

127
server/listeners/tcp.go Normal file
View File

@ -0,0 +1,127 @@
package listeners
import (
"crypto/tls"
"net"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
// TCP is a listener for establishing client connections on basic TCP protocol.
type TCP struct {
sync.RWMutex
id string // the internal id of the listener.
protocol string // the TCP protocol to use.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
config *Config // configuration values for the listener.
end uint32 // ensure the close methods are only called once.
}
// NewTCP initialises and returns a new TCP listener, listening on an address.
func NewTCP(id, address string) *TCP {
return &TCP{
id: id,
protocol: "tcp",
address: address,
config: &Config{ // default configuration.
Auth: new(auth.Allow),
TLS: new(TLS),
},
}
}
// SetConfig sets the configuration values for the listener config.
func (l *TCP) SetConfig(config *Config) {
l.Lock()
if config != nil {
l.config = config
// If a config has been passed without an auth controller,
// it may be a mistake, so disallow all traffic.
if l.config.Auth == nil {
l.config.Auth = new(auth.Disallow)
}
}
l.Unlock()
}
// ID returns the id of the listener.
func (l *TCP) ID() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
// Listen starts listening on the listener's network address.
func (l *TCP) Listen(s *system.Info) error {
var err error
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
var cert tls.Certificate
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
Certificates: []tls.Certificate{cert},
})
} else if l.config.TLSConfig != nil {
l.listen, err = tls.Listen(l.protocol, l.address, l.config.TLSConfig)
} else {
l.listen, err = net.Listen(l.protocol, l.address)
}
if err != nil {
return err
}
return nil
}
// Serve starts waiting for new TCP connections, and calls the establish
// connection callback for any received.
func (l *TCP) Serve(establish EstablishFunc) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
_ = establish(l.id, conn, l.config.Auth)
}()
}
}
}
// Close closes the listener and any client connections.
func (l *TCP) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

View File

@ -0,0 +1,228 @@
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/stretchr/testify/require"
)
const (
testPort = ":22222"
)
func TestNewTCP(t *testing.T) {
l := NewTCP("t1", testPort)
require.Equal(t, "t1", l.id)
require.Equal(t, testPort, l.address)
}
func BenchmarkNewTCP(b *testing.B) {
for n := 0; n < b.N; n++ {
NewTCP("t1", testPort)
}
}
func TestTCPSetConfig(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
})
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Allow), l.config.Auth)
// Switch to disallow on bad config set.
l.SetConfig(new(Config))
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Disallow), l.config.Auth)
}
func BenchmarkTCPSetConfig(b *testing.B) {
l := NewTCP("t1", testPort)
for n := 0; n < b.N; n++ {
l.SetConfig(new(Config))
}
}
func TestTCPID(t *testing.T) {
l := NewTCP("t1", testPort)
require.Equal(t, "t1", l.ID())
}
func BenchmarkTCPID(b *testing.B) {
l := NewTCP("t1", testPort)
for n := 0; n < b.N; n++ {
l.ID()
}
}
func TestTCPListen(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
l2 := NewTCP("t2", testPort)
err = l2.Listen(nil)
require.Error(t, err)
l.listen.Close()
}
func TestTCPListenTLSConfig(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(nil)
require.NoError(t, err)
l.listen.Close()
}
func TestTCPListenTLS(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.NoError(t, err)
l.listen.Close()
}
func TestTCPListenTLSInvalid(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: []byte("abcde"),
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.Error(t, err)
}
func TestTCPServeAndClose(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
}
func TestTCPServeTLSAndClose(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
}
func TestTCPCloseError(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
l.listen.Close()
l.Close(MockCloser)
<-o
}
func TestTCPServeEnd(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
l.Close(MockCloser)
l.Serve(func(id string, c net.Conn, ac auth.Controller) error {
return nil
})
}
func TestTCPEstablishThenError(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn, ac auth.Controller) error {
established <- true
return errors.New("testing") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial(l.protocol, l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}
func TestTCPEstablishButEnding(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
l.end = 1
o := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn, ac auth.Controller) error {
return nil
})
o <- true
}()
net.Dial(l.protocol, l.listen.Addr().String())
time.Sleep(time.Millisecond)
l.Close(MockCloser)
<-o
}

View File

@ -0,0 +1,178 @@
package listeners
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
var (
// ErrInvalidMessage indicates that a message payload was not valid.
ErrInvalidMessage = errors.New("message type not binary")
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
// websocket compliant connection.
wsUpgrader = &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
CheckOrigin: func(r *http.Request) bool { return true },
}
)
// Websocket is a listener for establishing websocket connections.
type Websocket struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
config *Config // configuration values for the listener.
listen *http.Server // an http server for serving websocket connections.
establish EstablishFunc // the server's establish connection handler.
end uint32 // ensure the close methods are only called once.
}
// wsConn is a websocket connection which satisfies the net.Conn interface.
// Inspired by
type wsConn struct {
net.Conn
c *websocket.Conn
}
// Read reads the next span of bytes from the websocket connection and returns
// the number of bytes read.
func (ws *wsConn) Read(p []byte) (n int, err error) {
op, r, err := ws.c.NextReader()
if err != nil {
return
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return
}
return r.Read(p)
}
// Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (n int, err error) {
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return
}
return len(p), nil
}
// Close signals the underlying websocket conn to close.
func (ws *wsConn) Close() error {
return ws.Conn.Close()
}
// NewWebsocket initialises and returns a new Websocket listener, listening on an address.
func NewWebsocket(id, address string) *Websocket {
return &Websocket{
id: id,
address: address,
config: &Config{
Auth: new(auth.Allow),
TLS: new(TLS),
},
}
}
// SetConfig sets the configuration values for the listener config.
func (l *Websocket) SetConfig(config *Config) {
l.Lock()
if config != nil {
l.config = config
// If a config has been passed without an auth controller,
// it may be a mistake, so disallow all traffic.
if l.config.Auth == nil {
l.config.Auth = new(auth.Disallow)
}
}
l.Unlock()
}
// ID returns the id of the listener.
func (l *Websocket) ID() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
// Listen starts listening on the listener's network address.
func (l *Websocket) Listen(s *system.Info) error {
mux := http.NewServeMux()
mux.HandleFunc("/", l.handler)
l.listen = &http.Server{
Addr: l.address,
Handler: mux,
}
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
} else {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil
}
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
c, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
l.establish(l.id, &wsConn{c.UnderlyingConn(), c}, l.config.Auth)
}
// Serve starts waiting for new Websocket connections, and calls the connection
// establishment callback for any received.
func (l *Websocket) Serve(establish EstablishFunc) {
l.establish = establish
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *Websocket) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}

View File

@ -0,0 +1,181 @@
package listeners
import (
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/stretchr/testify/require"
)
func TestWsConnClose(t *testing.T) {
r, _ := net.Pipe()
ws := &wsConn{r, new(websocket.Conn)}
err := ws.Close()
require.NoError(t, err)
}
func TestNewWebsocket(t *testing.T) {
l := NewWebsocket("t1", testPort)
require.Equal(t, "t1", l.id)
require.Equal(t, testPort, l.address)
}
func BenchmarkNewWebsocket(b *testing.B) {
for n := 0; n < b.N; n++ {
NewWebsocket("t1", testPort)
}
}
func TestWebsocketSetConfig(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
})
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Allow), l.config.Auth)
// Switch to disallow on bad config set.
l.SetConfig(new(Config))
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Disallow), l.config.Auth)
}
func BenchmarkWebsocketSetConfig(b *testing.B) {
l := NewWebsocket("t1", testPort)
for n := 0; n < b.N; n++ {
l.SetConfig(new(Config))
}
}
func TestWebsocketID(t *testing.T) {
l := NewWebsocket("t1", testPort)
require.Equal(t, "t1", l.ID())
}
func BenchmarkWebsocketID(b *testing.B) {
l := NewWebsocket("t1", testPort)
for n := 0; n < b.N; n++ {
l.ID()
}
}
func TestWebsocketListen(t *testing.T) {
l := NewWebsocket("t1", testPort)
require.Nil(t, l.listen)
err := l.Listen(nil)
require.NoError(t, err)
require.NotNil(t, l.listen)
}
func TestWebsocketListenTLSConfig(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(nil)
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestWebsocketListenTLS(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestWebsocketListenTLSInvalid(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: []byte("abcde"),
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.Error(t, err)
}
func TestWebsocketServeAndClose(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.Listen(nil)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
}
func TestWebsocketServeTLSAndClose(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func TestWebsocketUpgrade(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.Listen(nil)
e := make(chan bool)
l.establish = func(id string, c net.Conn, ac auth.Controller) error {
e <- true
return nil
}
s := httptest.NewServer(http.HandlerFunc(l.handler))
u := "ws" + strings.TrimPrefix(s.URL, "http")
ws, _, err := websocket.DefaultDialer.Dial(u, nil)
require.NoError(t, err)
require.Equal(t, true, <-e)
s.Close()
ws.Close()
}