clone
This commit is contained in:
212
server/internal/circ/buffer.go
Normal file
212
server/internal/circ/buffer.go
Normal file
@ -0,0 +1,212 @@
|
||||
package circ
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultBufferSize is the default size of the buffer in bytes.
|
||||
DefaultBufferSize int = 1024 * 256
|
||||
|
||||
// DefaultBlockSize is the default size per R/W block in bytes.
|
||||
DefaultBlockSize int = 1024 * 8
|
||||
|
||||
// ErrOutOfRange indicates that the index was out of range.
|
||||
ErrOutOfRange = errors.New("Indexes out of range")
|
||||
|
||||
// ErrInsufficientBytes indicates that there were not enough bytes to return.
|
||||
ErrInsufficientBytes = errors.New("Insufficient bytes to return")
|
||||
)
|
||||
|
||||
// Buffer is a circular buffer for reading and writing messages.
|
||||
type Buffer struct {
|
||||
buf []byte // the bytes buffer.
|
||||
tmp []byte // a temporary buffer.
|
||||
Mu sync.RWMutex // the buffer needs its own mutex to work properly.
|
||||
ID string // the identifier of the buffer. This is used in debug output.
|
||||
head int64 // the current position in the sequence - a forever increasing index.
|
||||
tail int64 // the committed position in the sequence - a forever increasing index.
|
||||
rcond *sync.Cond // the sync condition for the buffer reader.
|
||||
wcond *sync.Cond // the sync condition for the buffer writer.
|
||||
size int // the size of the buffer.
|
||||
mask int // a bitmask of the buffer size (size-1).
|
||||
block int // the size of the R/W block.
|
||||
done uint32 // indicates that the buffer is closed.
|
||||
State uint32 // indicates whether the buffer is reading from (1) or writing to (2).
|
||||
}
|
||||
|
||||
// NewBuffer returns a new instance of buffer. You should call NewReader or
|
||||
// NewWriter instead of this function.
|
||||
func NewBuffer(size, block int) *Buffer {
|
||||
if size == 0 {
|
||||
size = DefaultBufferSize
|
||||
}
|
||||
|
||||
if block == 0 {
|
||||
block = DefaultBlockSize
|
||||
}
|
||||
|
||||
if size < 2*block {
|
||||
size = 2 * block
|
||||
}
|
||||
|
||||
return &Buffer{
|
||||
size: size,
|
||||
mask: size - 1,
|
||||
block: block,
|
||||
buf: make([]byte, size),
|
||||
rcond: sync.NewCond(new(sync.Mutex)),
|
||||
wcond: sync.NewCond(new(sync.Mutex)),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBufferFromSlice returns a new instance of buffer using a
|
||||
// pre-existing byte slice.
|
||||
func NewBufferFromSlice(block int, buf []byte) *Buffer {
|
||||
l := len(buf)
|
||||
|
||||
if block == 0 {
|
||||
block = DefaultBlockSize
|
||||
}
|
||||
|
||||
b := &Buffer{
|
||||
size: l,
|
||||
mask: l - 1,
|
||||
block: block,
|
||||
buf: buf,
|
||||
rcond: sync.NewCond(new(sync.Mutex)),
|
||||
wcond: sync.NewCond(new(sync.Mutex)),
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// GetPos will return the tail and head positions of the buffer.
|
||||
// This method is for use with testing.
|
||||
func (b *Buffer) GetPos() (int64, int64) {
|
||||
return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
|
||||
}
|
||||
|
||||
// SetPos sets the head and tail of the buffer.
|
||||
func (b *Buffer) SetPos(tail, head int64) {
|
||||
atomic.StoreInt64(&b.tail, tail)
|
||||
atomic.StoreInt64(&b.head, head)
|
||||
}
|
||||
|
||||
// Get returns the internal buffer.
|
||||
func (b *Buffer) Get() []byte {
|
||||
b.Mu.Lock()
|
||||
defer b.Mu.Unlock()
|
||||
return b.buf
|
||||
}
|
||||
|
||||
// Set writes bytes to a range of indexes in the byte buffer.
|
||||
func (b *Buffer) Set(p []byte, start, end int) error {
|
||||
b.Mu.Lock()
|
||||
defer b.Mu.Unlock()
|
||||
|
||||
if end > b.size || start > b.size {
|
||||
return ErrOutOfRange
|
||||
}
|
||||
|
||||
o := 0
|
||||
for i := start; i < end; i++ {
|
||||
b.buf[i] = p[o]
|
||||
o++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Index returns the buffer-relative index of an integer.
|
||||
func (b *Buffer) Index(i int64) int {
|
||||
return b.mask & int(i)
|
||||
}
|
||||
|
||||
// awaitEmpty will block until there is at least n bytes between
|
||||
// the head and the tail (looking forward).
|
||||
func (b *Buffer) awaitEmpty(n int) error {
|
||||
// If the head has wrapped behind the tail, and next will overrun tail,
|
||||
// then wait until tail has moved.
|
||||
b.rcond.L.Lock()
|
||||
for !b.checkEmpty(n) {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
b.rcond.L.Unlock()
|
||||
return io.EOF
|
||||
}
|
||||
b.rcond.Wait()
|
||||
}
|
||||
b.rcond.L.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// awaitFilled will block until there are at least n bytes between the
|
||||
// tail and the head (looking forward).
|
||||
func (b *Buffer) awaitFilled(n int) error {
|
||||
// Because awaitCapacity prevents the head from overrunning the t
|
||||
// able on write, we can simply ensure there is enough space
|
||||
// the forever-incrementing tail and head integers.
|
||||
b.wcond.L.Lock()
|
||||
for !b.checkFilled(n) {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
b.wcond.L.Unlock()
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
b.wcond.Wait()
|
||||
}
|
||||
b.wcond.L.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkEmpty returns true if there are at least n bytes between the head and
|
||||
// the tail.
|
||||
func (b *Buffer) checkEmpty(n int) bool {
|
||||
head := atomic.LoadInt64(&b.head)
|
||||
next := head + int64(n)
|
||||
tail := atomic.LoadInt64(&b.tail)
|
||||
if next-tail > int64(b.size) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// checkFilled returns true if there are at least n bytes between the tail and
|
||||
// the head.
|
||||
func (b *Buffer) checkFilled(n int) bool {
|
||||
if atomic.LoadInt64(&b.tail)+int64(n) <= atomic.LoadInt64(&b.head) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// CommitTail moves the tail position of the buffer n bytes.
|
||||
func (b *Buffer) CommitTail(n int) {
|
||||
atomic.AddInt64(&b.tail, int64(n))
|
||||
b.rcond.L.Lock()
|
||||
b.rcond.Broadcast()
|
||||
b.rcond.L.Unlock()
|
||||
}
|
||||
|
||||
// CapDelta returns the difference between the head and tail.
|
||||
func (b *Buffer) CapDelta() int {
|
||||
return int(atomic.LoadInt64(&b.head) - atomic.LoadInt64(&b.tail))
|
||||
}
|
||||
|
||||
// Stop signals the buffer to stop processing.
|
||||
func (b *Buffer) Stop() {
|
||||
atomic.StoreUint32(&b.done, 1)
|
||||
b.rcond.L.Lock()
|
||||
b.rcond.Broadcast()
|
||||
b.rcond.L.Unlock()
|
||||
b.wcond.L.Lock()
|
||||
b.wcond.Broadcast()
|
||||
b.wcond.L.Unlock()
|
||||
}
|
317
server/internal/circ/buffer_test.go
Normal file
317
server/internal/circ/buffer_test.go
Normal file
@ -0,0 +1,317 @@
|
||||
package circ
|
||||
|
||||
import (
|
||||
//"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewBuffer(t *testing.T) {
|
||||
var size int = 16
|
||||
var block int = 4
|
||||
buf := NewBuffer(size, block)
|
||||
|
||||
require.NotNil(t, buf.buf)
|
||||
require.NotNil(t, buf.rcond)
|
||||
require.NotNil(t, buf.wcond)
|
||||
require.Equal(t, size, len(buf.buf))
|
||||
require.Equal(t, size, buf.size)
|
||||
require.Equal(t, block, buf.block)
|
||||
}
|
||||
|
||||
func TestNewBuffer0Size(t *testing.T) {
|
||||
buf := NewBuffer(0, 0)
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, DefaultBufferSize, buf.size)
|
||||
require.Equal(t, DefaultBlockSize, buf.block)
|
||||
}
|
||||
|
||||
func TestNewBufferUndersize(t *testing.T) {
|
||||
buf := NewBuffer(DefaultBlockSize+10, DefaultBlockSize)
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, DefaultBlockSize*2, buf.size)
|
||||
require.Equal(t, DefaultBlockSize, buf.block)
|
||||
}
|
||||
|
||||
func TestNewBufferFromSlice(t *testing.T) {
|
||||
b := NewBytesPool(256)
|
||||
buf := NewBufferFromSlice(DefaultBlockSize, b.Get())
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, 256, cap(buf.buf))
|
||||
}
|
||||
|
||||
func TestNewBufferFromSlice0Size(t *testing.T) {
|
||||
b := NewBytesPool(256)
|
||||
buf := NewBufferFromSlice(0, b.Get())
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, 256, cap(buf.buf))
|
||||
}
|
||||
|
||||
func TestAtomicAlignment(t *testing.T) {
|
||||
var b Buffer
|
||||
|
||||
offset := unsafe.Offsetof(b.head)
|
||||
require.Equalf(t, uintptr(0), offset%8,
|
||||
"head requires 64-bit alignment for atomic: offset %d", offset)
|
||||
|
||||
offset = unsafe.Offsetof(b.tail)
|
||||
require.Equalf(t, uintptr(0), offset%8,
|
||||
"tail requires 64-bit alignment for atomic: offset %d", offset)
|
||||
}
|
||||
|
||||
func TestGetPos(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
tail, head := buf.GetPos()
|
||||
require.Equal(t, int64(0), tail)
|
||||
require.Equal(t, int64(0), head)
|
||||
|
||||
atomic.StoreInt64(&buf.tail, 3)
|
||||
atomic.StoreInt64(&buf.head, 11)
|
||||
|
||||
tail, head = buf.GetPos()
|
||||
require.Equal(t, int64(3), tail)
|
||||
require.Equal(t, int64(11), head)
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
require.Equal(t, make([]byte, 16), buf.Get())
|
||||
|
||||
buf.buf[0] = 1
|
||||
buf.buf[15] = 1
|
||||
require.Equal(t, []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, buf.Get())
|
||||
}
|
||||
|
||||
func TestSetPos(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&buf.tail))
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&buf.head))
|
||||
|
||||
buf.SetPos(4, 8)
|
||||
require.Equal(t, int64(4), atomic.LoadInt64(&buf.tail))
|
||||
require.Equal(t, int64(8), atomic.LoadInt64(&buf.head))
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
err := buf.Set([]byte{1, 1, 1, 1}, 17, 19)
|
||||
require.Error(t, err)
|
||||
|
||||
err = buf.Set([]byte{1, 1, 1, 1}, 4, 8)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte{0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, buf.buf)
|
||||
}
|
||||
|
||||
func TestIndex(t *testing.T) {
|
||||
buf := NewBuffer(1024, 4)
|
||||
require.Equal(t, 512, buf.Index(512))
|
||||
require.Equal(t, 0, buf.Index(1024))
|
||||
require.Equal(t, 6, buf.Index(1030))
|
||||
require.Equal(t, 6, buf.Index(61446))
|
||||
}
|
||||
|
||||
func TestAwaitFilled(t *testing.T) {
|
||||
tests := []struct {
|
||||
tail int64
|
||||
head int64
|
||||
n int
|
||||
await int
|
||||
desc string
|
||||
}{
|
||||
{tail: 0, head: 4, n: 4, await: 1, desc: "OK 0, 4"},
|
||||
{tail: 8, head: 11, n: 4, await: 1, desc: "OK 8, 11"},
|
||||
{tail: 102, head: 103, n: 4, await: 3, desc: "OK 102, 103"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
//fmt.Println(i)
|
||||
buf := NewBuffer(16, 4)
|
||||
buf.SetPos(tt.tail, tt.head)
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- buf.awaitFilled(4)
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.AddInt64(&buf.head, int64(tt.await))
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
||||
require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAwaitFilledEnded(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- buf.awaitFilled(4)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
||||
require.Error(t, <-o)
|
||||
}
|
||||
|
||||
func TestAwaitEmptyOK(t *testing.T) {
|
||||
tests := []struct {
|
||||
tail int64
|
||||
head int64
|
||||
await int
|
||||
desc string
|
||||
}{
|
||||
{tail: 0, head: 0, await: 0, desc: "OK 0, 0"},
|
||||
{tail: 0, head: 5, await: 0, desc: "OK 0, 5"},
|
||||
{tail: 0, head: 14, await: 3, desc: "OK wrap 0, 14 "},
|
||||
{tail: 22, head: 35, await: 2, desc: "OK wrap 0, 14 "},
|
||||
{tail: 15, head: 17, await: 7, desc: "OK 15,2"},
|
||||
{tail: 0, head: 10, await: 2, desc: "OK 0, 10"},
|
||||
{tail: 1, head: 15, await: 4, desc: "OK 2, 14"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
buf := NewBuffer(16, 4)
|
||||
buf.SetPos(tt.tail, tt.head)
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- buf.awaitEmpty(4)
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.AddInt64(&buf.tail, int64(tt.await))
|
||||
buf.rcond.L.Lock()
|
||||
buf.rcond.Broadcast()
|
||||
buf.rcond.L.Unlock()
|
||||
|
||||
require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAwaitEmptyEnded(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
buf.SetPos(1, 15)
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- buf.awaitEmpty(4)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.rcond.L.Lock()
|
||||
buf.rcond.Broadcast()
|
||||
buf.rcond.L.Unlock()
|
||||
|
||||
require.Error(t, <-o)
|
||||
}
|
||||
|
||||
func TestCheckEmpty(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
|
||||
tests := []struct {
|
||||
head int64
|
||||
tail int64
|
||||
want bool
|
||||
desc string
|
||||
}{
|
||||
{tail: 0, head: 0, want: true, desc: "0, 0 true"},
|
||||
{tail: 3, head: 4, want: true, desc: "4, 3 true"},
|
||||
{tail: 15, head: 17, want: true, desc: "15, 17(1) true"},
|
||||
{tail: 1, head: 30, want: false, desc: "1, 30(14) false"},
|
||||
{tail: 15, head: 30, want: false, desc: "15, 30(14) false; head has caught up to tail"},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
buf.SetPos(tt.tail, tt.head)
|
||||
require.Equal(t, tt.want, buf.checkEmpty(4), "Mismatched bool wanted [i:%d] %s", i, tt.desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckFilled(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
|
||||
tests := []struct {
|
||||
head int64
|
||||
tail int64
|
||||
want bool
|
||||
desc string
|
||||
}{
|
||||
{tail: 0, head: 0, want: false, desc: "0, 0 false"},
|
||||
{tail: 0, head: 4, want: true, desc: "0, 4 true"},
|
||||
{tail: 14, head: 16, want: false, desc: "14,16 false"},
|
||||
{tail: 14, head: 18, want: true, desc: "14,16 true"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
buf.SetPos(tt.tail, tt.head)
|
||||
require.Equal(t, tt.want, buf.checkFilled(4), "Mismatched bool wanted [i:%d] %s", i, tt.desc)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestCommitTail(t *testing.T) {
|
||||
tests := []struct {
|
||||
tail int64
|
||||
head int64
|
||||
n int
|
||||
next int64
|
||||
await int
|
||||
desc string
|
||||
}{
|
||||
{tail: 0, head: 5, n: 4, next: 4, await: 0, desc: "OK 0, 4"},
|
||||
{tail: 0, head: 5, n: 6, next: 6, await: 1, desc: "OK 0, 5"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
buf := NewBuffer(16, 4)
|
||||
buf.SetPos(tt.tail, tt.head)
|
||||
go func() {
|
||||
buf.CommitTail(tt.n)
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
for j := 0; j < tt.await; j++ {
|
||||
atomic.AddInt64(&buf.head, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
}
|
||||
require.Equal(t, tt.next, atomic.LoadInt64(&buf.tail), "Next tail mismatch [i:%d] %s", i, tt.desc)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func TestCommitTailEnded(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- buf.CommitTail(5)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
||||
require.Error(t, <-o)
|
||||
}
|
||||
*/
|
||||
func TestCapDelta(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
|
||||
require.Equal(t, 0, buf.CapDelta())
|
||||
|
||||
buf.SetPos(10, 15)
|
||||
require.Equal(t, 5, buf.CapDelta())
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
buf.Stop()
|
||||
require.Equal(t, uint32(1), buf.done)
|
||||
}
|
49
server/internal/circ/pool.go
Normal file
49
server/internal/circ/pool.go
Normal file
@ -0,0 +1,49 @@
|
||||
package circ
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// BytesPool is a pool of []byte.
|
||||
type BytesPool struct {
|
||||
// int64/uint64 has to the first words in order
|
||||
// to be 64-aligned on 32-bit architectures.
|
||||
used int64 // access atomically
|
||||
pool *sync.Pool
|
||||
}
|
||||
|
||||
// NewBytesPool returns a sync.pool of []byte.
|
||||
func NewBytesPool(n int) *BytesPool {
|
||||
if n == 0 {
|
||||
n = DefaultBufferSize
|
||||
}
|
||||
|
||||
return &BytesPool{
|
||||
pool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, n)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a pooled bytes.Buffer.
|
||||
func (b *BytesPool) Get() []byte {
|
||||
atomic.AddInt64(&b.used, 1)
|
||||
return b.pool.Get().([]byte)
|
||||
}
|
||||
|
||||
// Put puts the byte slice back into the pool.
|
||||
func (b *BytesPool) Put(x []byte) {
|
||||
for i := range x {
|
||||
x[i] = 0
|
||||
}
|
||||
b.pool.Put(x)
|
||||
atomic.AddInt64(&b.used, -1)
|
||||
}
|
||||
|
||||
// InUse returns the number of pool blocks in use.
|
||||
func (b *BytesPool) InUse() int64 {
|
||||
return atomic.LoadInt64(&b.used)
|
||||
}
|
49
server/internal/circ/pool_test.go
Normal file
49
server/internal/circ/pool_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package circ
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewBytesPool(t *testing.T) {
|
||||
bpool := NewBytesPool(256)
|
||||
require.NotNil(t, bpool.pool)
|
||||
}
|
||||
|
||||
func BenchmarkNewBytesPool(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
NewBytesPool(256)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBytesPoolGet(t *testing.T) {
|
||||
bpool := NewBytesPool(256)
|
||||
buf := bpool.Get()
|
||||
|
||||
require.Equal(t, make([]byte, 256), buf)
|
||||
require.Equal(t, int64(1), bpool.InUse())
|
||||
}
|
||||
|
||||
func BenchmarkBytesPoolGet(b *testing.B) {
|
||||
bpool := NewBytesPool(256)
|
||||
for n := 0; n < b.N; n++ {
|
||||
bpool.Get()
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBytesPoolPut(t *testing.T) {
|
||||
bpool := NewBytesPool(256)
|
||||
buf := bpool.Get()
|
||||
require.Equal(t, int64(1), bpool.InUse())
|
||||
bpool.Put(buf)
|
||||
require.Equal(t, int64(0), bpool.InUse())
|
||||
}
|
||||
|
||||
func BenchmarkBytesPoolPut(b *testing.B) {
|
||||
bpool := NewBytesPool(256)
|
||||
buf := bpool.Get()
|
||||
for n := 0; n < b.N; n++ {
|
||||
bpool.Put(buf)
|
||||
}
|
||||
}
|
96
server/internal/circ/reader.go
Normal file
96
server/internal/circ/reader.go
Normal file
@ -0,0 +1,96 @@
|
||||
package circ
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Reader is a circular buffer for reading data from an io.Reader.
|
||||
type Reader struct {
|
||||
*Buffer
|
||||
}
|
||||
|
||||
// NewReader returns a new Circular Reader.
|
||||
func NewReader(size, block int) *Reader {
|
||||
b := NewBuffer(size, block)
|
||||
b.ID = "\treader"
|
||||
return &Reader{
|
||||
b,
|
||||
}
|
||||
}
|
||||
|
||||
// NewReaderFromSlice returns a new Circular Reader using a pre-existing
|
||||
// byte slice.
|
||||
func NewReaderFromSlice(block int, p []byte) *Reader {
|
||||
b := NewBufferFromSlice(block, p)
|
||||
b.ID = "\treader"
|
||||
return &Reader{
|
||||
b,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
|
||||
// there is sufficient capacity to do so.
|
||||
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
||||
atomic.StoreUint32(&b.State, 1)
|
||||
defer atomic.StoreUint32(&b.State, 0)
|
||||
for {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// Wait until there's enough capacity in the buffer before
|
||||
// trying to read more bytes from the io.Reader.
|
||||
err := b.awaitEmpty(b.block)
|
||||
if err != nil {
|
||||
// b.done is the only error condition for awaitCapacity
|
||||
// so loop around and return properly.
|
||||
continue
|
||||
}
|
||||
|
||||
// If the block will overrun the circle end, just fill up
|
||||
// and collect the rest on the next pass.
|
||||
start := b.Index(atomic.LoadInt64(&b.head))
|
||||
end := start + b.block
|
||||
if end > b.size {
|
||||
end = b.size
|
||||
}
|
||||
|
||||
// Read into the buffer between the start and end indexes only.
|
||||
n, err := r.Read(b.buf[start:end])
|
||||
total += int64(n) // incr total bytes read.
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
// Move the head forward however many bytes were read.
|
||||
atomic.AddInt64(&b.head, int64(n))
|
||||
|
||||
b.wcond.L.Lock()
|
||||
b.wcond.Broadcast()
|
||||
b.wcond.L.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads n bytes from the buffer, and will block until at n bytes
|
||||
// exist in the buffer to read.
|
||||
func (b *Buffer) Read(n int) (p []byte, err error) {
|
||||
err = b.awaitFilled(n)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tail := atomic.LoadInt64(&b.tail)
|
||||
next := tail + int64(n)
|
||||
|
||||
// If the read overruns the buffer, get everything until the end
|
||||
// and then whatever is left from the start.
|
||||
if b.Index(tail) > b.Index(next) {
|
||||
b.tmp = b.buf[b.Index(tail):]
|
||||
b.tmp = append(b.tmp, b.buf[:b.Index(next)]...)
|
||||
} else {
|
||||
b.tmp = b.buf[b.Index(tail):b.Index(next)] // Otherwise, simple tail:next read.
|
||||
}
|
||||
|
||||
return b.tmp, nil
|
||||
}
|
129
server/internal/circ/reader_test.go
Normal file
129
server/internal/circ/reader_test.go
Normal file
@ -0,0 +1,129 @@
|
||||
package circ
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewReader(t *testing.T) {
|
||||
var size = 16
|
||||
var block = 4
|
||||
buf := NewReader(size, block)
|
||||
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, size, len(buf.buf))
|
||||
require.Equal(t, size, buf.size)
|
||||
require.Equal(t, block, buf.block)
|
||||
}
|
||||
|
||||
func TestNewReaderFromSlice(t *testing.T) {
|
||||
b := NewBytesPool(256)
|
||||
buf := NewReaderFromSlice(DefaultBlockSize, b.Get())
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, 256, cap(buf.buf))
|
||||
}
|
||||
|
||||
func TestReadFrom(t *testing.T) {
|
||||
buf := NewReader(16, 4)
|
||||
|
||||
b4 := bytes.Repeat([]byte{'-'}, 4)
|
||||
br := bytes.NewReader(b4)
|
||||
|
||||
_, err := buf.ReadFrom(br)
|
||||
require.True(t, errors.Is(err, io.EOF))
|
||||
require.Equal(t, bytes.Repeat([]byte{'-'}, 4), buf.buf[:4])
|
||||
require.Equal(t, int64(4), buf.head)
|
||||
|
||||
br.Reset(b4)
|
||||
_, err = buf.ReadFrom(br)
|
||||
require.True(t, errors.Is(err, io.EOF))
|
||||
require.Equal(t, int64(8), buf.head)
|
||||
|
||||
br.Reset(b4)
|
||||
_, err = buf.ReadFrom(br)
|
||||
require.True(t, errors.Is(err, io.EOF))
|
||||
require.Equal(t, int64(12), buf.head)
|
||||
}
|
||||
|
||||
func TestReadFromWrap(t *testing.T) {
|
||||
buf := NewReader(16, 4)
|
||||
buf.buf = bytes.Repeat([]byte{'-'}, 16)
|
||||
buf.SetPos(8, 14)
|
||||
br := bytes.NewReader(bytes.Repeat([]byte{'/'}, 8))
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
_, err := buf.ReadFrom(br)
|
||||
o <- err
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
go func() {
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.rcond.L.Lock()
|
||||
buf.rcond.Broadcast()
|
||||
buf.rcond.L.Unlock()
|
||||
}()
|
||||
<-o
|
||||
require.Equal(t, []byte{'/', '/', '/', '/', '/', '/', '-', '-', '-', '-', '-', '-', '-', '-', '/', '/'}, buf.Get())
|
||||
require.Equal(t, int64(22), atomic.LoadInt64(&buf.head))
|
||||
require.Equal(t, 6, buf.Index(atomic.LoadInt64(&buf.head)))
|
||||
}
|
||||
|
||||
func TestReadOK(t *testing.T) {
|
||||
buf := NewReader(16, 4)
|
||||
buf.buf = []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
|
||||
|
||||
tests := []struct {
|
||||
tail int64
|
||||
head int64
|
||||
n int
|
||||
bytes []byte
|
||||
desc string
|
||||
}{
|
||||
{tail: 0, head: 4, n: 4, bytes: []byte{'a', 'b', 'c', 'd'}, desc: "0, 4 OK"},
|
||||
{tail: 3, head: 15, n: 8, bytes: []byte{'d', 'e', 'f', 'g', 'h', 'i', 'j', 'k'}, desc: "3, 15 OK"},
|
||||
{tail: 14, head: 15, n: 6, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd'}, desc: "14, 2 wrapped OK"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
buf.SetPos(tt.tail, tt.head)
|
||||
o := make(chan []byte)
|
||||
go func() {
|
||||
p, _ := buf.Read(tt.n)
|
||||
o <- p
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.head, buf.head+int64(tt.n))
|
||||
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
||||
done := <-o
|
||||
require.Equal(t, tt.bytes, done, "Peeked bytes mismatch [i:%d] %s", i, tt.desc)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadEnded(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
_, err := buf.Read(4)
|
||||
o <- err
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
||||
require.Error(t, <-o)
|
||||
}
|
106
server/internal/circ/writer.go
Normal file
106
server/internal/circ/writer.go
Normal file
@ -0,0 +1,106 @@
|
||||
package circ
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Writer is a circular buffer for writing data to an io.Writer.
|
||||
type Writer struct {
|
||||
*Buffer
|
||||
}
|
||||
|
||||
// NewWriter returns a pointer to a new Circular Writer.
|
||||
func NewWriter(size, block int) *Writer {
|
||||
b := NewBuffer(size, block)
|
||||
b.ID = "writer"
|
||||
return &Writer{
|
||||
b,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWriterFromSlice returns a new Circular Writer using a pre-existing
|
||||
// byte slice.
|
||||
func NewWriterFromSlice(block int, p []byte) *Writer {
|
||||
b := NewBufferFromSlice(block, p)
|
||||
b.ID = "writer"
|
||||
return &Writer{
|
||||
b,
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTo writes the contents of the buffer to an io.Writer.
|
||||
func (b *Writer) WriteTo(w io.Writer) (total int64, err error) {
|
||||
atomic.StoreUint32(&b.State, 2)
|
||||
defer atomic.StoreUint32(&b.State, 0)
|
||||
for {
|
||||
if atomic.LoadUint32(&b.done) == 1 && b.CapDelta() == 0 {
|
||||
return total, io.EOF
|
||||
}
|
||||
|
||||
// Read from the buffer until there is at least 1 byte to write.
|
||||
err = b.awaitFilled(1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get all the bytes between the tail and head, wrapping if necessary.
|
||||
tail := atomic.LoadInt64(&b.tail)
|
||||
rTail := b.Index(tail)
|
||||
rHead := b.Index(atomic.LoadInt64(&b.head))
|
||||
n := b.CapDelta()
|
||||
p := make([]byte, 0, n)
|
||||
|
||||
if rTail > rHead {
|
||||
p = append(p, b.buf[rTail:]...)
|
||||
p = append(p, b.buf[:rHead]...)
|
||||
} else {
|
||||
p = append(p, b.buf[rTail:rHead]...)
|
||||
}
|
||||
|
||||
n, err = w.Write(p)
|
||||
total += int64(n)
|
||||
if err != nil {
|
||||
log.Println("error writing to buffer io.Writer;", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Move the tail forward the bytes written and broadcast change.
|
||||
atomic.StoreInt64(&b.tail, tail+int64(n))
|
||||
b.rcond.L.Lock()
|
||||
b.rcond.Broadcast()
|
||||
b.rcond.L.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes the buffer to the buffer p, returning the number of bytes written.
|
||||
// The bytes written to the buffer are picked up by WriteTo.
|
||||
func (b *Writer) Write(p []byte) (total int, err error) {
|
||||
err = b.awaitEmpty(len(p))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
total = b.writeBytes(p)
|
||||
atomic.AddInt64(&b.head, int64(total))
|
||||
b.wcond.L.Lock()
|
||||
b.wcond.Broadcast()
|
||||
b.wcond.L.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// writeBytes writes bytes to the buffer from the start position, and returns
|
||||
// the new head position. This function does not wait for capacity and will
|
||||
// overwrite any existing bytes.
|
||||
func (b *Writer) writeBytes(p []byte) int {
|
||||
var o int
|
||||
var n int
|
||||
for i := 0; i < len(p); i++ {
|
||||
o = b.Index(atomic.LoadInt64(&b.head) + int64(i))
|
||||
b.buf[o] = p[i]
|
||||
n++
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
155
server/internal/circ/writer_test.go
Normal file
155
server/internal/circ/writer_test.go
Normal file
@ -0,0 +1,155 @@
|
||||
package circ
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewWriter(t *testing.T) {
|
||||
var size = 16
|
||||
var block = 4
|
||||
buf := NewWriter(size, block)
|
||||
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, size, len(buf.buf))
|
||||
require.Equal(t, size, buf.size)
|
||||
require.Equal(t, block, buf.block)
|
||||
}
|
||||
|
||||
func TestNewWriterFromSlice(t *testing.T) {
|
||||
b := NewBytesPool(256)
|
||||
buf := NewWriterFromSlice(DefaultBlockSize, b.Get())
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, 256, cap(buf.buf))
|
||||
}
|
||||
|
||||
func TestWriteTo(t *testing.T) {
|
||||
tests := []struct {
|
||||
tail int64
|
||||
head int64
|
||||
bytes []byte
|
||||
await int
|
||||
total int
|
||||
err error
|
||||
desc string
|
||||
}{
|
||||
{tail: 0, head: 5, bytes: []byte{'a', 'b', 'c', 'd', 'e'}, desc: "0,5 OK"},
|
||||
{tail: 14, head: 21, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd', 'e'}, desc: "14,16(2) OK"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
|
||||
buf := NewWriter(16, 4)
|
||||
buf.Set(bb, 0, 16)
|
||||
buf.SetPos(tt.tail, tt.head)
|
||||
|
||||
var b bytes.Buffer
|
||||
w := bufio.NewWriter(&b)
|
||||
|
||||
nc := make(chan int64)
|
||||
go func() {
|
||||
n, _ := buf.WriteTo(w)
|
||||
nc <- n
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
||||
w.Flush()
|
||||
require.Equal(t, tt.bytes, b.Bytes(), "Written bytes mismatch [i:%d] %s", i, tt.desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteToEndedFirst(t *testing.T) {
|
||||
buf := NewWriter(16, 4)
|
||||
buf.done = 1
|
||||
|
||||
var b bytes.Buffer
|
||||
w := bufio.NewWriter(&b)
|
||||
_, err := buf.WriteTo(w)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWriteToBadWriter(t *testing.T) {
|
||||
bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
|
||||
buf := NewWriter(16, 4)
|
||||
buf.Set(bb, 0, 16)
|
||||
buf.SetPos(0, 6)
|
||||
r, w := net.Pipe()
|
||||
|
||||
w.Close()
|
||||
_, err := buf.WriteTo(w)
|
||||
require.Error(t, err)
|
||||
r.Close()
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
tail int64
|
||||
head int64
|
||||
rHead int64
|
||||
bytes []byte
|
||||
want []byte
|
||||
desc string
|
||||
}{
|
||||
{tail: 0, head: 0, rHead: 4, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'a', 'b', 'c', 'd', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, desc: "0>4 OK"},
|
||||
{tail: 4, head: 14, rHead: 2, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'c', 'd', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 'a', 'b'}, desc: "14>2 OK"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
buf := NewWriter(16, 4)
|
||||
buf.SetPos(tt.tail, tt.head)
|
||||
|
||||
o := make(chan []interface{})
|
||||
go func() {
|
||||
nn, err := buf.Write(tt.bytes)
|
||||
o <- []interface{}{nn, err}
|
||||
}()
|
||||
|
||||
done := <-o
|
||||
require.Equal(t, tt.want, buf.buf, "Wanted written mismatch [i:%d] %s", i, tt.desc)
|
||||
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteEnded(t *testing.T) {
|
||||
buf := NewWriter(16, 4)
|
||||
buf.SetPos(15, 30)
|
||||
buf.done = 1
|
||||
|
||||
_, err := buf.Write([]byte{'a', 'b', 'c', 'd'})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWriteBytes(t *testing.T) {
|
||||
tests := []struct {
|
||||
tail int64
|
||||
head int64
|
||||
bytes []byte
|
||||
want []byte
|
||||
start int
|
||||
desc string
|
||||
}{
|
||||
{tail: 0, head: 0, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'a', 'b', 'c', 'd', 0, 0, 0, 0}, desc: "0,4 OK"},
|
||||
{tail: 6, head: 6, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'c', 'd', 0, 0, 0, 0, 'a', 'b'}, desc: "6,2 OK wrapped"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
buf := NewWriter(8, 4)
|
||||
buf.SetPos(tt.tail, tt.head)
|
||||
n := buf.writeBytes(tt.bytes)
|
||||
|
||||
require.Equal(t, tt.want, buf.buf, "Buffer mistmatch [i:%d] %s", i, tt.desc)
|
||||
require.Equal(t, len(tt.bytes), n)
|
||||
}
|
||||
|
||||
}
|
564
server/internal/clients/clients.go
Normal file
564
server/internal/clients/clients.go
Normal file
@ -0,0 +1,564 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/events"
|
||||
"github.com/mochi-co/mqtt/server/internal/circ"
|
||||
"github.com/mochi-co/mqtt/server/internal/packets"
|
||||
"github.com/mochi-co/mqtt/server/internal/topics"
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/mochi-co/mqtt/server/system"
|
||||
)
|
||||
|
||||
var (
|
||||
// defaultKeepalive is the default connection keepalive value in seconds.
|
||||
defaultKeepalive uint16 = 10
|
||||
|
||||
// ErrConnectionClosed is returned when operating on a closed
|
||||
// connection and/or when no error cause has been given.
|
||||
ErrConnectionClosed = errors.New("connection not open")
|
||||
)
|
||||
|
||||
// Clients contains a map of the clients known by the broker.
|
||||
type Clients struct {
|
||||
sync.RWMutex
|
||||
internal map[string]*Client // clients known by the broker, keyed on client id.
|
||||
}
|
||||
|
||||
// New returns an instance of Clients.
|
||||
func New() *Clients {
|
||||
return &Clients{
|
||||
internal: make(map[string]*Client),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new client to the clients map, keyed on client id.
|
||||
func (cl *Clients) Add(val *Client) {
|
||||
cl.Lock()
|
||||
cl.internal[val.ID] = val
|
||||
cl.Unlock()
|
||||
}
|
||||
|
||||
// Get returns the value of a client if it exists.
|
||||
func (cl *Clients) Get(id string) (*Client, bool) {
|
||||
cl.RLock()
|
||||
val, ok := cl.internal[id]
|
||||
cl.RUnlock()
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the length of the clients map.
|
||||
func (cl *Clients) Len() int {
|
||||
cl.RLock()
|
||||
val := len(cl.internal)
|
||||
cl.RUnlock()
|
||||
return val
|
||||
}
|
||||
|
||||
// Delete removes a client from the internal map.
|
||||
func (cl *Clients) Delete(id string) {
|
||||
cl.Lock()
|
||||
delete(cl.internal, id)
|
||||
cl.Unlock()
|
||||
}
|
||||
|
||||
// GetByListener returns clients matching a listener id.
|
||||
func (cl *Clients) GetByListener(id string) []*Client {
|
||||
clients := make([]*Client, 0, cl.Len())
|
||||
cl.RLock()
|
||||
for _, v := range cl.internal {
|
||||
if v.Listener == id && atomic.LoadUint32(&v.State.Done) == 0 {
|
||||
clients = append(clients, v)
|
||||
}
|
||||
}
|
||||
cl.RUnlock()
|
||||
return clients
|
||||
}
|
||||
|
||||
// Client contains information about a client known by the broker.
|
||||
type Client struct {
|
||||
State State // the operational state of the client.
|
||||
LWT LWT // the last will and testament for the client.
|
||||
Inflight *Inflight // a map of in-flight qos messages.
|
||||
sync.RWMutex // mutex
|
||||
Username []byte // the username the client authenticated with.
|
||||
AC auth.Controller // an auth controller inherited from the listener.
|
||||
Listener string // the id of the listener the client is connected to.
|
||||
ID string // the client id.
|
||||
conn net.Conn // the net.Conn used to establish the connection.
|
||||
R *circ.Reader // a reader for reading incoming bytes.
|
||||
W *circ.Writer // a writer for writing outgoing bytes.
|
||||
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
|
||||
systemInfo *system.Info // pointers to server system info.
|
||||
packetID uint32 // the current highest packetID.
|
||||
keepalive uint16 // the number of seconds the connection can wait.
|
||||
CleanSession bool // indicates if the client expects a clean-session.
|
||||
}
|
||||
|
||||
// State tracks the state of the client.
|
||||
type State struct {
|
||||
started *sync.WaitGroup // tracks the goroutines which have been started.
|
||||
endedW *sync.WaitGroup // tracks when the writer has ended.
|
||||
endedR *sync.WaitGroup // tracks when the reader has ended.
|
||||
Done uint32 // atomic counter which indicates that the client has closed.
|
||||
endOnce sync.Once // only end once.
|
||||
stopCause atomic.Value // reason for stopping.
|
||||
}
|
||||
|
||||
// NewClient returns a new instance of Client.
|
||||
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
|
||||
cl := &Client{
|
||||
conn: c,
|
||||
R: r,
|
||||
W: w,
|
||||
systemInfo: s,
|
||||
keepalive: defaultKeepalive,
|
||||
Inflight: &Inflight{
|
||||
internal: make(map[uint16]InflightMessage),
|
||||
},
|
||||
Subscriptions: make(map[string]byte),
|
||||
State: State{
|
||||
started: new(sync.WaitGroup),
|
||||
endedW: new(sync.WaitGroup),
|
||||
endedR: new(sync.WaitGroup),
|
||||
},
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.keepalive)
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
// NewClientStub returns an instance of Client with basic initializations. This
|
||||
// method is typically called by the persistence restoration system.
|
||||
func NewClientStub(s *system.Info) *Client {
|
||||
return &Client{
|
||||
Inflight: &Inflight{
|
||||
internal: make(map[uint16]InflightMessage),
|
||||
},
|
||||
Subscriptions: make(map[string]byte),
|
||||
State: State{
|
||||
Done: 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Identify sets the identification values of a client instance.
|
||||
func (cl *Client) Identify(lid string, pk packets.Packet, ac auth.Controller) {
|
||||
cl.Listener = lid
|
||||
cl.AC = ac
|
||||
|
||||
cl.ID = pk.ClientIdentifier
|
||||
if cl.ID == "" {
|
||||
cl.ID = xid.New().String()
|
||||
}
|
||||
|
||||
cl.R.ID = cl.ID + " READER"
|
||||
cl.W.ID = cl.ID + " WRITER"
|
||||
|
||||
cl.Username = pk.Username
|
||||
cl.CleanSession = pk.CleanSession
|
||||
cl.keepalive = pk.Keepalive
|
||||
|
||||
if pk.WillFlag {
|
||||
cl.LWT = LWT{
|
||||
Topic: pk.WillTopic,
|
||||
Message: pk.WillMessage,
|
||||
Qos: pk.WillQos,
|
||||
Retain: pk.WillRetain,
|
||||
}
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.keepalive)
|
||||
}
|
||||
|
||||
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
||||
func (cl *Client) refreshDeadline(keepalive uint16) {
|
||||
if cl.conn != nil {
|
||||
var expiry time.Time // Nil time can be used to disable deadline if keepalive = 0
|
||||
if keepalive > 0 {
|
||||
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second)
|
||||
}
|
||||
_ = cl.conn.SetDeadline(expiry)
|
||||
}
|
||||
}
|
||||
|
||||
// Info returns an event-version of a client, containing minimal information.
|
||||
func (cl *Client) Info() events.Client {
|
||||
addr := "unknown"
|
||||
if cl.conn != nil && cl.conn.RemoteAddr() != nil {
|
||||
addr = cl.conn.RemoteAddr().String()
|
||||
}
|
||||
return events.Client{
|
||||
ID: cl.ID,
|
||||
Remote: addr,
|
||||
Username: cl.Username,
|
||||
CleanSession: cl.CleanSession,
|
||||
Listener: cl.Listener,
|
||||
}
|
||||
}
|
||||
|
||||
// NextPacketID returns the next packet id for a client, looping back to 0
|
||||
// if the maximum ID has been reached.
|
||||
func (cl *Client) NextPacketID() uint32 {
|
||||
i := atomic.LoadUint32(&cl.packetID)
|
||||
if i == uint32(65535) || i == uint32(0) {
|
||||
atomic.StoreUint32(&cl.packetID, 1)
|
||||
return 1
|
||||
}
|
||||
|
||||
return atomic.AddUint32(&cl.packetID, 1)
|
||||
}
|
||||
|
||||
// NoteSubscription makes a note of a subscription for the client.
|
||||
func (cl *Client) NoteSubscription(filter string, qos byte) {
|
||||
cl.Lock()
|
||||
cl.Subscriptions[filter] = qos
|
||||
cl.Unlock()
|
||||
}
|
||||
|
||||
// ForgetSubscription forgests a subscription note for the client.
|
||||
func (cl *Client) ForgetSubscription(filter string) {
|
||||
cl.Lock()
|
||||
delete(cl.Subscriptions, filter)
|
||||
cl.Unlock()
|
||||
}
|
||||
|
||||
// Start begins the client goroutines reading and writing packets.
|
||||
func (cl *Client) Start() {
|
||||
cl.State.started.Add(2)
|
||||
cl.State.endedW.Add(1)
|
||||
cl.State.endedR.Add(1)
|
||||
|
||||
go func() {
|
||||
cl.State.started.Done()
|
||||
_, err := cl.W.WriteTo(cl.conn)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("writer: %w", err)
|
||||
}
|
||||
cl.State.endedW.Done()
|
||||
cl.Stop(err)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
cl.State.started.Done()
|
||||
_, err := cl.R.ReadFrom(cl.conn)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("reader: %w", err)
|
||||
}
|
||||
cl.State.endedR.Done()
|
||||
cl.Stop(err)
|
||||
}()
|
||||
|
||||
cl.State.started.Wait()
|
||||
}
|
||||
|
||||
// ClearBuffers sets the read/write buffers to nil so they can be
|
||||
// deallocated automatically when no longer in use.
|
||||
func (cl *Client) ClearBuffers() {
|
||||
cl.R = nil
|
||||
cl.W = nil
|
||||
}
|
||||
|
||||
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
||||
// A cause error may be passed to identfy the reason for stopping.
|
||||
func (cl *Client) Stop(err error) {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
cl.State.endOnce.Do(func() {
|
||||
cl.R.Stop()
|
||||
cl.W.Stop()
|
||||
|
||||
cl.State.endedW.Wait()
|
||||
|
||||
_ = cl.conn.Close() // omit close error
|
||||
|
||||
cl.State.endedR.Wait()
|
||||
atomic.StoreUint32(&cl.State.Done, 1)
|
||||
|
||||
if err == nil {
|
||||
err = ErrConnectionClosed
|
||||
}
|
||||
cl.State.stopCause.Store(err)
|
||||
})
|
||||
}
|
||||
|
||||
// StopCause returns the reason the client connection was stopped, if any.
|
||||
func (cl *Client) StopCause() error {
|
||||
if cl.State.stopCause.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
return cl.State.stopCause.Load().(error)
|
||||
}
|
||||
|
||||
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
||||
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
p, err := cl.R.Read(1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = fh.Decode(p[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The remaining length value can be up to 5 bytes. Read through each byte
|
||||
// looking for continue values, and if found increase the read. Otherwise
|
||||
// decode the bytes that were legit.
|
||||
buf := make([]byte, 0, 6)
|
||||
i := 1
|
||||
n := 2
|
||||
for ; n < 6; n++ {
|
||||
p, err = cl.R.Read(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf = append(buf, p[i])
|
||||
|
||||
// If it's not a continuation flag, end here.
|
||||
if p[i] < 128 {
|
||||
break
|
||||
}
|
||||
|
||||
// If i has reached 4 without a length terminator, return a protocol violation.
|
||||
i++
|
||||
if i == 4 {
|
||||
return packets.ErrOversizedLengthIndicator
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate and store the remaining length of the packet payload.
|
||||
rem, _ := binary.Uvarint(buf)
|
||||
fh.Remaining = int(rem)
|
||||
|
||||
// Having successfully read n bytes, commit the tail forward.
|
||||
cl.R.CommitTail(n)
|
||||
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read loops forever reading new packets from a client connection until
|
||||
// an error is encountered (or the connection is closed).
|
||||
func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error {
|
||||
for {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 && cl.R.CapDelta() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.keepalive)
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pk, err := cl.ReadPacket(fh)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = packetHandler(cl, pk) // Process inbound packet.
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadPacket reads the remaining buffer into an MQTT packet.
|
||||
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
|
||||
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
|
||||
|
||||
pk.FixedHeader = *fh
|
||||
if pk.FixedHeader.Remaining == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
p, err := cl.R.Read(pk.FixedHeader.Remaining)
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
|
||||
|
||||
// Decode the remaining packet values using a fresh copy of the bytes,
|
||||
// otherwise the next packet will change the data of this one.
|
||||
px := append([]byte{}, p[:]...)
|
||||
|
||||
switch pk.FixedHeader.Type {
|
||||
case packets.Connect:
|
||||
err = pk.ConnectDecode(px)
|
||||
case packets.Connack:
|
||||
err = pk.ConnackDecode(px)
|
||||
case packets.Publish:
|
||||
err = pk.PublishDecode(px)
|
||||
if err == nil {
|
||||
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
|
||||
}
|
||||
case packets.Puback:
|
||||
err = pk.PubackDecode(px)
|
||||
case packets.Pubrec:
|
||||
err = pk.PubrecDecode(px)
|
||||
case packets.Pubrel:
|
||||
err = pk.PubrelDecode(px)
|
||||
case packets.Pubcomp:
|
||||
err = pk.PubcompDecode(px)
|
||||
case packets.Subscribe:
|
||||
err = pk.SubscribeDecode(px)
|
||||
case packets.Suback:
|
||||
err = pk.SubackDecode(px)
|
||||
case packets.Unsubscribe:
|
||||
err = pk.UnsubscribeDecode(px)
|
||||
case packets.Unsuback:
|
||||
err = pk.UnsubackDecode(px)
|
||||
case packets.Pingreq:
|
||||
case packets.Pingresp:
|
||||
case packets.Disconnect:
|
||||
default:
|
||||
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
|
||||
}
|
||||
|
||||
cl.R.CommitTail(pk.FixedHeader.Remaining)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// WritePacket encodes and writes a packet to the client.
|
||||
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 {
|
||||
return 0, ErrConnectionClosed
|
||||
}
|
||||
|
||||
cl.W.Mu.Lock()
|
||||
defer cl.W.Mu.Unlock()
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
switch pk.FixedHeader.Type {
|
||||
case packets.Connect:
|
||||
err = pk.ConnectEncode(buf)
|
||||
case packets.Connack:
|
||||
err = pk.ConnackEncode(buf)
|
||||
case packets.Publish:
|
||||
err = pk.PublishEncode(buf)
|
||||
if err == nil {
|
||||
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
|
||||
}
|
||||
case packets.Puback:
|
||||
err = pk.PubackEncode(buf)
|
||||
case packets.Pubrec:
|
||||
err = pk.PubrecEncode(buf)
|
||||
case packets.Pubrel:
|
||||
err = pk.PubrelEncode(buf)
|
||||
case packets.Pubcomp:
|
||||
err = pk.PubcompEncode(buf)
|
||||
case packets.Subscribe:
|
||||
err = pk.SubscribeEncode(buf)
|
||||
case packets.Suback:
|
||||
err = pk.SubackEncode(buf)
|
||||
case packets.Unsubscribe:
|
||||
err = pk.UnsubscribeEncode(buf)
|
||||
case packets.Unsuback:
|
||||
err = pk.UnsubackEncode(buf)
|
||||
case packets.Pingreq:
|
||||
err = pk.PingreqEncode(buf)
|
||||
case packets.Pingresp:
|
||||
err = pk.PingrespEncode(buf)
|
||||
case packets.Disconnect:
|
||||
err = pk.DisconnectEncode(buf)
|
||||
default:
|
||||
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Write the packet bytes to the client byte buffer.
|
||||
n, err = cl.W.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
|
||||
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
|
||||
|
||||
cl.refreshDeadline(cl.keepalive)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// LWT contains the last will and testament details for a client connection.
|
||||
type LWT struct {
|
||||
Message []byte // the message that shall be sent when the client disconnects.
|
||||
Topic string // the topic the will message shall be sent to.
|
||||
Qos byte // the quality of service desired.
|
||||
Retain bool // indicates whether the will message should be retained
|
||||
}
|
||||
|
||||
// InflightMessage contains data about a packet which is currently in-flight.
|
||||
type InflightMessage struct {
|
||||
Packet packets.Packet // the packet currently in-flight.
|
||||
Sent int64 // the last time the message was sent (for retries) in unixtime.
|
||||
Resends int // the number of times the message was attempted to be sent.
|
||||
}
|
||||
|
||||
// Inflight is a map of InflightMessage keyed on packet id.
|
||||
type Inflight struct {
|
||||
sync.RWMutex
|
||||
internal map[uint16]InflightMessage // internal contains the inflight messages.
|
||||
}
|
||||
|
||||
// Set stores the packet of an Inflight message, keyed on message id. Returns
|
||||
// true if the inflight message was new.
|
||||
func (i *Inflight) Set(key uint16, in InflightMessage) bool {
|
||||
i.Lock()
|
||||
_, ok := i.internal[key]
|
||||
i.internal[key] = in
|
||||
i.Unlock()
|
||||
return !ok
|
||||
}
|
||||
|
||||
// Get returns the value of an in-flight message if it exists.
|
||||
func (i *Inflight) Get(key uint16) (InflightMessage, bool) {
|
||||
i.RLock()
|
||||
val, ok := i.internal[key]
|
||||
i.RUnlock()
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the size of the in-flight messages map.
|
||||
func (i *Inflight) Len() int {
|
||||
i.RLock()
|
||||
v := len(i.internal)
|
||||
i.RUnlock()
|
||||
return v
|
||||
}
|
||||
|
||||
// GetAll returns all the in-flight messages.
|
||||
func (i *Inflight) GetAll() map[uint16]InflightMessage {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
return i.internal
|
||||
}
|
||||
|
||||
// Delete removes an in-flight message from the map. Returns true if the
|
||||
// message existed.
|
||||
func (i *Inflight) Delete(key uint16) bool {
|
||||
i.Lock()
|
||||
_, ok := i.internal[key]
|
||||
delete(i.internal, key)
|
||||
i.Unlock()
|
||||
return ok
|
||||
}
|
1097
server/internal/clients/clients_test.go
Normal file
1097
server/internal/clients/clients_test.go
Normal file
File diff suppressed because it is too large
Load Diff
116
server/internal/packets/codec.go
Normal file
116
server/internal/packets/codec.go
Normal file
@ -0,0 +1,116 @@
|
||||
package packets
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"unicode/utf8"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// bytesToString provides a zero-alloc, no-copy byte to string conversion.
|
||||
// via https://github.com/golang/go/issues/25484#issuecomment-391415660
|
||||
func bytesToString(bs []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&bs))
|
||||
}
|
||||
|
||||
// decodeUint16 extracts the value of two bytes from a byte array.
|
||||
func decodeUint16(buf []byte, offset int) (uint16, int, error) {
|
||||
if len(buf) < offset+2 {
|
||||
return 0, 0, ErrOffsetUintOutOfRange
|
||||
}
|
||||
|
||||
return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
|
||||
}
|
||||
|
||||
// decodeString extracts a string from a byte array, beginning at an offset.
|
||||
func decodeString(buf []byte, offset int) (string, int, error) {
|
||||
b, n, err := decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
if !validUTF8(b) {
|
||||
return "", 0, ErrOffsetStrInvalidUTF8
|
||||
}
|
||||
|
||||
return bytesToString(b), n, nil
|
||||
}
|
||||
|
||||
// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
|
||||
func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
|
||||
length, next, err := decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return make([]byte, 0, 0), 0, err
|
||||
}
|
||||
|
||||
if next+int(length) > len(buf) {
|
||||
return make([]byte, 0, 0), 0, ErrOffsetBytesOutOfRange
|
||||
}
|
||||
|
||||
// Note: there is no validUTF8() test for []byte payloads
|
||||
|
||||
return buf[next : next+int(length)], next + int(length), nil
|
||||
}
|
||||
|
||||
// decodeByte extracts the value of a byte from a byte array.
|
||||
func decodeByte(buf []byte, offset int) (byte, int, error) {
|
||||
if len(buf) <= offset {
|
||||
return 0, 0, ErrOffsetByteOutOfRange
|
||||
}
|
||||
return buf[offset], offset + 1, nil
|
||||
}
|
||||
|
||||
// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
|
||||
func decodeByteBool(buf []byte, offset int) (bool, int, error) {
|
||||
if len(buf) <= offset {
|
||||
return false, 0, ErrOffsetBoolOutOfRange
|
||||
}
|
||||
return 1&buf[offset] > 0, offset + 1, nil
|
||||
}
|
||||
|
||||
// encodeBool returns a byte instead of a bool.
|
||||
func encodeBool(b bool) byte {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
|
||||
func encodeBytes(val []byte) []byte {
|
||||
// In many circumstances the number of bytes being encoded is small.
|
||||
// Setting the cap to a low amount allows us to account for those without
|
||||
// triggering allocation growth on append unless we need to.
|
||||
buf := make([]byte, 2, 32)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(val)))
|
||||
return append(buf, val...)
|
||||
}
|
||||
|
||||
// encodeUint16 encodes a uint16 value to a byte array.
|
||||
func encodeUint16(val uint16) []byte {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, val)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeString encodes a string to a byte array.
|
||||
func encodeString(val string) []byte {
|
||||
// Like encodeBytes, we set the cap to a small number to avoid
|
||||
// triggering allocation growth on append unless we absolutely need to.
|
||||
buf := make([]byte, 2, 32)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(val)))
|
||||
return append(buf, []byte(val)...)
|
||||
}
|
||||
|
||||
// validUTF8 checks if the byte array contains valid UTF-8 characters, specifically
|
||||
// conforming to the MQTT specification requirements.
|
||||
func validUTF8(b []byte) bool {
|
||||
// [MQTT-1.4.0-1] The character data in a UTF-8 encoded string MUST be well-formed UTF-8...
|
||||
if !utf8.Valid(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
// [MQTT-1.4.0-2] A UTF-8 encoded string MUST NOT include an encoding of the null character U+0000...
|
||||
// ...
|
||||
return true
|
||||
|
||||
}
|
386
server/internal/packets/codec_test.go
Normal file
386
server/internal/packets/codec_test.go
Normal file
@ -0,0 +1,386 @@
|
||||
package packets
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBytesToString(t *testing.T) {
|
||||
b := []byte{'a', 'b', 'c'}
|
||||
require.Equal(t, "abc", bytesToString(b))
|
||||
}
|
||||
|
||||
func BenchmarkBytesToString(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
bytesToString([]byte{'a', 'b', 'c'})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeString(t *testing.T) {
|
||||
expect := []struct {
|
||||
name string
|
||||
rawBytes []byte
|
||||
result string
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: "a/b/c/d",
|
||||
},
|
||||
{
|
||||
offset: 14,
|
||||
rawBytes: []byte{
|
||||
byte(Connect << 4), 17, // Fixed header
|
||||
0, 6, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
|
||||
3, // Protocol Version
|
||||
0, // Packet Flags
|
||||
0, 30, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'h', 'e', 'y', // Client ID "zen"},
|
||||
},
|
||||
result: "hey",
|
||||
},
|
||||
{
|
||||
offset: 2,
|
||||
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
|
||||
result: "1/2/3/4/a/b/c/d/e/^/@/!",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
|
||||
result: "x/y/z",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 5,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 9,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
|
||||
shouldFail: ErrOffsetUintOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 17,
|
||||
rawBytes: []byte{
|
||||
byte(Connect << 4), 0, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
0, // Flags
|
||||
0, 20, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
0, 6, // Will Topic - MSB+LSB
|
||||
'l',
|
||||
},
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
shouldFail: ErrOffsetStrInvalidUTF8,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeString(b *testing.B) {
|
||||
in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}
|
||||
for n := 0; n < b.N; n++ {
|
||||
decodeString(in, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBytes(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result []uint8
|
||||
next int
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session)
|
||||
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
|
||||
next: 6,
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // ... truncated connect packet, only checking start
|
||||
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
|
||||
next: 6,
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
offset: 0,
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
offset: 8,
|
||||
shouldFail: ErrOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeBytes(b *testing.B) {
|
||||
in := []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}
|
||||
for n := 0; n < b.N; n++ {
|
||||
decodeBytes(in, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeByte(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint8
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
|
||||
result: uint8(0x00),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x04),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x4d),
|
||||
offset: 2,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x51),
|
||||
offset: 3,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 80, 82, 84},
|
||||
offset: 8,
|
||||
shouldFail: ErrOffsetByteOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeByte(b *testing.B) {
|
||||
in := []byte{0, 4, 77, 81, 84, 84}
|
||||
for n := 0; n < b.N; n++ {
|
||||
decodeByte(in, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUint16(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint16
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: uint16(0x07),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: uint16(0x761),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 255, 47},
|
||||
offset: 8,
|
||||
shouldFail: ErrOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+2, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeUint16(b *testing.B) {
|
||||
in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}
|
||||
for n := 0; n < b.N; n++ {
|
||||
decodeUint16(in, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeByteBool(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result bool
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0x00, 0x00},
|
||||
result: false,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0x01, 0x00},
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0x01, 0x00},
|
||||
offset: 5,
|
||||
shouldFail: ErrOffsetBoolOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, 1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeByteBool(b *testing.B) {
|
||||
in := []byte{0x00, 0x00}
|
||||
for n := 0; n < b.N; n++ {
|
||||
decodeByteBool(in, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeBool(t *testing.T) {
|
||||
result := encodeBool(true)
|
||||
require.Equal(t, byte(1), result, "Incorrect encoded value; not true")
|
||||
|
||||
result = encodeBool(false)
|
||||
require.Equal(t, byte(0), result, "Incorrect encoded value; not false")
|
||||
|
||||
// Check failure.
|
||||
result = encodeBool(false)
|
||||
require.NotEqual(t, byte(1), result, "Expected failure, incorrect encoded value")
|
||||
}
|
||||
|
||||
func BenchmarkEncodeBool(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
encodeBool(true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeBytes(t *testing.T) {
|
||||
result := encodeBytes([]byte("testing"))
|
||||
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result, "Incorrect encoded value")
|
||||
|
||||
result = encodeBytes([]byte("testing"))
|
||||
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result, "Expected failure, incorrect encoded value")
|
||||
}
|
||||
|
||||
func BenchmarkEncodeBytes(b *testing.B) {
|
||||
bb := []byte("testing")
|
||||
for n := 0; n < b.N; n++ {
|
||||
encodeBytes(bb)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeUint16(t *testing.T) {
|
||||
result := encodeUint16(0)
|
||||
require.Equal(t, []byte{0x00, 0x00}, result, "Incorrect encoded value, 0")
|
||||
|
||||
result = encodeUint16(32767)
|
||||
require.Equal(t, []byte{0x7f, 0xff}, result, "Incorrect encoded value, 32767")
|
||||
|
||||
result = encodeUint16(65535)
|
||||
require.Equal(t, []byte{0xff, 0xff}, result, "Incorrect encoded value, 65535")
|
||||
}
|
||||
|
||||
func BenchmarkEncodeUint16(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
encodeUint16(32767)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeString(t *testing.T) {
|
||||
result := encodeString("testing")
|
||||
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result, "Incorrect encoded value, testing")
|
||||
|
||||
result = encodeString("")
|
||||
require.Equal(t, []uint8{0x00, 0x00}, result, "Incorrect encoded value, null")
|
||||
|
||||
result = encodeString("a")
|
||||
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result, "Incorrect encoded value, a")
|
||||
|
||||
result = encodeString("b")
|
||||
require.NotEqual(t, []uint8{0x00, 0x00}, result, "Expected failure, incorrect encoded value, b")
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkEncodeString(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
encodeString("benchmarking")
|
||||
}
|
||||
}
|
59
server/internal/packets/fixedheader.go
Normal file
59
server/internal/packets/fixedheader.go
Normal file
@ -0,0 +1,59 @@
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
|
||||
type FixedHeader struct {
|
||||
Remaining int // the number of remaining bytes in the payload.
|
||||
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
|
||||
Qos byte // indicates the quality of service expected.
|
||||
Dup bool // indicates if the packet was already sent at an earlier time.
|
||||
Retain bool // whether the message should be retained.
|
||||
}
|
||||
|
||||
// Encode encodes the FixedHeader and returns a bytes buffer.
|
||||
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
|
||||
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
|
||||
encodeLength(buf, int64(fh.Remaining))
|
||||
}
|
||||
|
||||
// Decode extracts the specification bits from the header byte.
|
||||
func (fh *FixedHeader) Decode(headerByte byte) error {
|
||||
fh.Type = headerByte >> 4 // Get the message type from the first 4 bytes.
|
||||
|
||||
switch fh.Type {
|
||||
case Publish:
|
||||
fh.Dup = (headerByte>>3)&0x01 > 0 // Extract flags. Check if message is duplicate.
|
||||
fh.Qos = (headerByte >> 1) & 0x03 // Extract QoS flag.
|
||||
fh.Retain = headerByte&0x01 > 0 // Extract retain flag.
|
||||
case Pubrel:
|
||||
fh.Qos = (headerByte >> 1) & 0x03
|
||||
case Subscribe:
|
||||
fh.Qos = (headerByte >> 1) & 0x03
|
||||
case Unsubscribe:
|
||||
fh.Qos = (headerByte >> 1) & 0x03
|
||||
default:
|
||||
if (headerByte>>3)&0x01 > 0 || (headerByte>>1)&0x03 > 0 || headerByte&0x01 > 0 {
|
||||
return ErrInvalidFlags
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeLength writes length bits for the header.
|
||||
func encodeLength(buf *bytes.Buffer, length int64) {
|
||||
for {
|
||||
digit := byte(length % 128)
|
||||
length /= 128
|
||||
if length > 0 {
|
||||
digit |= 0x80
|
||||
}
|
||||
buf.WriteByte(digit)
|
||||
if length == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
220
server/internal/packets/fixedheader_test.go
Normal file
220
server/internal/packets/fixedheader_test.go
Normal file
@ -0,0 +1,220 @@
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fixedHeaderTable struct {
|
||||
rawBytes []byte
|
||||
header FixedHeader
|
||||
packetError bool
|
||||
flagError bool
|
||||
}
|
||||
|
||||
var fixedHeaderExpected = []fixedHeaderTable{
|
||||
{
|
||||
rawBytes: []byte{Connect << 4, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Connack << 4, 0x00},
|
||||
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Puback << 4, 0x00},
|
||||
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pubrec << 4, 0x00},
|
||||
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pubcomp << 4, 0x00},
|
||||
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Suback << 4, 0x00},
|
||||
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Unsuback << 4, 0x00},
|
||||
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pingreq << 4, 0x00},
|
||||
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pingresp << 4, 0x00},
|
||||
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Disconnect << 4, 0x00},
|
||||
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
|
||||
// remaining length
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x0a},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x80, 0x04},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
|
||||
packetError: true,
|
||||
},
|
||||
|
||||
// Invalid flags for packet
|
||||
{
|
||||
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
flagError: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
flagError: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Connect<<4 | 1, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
|
||||
flagError: true,
|
||||
},
|
||||
}
|
||||
|
||||
func TestFixedHeaderEncode(t *testing.T) {
|
||||
for i, wanted := range fixedHeaderExpected {
|
||||
buf := new(bytes.Buffer)
|
||||
wanted.header.Encode(buf)
|
||||
if wanted.flagError == false {
|
||||
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()), "Mismatched fixedheader length [i:%d] %v", i, wanted.rawBytes)
|
||||
require.EqualValues(t, wanted.rawBytes, buf.Bytes(), "Mismatched byte values [i:%d] %v", i, wanted.rawBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFixedHeaderEncode(b *testing.B) {
|
||||
buf := new(bytes.Buffer)
|
||||
for n := 0; n < b.N; n++ {
|
||||
fixedHeaderExpected[0].header.Encode(buf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixedHeaderDecode(t *testing.T) {
|
||||
for i, wanted := range fixedHeaderExpected {
|
||||
fh := new(FixedHeader)
|
||||
err := fh.Decode(wanted.rawBytes[0])
|
||||
if wanted.flagError {
|
||||
require.Error(t, err, "Expected error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
|
||||
} else {
|
||||
require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
|
||||
require.Equal(t, wanted.header.Type, fh.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes)
|
||||
require.Equal(t, wanted.header.Dup, fh.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes)
|
||||
require.Equal(t, wanted.header.Qos, fh.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes)
|
||||
require.Equal(t, wanted.header.Retain, fh.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFixedHeaderDecode(b *testing.B) {
|
||||
fh := new(FixedHeader)
|
||||
for n := 0; n < b.N; n++ {
|
||||
err := fh.Decode(fixedHeaderExpected[0].rawBytes[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeLength(t *testing.T) {
|
||||
tt := []struct {
|
||||
have int64
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
120,
|
||||
[]byte{0x78},
|
||||
},
|
||||
{
|
||||
math.MaxInt64,
|
||||
[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range tt {
|
||||
buf := new(bytes.Buffer)
|
||||
encodeLength(buf, wanted.have)
|
||||
require.Equal(t, wanted.want, buf.Bytes(), "Returned bytes should match length [i:%d] %s", i, wanted.have)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeLength(b *testing.B) {
|
||||
buf := new(bytes.Buffer)
|
||||
for n := 0; n < b.N; n++ {
|
||||
encodeLength(buf, 120)
|
||||
}
|
||||
}
|
670
server/internal/packets/packets.go
Normal file
670
server/internal/packets/packets.go
Normal file
@ -0,0 +1,670 @@
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// All of the valid packet types and their packet identifier.
|
||||
const (
|
||||
Reserved byte = iota
|
||||
Connect // 1
|
||||
Connack // 2
|
||||
Publish // 3
|
||||
Puback // 4
|
||||
Pubrec // 5
|
||||
Pubrel // 6
|
||||
Pubcomp // 7
|
||||
Subscribe // 8
|
||||
Suback // 9
|
||||
Unsubscribe // 10
|
||||
Unsuback // 11
|
||||
Pingreq // 12
|
||||
Pingresp // 13
|
||||
Disconnect // 14
|
||||
|
||||
Accepted byte = 0x00
|
||||
Failed byte = 0xFF
|
||||
CodeConnectBadProtocolVersion byte = 0x01
|
||||
CodeConnectBadClientID byte = 0x02
|
||||
CodeConnectServerUnavailable byte = 0x03
|
||||
CodeConnectBadAuthValues byte = 0x04
|
||||
CodeConnectNotAuthorised byte = 0x05
|
||||
CodeConnectNetworkError byte = 0xFE
|
||||
CodeConnectProtocolViolation byte = 0xFF
|
||||
ErrSubAckNetworkError byte = 0x80
|
||||
)
|
||||
|
||||
var (
|
||||
// CONNECT
|
||||
ErrMalformedProtocolName = errors.New("malformed packet: protocol name")
|
||||
ErrMalformedProtocolVersion = errors.New("malformed packet: protocol version")
|
||||
ErrMalformedFlags = errors.New("malformed packet: flags")
|
||||
ErrMalformedKeepalive = errors.New("malformed packet: keepalive")
|
||||
ErrMalformedClientID = errors.New("malformed packet: client id")
|
||||
ErrMalformedWillTopic = errors.New("malformed packet: will topic")
|
||||
ErrMalformedWillMessage = errors.New("malformed packet: will message")
|
||||
ErrMalformedUsername = errors.New("malformed packet: username")
|
||||
ErrMalformedPassword = errors.New("malformed packet: password")
|
||||
|
||||
// CONNACK
|
||||
ErrMalformedSessionPresent = errors.New("malformed packet: session present")
|
||||
ErrMalformedReturnCode = errors.New("malformed packet: return code")
|
||||
|
||||
// PUBLISH
|
||||
ErrMalformedTopic = errors.New("malformed packet: topic name")
|
||||
ErrMalformedPacketID = errors.New("malformed packet: packet id")
|
||||
|
||||
// SUBSCRIBE
|
||||
ErrMalformedQoS = errors.New("malformed packet: qos")
|
||||
|
||||
// PACKETS
|
||||
ErrProtocolViolation = errors.New("protocol violation")
|
||||
ErrOffsetBytesOutOfRange = errors.New("offset bytes out of range")
|
||||
ErrOffsetByteOutOfRange = errors.New("offset byte out of range")
|
||||
ErrOffsetBoolOutOfRange = errors.New("offset bool out of range")
|
||||
ErrOffsetUintOutOfRange = errors.New("offset uint out of range")
|
||||
ErrOffsetStrInvalidUTF8 = errors.New("offset string invalid utf8")
|
||||
ErrInvalidFlags = errors.New("invalid flags set for packet")
|
||||
ErrOversizedLengthIndicator = errors.New("protocol violation: oversized length indicator")
|
||||
ErrMissingPacketID = errors.New("missing packet id")
|
||||
ErrSurplusPacketID = errors.New("surplus packet id")
|
||||
)
|
||||
|
||||
// Packet is an MQTT packet. Instead of providing a packet interface and variant
|
||||
// packet structs, this is a single concrete packet type to cover all packet
|
||||
// types, which allows us to take advantage of various compiler optimizations.
|
||||
type Packet struct {
|
||||
FixedHeader FixedHeader
|
||||
AllowClients []string // For use with OnMessage event hook.
|
||||
Topics []string
|
||||
ReturnCodes []byte
|
||||
ProtocolName []byte
|
||||
Qoss []byte
|
||||
Payload []byte
|
||||
Username []byte
|
||||
Password []byte
|
||||
WillMessage []byte
|
||||
ClientIdentifier string
|
||||
TopicName string
|
||||
WillTopic string
|
||||
PacketID uint16
|
||||
Keepalive uint16
|
||||
ReturnCode byte
|
||||
ProtocolVersion byte
|
||||
WillQos byte
|
||||
ReservedBit byte
|
||||
CleanSession bool
|
||||
WillFlag bool
|
||||
WillRetain bool
|
||||
UsernameFlag bool
|
||||
PasswordFlag bool
|
||||
SessionPresent bool
|
||||
}
|
||||
|
||||
// ConnectEncode encodes a connect packet.
|
||||
func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
|
||||
protoName := encodeBytes(pk.ProtocolName)
|
||||
protoVersion := pk.ProtocolVersion
|
||||
flag := encodeBool(pk.CleanSession)<<1 | encodeBool(pk.WillFlag)<<2 | pk.WillQos<<3 | encodeBool(pk.WillRetain)<<5 | encodeBool(pk.PasswordFlag)<<6 | encodeBool(pk.UsernameFlag)<<7
|
||||
keepalive := encodeUint16(pk.Keepalive)
|
||||
clientID := encodeString(pk.ClientIdentifier)
|
||||
|
||||
var willTopic, willFlag, usernameFlag, passwordFlag []byte
|
||||
|
||||
// If will flag is set, add topic and message.
|
||||
if pk.WillFlag {
|
||||
willTopic = encodeString(pk.WillTopic)
|
||||
willFlag = encodeBytes(pk.WillMessage)
|
||||
}
|
||||
|
||||
// If username flag is set, add username.
|
||||
if pk.UsernameFlag {
|
||||
usernameFlag = encodeBytes(pk.Username)
|
||||
}
|
||||
|
||||
// If password flag is set, add password.
|
||||
if pk.PasswordFlag {
|
||||
passwordFlag = encodeBytes(pk.Password)
|
||||
}
|
||||
|
||||
// Get a length for the connect header. This is not super pretty, but it works.
|
||||
pk.FixedHeader.Remaining =
|
||||
len(protoName) + 1 + 1 + len(keepalive) + len(clientID) +
|
||||
len(willTopic) + len(willFlag) +
|
||||
len(usernameFlag) + len(passwordFlag)
|
||||
|
||||
pk.FixedHeader.Encode(buf)
|
||||
|
||||
// Eschew magic for readability.
|
||||
buf.Write(protoName)
|
||||
buf.WriteByte(protoVersion)
|
||||
buf.WriteByte(flag)
|
||||
buf.Write(keepalive)
|
||||
buf.Write(clientID)
|
||||
buf.Write(willTopic)
|
||||
buf.Write(willFlag)
|
||||
buf.Write(usernameFlag)
|
||||
buf.Write(passwordFlag)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectDecode decodes a connect packet.
|
||||
func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
// Unpack protocol name and version.
|
||||
pk.ProtocolName, offset, err = decodeBytes(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolName)
|
||||
}
|
||||
|
||||
pk.ProtocolVersion, offset, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolVersion)
|
||||
}
|
||||
// Unpack flags byte.
|
||||
flags, offset, err := decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedFlags)
|
||||
}
|
||||
pk.ReservedBit = 1 & flags
|
||||
pk.CleanSession = 1&(flags>>1) > 0
|
||||
pk.WillFlag = 1&(flags>>2) > 0
|
||||
pk.WillQos = 3 & (flags >> 3) // this one is not a bool
|
||||
pk.WillRetain = 1&(flags>>5) > 0
|
||||
pk.PasswordFlag = 1&(flags>>6) > 0
|
||||
pk.UsernameFlag = 1&(flags>>7) > 0
|
||||
|
||||
// Get keepalive interval.
|
||||
pk.Keepalive, offset, err = decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedKeepalive)
|
||||
}
|
||||
|
||||
// Get client ID.
|
||||
pk.ClientIdentifier, offset, err = decodeString(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedClientID)
|
||||
}
|
||||
|
||||
// Get Last Will and Testament topic and message if applicable.
|
||||
if pk.WillFlag {
|
||||
pk.WillTopic, offset, err = decodeString(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedWillTopic)
|
||||
}
|
||||
|
||||
pk.WillMessage, offset, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedWillMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// Get username and password if applicable.
|
||||
if pk.UsernameFlag {
|
||||
pk.Username, offset, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedUsername)
|
||||
}
|
||||
}
|
||||
|
||||
if pk.PasswordFlag {
|
||||
pk.Password, _, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// ConnectValidate ensures the connect packet is compliant.
|
||||
func (pk *Packet) ConnectValidate() (b byte, err error) {
|
||||
|
||||
// End if protocol name is bad.
|
||||
if bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) != 0 &&
|
||||
bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) != 0 {
|
||||
return CodeConnectProtocolViolation, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if protocol version is bad.
|
||||
if (bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) == 0 && pk.ProtocolVersion != 3) ||
|
||||
(bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) == 0 && pk.ProtocolVersion != 4) {
|
||||
return CodeConnectBadProtocolVersion, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if reserved bit is not 0.
|
||||
if pk.ReservedBit != 0 {
|
||||
return CodeConnectProtocolViolation, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if ClientID is too long.
|
||||
if len(pk.ClientIdentifier) > 65535 {
|
||||
return CodeConnectProtocolViolation, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if password flag is set without a username.
|
||||
if pk.PasswordFlag && !pk.UsernameFlag {
|
||||
return CodeConnectProtocolViolation, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if Username or Password is too long.
|
||||
if len(pk.Username) > 65535 || len(pk.Password) > 65535 {
|
||||
return CodeConnectProtocolViolation, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if client id isn't set and clean session is false.
|
||||
if !pk.CleanSession && len(pk.ClientIdentifier) == 0 {
|
||||
return CodeConnectBadClientID, ErrProtocolViolation
|
||||
}
|
||||
|
||||
return Accepted, nil
|
||||
}
|
||||
|
||||
// ConnackEncode encodes a Connack packet.
|
||||
func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.WriteByte(encodeBool(pk.SessionPresent))
|
||||
buf.WriteByte(pk.ReturnCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnackDecode decodes a Connack packet.
|
||||
func (pk *Packet) ConnackDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
pk.SessionPresent, offset, err = decodeByteBool(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
|
||||
}
|
||||
|
||||
pk.ReturnCode, _, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisconnectEncode encodes a Disconnect packet.
|
||||
func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Encode(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PingreqEncode encodes a Pingreq packet.
|
||||
func (pk *Packet) PingreqEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Encode(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PingrespEncode encodes a Pingresp packet.
|
||||
func (pk *Packet) PingrespEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Encode(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubackEncode encodes a Puback packet.
|
||||
func (pk *Packet) PubackEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(encodeUint16(pk.PacketID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubackDecode decodes a Puback packet.
|
||||
func (pk *Packet) PubackDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubcompEncode encodes a Pubcomp packet.
|
||||
func (pk *Packet) PubcompEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(encodeUint16(pk.PacketID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubcompDecode decodes a Pubcomp packet.
|
||||
func (pk *Packet) PubcompDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishEncode encodes a Publish packet.
|
||||
func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
|
||||
topicName := encodeString(pk.TopicName)
|
||||
var packetID []byte
|
||||
|
||||
// Add PacketID if QOS is set.
|
||||
// [MQTT-2.3.1-5] A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0.
|
||||
if pk.FixedHeader.Qos > 0 {
|
||||
|
||||
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.PacketID == 0 {
|
||||
return ErrMissingPacketID
|
||||
}
|
||||
|
||||
packetID = encodeUint16(pk.PacketID)
|
||||
}
|
||||
|
||||
pk.FixedHeader.Remaining = len(topicName) + len(packetID) + len(pk.Payload)
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(topicName)
|
||||
buf.Write(packetID)
|
||||
buf.Write(pk.Payload)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishDecode extracts the data values from the packet.
|
||||
func (pk *Packet) PublishDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
pk.TopicName, offset, err = decodeString(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
|
||||
}
|
||||
|
||||
// If QOS decode Packet ID.
|
||||
if pk.FixedHeader.Qos > 0 {
|
||||
pk.PacketID, offset, err = decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
}
|
||||
|
||||
pk.Payload = buf[offset:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishCopy creates a new instance of Publish packet bearing the
|
||||
// same payload and destination topic, but with an empty header for
|
||||
// inheriting new QoS flags, etc.
|
||||
func (pk *Packet) PublishCopy() Packet {
|
||||
return Packet{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Publish,
|
||||
Retain: pk.FixedHeader.Retain,
|
||||
},
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
}
|
||||
}
|
||||
|
||||
// PublishValidate validates a publish packet.
|
||||
func (pk *Packet) PublishValidate() (byte, error) {
|
||||
|
||||
// @SPEC [MQTT-2.3.1-1]
|
||||
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
|
||||
return Failed, ErrMissingPacketID
|
||||
}
|
||||
|
||||
// @SPEC [MQTT-2.3.1-5]
|
||||
// A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0.
|
||||
if pk.FixedHeader.Qos == 0 && pk.PacketID > 0 {
|
||||
return Failed, ErrSurplusPacketID
|
||||
}
|
||||
|
||||
return Accepted, nil
|
||||
}
|
||||
|
||||
// PubrecEncode encodes a Pubrec packet.
|
||||
func (pk *Packet) PubrecEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(encodeUint16(pk.PacketID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubrecDecode decodes a Pubrec packet.
|
||||
func (pk *Packet) PubrecDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubrelEncode encodes a Pubrel packet.
|
||||
func (pk *Packet) PubrelEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(encodeUint16(pk.PacketID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubrelDecode decodes a Pubrel packet.
|
||||
func (pk *Packet) PubrelDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubackEncode encodes a Suback packet.
|
||||
func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
|
||||
packetID := encodeUint16(pk.PacketID)
|
||||
pk.FixedHeader.Remaining = len(packetID) + len(pk.ReturnCodes) // Set length.
|
||||
pk.FixedHeader.Encode(buf)
|
||||
|
||||
buf.Write(packetID) // Encode Packet ID.
|
||||
buf.Write(pk.ReturnCodes) // Encode granted QOS flags.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubackDecode decodes a Suback packet.
|
||||
func (pk *Packet) SubackDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
// Get Packet ID.
|
||||
pk.PacketID, offset, err = decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
// Get Granted QOS flags.
|
||||
pk.ReturnCodes = buf[offset:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubscribeEncode encodes a Subscribe packet.
|
||||
func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
// Add the Packet ID.
|
||||
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.PacketID == 0 {
|
||||
return ErrMissingPacketID
|
||||
}
|
||||
|
||||
packetID := encodeUint16(pk.PacketID)
|
||||
|
||||
// Count topics lengths and associated QOS flags.
|
||||
var topicsLen int
|
||||
for _, topic := range pk.Topics {
|
||||
topicsLen += len(encodeString(topic)) + 1
|
||||
}
|
||||
|
||||
pk.FixedHeader.Remaining = len(packetID) + topicsLen
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(packetID)
|
||||
|
||||
// Add all provided topic names and associated QOS flags.
|
||||
for i, topic := range pk.Topics {
|
||||
buf.Write(encodeString(topic))
|
||||
buf.WriteByte(pk.Qoss[i])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubscribeDecode decodes a Subscribe packet.
|
||||
func (pk *Packet) SubscribeDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
// Get the Packet ID.
|
||||
pk.PacketID, offset, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
// Keep decoding until there's no space left.
|
||||
for offset < len(buf) {
|
||||
|
||||
// Decode Topic Name.
|
||||
var topic string
|
||||
topic, offset, err = decodeString(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
|
||||
}
|
||||
pk.Topics = append(pk.Topics, topic)
|
||||
|
||||
// Decode QOS flag.
|
||||
var qos byte
|
||||
qos, offset, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedQoS)
|
||||
}
|
||||
|
||||
// Ensure QoS byte is within range.
|
||||
if !(qos >= 0 && qos <= 2) {
|
||||
//if !validateQoS(qos) {
|
||||
return ErrMalformedQoS
|
||||
}
|
||||
|
||||
pk.Qoss = append(pk.Qoss, qos)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubscribeValidate ensures the packet is compliant.
|
||||
func (pk *Packet) SubscribeValidate() (byte, error) {
|
||||
// @SPEC [MQTT-2.3.1-1].
|
||||
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
|
||||
return Failed, ErrMissingPacketID
|
||||
}
|
||||
|
||||
return Accepted, nil
|
||||
}
|
||||
|
||||
// UnsubackEncode encodes an Unsuback packet.
|
||||
func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(encodeUint16(pk.PacketID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnsubackDecode decodes an Unsuback packet.
|
||||
func (pk *Packet) UnsubackDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnsubscribeEncode encodes an Unsubscribe packet.
|
||||
func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
// Add the Packet ID.
|
||||
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.PacketID == 0 {
|
||||
return ErrMissingPacketID
|
||||
}
|
||||
|
||||
packetID := encodeUint16(pk.PacketID)
|
||||
|
||||
// Count topics lengths.
|
||||
var topicsLen int
|
||||
for _, topic := range pk.Topics {
|
||||
topicsLen += len(encodeString(topic))
|
||||
}
|
||||
|
||||
pk.FixedHeader.Remaining = len(packetID) + topicsLen
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(packetID)
|
||||
|
||||
// Add all provided topic names.
|
||||
for _, topic := range pk.Topics {
|
||||
buf.Write(encodeString(topic))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnsubscribeDecode decodes an Unsubscribe packet.
|
||||
func (pk *Packet) UnsubscribeDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
// Get the Packet ID.
|
||||
pk.PacketID, offset, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
// Keep decoding until there's no space left.
|
||||
for offset < len(buf) {
|
||||
var t string
|
||||
t, offset, err = decodeString(buf, offset) // Decode Topic Name.
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
|
||||
}
|
||||
|
||||
if len(t) > 0 {
|
||||
pk.Topics = append(pk.Topics, t)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// UnsubscribeValidate validates an Unsubscribe packet.
|
||||
func (pk *Packet) UnsubscribeValidate() (byte, error) {
|
||||
// @SPEC [MQTT-2.3.1-1].
|
||||
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
|
||||
return Failed, ErrMissingPacketID
|
||||
}
|
||||
|
||||
return Accepted, nil
|
||||
}
|
||||
|
||||
// FormatID returns the PacketID field as a decimal integer.
|
||||
func (pk *Packet) FormatID() string {
|
||||
return strconv.FormatUint(uint64(pk.PacketID), 10)
|
||||
}
|
1416
server/internal/packets/packets_tables_test.go
Normal file
1416
server/internal/packets/packets_tables_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1091
server/internal/packets/packets_test.go
Normal file
1091
server/internal/packets/packets_test.go
Normal file
File diff suppressed because it is too large
Load Diff
344
server/internal/topics/trie.go
Normal file
344
server/internal/topics/trie.go
Normal file
@ -0,0 +1,344 @@
|
||||
package topics
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/internal/packets"
|
||||
)
|
||||
|
||||
// Subscriptions is a map of subscriptions keyed on client.
|
||||
type Subscriptions map[string]byte
|
||||
|
||||
// Index is a prefix/trie tree containing topic subscribers and retained messages.
|
||||
type Index struct {
|
||||
mu sync.RWMutex // a mutex for locking the whole index.
|
||||
Root *Leaf // a leaf containing a message and more leaves.
|
||||
}
|
||||
|
||||
// New returns a pointer to a new instance of Index.
|
||||
func New() *Index {
|
||||
return &Index{
|
||||
Root: &Leaf{
|
||||
Leaves: make(map[string]*Leaf),
|
||||
Clients: make(map[string]byte),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RetainMessage saves a message payload to the end of a topic branch. Returns
|
||||
// 1 if a retained message was added, and -1 if the retained message was removed.
|
||||
// 0 is returned if sequential empty payloads are received.
|
||||
func (x *Index) RetainMessage(msg packets.Packet) int64 {
|
||||
x.mu.Lock()
|
||||
defer x.mu.Unlock()
|
||||
n := x.poperate(msg.TopicName)
|
||||
|
||||
// If there is a payload, we can store it.
|
||||
if len(msg.Payload) > 0 {
|
||||
n.Message = msg
|
||||
return 1
|
||||
}
|
||||
|
||||
// Otherwise, we are unsetting it.
|
||||
// If there was a previous retained message, return -1 instead of 0.
|
||||
var r int64 = 0
|
||||
if len(n.Message.Payload) > 0 && n.Message.FixedHeader.Retain == true {
|
||||
r = -1
|
||||
}
|
||||
x.unpoperate(msg.TopicName, "", true)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// Subscribe creates a subscription filter for a client. Returns true if the
|
||||
// subscription was new.
|
||||
func (x *Index) Subscribe(filter, client string, qos byte) bool {
|
||||
x.mu.Lock()
|
||||
defer x.mu.Unlock()
|
||||
|
||||
n := x.poperate(filter)
|
||||
_, ok := n.Clients[client]
|
||||
n.Clients[client] = qos
|
||||
n.Filter = filter
|
||||
|
||||
return !ok
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscription filter for a client. Returns true if an
|
||||
// unsubscribe action successful and the subscription existed.
|
||||
func (x *Index) Unsubscribe(filter, client string) bool {
|
||||
x.mu.Lock()
|
||||
defer x.mu.Unlock()
|
||||
|
||||
n := x.poperate(filter)
|
||||
_, ok := n.Clients[client]
|
||||
|
||||
return x.unpoperate(filter, client, false) && ok
|
||||
}
|
||||
|
||||
// unpoperate steps backward through a trie sequence and removes any orphaned
|
||||
// nodes. If a client id is specified, it will unsubscribe a client. If message
|
||||
// is true, it will delete a retained message.
|
||||
func (x *Index) unpoperate(filter string, client string, message bool) bool {
|
||||
var d int // Walk to end leaf.
|
||||
var particle string
|
||||
var hasNext = true
|
||||
e := x.Root
|
||||
for hasNext {
|
||||
particle, hasNext = isolateParticle(filter, d)
|
||||
d++
|
||||
e, _ = e.Leaves[particle]
|
||||
|
||||
// If the topic part doesn't exist in the tree, there's nothing
|
||||
// left to do.
|
||||
if e == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Step backward removing client and orphaned leaves.
|
||||
var key string
|
||||
var orphaned bool
|
||||
var end = true
|
||||
for e.Parent != nil {
|
||||
key = e.Key
|
||||
|
||||
// Wipe the client from this leaf if it's the filter end.
|
||||
if end {
|
||||
if client != "" {
|
||||
delete(e.Clients, client)
|
||||
}
|
||||
if message {
|
||||
e.Message = packets.Packet{}
|
||||
}
|
||||
end = false
|
||||
}
|
||||
|
||||
// If this leaf is empty, note it as orphaned.
|
||||
orphaned = len(e.Clients) == 0 && len(e.Leaves) == 0 && !e.Message.FixedHeader.Retain
|
||||
|
||||
// Traverse up the branch.
|
||||
e = e.Parent
|
||||
|
||||
// If the leaf we just came from was empty, delete it.
|
||||
if orphaned {
|
||||
delete(e.Leaves, key)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
// poperate iterates and populates through a topic/filter path, instantiating
|
||||
// leaves as it goes and returning the final leaf in the branch.
|
||||
// poperate is a more enjoyable word than iterpop.
|
||||
func (x *Index) poperate(topic string) *Leaf {
|
||||
var d int
|
||||
var particle string
|
||||
var hasNext = true
|
||||
n := x.Root
|
||||
for hasNext {
|
||||
particle, hasNext = isolateParticle(topic, d)
|
||||
d++
|
||||
|
||||
child, _ := n.Leaves[particle]
|
||||
if child == nil {
|
||||
child = &Leaf{
|
||||
Key: particle,
|
||||
Parent: n,
|
||||
Leaves: make(map[string]*Leaf),
|
||||
Clients: make(map[string]byte),
|
||||
}
|
||||
n.Leaves[particle] = child
|
||||
}
|
||||
n = child
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// Subscribers returns a map of clients who are subscribed to matching filters.
|
||||
func (x *Index) Subscribers(topic string) Subscriptions {
|
||||
x.mu.RLock()
|
||||
defer x.mu.RUnlock()
|
||||
return x.Root.scanSubscribers(topic, 0, make(Subscriptions))
|
||||
}
|
||||
|
||||
// Messages returns a slice of retained topic messages which match a filter.
|
||||
func (x *Index) Messages(filter string) []packets.Packet {
|
||||
// ReLeaf("messages", x.Root, 0)
|
||||
x.mu.RLock()
|
||||
defer x.mu.RUnlock()
|
||||
return x.Root.scanMessages(filter, 0, make([]packets.Packet, 0, 32))
|
||||
}
|
||||
|
||||
// Leaf is a child node on the tree.
|
||||
type Leaf struct {
|
||||
Message packets.Packet // a message which has been retained for a specific topic.
|
||||
Key string // the key that was used to create the leaf.
|
||||
Filter string // the path of the topic filter being matched.
|
||||
Parent *Leaf // a pointer to the parent node for the leaf.
|
||||
Leaves map[string]*Leaf // a map of child nodes, keyed on particle id.
|
||||
Clients map[string]byte // a map of client ids subscribed to the topic.
|
||||
}
|
||||
|
||||
// scanSubscribers recursively steps through a branch of leaves finding clients who
|
||||
// have subscription filters matching a topic, and their highest QoS byte.
|
||||
func (l *Leaf) scanSubscribers(topic string, d int, clients Subscriptions) Subscriptions {
|
||||
part, hasNext := isolateParticle(topic, d)
|
||||
|
||||
// For either the topic part, a +, or a #, follow the branch.
|
||||
for _, particle := range []string{part, "+", "#"} {
|
||||
|
||||
// Topics beginning with the reserved $ character are restricted from
|
||||
// being returned for top level wildcards.
|
||||
if d == 0 && len(part) > 0 && part[0] == '$' && (particle == "+" || particle == "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
if child, ok := l.Leaves[particle]; ok {
|
||||
|
||||
// We're only interested in getting clients from the final
|
||||
// element in the topic, or those with wildhashes.
|
||||
if !hasNext || particle == "#" {
|
||||
|
||||
// Capture the highest QOS byte for any client with a filter
|
||||
// matching the topic.
|
||||
for client, qos := range child.Clients {
|
||||
if ex, ok := clients[client]; !ok || ex < qos {
|
||||
clients[client] = qos
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we also capture any client who are listening
|
||||
// to this topic via path/#
|
||||
if !hasNext {
|
||||
if extra, ok := child.Leaves["#"]; ok {
|
||||
for client, qos := range extra.Clients {
|
||||
if ex, ok := clients[client]; !ok || ex < qos {
|
||||
clients[client] = qos
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If this branch has hit a wildhash, just return immediately.
|
||||
if particle == "#" {
|
||||
return clients
|
||||
} else if hasNext {
|
||||
clients = child.scanSubscribers(topic, d+1, clients)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clients
|
||||
}
|
||||
|
||||
// scanMessages recursively steps through a branch of leaves finding retained messages
|
||||
// that match a topic filter. Setting `d` to -1 will enable wildhash mode, and will
|
||||
// recursively check ALL child leaves in every subsequent branch.
|
||||
func (l *Leaf) scanMessages(filter string, d int, messages []packets.Packet) []packets.Packet {
|
||||
|
||||
// If a wildhash mode has been set, continue recursively checking through all
|
||||
// child leaves regardless of their particle key.
|
||||
if d == -1 {
|
||||
for _, child := range l.Leaves {
|
||||
if child.Message.FixedHeader.Retain {
|
||||
messages = append(messages, child.Message)
|
||||
}
|
||||
messages = child.scanMessages(filter, -1, messages)
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
// Otherwise, we'll get the particle for d in the filter.
|
||||
particle, hasNext := isolateParticle(filter, d)
|
||||
|
||||
// If there's no more particles after this one, then take the messages from
|
||||
// these topics.
|
||||
if !hasNext {
|
||||
|
||||
// Wildcards and Wildhashes must be checked first, otherwise they
|
||||
// may be detected as standard particles, and not act properly.
|
||||
if particle == "+" || particle == "#" {
|
||||
|
||||
// Otherwise, if it's a wildcard or wildhash, get messages from all
|
||||
// the child leaves. This wildhash captures messages on the actual
|
||||
// wildhash position, whereas the d == -1 block collects subsequent
|
||||
// messages further down the branch.
|
||||
for _, child := range l.Leaves {
|
||||
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
|
||||
continue
|
||||
}
|
||||
if child.Message.FixedHeader.Retain {
|
||||
messages = append(messages, child.Message)
|
||||
}
|
||||
}
|
||||
} else if child, ok := l.Leaves[particle]; ok {
|
||||
if child.Message.FixedHeader.Retain {
|
||||
messages = append(messages, child.Message)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
// If it's not the last particle, branch out to the next leaves, scanning
|
||||
// all available if it's a wildcard, or just one if it's a specific particle.
|
||||
if particle == "+" {
|
||||
for _, child := range l.Leaves {
|
||||
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
|
||||
continue
|
||||
}
|
||||
messages = child.scanMessages(filter, d+1, messages)
|
||||
}
|
||||
} else if child, ok := l.Leaves[particle]; ok {
|
||||
messages = child.scanMessages(filter, d+1, messages)
|
||||
}
|
||||
}
|
||||
|
||||
// If the particle was a wildhash, scan all the child leaves setting the
|
||||
// d value to wildhash mode.
|
||||
if particle == "#" {
|
||||
for _, child := range l.Leaves {
|
||||
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
|
||||
continue
|
||||
}
|
||||
messages = child.scanMessages(filter, -1, messages)
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// isolateParticle extracts a particle between d / and d+1 / without allocations.
|
||||
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
|
||||
var next, end int
|
||||
for i := 0; end > -1 && i <= d; i++ {
|
||||
end = strings.IndexRune(filter, '/')
|
||||
if d > -1 && i == d && end > -1 {
|
||||
hasNext = true
|
||||
particle = filter[next:end]
|
||||
} else if end > -1 {
|
||||
hasNext = false
|
||||
filter = filter[end+1:]
|
||||
} else {
|
||||
hasNext = false
|
||||
particle = filter[next:]
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ReLeaf is a dev function for showing the trie leafs.
|
||||
/*
|
||||
func ReLeaf(m string, leaf *Leaf, d int) {
|
||||
for k, v := range leaf.Leaves {
|
||||
fmt.Println(m, d, strings.Repeat(" ", d), k)
|
||||
ReLeaf(m, v, d+1)
|
||||
}
|
||||
}
|
||||
*/
|
494
server/internal/topics/trie_test.go
Normal file
494
server/internal/topics/trie_test.go
Normal file
@ -0,0 +1,494 @@
|
||||
package topics
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/internal/packets"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
index := New()
|
||||
require.NotNil(t, index)
|
||||
require.NotNil(t, index.Root)
|
||||
}
|
||||
|
||||
func BenchmarkNew(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
New()
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoperate(t *testing.T) {
|
||||
index := New()
|
||||
child := index.poperate("path/to/my/mqtt")
|
||||
require.Equal(t, "mqtt", child.Key)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
|
||||
|
||||
child = index.poperate("a/b/c/d/e")
|
||||
require.Equal(t, "e", child.Key)
|
||||
child = index.poperate("a/b/c/c/a")
|
||||
require.Equal(t, "a", child.Key)
|
||||
}
|
||||
|
||||
func BenchmarkPoperate(b *testing.B) {
|
||||
index := New()
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.poperate("path/to/my/mqtt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnpoperate(t *testing.T) {
|
||||
index := New()
|
||||
index.Subscribe("path/to/my/mqtt", "client-1", 0)
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
|
||||
|
||||
index.Subscribe("path/to/another/mqtt", "client-1", 0)
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
|
||||
|
||||
pk := packets.Packet{TopicName: "path/to/retained/message", Payload: []byte{'h', 'e', 'l', 'l', 'o'}}
|
||||
index.RetainMessage(pk)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["retained"].Leaves["message"])
|
||||
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["retained"].Leaves["message"].Message)
|
||||
|
||||
pk2 := packets.Packet{TopicName: "path/to/my/mqtt", Payload: []byte{'s', 'h', 'a', 'r', 'e', 'd'}}
|
||||
index.RetainMessage(pk2)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
|
||||
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
|
||||
|
||||
index.unpoperate("path/to/my/mqtt", "", true) // delete retained
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
|
||||
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message.FixedHeader.Retain)
|
||||
|
||||
index.unpoperate("path/to/my/mqtt", "client-1", false) // unsubscribe client
|
||||
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
|
||||
|
||||
index.unpoperate("path/to/retained/message", "", true) // delete retained
|
||||
require.NotContains(t, index.Root.Leaves["path"].Leaves["to"].Leaves, "my")
|
||||
|
||||
index.unpoperate("path/to/whatever", "client-1", false) // unsubscribe client
|
||||
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
|
||||
|
||||
//require.Empty(t, index.Root.Leaves["path"])
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkUnpoperate(b *testing.B) {
|
||||
index := New()
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.poperate("path/to/my/mqtt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetainMessage(t *testing.T) {
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
TopicName: "path/to/my/mqtt",
|
||||
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
|
||||
}
|
||||
pk2 := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
TopicName: "path/to/another/mqtt",
|
||||
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
|
||||
}
|
||||
|
||||
index := New()
|
||||
q := index.RetainMessage(pk)
|
||||
require.Equal(t, int64(1), q)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
|
||||
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
|
||||
|
||||
index.Subscribe("path/to/another/mqtt", "client-1", 0)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients["client-1"])
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
|
||||
|
||||
q = index.RetainMessage(pk2)
|
||||
require.Equal(t, int64(1), q)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
|
||||
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
|
||||
|
||||
// The same message already exists, but we're not doing a deep-copy check, so it's considered
|
||||
// to be a new message.
|
||||
q = index.RetainMessage(pk2)
|
||||
require.Equal(t, int64(1), q)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
|
||||
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
|
||||
|
||||
// Delete retained
|
||||
pk3 := packets.Packet{TopicName: "path/to/another/mqtt", Payload: []byte{}}
|
||||
q = index.RetainMessage(pk3)
|
||||
require.Equal(t, int64(-1), q)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
|
||||
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
|
||||
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
|
||||
|
||||
// Second Delete retained
|
||||
q = index.RetainMessage(pk3)
|
||||
require.Equal(t, int64(0), q)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
|
||||
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
|
||||
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkRetainMessage(b *testing.B) {
|
||||
index := New()
|
||||
pk := packets.Packet{TopicName: "path/to/another/mqtt"}
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.RetainMessage(pk)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribeOK(t *testing.T) {
|
||||
index := New()
|
||||
|
||||
q := index.Subscribe("path/to/my/mqtt", "client-1", 0)
|
||||
require.Equal(t, true, q)
|
||||
|
||||
q = index.Subscribe("path/to/my/mqtt", "client-1", 0)
|
||||
require.Equal(t, false, q)
|
||||
|
||||
q = index.Subscribe("path/to/my/mqtt", "client-2", 0)
|
||||
require.Equal(t, true, q)
|
||||
|
||||
q = index.Subscribe("path/to/another/mqtt", "client-1", 0)
|
||||
require.Equal(t, true, q)
|
||||
|
||||
q = index.Subscribe("path/+", "client-2", 0)
|
||||
require.Equal(t, true, q)
|
||||
|
||||
q = index.Subscribe("#", "client-3", 0)
|
||||
require.Equal(t, true, q)
|
||||
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
|
||||
require.Equal(t, "path/to/my/mqtt", index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Filter)
|
||||
require.Equal(t, "mqtt", index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Key)
|
||||
require.Equal(t, index.Root.Leaves["path"], index.Root.Leaves["path"].Leaves["to"].Parent)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-2")
|
||||
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["+"].Clients, "client-2")
|
||||
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
|
||||
}
|
||||
|
||||
func BenchmarkSubscribe(b *testing.B) {
|
||||
index := New()
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Subscribe("path/to/mqtt/basic", "client-1", 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribeA(t *testing.T) {
|
||||
index := New()
|
||||
index.Subscribe("path/to/my/mqtt", "client-1", 0)
|
||||
index.Subscribe("path/to/+/mqtt", "client-1", 0)
|
||||
index.Subscribe("path/to/stuff", "client-1", 0)
|
||||
index.Subscribe("path/to/stuff", "client-2", 0)
|
||||
index.Subscribe("#", "client-3", 0)
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["+"].Leaves["mqtt"].Clients, "client-1")
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-1")
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-2")
|
||||
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
|
||||
|
||||
ok := index.Unsubscribe("path/to/my/mqtt", "client-1")
|
||||
require.Equal(t, true, ok)
|
||||
|
||||
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["+"].Leaves["mqtt"].Clients, "client-1")
|
||||
|
||||
ok = index.Unsubscribe("path/to/stuff", "client-1")
|
||||
require.Equal(t, true, ok)
|
||||
|
||||
require.NotContains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-1")
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-2")
|
||||
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
|
||||
|
||||
ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "client-1")
|
||||
require.Equal(t, false, ok)
|
||||
|
||||
}
|
||||
|
||||
func TestUnsubscribeCascade(t *testing.T) {
|
||||
index := New()
|
||||
index.Subscribe("a/b/c", "client-1", 0)
|
||||
index.Subscribe("a/b/c/e/e", "client-1", 0)
|
||||
|
||||
ok := index.Unsubscribe("a/b/c/e/e", "client-1")
|
||||
require.Equal(t, true, ok)
|
||||
require.NotEmpty(t, index.Root.Leaves)
|
||||
require.Contains(t, index.Root.Leaves["a"].Leaves["b"].Leaves["c"].Clients, "client-1")
|
||||
}
|
||||
|
||||
// This benchmark is Unsubscribe-Subscribe
|
||||
func BenchmarkUnsubscribe(b *testing.B) {
|
||||
index := New()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Subscribe("path/to/my/mqtt", "client-1", 0)
|
||||
index.Unsubscribe("path/to/mqtt/basic", "client-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribersFind(t *testing.T) {
|
||||
tt := []struct {
|
||||
filter string
|
||||
topic string
|
||||
len int
|
||||
}{
|
||||
{
|
||||
filter: "a",
|
||||
topic: "a",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "a/",
|
||||
topic: "a",
|
||||
len: 0,
|
||||
},
|
||||
{
|
||||
filter: "a/",
|
||||
topic: "a/",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "/a",
|
||||
topic: "/a",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "path/to/my/mqtt",
|
||||
topic: "path/to/my/mqtt",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "path/to/+/mqtt",
|
||||
topic: "path/to/my/mqtt",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "+/to/+/mqtt",
|
||||
topic: "path/to/my/mqtt",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "#",
|
||||
topic: "path/to/my/mqtt",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "+/+/+/+",
|
||||
topic: "path/to/my/mqtt",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "+/+/+/#",
|
||||
topic: "path/to/my/mqtt",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "zen/#",
|
||||
topic: "zen",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "+/+/#",
|
||||
topic: "path/to/my/mqtt",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "path/to/",
|
||||
topic: "path/to/my/mqtt",
|
||||
len: 0,
|
||||
},
|
||||
{
|
||||
filter: "#/stuff",
|
||||
topic: "path/to/my/mqtt",
|
||||
len: 0,
|
||||
},
|
||||
{
|
||||
filter: "$SYS/#",
|
||||
topic: "$SYS/info",
|
||||
len: 1,
|
||||
},
|
||||
{
|
||||
filter: "#",
|
||||
topic: "$SYS/info",
|
||||
len: 0,
|
||||
},
|
||||
{
|
||||
filter: "+/info",
|
||||
topic: "$SYS/info",
|
||||
len: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for i, check := range tt {
|
||||
index := New()
|
||||
index.Subscribe(check.filter, "client-1", 0)
|
||||
clients := index.Subscribers(check.topic)
|
||||
//spew.Dump(clients)
|
||||
require.Equal(t, check.len, len(clients), "Unexpected clients len at %d %s %s", i, check.filter, check.topic)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkSubscribers(b *testing.B) {
|
||||
index := New()
|
||||
index.Subscribe("path/to/my/mqtt", "client-1", 0)
|
||||
index.Subscribe("path/to/+/mqtt", "client-1", 0)
|
||||
index.Subscribe("something/things/stuff/+", "client-1", 0)
|
||||
index.Subscribe("path/to/stuff", "client-2", 0)
|
||||
index.Subscribe("#", "client-3", 0)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Subscribers("path/to/testing/mqtt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsolateParticle(t *testing.T) {
|
||||
particle, hasNext := isolateParticle("path/to/my/mqtt", 0)
|
||||
require.Equal(t, "path", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("path/to/my/mqtt", 1)
|
||||
require.Equal(t, "to", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("path/to/my/mqtt", 2)
|
||||
require.Equal(t, "my", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("path/to/my/mqtt", 3)
|
||||
require.Equal(t, "mqtt", particle)
|
||||
require.Equal(t, false, hasNext)
|
||||
|
||||
particle, hasNext = isolateParticle("/path/", 0)
|
||||
require.Equal(t, "", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("/path/", 1)
|
||||
require.Equal(t, "path", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("/path/", 2)
|
||||
require.Equal(t, "", particle)
|
||||
require.Equal(t, false, hasNext)
|
||||
|
||||
particle, hasNext = isolateParticle("a/b/c/+/+", 3)
|
||||
require.Equal(t, "+", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("a/b/c/+/+", 4)
|
||||
require.Equal(t, "+", particle)
|
||||
require.Equal(t, false, hasNext)
|
||||
}
|
||||
|
||||
func BenchmarkIsolateParticle(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
isolateParticle("path/to/my/mqtt", 3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPattern(t *testing.T) {
|
||||
tt := []struct {
|
||||
packet packets.Packet
|
||||
filter string
|
||||
len int
|
||||
}{
|
||||
{
|
||||
packets.Packet{TopicName: "a/b/c/d", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"a/b/c/d",
|
||||
1,
|
||||
},
|
||||
{
|
||||
packets.Packet{TopicName: "a/b/c/e", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"a/+/c/+",
|
||||
2,
|
||||
},
|
||||
{
|
||||
packets.Packet{TopicName: "a/b/d/f", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"+/+/+/+",
|
||||
3,
|
||||
},
|
||||
{
|
||||
packets.Packet{TopicName: "q/w/e/r/t/y", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"q/w/e/#",
|
||||
1,
|
||||
},
|
||||
{
|
||||
packets.Packet{TopicName: "q/w/x/r/t/x", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"q/#",
|
||||
2,
|
||||
},
|
||||
{
|
||||
packets.Packet{TopicName: "asd", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"asd",
|
||||
1,
|
||||
},
|
||||
{
|
||||
packets.Packet{TopicName: "$SYS/testing", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"#",
|
||||
8,
|
||||
},
|
||||
{
|
||||
packets.Packet{TopicName: "$SYS/test", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"+/testing",
|
||||
0,
|
||||
},
|
||||
{
|
||||
packets.Packet{TopicName: "$SYS/info", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"$SYS/info",
|
||||
1,
|
||||
},
|
||||
{
|
||||
packets.Packet{TopicName: "$SYS/b", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"$SYS/#",
|
||||
4,
|
||||
},
|
||||
{
|
||||
packets.Packet{TopicName: "asd/fgh/jkl", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"#",
|
||||
8,
|
||||
},
|
||||
{
|
||||
packets.Packet{TopicName: "stuff/asdadsa/dsfdsafdsadfsa/dsfdsf/sdsadas", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
|
||||
"stuff/#/things", // indexer will ignore trailing /things
|
||||
1,
|
||||
},
|
||||
}
|
||||
index := New()
|
||||
for _, check := range tt {
|
||||
index.RetainMessage(check.packet)
|
||||
}
|
||||
|
||||
for i, check := range tt {
|
||||
messages := index.Messages(check.filter)
|
||||
require.Equal(t, check.len, len(messages), "Unexpected messages len at %d %s %s", i, check.filter, check.packet.TopicName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesFind(t *testing.T) {
|
||||
index := New()
|
||||
index.RetainMessage(packets.Packet{TopicName: "a/a", Payload: []byte{'a'}, FixedHeader: packets.FixedHeader{Retain: true}})
|
||||
index.RetainMessage(packets.Packet{TopicName: "a/b", Payload: []byte{'b'}, FixedHeader: packets.FixedHeader{Retain: true}})
|
||||
messages := index.Messages("a/a")
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
messages = index.Messages("a/+")
|
||||
require.Equal(t, 2, len(messages))
|
||||
}
|
||||
|
||||
func BenchmarkMessages(b *testing.B) {
|
||||
index := New()
|
||||
index.RetainMessage(packets.Packet{TopicName: "path/to/my/mqtt"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "path/to/another/mqtt"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "path/a/some/mqtt"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "what/is"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"})
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Messages("path/to/+/mqtt")
|
||||
}
|
||||
}
|
14
server/internal/utils/utils.go
Normal file
14
server/internal/utils/utils.go
Normal file
@ -0,0 +1,14 @@
|
||||
package utils
|
||||
|
||||
// InSliceString returns true if a string exists in a slice of strings.
|
||||
// This temporary and should be replaced with a function from the new
|
||||
// go slices package in 1.19 when available.
|
||||
// https://github.com/golang/go/issues/45955
|
||||
func InSliceString(sl []string, st string) bool {
|
||||
for _, v := range sl {
|
||||
if st == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
18
server/internal/utils/utils_test.go
Normal file
18
server/internal/utils/utils_test.go
Normal file
@ -0,0 +1,18 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInSliceString(t *testing.T) {
|
||||
sl := []string{"a", "b", "c"}
|
||||
require.Equal(t, true, InSliceString(sl, "b"))
|
||||
|
||||
sl = []string{"a", "a", "a"}
|
||||
require.Equal(t, true, InSliceString(sl, "a"))
|
||||
|
||||
sl = []string{"a", "b", "c"}
|
||||
require.Equal(t, false, InSliceString(sl, "d"))
|
||||
}
|
Reference in New Issue
Block a user