This commit is contained in:
Кобелев Андрей Андреевич
2022-08-15 23:06:20 +03:00
commit 4ea1a3802b
532 changed files with 211522 additions and 0 deletions

View File

@ -0,0 +1,212 @@
package circ
import (
"errors"
"io"
"sync"
"sync/atomic"
)
var (
// DefaultBufferSize is the default size of the buffer in bytes.
DefaultBufferSize int = 1024 * 256
// DefaultBlockSize is the default size per R/W block in bytes.
DefaultBlockSize int = 1024 * 8
// ErrOutOfRange indicates that the index was out of range.
ErrOutOfRange = errors.New("Indexes out of range")
// ErrInsufficientBytes indicates that there were not enough bytes to return.
ErrInsufficientBytes = errors.New("Insufficient bytes to return")
)
// Buffer is a circular buffer for reading and writing messages.
type Buffer struct {
buf []byte // the bytes buffer.
tmp []byte // a temporary buffer.
Mu sync.RWMutex // the buffer needs its own mutex to work properly.
ID string // the identifier of the buffer. This is used in debug output.
head int64 // the current position in the sequence - a forever increasing index.
tail int64 // the committed position in the sequence - a forever increasing index.
rcond *sync.Cond // the sync condition for the buffer reader.
wcond *sync.Cond // the sync condition for the buffer writer.
size int // the size of the buffer.
mask int // a bitmask of the buffer size (size-1).
block int // the size of the R/W block.
done uint32 // indicates that the buffer is closed.
State uint32 // indicates whether the buffer is reading from (1) or writing to (2).
}
// NewBuffer returns a new instance of buffer. You should call NewReader or
// NewWriter instead of this function.
func NewBuffer(size, block int) *Buffer {
if size == 0 {
size = DefaultBufferSize
}
if block == 0 {
block = DefaultBlockSize
}
if size < 2*block {
size = 2 * block
}
return &Buffer{
size: size,
mask: size - 1,
block: block,
buf: make([]byte, size),
rcond: sync.NewCond(new(sync.Mutex)),
wcond: sync.NewCond(new(sync.Mutex)),
}
}
// NewBufferFromSlice returns a new instance of buffer using a
// pre-existing byte slice.
func NewBufferFromSlice(block int, buf []byte) *Buffer {
l := len(buf)
if block == 0 {
block = DefaultBlockSize
}
b := &Buffer{
size: l,
mask: l - 1,
block: block,
buf: buf,
rcond: sync.NewCond(new(sync.Mutex)),
wcond: sync.NewCond(new(sync.Mutex)),
}
return b
}
// GetPos will return the tail and head positions of the buffer.
// This method is for use with testing.
func (b *Buffer) GetPos() (int64, int64) {
return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
}
// SetPos sets the head and tail of the buffer.
func (b *Buffer) SetPos(tail, head int64) {
atomic.StoreInt64(&b.tail, tail)
atomic.StoreInt64(&b.head, head)
}
// Get returns the internal buffer.
func (b *Buffer) Get() []byte {
b.Mu.Lock()
defer b.Mu.Unlock()
return b.buf
}
// Set writes bytes to a range of indexes in the byte buffer.
func (b *Buffer) Set(p []byte, start, end int) error {
b.Mu.Lock()
defer b.Mu.Unlock()
if end > b.size || start > b.size {
return ErrOutOfRange
}
o := 0
for i := start; i < end; i++ {
b.buf[i] = p[o]
o++
}
return nil
}
// Index returns the buffer-relative index of an integer.
func (b *Buffer) Index(i int64) int {
return b.mask & int(i)
}
// awaitEmpty will block until there is at least n bytes between
// the head and the tail (looking forward).
func (b *Buffer) awaitEmpty(n int) error {
// If the head has wrapped behind the tail, and next will overrun tail,
// then wait until tail has moved.
b.rcond.L.Lock()
for !b.checkEmpty(n) {
if atomic.LoadUint32(&b.done) == 1 {
b.rcond.L.Unlock()
return io.EOF
}
b.rcond.Wait()
}
b.rcond.L.Unlock()
return nil
}
// awaitFilled will block until there are at least n bytes between the
// tail and the head (looking forward).
func (b *Buffer) awaitFilled(n int) error {
// Because awaitCapacity prevents the head from overrunning the t
// able on write, we can simply ensure there is enough space
// the forever-incrementing tail and head integers.
b.wcond.L.Lock()
for !b.checkFilled(n) {
if atomic.LoadUint32(&b.done) == 1 {
b.wcond.L.Unlock()
return io.EOF
}
b.wcond.Wait()
}
b.wcond.L.Unlock()
return nil
}
// checkEmpty returns true if there are at least n bytes between the head and
// the tail.
func (b *Buffer) checkEmpty(n int) bool {
head := atomic.LoadInt64(&b.head)
next := head + int64(n)
tail := atomic.LoadInt64(&b.tail)
if next-tail > int64(b.size) {
return false
}
return true
}
// checkFilled returns true if there are at least n bytes between the tail and
// the head.
func (b *Buffer) checkFilled(n int) bool {
if atomic.LoadInt64(&b.tail)+int64(n) <= atomic.LoadInt64(&b.head) {
return true
}
return false
}
// CommitTail moves the tail position of the buffer n bytes.
func (b *Buffer) CommitTail(n int) {
atomic.AddInt64(&b.tail, int64(n))
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
}
// CapDelta returns the difference between the head and tail.
func (b *Buffer) CapDelta() int {
return int(atomic.LoadInt64(&b.head) - atomic.LoadInt64(&b.tail))
}
// Stop signals the buffer to stop processing.
func (b *Buffer) Stop() {
atomic.StoreUint32(&b.done, 1)
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
}

View File

@ -0,0 +1,317 @@
package circ
import (
//"fmt"
"sync/atomic"
"testing"
"time"
"unsafe"
"github.com/stretchr/testify/require"
)
func TestNewBuffer(t *testing.T) {
var size int = 16
var block int = 4
buf := NewBuffer(size, block)
require.NotNil(t, buf.buf)
require.NotNil(t, buf.rcond)
require.NotNil(t, buf.wcond)
require.Equal(t, size, len(buf.buf))
require.Equal(t, size, buf.size)
require.Equal(t, block, buf.block)
}
func TestNewBuffer0Size(t *testing.T) {
buf := NewBuffer(0, 0)
require.NotNil(t, buf.buf)
require.Equal(t, DefaultBufferSize, buf.size)
require.Equal(t, DefaultBlockSize, buf.block)
}
func TestNewBufferUndersize(t *testing.T) {
buf := NewBuffer(DefaultBlockSize+10, DefaultBlockSize)
require.NotNil(t, buf.buf)
require.Equal(t, DefaultBlockSize*2, buf.size)
require.Equal(t, DefaultBlockSize, buf.block)
}
func TestNewBufferFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewBufferFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestNewBufferFromSlice0Size(t *testing.T) {
b := NewBytesPool(256)
buf := NewBufferFromSlice(0, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestAtomicAlignment(t *testing.T) {
var b Buffer
offset := unsafe.Offsetof(b.head)
require.Equalf(t, uintptr(0), offset%8,
"head requires 64-bit alignment for atomic: offset %d", offset)
offset = unsafe.Offsetof(b.tail)
require.Equalf(t, uintptr(0), offset%8,
"tail requires 64-bit alignment for atomic: offset %d", offset)
}
func TestGetPos(t *testing.T) {
buf := NewBuffer(16, 4)
tail, head := buf.GetPos()
require.Equal(t, int64(0), tail)
require.Equal(t, int64(0), head)
atomic.StoreInt64(&buf.tail, 3)
atomic.StoreInt64(&buf.head, 11)
tail, head = buf.GetPos()
require.Equal(t, int64(3), tail)
require.Equal(t, int64(11), head)
}
func TestGet(t *testing.T) {
buf := NewBuffer(16, 4)
require.Equal(t, make([]byte, 16), buf.Get())
buf.buf[0] = 1
buf.buf[15] = 1
require.Equal(t, []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, buf.Get())
}
func TestSetPos(t *testing.T) {
buf := NewBuffer(16, 4)
require.Equal(t, int64(0), atomic.LoadInt64(&buf.tail))
require.Equal(t, int64(0), atomic.LoadInt64(&buf.head))
buf.SetPos(4, 8)
require.Equal(t, int64(4), atomic.LoadInt64(&buf.tail))
require.Equal(t, int64(8), atomic.LoadInt64(&buf.head))
}
func TestSet(t *testing.T) {
buf := NewBuffer(16, 4)
err := buf.Set([]byte{1, 1, 1, 1}, 17, 19)
require.Error(t, err)
err = buf.Set([]byte{1, 1, 1, 1}, 4, 8)
require.NoError(t, err)
require.Equal(t, []byte{0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, buf.buf)
}
func TestIndex(t *testing.T) {
buf := NewBuffer(1024, 4)
require.Equal(t, 512, buf.Index(512))
require.Equal(t, 0, buf.Index(1024))
require.Equal(t, 6, buf.Index(1030))
require.Equal(t, 6, buf.Index(61446))
}
func TestAwaitFilled(t *testing.T) {
tests := []struct {
tail int64
head int64
n int
await int
desc string
}{
{tail: 0, head: 4, n: 4, await: 1, desc: "OK 0, 4"},
{tail: 8, head: 11, n: 4, await: 1, desc: "OK 8, 11"},
{tail: 102, head: 103, n: 4, await: 3, desc: "OK 102, 103"},
}
for i, tt := range tests {
//fmt.Println(i)
buf := NewBuffer(16, 4)
buf.SetPos(tt.tail, tt.head)
o := make(chan error)
go func() {
o <- buf.awaitFilled(4)
}()
time.Sleep(time.Millisecond)
atomic.AddInt64(&buf.head, int64(tt.await))
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc)
}
}
func TestAwaitFilledEnded(t *testing.T) {
buf := NewBuffer(16, 4)
o := make(chan error)
go func() {
o <- buf.awaitFilled(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.Error(t, <-o)
}
func TestAwaitEmptyOK(t *testing.T) {
tests := []struct {
tail int64
head int64
await int
desc string
}{
{tail: 0, head: 0, await: 0, desc: "OK 0, 0"},
{tail: 0, head: 5, await: 0, desc: "OK 0, 5"},
{tail: 0, head: 14, await: 3, desc: "OK wrap 0, 14 "},
{tail: 22, head: 35, await: 2, desc: "OK wrap 0, 14 "},
{tail: 15, head: 17, await: 7, desc: "OK 15,2"},
{tail: 0, head: 10, await: 2, desc: "OK 0, 10"},
{tail: 1, head: 15, await: 4, desc: "OK 2, 14"},
}
for i, tt := range tests {
buf := NewBuffer(16, 4)
buf.SetPos(tt.tail, tt.head)
o := make(chan error)
go func() {
o <- buf.awaitEmpty(4)
}()
time.Sleep(time.Millisecond)
atomic.AddInt64(&buf.tail, int64(tt.await))
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc)
}
}
func TestAwaitEmptyEnded(t *testing.T) {
buf := NewBuffer(16, 4)
buf.SetPos(1, 15)
o := make(chan error)
go func() {
o <- buf.awaitEmpty(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
require.Error(t, <-o)
}
func TestCheckEmpty(t *testing.T) {
buf := NewBuffer(16, 4)
tests := []struct {
head int64
tail int64
want bool
desc string
}{
{tail: 0, head: 0, want: true, desc: "0, 0 true"},
{tail: 3, head: 4, want: true, desc: "4, 3 true"},
{tail: 15, head: 17, want: true, desc: "15, 17(1) true"},
{tail: 1, head: 30, want: false, desc: "1, 30(14) false"},
{tail: 15, head: 30, want: false, desc: "15, 30(14) false; head has caught up to tail"},
}
for i, tt := range tests {
buf.SetPos(tt.tail, tt.head)
require.Equal(t, tt.want, buf.checkEmpty(4), "Mismatched bool wanted [i:%d] %s", i, tt.desc)
}
}
func TestCheckFilled(t *testing.T) {
buf := NewBuffer(16, 4)
tests := []struct {
head int64
tail int64
want bool
desc string
}{
{tail: 0, head: 0, want: false, desc: "0, 0 false"},
{tail: 0, head: 4, want: true, desc: "0, 4 true"},
{tail: 14, head: 16, want: false, desc: "14,16 false"},
{tail: 14, head: 18, want: true, desc: "14,16 true"},
}
for i, tt := range tests {
buf.SetPos(tt.tail, tt.head)
require.Equal(t, tt.want, buf.checkFilled(4), "Mismatched bool wanted [i:%d] %s", i, tt.desc)
}
}
func TestCommitTail(t *testing.T) {
tests := []struct {
tail int64
head int64
n int
next int64
await int
desc string
}{
{tail: 0, head: 5, n: 4, next: 4, await: 0, desc: "OK 0, 4"},
{tail: 0, head: 5, n: 6, next: 6, await: 1, desc: "OK 0, 5"},
}
for i, tt := range tests {
buf := NewBuffer(16, 4)
buf.SetPos(tt.tail, tt.head)
go func() {
buf.CommitTail(tt.n)
}()
time.Sleep(time.Millisecond)
for j := 0; j < tt.await; j++ {
atomic.AddInt64(&buf.head, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
}
require.Equal(t, tt.next, atomic.LoadInt64(&buf.tail), "Next tail mismatch [i:%d] %s", i, tt.desc)
}
}
/*
func TestCommitTailEnded(t *testing.T) {
buf := NewBuffer(16, 4)
o := make(chan error)
go func() {
o <- buf.CommitTail(5)
}()
time.Sleep(time.Millisecond)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.Error(t, <-o)
}
*/
func TestCapDelta(t *testing.T) {
buf := NewBuffer(16, 4)
require.Equal(t, 0, buf.CapDelta())
buf.SetPos(10, 15)
require.Equal(t, 5, buf.CapDelta())
}
func TestStop(t *testing.T) {
buf := NewBuffer(16, 4)
buf.Stop()
require.Equal(t, uint32(1), buf.done)
}

View File

@ -0,0 +1,49 @@
package circ
import (
"sync"
"sync/atomic"
)
// BytesPool is a pool of []byte.
type BytesPool struct {
// int64/uint64 has to the first words in order
// to be 64-aligned on 32-bit architectures.
used int64 // access atomically
pool *sync.Pool
}
// NewBytesPool returns a sync.pool of []byte.
func NewBytesPool(n int) *BytesPool {
if n == 0 {
n = DefaultBufferSize
}
return &BytesPool{
pool: &sync.Pool{
New: func() interface{} {
return make([]byte, n)
},
},
}
}
// Get returns a pooled bytes.Buffer.
func (b *BytesPool) Get() []byte {
atomic.AddInt64(&b.used, 1)
return b.pool.Get().([]byte)
}
// Put puts the byte slice back into the pool.
func (b *BytesPool) Put(x []byte) {
for i := range x {
x[i] = 0
}
b.pool.Put(x)
atomic.AddInt64(&b.used, -1)
}
// InUse returns the number of pool blocks in use.
func (b *BytesPool) InUse() int64 {
return atomic.LoadInt64(&b.used)
}

View File

@ -0,0 +1,49 @@
package circ
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNewBytesPool(t *testing.T) {
bpool := NewBytesPool(256)
require.NotNil(t, bpool.pool)
}
func BenchmarkNewBytesPool(b *testing.B) {
for n := 0; n < b.N; n++ {
NewBytesPool(256)
}
}
func TestNewBytesPoolGet(t *testing.T) {
bpool := NewBytesPool(256)
buf := bpool.Get()
require.Equal(t, make([]byte, 256), buf)
require.Equal(t, int64(1), bpool.InUse())
}
func BenchmarkBytesPoolGet(b *testing.B) {
bpool := NewBytesPool(256)
for n := 0; n < b.N; n++ {
bpool.Get()
}
}
func TestNewBytesPoolPut(t *testing.T) {
bpool := NewBytesPool(256)
buf := bpool.Get()
require.Equal(t, int64(1), bpool.InUse())
bpool.Put(buf)
require.Equal(t, int64(0), bpool.InUse())
}
func BenchmarkBytesPoolPut(b *testing.B) {
bpool := NewBytesPool(256)
buf := bpool.Get()
for n := 0; n < b.N; n++ {
bpool.Put(buf)
}
}

View File

@ -0,0 +1,96 @@
package circ
import (
"io"
"sync/atomic"
)
// Reader is a circular buffer for reading data from an io.Reader.
type Reader struct {
*Buffer
}
// NewReader returns a new Circular Reader.
func NewReader(size, block int) *Reader {
b := NewBuffer(size, block)
b.ID = "\treader"
return &Reader{
b,
}
}
// NewReaderFromSlice returns a new Circular Reader using a pre-existing
// byte slice.
func NewReaderFromSlice(block int, p []byte) *Reader {
b := NewBufferFromSlice(block, p)
b.ID = "\treader"
return &Reader{
b,
}
}
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
// there is sufficient capacity to do so.
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
atomic.StoreUint32(&b.State, 1)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadUint32(&b.done) == 1 {
return total, nil
}
// Wait until there's enough capacity in the buffer before
// trying to read more bytes from the io.Reader.
err := b.awaitEmpty(b.block)
if err != nil {
// b.done is the only error condition for awaitCapacity
// so loop around and return properly.
continue
}
// If the block will overrun the circle end, just fill up
// and collect the rest on the next pass.
start := b.Index(atomic.LoadInt64(&b.head))
end := start + b.block
if end > b.size {
end = b.size
}
// Read into the buffer between the start and end indexes only.
n, err := r.Read(b.buf[start:end])
total += int64(n) // incr total bytes read.
if err != nil {
return total, err
}
// Move the head forward however many bytes were read.
atomic.AddInt64(&b.head, int64(n))
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
}
}
// Read reads n bytes from the buffer, and will block until at n bytes
// exist in the buffer to read.
func (b *Buffer) Read(n int) (p []byte, err error) {
err = b.awaitFilled(n)
if err != nil {
return
}
tail := atomic.LoadInt64(&b.tail)
next := tail + int64(n)
// If the read overruns the buffer, get everything until the end
// and then whatever is left from the start.
if b.Index(tail) > b.Index(next) {
b.tmp = b.buf[b.Index(tail):]
b.tmp = append(b.tmp, b.buf[:b.Index(next)]...)
} else {
b.tmp = b.buf[b.Index(tail):b.Index(next)] // Otherwise, simple tail:next read.
}
return b.tmp, nil
}

View File

@ -0,0 +1,129 @@
package circ
import (
"bytes"
"errors"
"io"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewReader(t *testing.T) {
var size = 16
var block = 4
buf := NewReader(size, block)
require.NotNil(t, buf.buf)
require.Equal(t, size, len(buf.buf))
require.Equal(t, size, buf.size)
require.Equal(t, block, buf.block)
}
func TestNewReaderFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewReaderFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestReadFrom(t *testing.T) {
buf := NewReader(16, 4)
b4 := bytes.Repeat([]byte{'-'}, 4)
br := bytes.NewReader(b4)
_, err := buf.ReadFrom(br)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, bytes.Repeat([]byte{'-'}, 4), buf.buf[:4])
require.Equal(t, int64(4), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, int64(8), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, int64(12), buf.head)
}
func TestReadFromWrap(t *testing.T) {
buf := NewReader(16, 4)
buf.buf = bytes.Repeat([]byte{'-'}, 16)
buf.SetPos(8, 14)
br := bytes.NewReader(bytes.Repeat([]byte{'/'}, 8))
o := make(chan error)
go func() {
_, err := buf.ReadFrom(br)
o <- err
}()
time.Sleep(time.Millisecond * 100)
go func() {
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
}()
<-o
require.Equal(t, []byte{'/', '/', '/', '/', '/', '/', '-', '-', '-', '-', '-', '-', '-', '-', '/', '/'}, buf.Get())
require.Equal(t, int64(22), atomic.LoadInt64(&buf.head))
require.Equal(t, 6, buf.Index(atomic.LoadInt64(&buf.head)))
}
func TestReadOK(t *testing.T) {
buf := NewReader(16, 4)
buf.buf = []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
tests := []struct {
tail int64
head int64
n int
bytes []byte
desc string
}{
{tail: 0, head: 4, n: 4, bytes: []byte{'a', 'b', 'c', 'd'}, desc: "0, 4 OK"},
{tail: 3, head: 15, n: 8, bytes: []byte{'d', 'e', 'f', 'g', 'h', 'i', 'j', 'k'}, desc: "3, 15 OK"},
{tail: 14, head: 15, n: 6, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd'}, desc: "14, 2 wrapped OK"},
}
for i, tt := range tests {
buf.SetPos(tt.tail, tt.head)
o := make(chan []byte)
go func() {
p, _ := buf.Read(tt.n)
o <- p
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.head, buf.head+int64(tt.n))
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
done := <-o
require.Equal(t, tt.bytes, done, "Peeked bytes mismatch [i:%d] %s", i, tt.desc)
}
}
func TestReadEnded(t *testing.T) {
buf := NewBuffer(16, 4)
o := make(chan error)
go func() {
_, err := buf.Read(4)
o <- err
}()
time.Sleep(time.Millisecond)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.Error(t, <-o)
}

View File

@ -0,0 +1,106 @@
package circ
import (
"io"
"log"
"sync/atomic"
)
// Writer is a circular buffer for writing data to an io.Writer.
type Writer struct {
*Buffer
}
// NewWriter returns a pointer to a new Circular Writer.
func NewWriter(size, block int) *Writer {
b := NewBuffer(size, block)
b.ID = "writer"
return &Writer{
b,
}
}
// NewWriterFromSlice returns a new Circular Writer using a pre-existing
// byte slice.
func NewWriterFromSlice(block int, p []byte) *Writer {
b := NewBufferFromSlice(block, p)
b.ID = "writer"
return &Writer{
b,
}
}
// WriteTo writes the contents of the buffer to an io.Writer.
func (b *Writer) WriteTo(w io.Writer) (total int64, err error) {
atomic.StoreUint32(&b.State, 2)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadUint32(&b.done) == 1 && b.CapDelta() == 0 {
return total, io.EOF
}
// Read from the buffer until there is at least 1 byte to write.
err = b.awaitFilled(1)
if err != nil {
return
}
// Get all the bytes between the tail and head, wrapping if necessary.
tail := atomic.LoadInt64(&b.tail)
rTail := b.Index(tail)
rHead := b.Index(atomic.LoadInt64(&b.head))
n := b.CapDelta()
p := make([]byte, 0, n)
if rTail > rHead {
p = append(p, b.buf[rTail:]...)
p = append(p, b.buf[:rHead]...)
} else {
p = append(p, b.buf[rTail:rHead]...)
}
n, err = w.Write(p)
total += int64(n)
if err != nil {
log.Println("error writing to buffer io.Writer;", err)
return
}
// Move the tail forward the bytes written and broadcast change.
atomic.StoreInt64(&b.tail, tail+int64(n))
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
}
}
// Write writes the buffer to the buffer p, returning the number of bytes written.
// The bytes written to the buffer are picked up by WriteTo.
func (b *Writer) Write(p []byte) (total int, err error) {
err = b.awaitEmpty(len(p))
if err != nil {
return
}
total = b.writeBytes(p)
atomic.AddInt64(&b.head, int64(total))
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
return
}
// writeBytes writes bytes to the buffer from the start position, and returns
// the new head position. This function does not wait for capacity and will
// overwrite any existing bytes.
func (b *Writer) writeBytes(p []byte) int {
var o int
var n int
for i := 0; i < len(p); i++ {
o = b.Index(atomic.LoadInt64(&b.head) + int64(i))
b.buf[o] = p[i]
n++
}
return n
}

View File

@ -0,0 +1,155 @@
package circ
import (
"bufio"
"bytes"
"net"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewWriter(t *testing.T) {
var size = 16
var block = 4
buf := NewWriter(size, block)
require.NotNil(t, buf.buf)
require.Equal(t, size, len(buf.buf))
require.Equal(t, size, buf.size)
require.Equal(t, block, buf.block)
}
func TestNewWriterFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewWriterFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestWriteTo(t *testing.T) {
tests := []struct {
tail int64
head int64
bytes []byte
await int
total int
err error
desc string
}{
{tail: 0, head: 5, bytes: []byte{'a', 'b', 'c', 'd', 'e'}, desc: "0,5 OK"},
{tail: 14, head: 21, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd', 'e'}, desc: "14,16(2) OK"},
}
for i, tt := range tests {
bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
buf := NewWriter(16, 4)
buf.Set(bb, 0, 16)
buf.SetPos(tt.tail, tt.head)
var b bytes.Buffer
w := bufio.NewWriter(&b)
nc := make(chan int64)
go func() {
n, _ := buf.WriteTo(w)
nc <- n
}()
time.Sleep(time.Millisecond * 100)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
w.Flush()
require.Equal(t, tt.bytes, b.Bytes(), "Written bytes mismatch [i:%d] %s", i, tt.desc)
}
}
func TestWriteToEndedFirst(t *testing.T) {
buf := NewWriter(16, 4)
buf.done = 1
var b bytes.Buffer
w := bufio.NewWriter(&b)
_, err := buf.WriteTo(w)
require.Error(t, err)
}
func TestWriteToBadWriter(t *testing.T) {
bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
buf := NewWriter(16, 4)
buf.Set(bb, 0, 16)
buf.SetPos(0, 6)
r, w := net.Pipe()
w.Close()
_, err := buf.WriteTo(w)
require.Error(t, err)
r.Close()
}
func TestWrite(t *testing.T) {
tests := []struct {
tail int64
head int64
rHead int64
bytes []byte
want []byte
desc string
}{
{tail: 0, head: 0, rHead: 4, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'a', 'b', 'c', 'd', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, desc: "0>4 OK"},
{tail: 4, head: 14, rHead: 2, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'c', 'd', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 'a', 'b'}, desc: "14>2 OK"},
}
for i, tt := range tests {
buf := NewWriter(16, 4)
buf.SetPos(tt.tail, tt.head)
o := make(chan []interface{})
go func() {
nn, err := buf.Write(tt.bytes)
o <- []interface{}{nn, err}
}()
done := <-o
require.Equal(t, tt.want, buf.buf, "Wanted written mismatch [i:%d] %s", i, tt.desc)
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
}
}
func TestWriteEnded(t *testing.T) {
buf := NewWriter(16, 4)
buf.SetPos(15, 30)
buf.done = 1
_, err := buf.Write([]byte{'a', 'b', 'c', 'd'})
require.Error(t, err)
}
func TestWriteBytes(t *testing.T) {
tests := []struct {
tail int64
head int64
bytes []byte
want []byte
start int
desc string
}{
{tail: 0, head: 0, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'a', 'b', 'c', 'd', 0, 0, 0, 0}, desc: "0,4 OK"},
{tail: 6, head: 6, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'c', 'd', 0, 0, 0, 0, 'a', 'b'}, desc: "6,2 OK wrapped"},
}
for i, tt := range tests {
buf := NewWriter(8, 4)
buf.SetPos(tt.tail, tt.head)
n := buf.writeBytes(tt.bytes)
require.Equal(t, tt.want, buf.buf, "Buffer mistmatch [i:%d] %s", i, tt.desc)
require.Equal(t, len(tt.bytes), n)
}
}

View File

@ -0,0 +1,564 @@
package clients
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
"github.com/mochi-co/mqtt/server/events"
"github.com/mochi-co/mqtt/server/internal/circ"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/internal/topics"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
var (
// defaultKeepalive is the default connection keepalive value in seconds.
defaultKeepalive uint16 = 10
// ErrConnectionClosed is returned when operating on a closed
// connection and/or when no error cause has been given.
ErrConnectionClosed = errors.New("connection not open")
)
// Clients contains a map of the clients known by the broker.
type Clients struct {
sync.RWMutex
internal map[string]*Client // clients known by the broker, keyed on client id.
}
// New returns an instance of Clients.
func New() *Clients {
return &Clients{
internal: make(map[string]*Client),
}
}
// Add adds a new client to the clients map, keyed on client id.
func (cl *Clients) Add(val *Client) {
cl.Lock()
cl.internal[val.ID] = val
cl.Unlock()
}
// Get returns the value of a client if it exists.
func (cl *Clients) Get(id string) (*Client, bool) {
cl.RLock()
val, ok := cl.internal[id]
cl.RUnlock()
return val, ok
}
// Len returns the length of the clients map.
func (cl *Clients) Len() int {
cl.RLock()
val := len(cl.internal)
cl.RUnlock()
return val
}
// Delete removes a client from the internal map.
func (cl *Clients) Delete(id string) {
cl.Lock()
delete(cl.internal, id)
cl.Unlock()
}
// GetByListener returns clients matching a listener id.
func (cl *Clients) GetByListener(id string) []*Client {
clients := make([]*Client, 0, cl.Len())
cl.RLock()
for _, v := range cl.internal {
if v.Listener == id && atomic.LoadUint32(&v.State.Done) == 0 {
clients = append(clients, v)
}
}
cl.RUnlock()
return clients
}
// Client contains information about a client known by the broker.
type Client struct {
State State // the operational state of the client.
LWT LWT // the last will and testament for the client.
Inflight *Inflight // a map of in-flight qos messages.
sync.RWMutex // mutex
Username []byte // the username the client authenticated with.
AC auth.Controller // an auth controller inherited from the listener.
Listener string // the id of the listener the client is connected to.
ID string // the client id.
conn net.Conn // the net.Conn used to establish the connection.
R *circ.Reader // a reader for reading incoming bytes.
W *circ.Writer // a writer for writing outgoing bytes.
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
systemInfo *system.Info // pointers to server system info.
packetID uint32 // the current highest packetID.
keepalive uint16 // the number of seconds the connection can wait.
CleanSession bool // indicates if the client expects a clean-session.
}
// State tracks the state of the client.
type State struct {
started *sync.WaitGroup // tracks the goroutines which have been started.
endedW *sync.WaitGroup // tracks when the writer has ended.
endedR *sync.WaitGroup // tracks when the reader has ended.
Done uint32 // atomic counter which indicates that the client has closed.
endOnce sync.Once // only end once.
stopCause atomic.Value // reason for stopping.
}
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
cl := &Client{
conn: c,
R: r,
W: w,
systemInfo: s,
keepalive: defaultKeepalive,
Inflight: &Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
State: State{
started: new(sync.WaitGroup),
endedW: new(sync.WaitGroup),
endedR: new(sync.WaitGroup),
},
}
cl.refreshDeadline(cl.keepalive)
return cl
}
// NewClientStub returns an instance of Client with basic initializations. This
// method is typically called by the persistence restoration system.
func NewClientStub(s *system.Info) *Client {
return &Client{
Inflight: &Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
State: State{
Done: 1,
},
}
}
// Identify sets the identification values of a client instance.
func (cl *Client) Identify(lid string, pk packets.Packet, ac auth.Controller) {
cl.Listener = lid
cl.AC = ac
cl.ID = pk.ClientIdentifier
if cl.ID == "" {
cl.ID = xid.New().String()
}
cl.R.ID = cl.ID + " READER"
cl.W.ID = cl.ID + " WRITER"
cl.Username = pk.Username
cl.CleanSession = pk.CleanSession
cl.keepalive = pk.Keepalive
if pk.WillFlag {
cl.LWT = LWT{
Topic: pk.WillTopic,
Message: pk.WillMessage,
Qos: pk.WillQos,
Retain: pk.WillRetain,
}
}
cl.refreshDeadline(cl.keepalive)
}
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
if cl.conn != nil {
var expiry time.Time // Nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second)
}
_ = cl.conn.SetDeadline(expiry)
}
}
// Info returns an event-version of a client, containing minimal information.
func (cl *Client) Info() events.Client {
addr := "unknown"
if cl.conn != nil && cl.conn.RemoteAddr() != nil {
addr = cl.conn.RemoteAddr().String()
}
return events.Client{
ID: cl.ID,
Remote: addr,
Username: cl.Username,
CleanSession: cl.CleanSession,
Listener: cl.Listener,
}
}
// NextPacketID returns the next packet id for a client, looping back to 0
// if the maximum ID has been reached.
func (cl *Client) NextPacketID() uint32 {
i := atomic.LoadUint32(&cl.packetID)
if i == uint32(65535) || i == uint32(0) {
atomic.StoreUint32(&cl.packetID, 1)
return 1
}
return atomic.AddUint32(&cl.packetID, 1)
}
// NoteSubscription makes a note of a subscription for the client.
func (cl *Client) NoteSubscription(filter string, qos byte) {
cl.Lock()
cl.Subscriptions[filter] = qos
cl.Unlock()
}
// ForgetSubscription forgests a subscription note for the client.
func (cl *Client) ForgetSubscription(filter string) {
cl.Lock()
delete(cl.Subscriptions, filter)
cl.Unlock()
}
// Start begins the client goroutines reading and writing packets.
func (cl *Client) Start() {
cl.State.started.Add(2)
cl.State.endedW.Add(1)
cl.State.endedR.Add(1)
go func() {
cl.State.started.Done()
_, err := cl.W.WriteTo(cl.conn)
if err != nil {
err = fmt.Errorf("writer: %w", err)
}
cl.State.endedW.Done()
cl.Stop(err)
}()
go func() {
cl.State.started.Done()
_, err := cl.R.ReadFrom(cl.conn)
if err != nil {
err = fmt.Errorf("reader: %w", err)
}
cl.State.endedR.Done()
cl.Stop(err)
}()
cl.State.started.Wait()
}
// ClearBuffers sets the read/write buffers to nil so they can be
// deallocated automatically when no longer in use.
func (cl *Client) ClearBuffers() {
cl.R = nil
cl.W = nil
}
// Stop instructs the client to shut down all processing goroutines and disconnect.
// A cause error may be passed to identfy the reason for stopping.
func (cl *Client) Stop(err error) {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return
}
cl.State.endOnce.Do(func() {
cl.R.Stop()
cl.W.Stop()
cl.State.endedW.Wait()
_ = cl.conn.Close() // omit close error
cl.State.endedR.Wait()
atomic.StoreUint32(&cl.State.Done, 1)
if err == nil {
err = ErrConnectionClosed
}
cl.State.stopCause.Store(err)
})
}
// StopCause returns the reason the client connection was stopped, if any.
func (cl *Client) StopCause() error {
if cl.State.stopCause.Load() == nil {
return nil
}
return cl.State.stopCause.Load().(error)
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
p, err := cl.R.Read(1)
if err != nil {
return err
}
err = fh.Decode(p[0])
if err != nil {
return err
}
// The remaining length value can be up to 5 bytes. Read through each byte
// looking for continue values, and if found increase the read. Otherwise
// decode the bytes that were legit.
buf := make([]byte, 0, 6)
i := 1
n := 2
for ; n < 6; n++ {
p, err = cl.R.Read(n)
if err != nil {
return err
}
buf = append(buf, p[i])
// If it's not a continuation flag, end here.
if p[i] < 128 {
break
}
// If i has reached 4 without a length terminator, return a protocol violation.
i++
if i == 4 {
return packets.ErrOversizedLengthIndicator
}
}
// Calculate and store the remaining length of the packet payload.
rem, _ := binary.Uvarint(buf)
fh.Remaining = int(rem)
// Having successfully read n bytes, commit the tail forward.
cl.R.CommitTail(n)
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
return nil
}
// Read loops forever reading new packets from a client connection until
// an error is encountered (or the connection is closed).
func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error {
for {
if atomic.LoadUint32(&cl.State.Done) == 1 && cl.R.CapDelta() == 0 {
return nil
}
cl.refreshDeadline(cl.keepalive)
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
if err != nil {
return err
}
pk, err := cl.ReadPacket(fh)
if err != nil {
return err
}
err = packetHandler(cl, pk) // Process inbound packet.
if err != nil {
return err
}
}
}
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
pk.FixedHeader = *fh
if pk.FixedHeader.Remaining == 0 {
return
}
p, err := cl.R.Read(pk.FixedHeader.Remaining)
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
px := append([]byte{}, p[:]...)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectDecode(px)
case packets.Connack:
err = pk.ConnackDecode(px)
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
case packets.Pubrec:
err = pk.PubrecDecode(px)
case packets.Pubrel:
err = pk.PubrelDecode(px)
case packets.Pubcomp:
err = pk.PubcompDecode(px)
case packets.Subscribe:
err = pk.SubscribeDecode(px)
case packets.Suback:
err = pk.SubackDecode(px)
case packets.Unsubscribe:
err = pk.UnsubscribeDecode(px)
case packets.Unsuback:
err = pk.UnsubackDecode(px)
case packets.Pingreq:
case packets.Pingresp:
case packets.Disconnect:
default:
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
}
cl.R.CommitTail(pk.FixedHeader.Remaining)
return
}
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return 0, ErrConnectionClosed
}
cl.W.Mu.Lock()
defer cl.W.Mu.Unlock()
buf := new(bytes.Buffer)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectEncode(buf)
case packets.Connack:
err = pk.ConnackEncode(buf)
case packets.Publish:
err = pk.PublishEncode(buf)
if err == nil {
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
}
case packets.Puback:
err = pk.PubackEncode(buf)
case packets.Pubrec:
err = pk.PubrecEncode(buf)
case packets.Pubrel:
err = pk.PubrelEncode(buf)
case packets.Pubcomp:
err = pk.PubcompEncode(buf)
case packets.Subscribe:
err = pk.SubscribeEncode(buf)
case packets.Suback:
err = pk.SubackEncode(buf)
case packets.Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case packets.Unsuback:
err = pk.UnsubackEncode(buf)
case packets.Pingreq:
err = pk.PingreqEncode(buf)
case packets.Pingresp:
err = pk.PingrespEncode(buf)
case packets.Disconnect:
err = pk.DisconnectEncode(buf)
default:
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
}
if err != nil {
return
}
// Write the packet bytes to the client byte buffer.
n, err = cl.W.Write(buf.Bytes())
if err != nil {
return
}
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
cl.refreshDeadline(cl.keepalive)
return
}
// LWT contains the last will and testament details for a client connection.
type LWT struct {
Message []byte // the message that shall be sent when the client disconnects.
Topic string // the topic the will message shall be sent to.
Qos byte // the quality of service desired.
Retain bool // indicates whether the will message should be retained
}
// InflightMessage contains data about a packet which is currently in-flight.
type InflightMessage struct {
Packet packets.Packet // the packet currently in-flight.
Sent int64 // the last time the message was sent (for retries) in unixtime.
Resends int // the number of times the message was attempted to be sent.
}
// Inflight is a map of InflightMessage keyed on packet id.
type Inflight struct {
sync.RWMutex
internal map[uint16]InflightMessage // internal contains the inflight messages.
}
// Set stores the packet of an Inflight message, keyed on message id. Returns
// true if the inflight message was new.
func (i *Inflight) Set(key uint16, in InflightMessage) bool {
i.Lock()
_, ok := i.internal[key]
i.internal[key] = in
i.Unlock()
return !ok
}
// Get returns the value of an in-flight message if it exists.
func (i *Inflight) Get(key uint16) (InflightMessage, bool) {
i.RLock()
val, ok := i.internal[key]
i.RUnlock()
return val, ok
}
// Len returns the size of the in-flight messages map.
func (i *Inflight) Len() int {
i.RLock()
v := len(i.internal)
i.RUnlock()
return v
}
// GetAll returns all the in-flight messages.
func (i *Inflight) GetAll() map[uint16]InflightMessage {
i.RLock()
defer i.RUnlock()
return i.internal
}
// Delete removes an in-flight message from the map. Returns true if the
// message existed.
func (i *Inflight) Delete(key uint16) bool {
i.Lock()
_, ok := i.internal[key]
delete(i.internal, key)
i.Unlock()
return ok
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,116 @@
package packets
import (
"encoding/binary"
"unicode/utf8"
"unsafe"
)
// bytesToString provides a zero-alloc, no-copy byte to string conversion.
// via https://github.com/golang/go/issues/25484#issuecomment-391415660
func bytesToString(bs []byte) string {
return *(*string)(unsafe.Pointer(&bs))
}
// decodeUint16 extracts the value of two bytes from a byte array.
func decodeUint16(buf []byte, offset int) (uint16, int, error) {
if len(buf) < offset+2 {
return 0, 0, ErrOffsetUintOutOfRange
}
return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
}
// decodeString extracts a string from a byte array, beginning at an offset.
func decodeString(buf []byte, offset int) (string, int, error) {
b, n, err := decodeBytes(buf, offset)
if err != nil {
return "", 0, err
}
if !validUTF8(b) {
return "", 0, ErrOffsetStrInvalidUTF8
}
return bytesToString(b), n, nil
}
// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
length, next, err := decodeUint16(buf, offset)
if err != nil {
return make([]byte, 0, 0), 0, err
}
if next+int(length) > len(buf) {
return make([]byte, 0, 0), 0, ErrOffsetBytesOutOfRange
}
// Note: there is no validUTF8() test for []byte payloads
return buf[next : next+int(length)], next + int(length), nil
}
// decodeByte extracts the value of a byte from a byte array.
func decodeByte(buf []byte, offset int) (byte, int, error) {
if len(buf) <= offset {
return 0, 0, ErrOffsetByteOutOfRange
}
return buf[offset], offset + 1, nil
}
// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
func decodeByteBool(buf []byte, offset int) (bool, int, error) {
if len(buf) <= offset {
return false, 0, ErrOffsetBoolOutOfRange
}
return 1&buf[offset] > 0, offset + 1, nil
}
// encodeBool returns a byte instead of a bool.
func encodeBool(b bool) byte {
if b {
return 1
}
return 0
}
// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
func encodeBytes(val []byte) []byte {
// In many circumstances the number of bytes being encoded is small.
// Setting the cap to a low amount allows us to account for those without
// triggering allocation growth on append unless we need to.
buf := make([]byte, 2, 32)
binary.BigEndian.PutUint16(buf, uint16(len(val)))
return append(buf, val...)
}
// encodeUint16 encodes a uint16 value to a byte array.
func encodeUint16(val uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, val)
return buf
}
// encodeString encodes a string to a byte array.
func encodeString(val string) []byte {
// Like encodeBytes, we set the cap to a small number to avoid
// triggering allocation growth on append unless we absolutely need to.
buf := make([]byte, 2, 32)
binary.BigEndian.PutUint16(buf, uint16(len(val)))
return append(buf, []byte(val)...)
}
// validUTF8 checks if the byte array contains valid UTF-8 characters, specifically
// conforming to the MQTT specification requirements.
func validUTF8(b []byte) bool {
// [MQTT-1.4.0-1] The character data in a UTF-8 encoded string MUST be well-formed UTF-8...
if !utf8.Valid(b) {
return false
}
// [MQTT-1.4.0-2] A UTF-8 encoded string MUST NOT include an encoding of the null character U+0000...
// ...
return true
}

View File

@ -0,0 +1,386 @@
package packets
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestBytesToString(t *testing.T) {
b := []byte{'a', 'b', 'c'}
require.Equal(t, "abc", bytesToString(b))
}
func BenchmarkBytesToString(b *testing.B) {
for n := 0; n < b.N; n++ {
bytesToString([]byte{'a', 'b', 'c'})
}
}
func TestDecodeString(t *testing.T) {
expect := []struct {
name string
rawBytes []byte
result string
offset int
shouldFail error
}{
{
offset: 0,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: "a/b/c/d",
},
{
offset: 14,
rawBytes: []byte{
byte(Connect << 4), 17, // Fixed header
0, 6, // Protocol Name - MSB+LSB
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
3, // Protocol Version
0, // Packet Flags
0, 30, // Keepalive
0, 3, // Client ID - MSB+LSB
'h', 'e', 'y', // Client ID "zen"},
},
result: "hey",
},
{
offset: 2,
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
result: "1/2/3/4/a/b/c/d/e/^/@/!",
},
{
offset: 0,
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
result: "x/y/z",
},
{
offset: 0,
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 5,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 9,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
shouldFail: ErrOffsetUintOutOfRange,
},
{
offset: 17,
rawBytes: []byte{
byte(Connect << 4), 0, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Flags
0, 20, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 6, // Will Topic - MSB+LSB
'l',
},
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
shouldFail: ErrOffsetStrInvalidUTF8,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func BenchmarkDecodeString(b *testing.B) {
in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}
for n := 0; n < b.N; n++ {
decodeString(in, 0)
}
}
func TestDecodeBytes(t *testing.T) {
expect := []struct {
rawBytes []byte
result []uint8
next int
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session)
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // ... truncated connect packet, only checking start
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 0,
shouldFail: ErrOffsetBytesOutOfRange,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 8,
shouldFail: ErrOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func BenchmarkDecodeBytes(b *testing.B) {
in := []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}
for n := 0; n < b.N; n++ {
decodeBytes(in, 0)
}
}
func TestDecodeByte(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint8
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
result: uint8(0x00),
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x04),
offset: 1,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x4d),
offset: 2,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x51),
offset: 3,
},
{
rawBytes: []byte{0, 4, 77, 80, 82, 84},
offset: 8,
shouldFail: ErrOffsetByteOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+1, offset)
})
}
}
func BenchmarkDecodeByte(b *testing.B) {
in := []byte{0, 4, 77, 81, 84, 84}
for n := 0; n < b.N; n++ {
decodeByte(in, 0)
}
}
func TestDecodeUint16(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint16
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x07),
offset: 0,
},
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x761),
offset: 1,
},
{
rawBytes: []byte{0, 7, 255, 47},
offset: 8,
shouldFail: ErrOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+2, offset)
})
}
}
func BenchmarkDecodeUint16(b *testing.B) {
in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}
for n := 0; n < b.N; n++ {
decodeUint16(in, 0)
}
}
func TestDecodeByteBool(t *testing.T) {
expect := []struct {
rawBytes []byte
result bool
offset int
shouldFail error
}{
{
rawBytes: []byte{0x00, 0x00},
result: false,
},
{
rawBytes: []byte{0x01, 0x00},
result: true,
},
{
rawBytes: []byte{0x01, 0x00},
offset: 5,
shouldFail: ErrOffsetBoolOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, 1, offset)
})
}
}
func BenchmarkDecodeByteBool(b *testing.B) {
in := []byte{0x00, 0x00}
for n := 0; n < b.N; n++ {
decodeByteBool(in, 0)
}
}
func TestEncodeBool(t *testing.T) {
result := encodeBool(true)
require.Equal(t, byte(1), result, "Incorrect encoded value; not true")
result = encodeBool(false)
require.Equal(t, byte(0), result, "Incorrect encoded value; not false")
// Check failure.
result = encodeBool(false)
require.NotEqual(t, byte(1), result, "Expected failure, incorrect encoded value")
}
func BenchmarkEncodeBool(b *testing.B) {
for n := 0; n < b.N; n++ {
encodeBool(true)
}
}
func TestEncodeBytes(t *testing.T) {
result := encodeBytes([]byte("testing"))
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result, "Incorrect encoded value")
result = encodeBytes([]byte("testing"))
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result, "Expected failure, incorrect encoded value")
}
func BenchmarkEncodeBytes(b *testing.B) {
bb := []byte("testing")
for n := 0; n < b.N; n++ {
encodeBytes(bb)
}
}
func TestEncodeUint16(t *testing.T) {
result := encodeUint16(0)
require.Equal(t, []byte{0x00, 0x00}, result, "Incorrect encoded value, 0")
result = encodeUint16(32767)
require.Equal(t, []byte{0x7f, 0xff}, result, "Incorrect encoded value, 32767")
result = encodeUint16(65535)
require.Equal(t, []byte{0xff, 0xff}, result, "Incorrect encoded value, 65535")
}
func BenchmarkEncodeUint16(b *testing.B) {
for n := 0; n < b.N; n++ {
encodeUint16(32767)
}
}
func TestEncodeString(t *testing.T) {
result := encodeString("testing")
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result, "Incorrect encoded value, testing")
result = encodeString("")
require.Equal(t, []uint8{0x00, 0x00}, result, "Incorrect encoded value, null")
result = encodeString("a")
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result, "Incorrect encoded value, a")
result = encodeString("b")
require.NotEqual(t, []uint8{0x00, 0x00}, result, "Expected failure, incorrect encoded value, b")
}
func BenchmarkEncodeString(b *testing.B) {
for n := 0; n < b.N; n++ {
encodeString("benchmarking")
}
}

View File

@ -0,0 +1,59 @@
package packets
import (
"bytes"
)
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
type FixedHeader struct {
Remaining int // the number of remaining bytes in the payload.
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte // indicates the quality of service expected.
Dup bool // indicates if the packet was already sent at an earlier time.
Retain bool // whether the message should be retained.
}
// Encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, int64(fh.Remaining))
}
// Decode extracts the specification bits from the header byte.
func (fh *FixedHeader) Decode(headerByte byte) error {
fh.Type = headerByte >> 4 // Get the message type from the first 4 bytes.
switch fh.Type {
case Publish:
fh.Dup = (headerByte>>3)&0x01 > 0 // Extract flags. Check if message is duplicate.
fh.Qos = (headerByte >> 1) & 0x03 // Extract QoS flag.
fh.Retain = headerByte&0x01 > 0 // Extract retain flag.
case Pubrel:
fh.Qos = (headerByte >> 1) & 0x03
case Subscribe:
fh.Qos = (headerByte >> 1) & 0x03
case Unsubscribe:
fh.Qos = (headerByte >> 1) & 0x03
default:
if (headerByte>>3)&0x01 > 0 || (headerByte>>1)&0x03 > 0 || headerByte&0x01 > 0 {
return ErrInvalidFlags
}
}
return nil
}
// encodeLength writes length bits for the header.
func encodeLength(buf *bytes.Buffer, length int64) {
for {
digit := byte(length % 128)
length /= 128
if length > 0 {
digit |= 0x80
}
buf.WriteByte(digit)
if length == 0 {
break
}
}
}

View File

@ -0,0 +1,220 @@
package packets
import (
"bytes"
"math"
"testing"
"github.com/stretchr/testify/require"
)
type fixedHeaderTable struct {
rawBytes []byte
header FixedHeader
packetError bool
flagError bool
}
var fixedHeaderExpected = []fixedHeaderTable{
{
rawBytes: []byte{Connect << 4, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Connack << 4, 0x00},
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish << 4, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
// remaining length
{
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
},
{
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
},
{
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
},
{
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
},
{
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
packetError: true,
},
// Invalid flags for packet
{
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
flagError: true,
},
}
func TestFixedHeaderEncode(t *testing.T) {
for i, wanted := range fixedHeaderExpected {
buf := new(bytes.Buffer)
wanted.header.Encode(buf)
if wanted.flagError == false {
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()), "Mismatched fixedheader length [i:%d] %v", i, wanted.rawBytes)
require.EqualValues(t, wanted.rawBytes, buf.Bytes(), "Mismatched byte values [i:%d] %v", i, wanted.rawBytes)
}
}
}
func BenchmarkFixedHeaderEncode(b *testing.B) {
buf := new(bytes.Buffer)
for n := 0; n < b.N; n++ {
fixedHeaderExpected[0].header.Encode(buf)
}
}
func TestFixedHeaderDecode(t *testing.T) {
for i, wanted := range fixedHeaderExpected {
fh := new(FixedHeader)
err := fh.Decode(wanted.rawBytes[0])
if wanted.flagError {
require.Error(t, err, "Expected error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
} else {
require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Type, fh.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Dup, fh.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Qos, fh.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Retain, fh.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes)
}
}
}
func BenchmarkFixedHeaderDecode(b *testing.B) {
fh := new(FixedHeader)
for n := 0; n < b.N; n++ {
err := fh.Decode(fixedHeaderExpected[0].rawBytes[0])
if err != nil {
panic(err)
}
}
}
func TestEncodeLength(t *testing.T) {
tt := []struct {
have int64
want []byte
}{
{
120,
[]byte{0x78},
},
{
math.MaxInt64,
[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
},
}
for i, wanted := range tt {
buf := new(bytes.Buffer)
encodeLength(buf, wanted.have)
require.Equal(t, wanted.want, buf.Bytes(), "Returned bytes should match length [i:%d] %s", i, wanted.have)
}
}
func BenchmarkEncodeLength(b *testing.B) {
buf := new(bytes.Buffer)
for n := 0; n < b.N; n++ {
encodeLength(buf, 120)
}
}

View File

@ -0,0 +1,670 @@
package packets
import (
"bytes"
"errors"
"fmt"
"strconv"
)
// All of the valid packet types and their packet identifier.
const (
Reserved byte = iota
Connect // 1
Connack // 2
Publish // 3
Puback // 4
Pubrec // 5
Pubrel // 6
Pubcomp // 7
Subscribe // 8
Suback // 9
Unsubscribe // 10
Unsuback // 11
Pingreq // 12
Pingresp // 13
Disconnect // 14
Accepted byte = 0x00
Failed byte = 0xFF
CodeConnectBadProtocolVersion byte = 0x01
CodeConnectBadClientID byte = 0x02
CodeConnectServerUnavailable byte = 0x03
CodeConnectBadAuthValues byte = 0x04
CodeConnectNotAuthorised byte = 0x05
CodeConnectNetworkError byte = 0xFE
CodeConnectProtocolViolation byte = 0xFF
ErrSubAckNetworkError byte = 0x80
)
var (
// CONNECT
ErrMalformedProtocolName = errors.New("malformed packet: protocol name")
ErrMalformedProtocolVersion = errors.New("malformed packet: protocol version")
ErrMalformedFlags = errors.New("malformed packet: flags")
ErrMalformedKeepalive = errors.New("malformed packet: keepalive")
ErrMalformedClientID = errors.New("malformed packet: client id")
ErrMalformedWillTopic = errors.New("malformed packet: will topic")
ErrMalformedWillMessage = errors.New("malformed packet: will message")
ErrMalformedUsername = errors.New("malformed packet: username")
ErrMalformedPassword = errors.New("malformed packet: password")
// CONNACK
ErrMalformedSessionPresent = errors.New("malformed packet: session present")
ErrMalformedReturnCode = errors.New("malformed packet: return code")
// PUBLISH
ErrMalformedTopic = errors.New("malformed packet: topic name")
ErrMalformedPacketID = errors.New("malformed packet: packet id")
// SUBSCRIBE
ErrMalformedQoS = errors.New("malformed packet: qos")
// PACKETS
ErrProtocolViolation = errors.New("protocol violation")
ErrOffsetBytesOutOfRange = errors.New("offset bytes out of range")
ErrOffsetByteOutOfRange = errors.New("offset byte out of range")
ErrOffsetBoolOutOfRange = errors.New("offset bool out of range")
ErrOffsetUintOutOfRange = errors.New("offset uint out of range")
ErrOffsetStrInvalidUTF8 = errors.New("offset string invalid utf8")
ErrInvalidFlags = errors.New("invalid flags set for packet")
ErrOversizedLengthIndicator = errors.New("protocol violation: oversized length indicator")
ErrMissingPacketID = errors.New("missing packet id")
ErrSurplusPacketID = errors.New("surplus packet id")
)
// Packet is an MQTT packet. Instead of providing a packet interface and variant
// packet structs, this is a single concrete packet type to cover all packet
// types, which allows us to take advantage of various compiler optimizations.
type Packet struct {
FixedHeader FixedHeader
AllowClients []string // For use with OnMessage event hook.
Topics []string
ReturnCodes []byte
ProtocolName []byte
Qoss []byte
Payload []byte
Username []byte
Password []byte
WillMessage []byte
ClientIdentifier string
TopicName string
WillTopic string
PacketID uint16
Keepalive uint16
ReturnCode byte
ProtocolVersion byte
WillQos byte
ReservedBit byte
CleanSession bool
WillFlag bool
WillRetain bool
UsernameFlag bool
PasswordFlag bool
SessionPresent bool
}
// ConnectEncode encodes a connect packet.
func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
protoName := encodeBytes(pk.ProtocolName)
protoVersion := pk.ProtocolVersion
flag := encodeBool(pk.CleanSession)<<1 | encodeBool(pk.WillFlag)<<2 | pk.WillQos<<3 | encodeBool(pk.WillRetain)<<5 | encodeBool(pk.PasswordFlag)<<6 | encodeBool(pk.UsernameFlag)<<7
keepalive := encodeUint16(pk.Keepalive)
clientID := encodeString(pk.ClientIdentifier)
var willTopic, willFlag, usernameFlag, passwordFlag []byte
// If will flag is set, add topic and message.
if pk.WillFlag {
willTopic = encodeString(pk.WillTopic)
willFlag = encodeBytes(pk.WillMessage)
}
// If username flag is set, add username.
if pk.UsernameFlag {
usernameFlag = encodeBytes(pk.Username)
}
// If password flag is set, add password.
if pk.PasswordFlag {
passwordFlag = encodeBytes(pk.Password)
}
// Get a length for the connect header. This is not super pretty, but it works.
pk.FixedHeader.Remaining =
len(protoName) + 1 + 1 + len(keepalive) + len(clientID) +
len(willTopic) + len(willFlag) +
len(usernameFlag) + len(passwordFlag)
pk.FixedHeader.Encode(buf)
// Eschew magic for readability.
buf.Write(protoName)
buf.WriteByte(protoVersion)
buf.WriteByte(flag)
buf.Write(keepalive)
buf.Write(clientID)
buf.Write(willTopic)
buf.Write(willFlag)
buf.Write(usernameFlag)
buf.Write(passwordFlag)
return nil
}
// ConnectDecode decodes a connect packet.
func (pk *Packet) ConnectDecode(buf []byte) error {
var offset int
var err error
// Unpack protocol name and version.
pk.ProtocolName, offset, err = decodeBytes(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolName)
}
pk.ProtocolVersion, offset, err = decodeByte(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolVersion)
}
// Unpack flags byte.
flags, offset, err := decodeByte(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedFlags)
}
pk.ReservedBit = 1 & flags
pk.CleanSession = 1&(flags>>1) > 0
pk.WillFlag = 1&(flags>>2) > 0
pk.WillQos = 3 & (flags >> 3) // this one is not a bool
pk.WillRetain = 1&(flags>>5) > 0
pk.PasswordFlag = 1&(flags>>6) > 0
pk.UsernameFlag = 1&(flags>>7) > 0
// Get keepalive interval.
pk.Keepalive, offset, err = decodeUint16(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedKeepalive)
}
// Get client ID.
pk.ClientIdentifier, offset, err = decodeString(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedClientID)
}
// Get Last Will and Testament topic and message if applicable.
if pk.WillFlag {
pk.WillTopic, offset, err = decodeString(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedWillTopic)
}
pk.WillMessage, offset, err = decodeBytes(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedWillMessage)
}
}
// Get username and password if applicable.
if pk.UsernameFlag {
pk.Username, offset, err = decodeBytes(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedUsername)
}
}
if pk.PasswordFlag {
pk.Password, _, err = decodeBytes(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
}
}
return nil
}
// ConnectValidate ensures the connect packet is compliant.
func (pk *Packet) ConnectValidate() (b byte, err error) {
// End if protocol name is bad.
if bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) != 0 &&
bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) != 0 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if protocol version is bad.
if (bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) == 0 && pk.ProtocolVersion != 3) ||
(bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) == 0 && pk.ProtocolVersion != 4) {
return CodeConnectBadProtocolVersion, ErrProtocolViolation
}
// End if reserved bit is not 0.
if pk.ReservedBit != 0 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if ClientID is too long.
if len(pk.ClientIdentifier) > 65535 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if password flag is set without a username.
if pk.PasswordFlag && !pk.UsernameFlag {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if Username or Password is too long.
if len(pk.Username) > 65535 || len(pk.Password) > 65535 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if client id isn't set and clean session is false.
if !pk.CleanSession && len(pk.ClientIdentifier) == 0 {
return CodeConnectBadClientID, ErrProtocolViolation
}
return Accepted, nil
}
// ConnackEncode encodes a Connack packet.
func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.WriteByte(encodeBool(pk.SessionPresent))
buf.WriteByte(pk.ReturnCode)
return nil
}
// ConnackDecode decodes a Connack packet.
func (pk *Packet) ConnackDecode(buf []byte) error {
var offset int
var err error
pk.SessionPresent, offset, err = decodeByteBool(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
}
pk.ReturnCode, _, err = decodeByte(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
}
return nil
}
// DisconnectEncode encodes a Disconnect packet.
func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Encode(buf)
return nil
}
// PingreqEncode encodes a Pingreq packet.
func (pk *Packet) PingreqEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Encode(buf)
return nil
}
// PingrespEncode encodes a Pingresp packet.
func (pk *Packet) PingrespEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Encode(buf)
return nil
}
// PubackEncode encodes a Puback packet.
func (pk *Packet) PubackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubackDecode decodes a Puback packet.
func (pk *Packet) PubackDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
// PubcompEncode encodes a Pubcomp packet.
func (pk *Packet) PubcompEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubcompDecode decodes a Pubcomp packet.
func (pk *Packet) PubcompDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
// PublishEncode encodes a Publish packet.
func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
topicName := encodeString(pk.TopicName)
var packetID []byte
// Add PacketID if QOS is set.
// [MQTT-2.3.1-5] A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0.
if pk.FixedHeader.Qos > 0 {
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.PacketID == 0 {
return ErrMissingPacketID
}
packetID = encodeUint16(pk.PacketID)
}
pk.FixedHeader.Remaining = len(topicName) + len(packetID) + len(pk.Payload)
pk.FixedHeader.Encode(buf)
buf.Write(topicName)
buf.Write(packetID)
buf.Write(pk.Payload)
return nil
}
// PublishDecode extracts the data values from the packet.
func (pk *Packet) PublishDecode(buf []byte) error {
var offset int
var err error
pk.TopicName, offset, err = decodeString(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
// If QOS decode Packet ID.
if pk.FixedHeader.Qos > 0 {
pk.PacketID, offset, err = decodeUint16(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
}
pk.Payload = buf[offset:]
return nil
}
// PublishCopy creates a new instance of Publish packet bearing the
// same payload and destination topic, but with an empty header for
// inheriting new QoS flags, etc.
func (pk *Packet) PublishCopy() Packet {
return Packet{
FixedHeader: FixedHeader{
Type: Publish,
Retain: pk.FixedHeader.Retain,
},
TopicName: pk.TopicName,
Payload: pk.Payload,
}
}
// PublishValidate validates a publish packet.
func (pk *Packet) PublishValidate() (byte, error) {
// @SPEC [MQTT-2.3.1-1]
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
return Failed, ErrMissingPacketID
}
// @SPEC [MQTT-2.3.1-5]
// A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0.
if pk.FixedHeader.Qos == 0 && pk.PacketID > 0 {
return Failed, ErrSurplusPacketID
}
return Accepted, nil
}
// PubrecEncode encodes a Pubrec packet.
func (pk *Packet) PubrecEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubrecDecode decodes a Pubrec packet.
func (pk *Packet) PubrecDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
// PubrelEncode encodes a Pubrel packet.
func (pk *Packet) PubrelEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubrelDecode decodes a Pubrel packet.
func (pk *Packet) PubrelDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
// SubackEncode encodes a Suback packet.
func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
packetID := encodeUint16(pk.PacketID)
pk.FixedHeader.Remaining = len(packetID) + len(pk.ReturnCodes) // Set length.
pk.FixedHeader.Encode(buf)
buf.Write(packetID) // Encode Packet ID.
buf.Write(pk.ReturnCodes) // Encode granted QOS flags.
return nil
}
// SubackDecode decodes a Suback packet.
func (pk *Packet) SubackDecode(buf []byte) error {
var offset int
var err error
// Get Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Get Granted QOS flags.
pk.ReturnCodes = buf[offset:]
return nil
}
// SubscribeEncode encodes a Subscribe packet.
func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
// Add the Packet ID.
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.PacketID == 0 {
return ErrMissingPacketID
}
packetID := encodeUint16(pk.PacketID)
// Count topics lengths and associated QOS flags.
var topicsLen int
for _, topic := range pk.Topics {
topicsLen += len(encodeString(topic)) + 1
}
pk.FixedHeader.Remaining = len(packetID) + topicsLen
pk.FixedHeader.Encode(buf)
buf.Write(packetID)
// Add all provided topic names and associated QOS flags.
for i, topic := range pk.Topics {
buf.Write(encodeString(topic))
buf.WriteByte(pk.Qoss[i])
}
return nil
}
// SubscribeDecode decodes a Subscribe packet.
func (pk *Packet) SubscribeDecode(buf []byte) error {
var offset int
var err error
// Get the Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Keep decoding until there's no space left.
for offset < len(buf) {
// Decode Topic Name.
var topic string
topic, offset, err = decodeString(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
pk.Topics = append(pk.Topics, topic)
// Decode QOS flag.
var qos byte
qos, offset, err = decodeByte(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedQoS)
}
// Ensure QoS byte is within range.
if !(qos >= 0 && qos <= 2) {
//if !validateQoS(qos) {
return ErrMalformedQoS
}
pk.Qoss = append(pk.Qoss, qos)
}
return nil
}
// SubscribeValidate ensures the packet is compliant.
func (pk *Packet) SubscribeValidate() (byte, error) {
// @SPEC [MQTT-2.3.1-1].
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
return Failed, ErrMissingPacketID
}
return Accepted, nil
}
// UnsubackEncode encodes an Unsuback packet.
func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// UnsubackDecode decodes an Unsuback packet.
func (pk *Packet) UnsubackDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
// UnsubscribeEncode encodes an Unsubscribe packet.
func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
// Add the Packet ID.
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.PacketID == 0 {
return ErrMissingPacketID
}
packetID := encodeUint16(pk.PacketID)
// Count topics lengths.
var topicsLen int
for _, topic := range pk.Topics {
topicsLen += len(encodeString(topic))
}
pk.FixedHeader.Remaining = len(packetID) + topicsLen
pk.FixedHeader.Encode(buf)
buf.Write(packetID)
// Add all provided topic names.
for _, topic := range pk.Topics {
buf.Write(encodeString(topic))
}
return nil
}
// UnsubscribeDecode decodes an Unsubscribe packet.
func (pk *Packet) UnsubscribeDecode(buf []byte) error {
var offset int
var err error
// Get the Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Keep decoding until there's no space left.
for offset < len(buf) {
var t string
t, offset, err = decodeString(buf, offset) // Decode Topic Name.
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
if len(t) > 0 {
pk.Topics = append(pk.Topics, t)
}
}
return nil
}
// UnsubscribeValidate validates an Unsubscribe packet.
func (pk *Packet) UnsubscribeValidate() (byte, error) {
// @SPEC [MQTT-2.3.1-1].
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
return Failed, ErrMissingPacketID
}
return Accepted, nil
}
// FormatID returns the PacketID field as a decimal integer.
func (pk *Packet) FormatID() string {
return strconv.FormatUint(uint64(pk.PacketID), 10)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,344 @@
package topics
import (
"strings"
"sync"
"github.com/mochi-co/mqtt/server/internal/packets"
)
// Subscriptions is a map of subscriptions keyed on client.
type Subscriptions map[string]byte
// Index is a prefix/trie tree containing topic subscribers and retained messages.
type Index struct {
mu sync.RWMutex // a mutex for locking the whole index.
Root *Leaf // a leaf containing a message and more leaves.
}
// New returns a pointer to a new instance of Index.
func New() *Index {
return &Index{
Root: &Leaf{
Leaves: make(map[string]*Leaf),
Clients: make(map[string]byte),
},
}
}
// RetainMessage saves a message payload to the end of a topic branch. Returns
// 1 if a retained message was added, and -1 if the retained message was removed.
// 0 is returned if sequential empty payloads are received.
func (x *Index) RetainMessage(msg packets.Packet) int64 {
x.mu.Lock()
defer x.mu.Unlock()
n := x.poperate(msg.TopicName)
// If there is a payload, we can store it.
if len(msg.Payload) > 0 {
n.Message = msg
return 1
}
// Otherwise, we are unsetting it.
// If there was a previous retained message, return -1 instead of 0.
var r int64 = 0
if len(n.Message.Payload) > 0 && n.Message.FixedHeader.Retain == true {
r = -1
}
x.unpoperate(msg.TopicName, "", true)
return r
}
// Subscribe creates a subscription filter for a client. Returns true if the
// subscription was new.
func (x *Index) Subscribe(filter, client string, qos byte) bool {
x.mu.Lock()
defer x.mu.Unlock()
n := x.poperate(filter)
_, ok := n.Clients[client]
n.Clients[client] = qos
n.Filter = filter
return !ok
}
// Unsubscribe removes a subscription filter for a client. Returns true if an
// unsubscribe action successful and the subscription existed.
func (x *Index) Unsubscribe(filter, client string) bool {
x.mu.Lock()
defer x.mu.Unlock()
n := x.poperate(filter)
_, ok := n.Clients[client]
return x.unpoperate(filter, client, false) && ok
}
// unpoperate steps backward through a trie sequence and removes any orphaned
// nodes. If a client id is specified, it will unsubscribe a client. If message
// is true, it will delete a retained message.
func (x *Index) unpoperate(filter string, client string, message bool) bool {
var d int // Walk to end leaf.
var particle string
var hasNext = true
e := x.Root
for hasNext {
particle, hasNext = isolateParticle(filter, d)
d++
e, _ = e.Leaves[particle]
// If the topic part doesn't exist in the tree, there's nothing
// left to do.
if e == nil {
return false
}
}
// Step backward removing client and orphaned leaves.
var key string
var orphaned bool
var end = true
for e.Parent != nil {
key = e.Key
// Wipe the client from this leaf if it's the filter end.
if end {
if client != "" {
delete(e.Clients, client)
}
if message {
e.Message = packets.Packet{}
}
end = false
}
// If this leaf is empty, note it as orphaned.
orphaned = len(e.Clients) == 0 && len(e.Leaves) == 0 && !e.Message.FixedHeader.Retain
// Traverse up the branch.
e = e.Parent
// If the leaf we just came from was empty, delete it.
if orphaned {
delete(e.Leaves, key)
}
}
return true
}
// poperate iterates and populates through a topic/filter path, instantiating
// leaves as it goes and returning the final leaf in the branch.
// poperate is a more enjoyable word than iterpop.
func (x *Index) poperate(topic string) *Leaf {
var d int
var particle string
var hasNext = true
n := x.Root
for hasNext {
particle, hasNext = isolateParticle(topic, d)
d++
child, _ := n.Leaves[particle]
if child == nil {
child = &Leaf{
Key: particle,
Parent: n,
Leaves: make(map[string]*Leaf),
Clients: make(map[string]byte),
}
n.Leaves[particle] = child
}
n = child
}
return n
}
// Subscribers returns a map of clients who are subscribed to matching filters.
func (x *Index) Subscribers(topic string) Subscriptions {
x.mu.RLock()
defer x.mu.RUnlock()
return x.Root.scanSubscribers(topic, 0, make(Subscriptions))
}
// Messages returns a slice of retained topic messages which match a filter.
func (x *Index) Messages(filter string) []packets.Packet {
// ReLeaf("messages", x.Root, 0)
x.mu.RLock()
defer x.mu.RUnlock()
return x.Root.scanMessages(filter, 0, make([]packets.Packet, 0, 32))
}
// Leaf is a child node on the tree.
type Leaf struct {
Message packets.Packet // a message which has been retained for a specific topic.
Key string // the key that was used to create the leaf.
Filter string // the path of the topic filter being matched.
Parent *Leaf // a pointer to the parent node for the leaf.
Leaves map[string]*Leaf // a map of child nodes, keyed on particle id.
Clients map[string]byte // a map of client ids subscribed to the topic.
}
// scanSubscribers recursively steps through a branch of leaves finding clients who
// have subscription filters matching a topic, and their highest QoS byte.
func (l *Leaf) scanSubscribers(topic string, d int, clients Subscriptions) Subscriptions {
part, hasNext := isolateParticle(topic, d)
// For either the topic part, a +, or a #, follow the branch.
for _, particle := range []string{part, "+", "#"} {
// Topics beginning with the reserved $ character are restricted from
// being returned for top level wildcards.
if d == 0 && len(part) > 0 && part[0] == '$' && (particle == "+" || particle == "#") {
continue
}
if child, ok := l.Leaves[particle]; ok {
// We're only interested in getting clients from the final
// element in the topic, or those with wildhashes.
if !hasNext || particle == "#" {
// Capture the highest QOS byte for any client with a filter
// matching the topic.
for client, qos := range child.Clients {
if ex, ok := clients[client]; !ok || ex < qos {
clients[client] = qos
}
}
// Make sure we also capture any client who are listening
// to this topic via path/#
if !hasNext {
if extra, ok := child.Leaves["#"]; ok {
for client, qos := range extra.Clients {
if ex, ok := clients[client]; !ok || ex < qos {
clients[client] = qos
}
}
}
}
}
// If this branch has hit a wildhash, just return immediately.
if particle == "#" {
return clients
} else if hasNext {
clients = child.scanSubscribers(topic, d+1, clients)
}
}
}
return clients
}
// scanMessages recursively steps through a branch of leaves finding retained messages
// that match a topic filter. Setting `d` to -1 will enable wildhash mode, and will
// recursively check ALL child leaves in every subsequent branch.
func (l *Leaf) scanMessages(filter string, d int, messages []packets.Packet) []packets.Packet {
// If a wildhash mode has been set, continue recursively checking through all
// child leaves regardless of their particle key.
if d == -1 {
for _, child := range l.Leaves {
if child.Message.FixedHeader.Retain {
messages = append(messages, child.Message)
}
messages = child.scanMessages(filter, -1, messages)
}
return messages
}
// Otherwise, we'll get the particle for d in the filter.
particle, hasNext := isolateParticle(filter, d)
// If there's no more particles after this one, then take the messages from
// these topics.
if !hasNext {
// Wildcards and Wildhashes must be checked first, otherwise they
// may be detected as standard particles, and not act properly.
if particle == "+" || particle == "#" {
// Otherwise, if it's a wildcard or wildhash, get messages from all
// the child leaves. This wildhash captures messages on the actual
// wildhash position, whereas the d == -1 block collects subsequent
// messages further down the branch.
for _, child := range l.Leaves {
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
continue
}
if child.Message.FixedHeader.Retain {
messages = append(messages, child.Message)
}
}
} else if child, ok := l.Leaves[particle]; ok {
if child.Message.FixedHeader.Retain {
messages = append(messages, child.Message)
}
}
} else {
// If it's not the last particle, branch out to the next leaves, scanning
// all available if it's a wildcard, or just one if it's a specific particle.
if particle == "+" {
for _, child := range l.Leaves {
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
continue
}
messages = child.scanMessages(filter, d+1, messages)
}
} else if child, ok := l.Leaves[particle]; ok {
messages = child.scanMessages(filter, d+1, messages)
}
}
// If the particle was a wildhash, scan all the child leaves setting the
// d value to wildhash mode.
if particle == "#" {
for _, child := range l.Leaves {
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
continue
}
messages = child.scanMessages(filter, -1, messages)
}
}
return messages
}
// isolateParticle extracts a particle between d / and d+1 / without allocations.
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
var next, end int
for i := 0; end > -1 && i <= d; i++ {
end = strings.IndexRune(filter, '/')
if d > -1 && i == d && end > -1 {
hasNext = true
particle = filter[next:end]
} else if end > -1 {
hasNext = false
filter = filter[end+1:]
} else {
hasNext = false
particle = filter[next:]
}
}
return
}
// ReLeaf is a dev function for showing the trie leafs.
/*
func ReLeaf(m string, leaf *Leaf, d int) {
for k, v := range leaf.Leaves {
fmt.Println(m, d, strings.Repeat(" ", d), k)
ReLeaf(m, v, d+1)
}
}
*/

View File

@ -0,0 +1,494 @@
package topics
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/server/internal/packets"
)
func TestNew(t *testing.T) {
index := New()
require.NotNil(t, index)
require.NotNil(t, index.Root)
}
func BenchmarkNew(b *testing.B) {
for n := 0; n < b.N; n++ {
New()
}
}
func TestPoperate(t *testing.T) {
index := New()
child := index.poperate("path/to/my/mqtt")
require.Equal(t, "mqtt", child.Key)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
child = index.poperate("a/b/c/d/e")
require.Equal(t, "e", child.Key)
child = index.poperate("a/b/c/c/a")
require.Equal(t, "a", child.Key)
}
func BenchmarkPoperate(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.poperate("path/to/my/mqtt")
}
}
func TestUnpoperate(t *testing.T) {
index := New()
index.Subscribe("path/to/my/mqtt", "client-1", 0)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
index.Subscribe("path/to/another/mqtt", "client-1", 0)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
pk := packets.Packet{TopicName: "path/to/retained/message", Payload: []byte{'h', 'e', 'l', 'l', 'o'}}
index.RetainMessage(pk)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["retained"].Leaves["message"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["retained"].Leaves["message"].Message)
pk2 := packets.Packet{TopicName: "path/to/my/mqtt", Payload: []byte{'s', 'h', 'a', 'r', 'e', 'd'}}
index.RetainMessage(pk2)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
index.unpoperate("path/to/my/mqtt", "", true) // delete retained
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message.FixedHeader.Retain)
index.unpoperate("path/to/my/mqtt", "client-1", false) // unsubscribe client
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
index.unpoperate("path/to/retained/message", "", true) // delete retained
require.NotContains(t, index.Root.Leaves["path"].Leaves["to"].Leaves, "my")
index.unpoperate("path/to/whatever", "client-1", false) // unsubscribe client
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
//require.Empty(t, index.Root.Leaves["path"])
}
func BenchmarkUnpoperate(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.poperate("path/to/my/mqtt")
}
}
func TestRetainMessage(t *testing.T) {
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
TopicName: "path/to/my/mqtt",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
}
pk2 := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
TopicName: "path/to/another/mqtt",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
}
index := New()
q := index.RetainMessage(pk)
require.Equal(t, int64(1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
index.Subscribe("path/to/another/mqtt", "client-1", 0)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients["client-1"])
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
q = index.RetainMessage(pk2)
require.Equal(t, int64(1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
// The same message already exists, but we're not doing a deep-copy check, so it's considered
// to be a new message.
q = index.RetainMessage(pk2)
require.Equal(t, int64(1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
// Delete retained
pk3 := packets.Packet{TopicName: "path/to/another/mqtt", Payload: []byte{}}
q = index.RetainMessage(pk3)
require.Equal(t, int64(-1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
// Second Delete retained
q = index.RetainMessage(pk3)
require.Equal(t, int64(0), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
}
func BenchmarkRetainMessage(b *testing.B) {
index := New()
pk := packets.Packet{TopicName: "path/to/another/mqtt"}
for n := 0; n < b.N; n++ {
index.RetainMessage(pk)
}
}
func TestSubscribeOK(t *testing.T) {
index := New()
q := index.Subscribe("path/to/my/mqtt", "client-1", 0)
require.Equal(t, true, q)
q = index.Subscribe("path/to/my/mqtt", "client-1", 0)
require.Equal(t, false, q)
q = index.Subscribe("path/to/my/mqtt", "client-2", 0)
require.Equal(t, true, q)
q = index.Subscribe("path/to/another/mqtt", "client-1", 0)
require.Equal(t, true, q)
q = index.Subscribe("path/+", "client-2", 0)
require.Equal(t, true, q)
q = index.Subscribe("#", "client-3", 0)
require.Equal(t, true, q)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
require.Equal(t, "path/to/my/mqtt", index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Filter)
require.Equal(t, "mqtt", index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Key)
require.Equal(t, index.Root.Leaves["path"], index.Root.Leaves["path"].Leaves["to"].Parent)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["+"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
}
func BenchmarkSubscribe(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.Subscribe("path/to/mqtt/basic", "client-1", 0)
}
}
func TestUnsubscribeA(t *testing.T) {
index := New()
index.Subscribe("path/to/my/mqtt", "client-1", 0)
index.Subscribe("path/to/+/mqtt", "client-1", 0)
index.Subscribe("path/to/stuff", "client-1", 0)
index.Subscribe("path/to/stuff", "client-2", 0)
index.Subscribe("#", "client-3", 0)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["+"].Leaves["mqtt"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
ok := index.Unsubscribe("path/to/my/mqtt", "client-1")
require.Equal(t, true, ok)
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["+"].Leaves["mqtt"].Clients, "client-1")
ok = index.Unsubscribe("path/to/stuff", "client-1")
require.Equal(t, true, ok)
require.NotContains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "client-1")
require.Equal(t, false, ok)
}
func TestUnsubscribeCascade(t *testing.T) {
index := New()
index.Subscribe("a/b/c", "client-1", 0)
index.Subscribe("a/b/c/e/e", "client-1", 0)
ok := index.Unsubscribe("a/b/c/e/e", "client-1")
require.Equal(t, true, ok)
require.NotEmpty(t, index.Root.Leaves)
require.Contains(t, index.Root.Leaves["a"].Leaves["b"].Leaves["c"].Clients, "client-1")
}
// This benchmark is Unsubscribe-Subscribe
func BenchmarkUnsubscribe(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.Subscribe("path/to/my/mqtt", "client-1", 0)
index.Unsubscribe("path/to/mqtt/basic", "client-1")
}
}
func TestSubscribersFind(t *testing.T) {
tt := []struct {
filter string
topic string
len int
}{
{
filter: "a",
topic: "a",
len: 1,
},
{
filter: "a/",
topic: "a",
len: 0,
},
{
filter: "a/",
topic: "a/",
len: 1,
},
{
filter: "/a",
topic: "/a",
len: 1,
},
{
filter: "path/to/my/mqtt",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "path/to/+/mqtt",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "+/to/+/mqtt",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "#",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "+/+/+/+",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "+/+/+/#",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "zen/#",
topic: "zen",
len: 1,
},
{
filter: "+/+/#",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "path/to/",
topic: "path/to/my/mqtt",
len: 0,
},
{
filter: "#/stuff",
topic: "path/to/my/mqtt",
len: 0,
},
{
filter: "$SYS/#",
topic: "$SYS/info",
len: 1,
},
{
filter: "#",
topic: "$SYS/info",
len: 0,
},
{
filter: "+/info",
topic: "$SYS/info",
len: 0,
},
}
for i, check := range tt {
index := New()
index.Subscribe(check.filter, "client-1", 0)
clients := index.Subscribers(check.topic)
//spew.Dump(clients)
require.Equal(t, check.len, len(clients), "Unexpected clients len at %d %s %s", i, check.filter, check.topic)
}
}
func BenchmarkSubscribers(b *testing.B) {
index := New()
index.Subscribe("path/to/my/mqtt", "client-1", 0)
index.Subscribe("path/to/+/mqtt", "client-1", 0)
index.Subscribe("something/things/stuff/+", "client-1", 0)
index.Subscribe("path/to/stuff", "client-2", 0)
index.Subscribe("#", "client-3", 0)
for n := 0; n < b.N; n++ {
index.Subscribers("path/to/testing/mqtt")
}
}
func TestIsolateParticle(t *testing.T) {
particle, hasNext := isolateParticle("path/to/my/mqtt", 0)
require.Equal(t, "path", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 1)
require.Equal(t, "to", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 2)
require.Equal(t, "my", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 3)
require.Equal(t, "mqtt", particle)
require.Equal(t, false, hasNext)
particle, hasNext = isolateParticle("/path/", 0)
require.Equal(t, "", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("/path/", 1)
require.Equal(t, "path", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("/path/", 2)
require.Equal(t, "", particle)
require.Equal(t, false, hasNext)
particle, hasNext = isolateParticle("a/b/c/+/+", 3)
require.Equal(t, "+", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("a/b/c/+/+", 4)
require.Equal(t, "+", particle)
require.Equal(t, false, hasNext)
}
func BenchmarkIsolateParticle(b *testing.B) {
for n := 0; n < b.N; n++ {
isolateParticle("path/to/my/mqtt", 3)
}
}
func TestMessagesPattern(t *testing.T) {
tt := []struct {
packet packets.Packet
filter string
len int
}{
{
packets.Packet{TopicName: "a/b/c/d", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"a/b/c/d",
1,
},
{
packets.Packet{TopicName: "a/b/c/e", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"a/+/c/+",
2,
},
{
packets.Packet{TopicName: "a/b/d/f", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"+/+/+/+",
3,
},
{
packets.Packet{TopicName: "q/w/e/r/t/y", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"q/w/e/#",
1,
},
{
packets.Packet{TopicName: "q/w/x/r/t/x", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"q/#",
2,
},
{
packets.Packet{TopicName: "asd", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"asd",
1,
},
{
packets.Packet{TopicName: "$SYS/testing", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"#",
8,
},
{
packets.Packet{TopicName: "$SYS/test", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"+/testing",
0,
},
{
packets.Packet{TopicName: "$SYS/info", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"$SYS/info",
1,
},
{
packets.Packet{TopicName: "$SYS/b", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"$SYS/#",
4,
},
{
packets.Packet{TopicName: "asd/fgh/jkl", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"#",
8,
},
{
packets.Packet{TopicName: "stuff/asdadsa/dsfdsafdsadfsa/dsfdsf/sdsadas", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"stuff/#/things", // indexer will ignore trailing /things
1,
},
}
index := New()
for _, check := range tt {
index.RetainMessage(check.packet)
}
for i, check := range tt {
messages := index.Messages(check.filter)
require.Equal(t, check.len, len(messages), "Unexpected messages len at %d %s %s", i, check.filter, check.packet.TopicName)
}
}
func TestMessagesFind(t *testing.T) {
index := New()
index.RetainMessage(packets.Packet{TopicName: "a/a", Payload: []byte{'a'}, FixedHeader: packets.FixedHeader{Retain: true}})
index.RetainMessage(packets.Packet{TopicName: "a/b", Payload: []byte{'b'}, FixedHeader: packets.FixedHeader{Retain: true}})
messages := index.Messages("a/a")
require.Equal(t, 1, len(messages))
messages = index.Messages("a/+")
require.Equal(t, 2, len(messages))
}
func BenchmarkMessages(b *testing.B) {
index := New()
index.RetainMessage(packets.Packet{TopicName: "path/to/my/mqtt"})
index.RetainMessage(packets.Packet{TopicName: "path/to/another/mqtt"})
index.RetainMessage(packets.Packet{TopicName: "path/a/some/mqtt"})
index.RetainMessage(packets.Packet{TopicName: "what/is"})
index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"})
for n := 0; n < b.N; n++ {
index.Messages("path/to/+/mqtt")
}
}

View File

@ -0,0 +1,14 @@
package utils
// InSliceString returns true if a string exists in a slice of strings.
// This temporary and should be replaced with a function from the new
// go slices package in 1.19 when available.
// https://github.com/golang/go/issues/45955
func InSliceString(sl []string, st string) bool {
for _, v := range sl {
if st == v {
return true
}
}
return false
}

View File

@ -0,0 +1,18 @@
package utils
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestInSliceString(t *testing.T) {
sl := []string{"a", "b", "c"}
require.Equal(t, true, InSliceString(sl, "b"))
sl = []string{"a", "a", "a"}
require.Equal(t, true, InSliceString(sl, "a"))
sl = []string{"a", "b", "c"}
require.Equal(t, false, InSliceString(sl, "d"))
}