clone
This commit is contained in:
116
server/internal/packets/codec.go
Normal file
116
server/internal/packets/codec.go
Normal file
@ -0,0 +1,116 @@
|
||||
package packets
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"unicode/utf8"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// bytesToString provides a zero-alloc, no-copy byte to string conversion.
|
||||
// via https://github.com/golang/go/issues/25484#issuecomment-391415660
|
||||
func bytesToString(bs []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&bs))
|
||||
}
|
||||
|
||||
// decodeUint16 extracts the value of two bytes from a byte array.
|
||||
func decodeUint16(buf []byte, offset int) (uint16, int, error) {
|
||||
if len(buf) < offset+2 {
|
||||
return 0, 0, ErrOffsetUintOutOfRange
|
||||
}
|
||||
|
||||
return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
|
||||
}
|
||||
|
||||
// decodeString extracts a string from a byte array, beginning at an offset.
|
||||
func decodeString(buf []byte, offset int) (string, int, error) {
|
||||
b, n, err := decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
if !validUTF8(b) {
|
||||
return "", 0, ErrOffsetStrInvalidUTF8
|
||||
}
|
||||
|
||||
return bytesToString(b), n, nil
|
||||
}
|
||||
|
||||
// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
|
||||
func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
|
||||
length, next, err := decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return make([]byte, 0, 0), 0, err
|
||||
}
|
||||
|
||||
if next+int(length) > len(buf) {
|
||||
return make([]byte, 0, 0), 0, ErrOffsetBytesOutOfRange
|
||||
}
|
||||
|
||||
// Note: there is no validUTF8() test for []byte payloads
|
||||
|
||||
return buf[next : next+int(length)], next + int(length), nil
|
||||
}
|
||||
|
||||
// decodeByte extracts the value of a byte from a byte array.
|
||||
func decodeByte(buf []byte, offset int) (byte, int, error) {
|
||||
if len(buf) <= offset {
|
||||
return 0, 0, ErrOffsetByteOutOfRange
|
||||
}
|
||||
return buf[offset], offset + 1, nil
|
||||
}
|
||||
|
||||
// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
|
||||
func decodeByteBool(buf []byte, offset int) (bool, int, error) {
|
||||
if len(buf) <= offset {
|
||||
return false, 0, ErrOffsetBoolOutOfRange
|
||||
}
|
||||
return 1&buf[offset] > 0, offset + 1, nil
|
||||
}
|
||||
|
||||
// encodeBool returns a byte instead of a bool.
|
||||
func encodeBool(b bool) byte {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
|
||||
func encodeBytes(val []byte) []byte {
|
||||
// In many circumstances the number of bytes being encoded is small.
|
||||
// Setting the cap to a low amount allows us to account for those without
|
||||
// triggering allocation growth on append unless we need to.
|
||||
buf := make([]byte, 2, 32)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(val)))
|
||||
return append(buf, val...)
|
||||
}
|
||||
|
||||
// encodeUint16 encodes a uint16 value to a byte array.
|
||||
func encodeUint16(val uint16) []byte {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, val)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeString encodes a string to a byte array.
|
||||
func encodeString(val string) []byte {
|
||||
// Like encodeBytes, we set the cap to a small number to avoid
|
||||
// triggering allocation growth on append unless we absolutely need to.
|
||||
buf := make([]byte, 2, 32)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(val)))
|
||||
return append(buf, []byte(val)...)
|
||||
}
|
||||
|
||||
// validUTF8 checks if the byte array contains valid UTF-8 characters, specifically
|
||||
// conforming to the MQTT specification requirements.
|
||||
func validUTF8(b []byte) bool {
|
||||
// [MQTT-1.4.0-1] The character data in a UTF-8 encoded string MUST be well-formed UTF-8...
|
||||
if !utf8.Valid(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
// [MQTT-1.4.0-2] A UTF-8 encoded string MUST NOT include an encoding of the null character U+0000...
|
||||
// ...
|
||||
return true
|
||||
|
||||
}
|
386
server/internal/packets/codec_test.go
Normal file
386
server/internal/packets/codec_test.go
Normal file
@ -0,0 +1,386 @@
|
||||
package packets
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBytesToString(t *testing.T) {
|
||||
b := []byte{'a', 'b', 'c'}
|
||||
require.Equal(t, "abc", bytesToString(b))
|
||||
}
|
||||
|
||||
func BenchmarkBytesToString(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
bytesToString([]byte{'a', 'b', 'c'})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeString(t *testing.T) {
|
||||
expect := []struct {
|
||||
name string
|
||||
rawBytes []byte
|
||||
result string
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: "a/b/c/d",
|
||||
},
|
||||
{
|
||||
offset: 14,
|
||||
rawBytes: []byte{
|
||||
byte(Connect << 4), 17, // Fixed header
|
||||
0, 6, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
|
||||
3, // Protocol Version
|
||||
0, // Packet Flags
|
||||
0, 30, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'h', 'e', 'y', // Client ID "zen"},
|
||||
},
|
||||
result: "hey",
|
||||
},
|
||||
{
|
||||
offset: 2,
|
||||
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
|
||||
result: "1/2/3/4/a/b/c/d/e/^/@/!",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
|
||||
result: "x/y/z",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 5,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 9,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
|
||||
shouldFail: ErrOffsetUintOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 17,
|
||||
rawBytes: []byte{
|
||||
byte(Connect << 4), 0, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
0, // Flags
|
||||
0, 20, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
0, 6, // Will Topic - MSB+LSB
|
||||
'l',
|
||||
},
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
shouldFail: ErrOffsetStrInvalidUTF8,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeString(b *testing.B) {
|
||||
in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}
|
||||
for n := 0; n < b.N; n++ {
|
||||
decodeString(in, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBytes(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result []uint8
|
||||
next int
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session)
|
||||
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
|
||||
next: 6,
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // ... truncated connect packet, only checking start
|
||||
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
|
||||
next: 6,
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
offset: 0,
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
offset: 8,
|
||||
shouldFail: ErrOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeBytes(b *testing.B) {
|
||||
in := []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}
|
||||
for n := 0; n < b.N; n++ {
|
||||
decodeBytes(in, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeByte(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint8
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
|
||||
result: uint8(0x00),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x04),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x4d),
|
||||
offset: 2,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x51),
|
||||
offset: 3,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 80, 82, 84},
|
||||
offset: 8,
|
||||
shouldFail: ErrOffsetByteOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeByte(b *testing.B) {
|
||||
in := []byte{0, 4, 77, 81, 84, 84}
|
||||
for n := 0; n < b.N; n++ {
|
||||
decodeByte(in, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUint16(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint16
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: uint16(0x07),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: uint16(0x761),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 255, 47},
|
||||
offset: 8,
|
||||
shouldFail: ErrOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+2, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeUint16(b *testing.B) {
|
||||
in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}
|
||||
for n := 0; n < b.N; n++ {
|
||||
decodeUint16(in, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeByteBool(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result bool
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0x00, 0x00},
|
||||
result: false,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0x01, 0x00},
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0x01, 0x00},
|
||||
offset: 5,
|
||||
shouldFail: ErrOffsetBoolOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, 1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeByteBool(b *testing.B) {
|
||||
in := []byte{0x00, 0x00}
|
||||
for n := 0; n < b.N; n++ {
|
||||
decodeByteBool(in, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeBool(t *testing.T) {
|
||||
result := encodeBool(true)
|
||||
require.Equal(t, byte(1), result, "Incorrect encoded value; not true")
|
||||
|
||||
result = encodeBool(false)
|
||||
require.Equal(t, byte(0), result, "Incorrect encoded value; not false")
|
||||
|
||||
// Check failure.
|
||||
result = encodeBool(false)
|
||||
require.NotEqual(t, byte(1), result, "Expected failure, incorrect encoded value")
|
||||
}
|
||||
|
||||
func BenchmarkEncodeBool(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
encodeBool(true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeBytes(t *testing.T) {
|
||||
result := encodeBytes([]byte("testing"))
|
||||
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result, "Incorrect encoded value")
|
||||
|
||||
result = encodeBytes([]byte("testing"))
|
||||
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result, "Expected failure, incorrect encoded value")
|
||||
}
|
||||
|
||||
func BenchmarkEncodeBytes(b *testing.B) {
|
||||
bb := []byte("testing")
|
||||
for n := 0; n < b.N; n++ {
|
||||
encodeBytes(bb)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeUint16(t *testing.T) {
|
||||
result := encodeUint16(0)
|
||||
require.Equal(t, []byte{0x00, 0x00}, result, "Incorrect encoded value, 0")
|
||||
|
||||
result = encodeUint16(32767)
|
||||
require.Equal(t, []byte{0x7f, 0xff}, result, "Incorrect encoded value, 32767")
|
||||
|
||||
result = encodeUint16(65535)
|
||||
require.Equal(t, []byte{0xff, 0xff}, result, "Incorrect encoded value, 65535")
|
||||
}
|
||||
|
||||
func BenchmarkEncodeUint16(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
encodeUint16(32767)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeString(t *testing.T) {
|
||||
result := encodeString("testing")
|
||||
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result, "Incorrect encoded value, testing")
|
||||
|
||||
result = encodeString("")
|
||||
require.Equal(t, []uint8{0x00, 0x00}, result, "Incorrect encoded value, null")
|
||||
|
||||
result = encodeString("a")
|
||||
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result, "Incorrect encoded value, a")
|
||||
|
||||
result = encodeString("b")
|
||||
require.NotEqual(t, []uint8{0x00, 0x00}, result, "Expected failure, incorrect encoded value, b")
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkEncodeString(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
encodeString("benchmarking")
|
||||
}
|
||||
}
|
59
server/internal/packets/fixedheader.go
Normal file
59
server/internal/packets/fixedheader.go
Normal file
@ -0,0 +1,59 @@
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
|
||||
type FixedHeader struct {
|
||||
Remaining int // the number of remaining bytes in the payload.
|
||||
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
|
||||
Qos byte // indicates the quality of service expected.
|
||||
Dup bool // indicates if the packet was already sent at an earlier time.
|
||||
Retain bool // whether the message should be retained.
|
||||
}
|
||||
|
||||
// Encode encodes the FixedHeader and returns a bytes buffer.
|
||||
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
|
||||
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
|
||||
encodeLength(buf, int64(fh.Remaining))
|
||||
}
|
||||
|
||||
// Decode extracts the specification bits from the header byte.
|
||||
func (fh *FixedHeader) Decode(headerByte byte) error {
|
||||
fh.Type = headerByte >> 4 // Get the message type from the first 4 bytes.
|
||||
|
||||
switch fh.Type {
|
||||
case Publish:
|
||||
fh.Dup = (headerByte>>3)&0x01 > 0 // Extract flags. Check if message is duplicate.
|
||||
fh.Qos = (headerByte >> 1) & 0x03 // Extract QoS flag.
|
||||
fh.Retain = headerByte&0x01 > 0 // Extract retain flag.
|
||||
case Pubrel:
|
||||
fh.Qos = (headerByte >> 1) & 0x03
|
||||
case Subscribe:
|
||||
fh.Qos = (headerByte >> 1) & 0x03
|
||||
case Unsubscribe:
|
||||
fh.Qos = (headerByte >> 1) & 0x03
|
||||
default:
|
||||
if (headerByte>>3)&0x01 > 0 || (headerByte>>1)&0x03 > 0 || headerByte&0x01 > 0 {
|
||||
return ErrInvalidFlags
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeLength writes length bits for the header.
|
||||
func encodeLength(buf *bytes.Buffer, length int64) {
|
||||
for {
|
||||
digit := byte(length % 128)
|
||||
length /= 128
|
||||
if length > 0 {
|
||||
digit |= 0x80
|
||||
}
|
||||
buf.WriteByte(digit)
|
||||
if length == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
220
server/internal/packets/fixedheader_test.go
Normal file
220
server/internal/packets/fixedheader_test.go
Normal file
@ -0,0 +1,220 @@
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fixedHeaderTable struct {
|
||||
rawBytes []byte
|
||||
header FixedHeader
|
||||
packetError bool
|
||||
flagError bool
|
||||
}
|
||||
|
||||
var fixedHeaderExpected = []fixedHeaderTable{
|
||||
{
|
||||
rawBytes: []byte{Connect << 4, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Connack << 4, 0x00},
|
||||
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Puback << 4, 0x00},
|
||||
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pubrec << 4, 0x00},
|
||||
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pubcomp << 4, 0x00},
|
||||
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Suback << 4, 0x00},
|
||||
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Unsuback << 4, 0x00},
|
||||
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pingreq << 4, 0x00},
|
||||
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pingresp << 4, 0x00},
|
||||
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Disconnect << 4, 0x00},
|
||||
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
|
||||
// remaining length
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x0a},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x80, 0x04},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
|
||||
packetError: true,
|
||||
},
|
||||
|
||||
// Invalid flags for packet
|
||||
{
|
||||
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
flagError: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
flagError: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Connect<<4 | 1, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
|
||||
flagError: true,
|
||||
},
|
||||
}
|
||||
|
||||
func TestFixedHeaderEncode(t *testing.T) {
|
||||
for i, wanted := range fixedHeaderExpected {
|
||||
buf := new(bytes.Buffer)
|
||||
wanted.header.Encode(buf)
|
||||
if wanted.flagError == false {
|
||||
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()), "Mismatched fixedheader length [i:%d] %v", i, wanted.rawBytes)
|
||||
require.EqualValues(t, wanted.rawBytes, buf.Bytes(), "Mismatched byte values [i:%d] %v", i, wanted.rawBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFixedHeaderEncode(b *testing.B) {
|
||||
buf := new(bytes.Buffer)
|
||||
for n := 0; n < b.N; n++ {
|
||||
fixedHeaderExpected[0].header.Encode(buf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixedHeaderDecode(t *testing.T) {
|
||||
for i, wanted := range fixedHeaderExpected {
|
||||
fh := new(FixedHeader)
|
||||
err := fh.Decode(wanted.rawBytes[0])
|
||||
if wanted.flagError {
|
||||
require.Error(t, err, "Expected error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
|
||||
} else {
|
||||
require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
|
||||
require.Equal(t, wanted.header.Type, fh.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes)
|
||||
require.Equal(t, wanted.header.Dup, fh.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes)
|
||||
require.Equal(t, wanted.header.Qos, fh.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes)
|
||||
require.Equal(t, wanted.header.Retain, fh.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFixedHeaderDecode(b *testing.B) {
|
||||
fh := new(FixedHeader)
|
||||
for n := 0; n < b.N; n++ {
|
||||
err := fh.Decode(fixedHeaderExpected[0].rawBytes[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeLength(t *testing.T) {
|
||||
tt := []struct {
|
||||
have int64
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
120,
|
||||
[]byte{0x78},
|
||||
},
|
||||
{
|
||||
math.MaxInt64,
|
||||
[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range tt {
|
||||
buf := new(bytes.Buffer)
|
||||
encodeLength(buf, wanted.have)
|
||||
require.Equal(t, wanted.want, buf.Bytes(), "Returned bytes should match length [i:%d] %s", i, wanted.have)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeLength(b *testing.B) {
|
||||
buf := new(bytes.Buffer)
|
||||
for n := 0; n < b.N; n++ {
|
||||
encodeLength(buf, 120)
|
||||
}
|
||||
}
|
670
server/internal/packets/packets.go
Normal file
670
server/internal/packets/packets.go
Normal file
@ -0,0 +1,670 @@
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// All of the valid packet types and their packet identifier.
|
||||
const (
|
||||
Reserved byte = iota
|
||||
Connect // 1
|
||||
Connack // 2
|
||||
Publish // 3
|
||||
Puback // 4
|
||||
Pubrec // 5
|
||||
Pubrel // 6
|
||||
Pubcomp // 7
|
||||
Subscribe // 8
|
||||
Suback // 9
|
||||
Unsubscribe // 10
|
||||
Unsuback // 11
|
||||
Pingreq // 12
|
||||
Pingresp // 13
|
||||
Disconnect // 14
|
||||
|
||||
Accepted byte = 0x00
|
||||
Failed byte = 0xFF
|
||||
CodeConnectBadProtocolVersion byte = 0x01
|
||||
CodeConnectBadClientID byte = 0x02
|
||||
CodeConnectServerUnavailable byte = 0x03
|
||||
CodeConnectBadAuthValues byte = 0x04
|
||||
CodeConnectNotAuthorised byte = 0x05
|
||||
CodeConnectNetworkError byte = 0xFE
|
||||
CodeConnectProtocolViolation byte = 0xFF
|
||||
ErrSubAckNetworkError byte = 0x80
|
||||
)
|
||||
|
||||
var (
|
||||
// CONNECT
|
||||
ErrMalformedProtocolName = errors.New("malformed packet: protocol name")
|
||||
ErrMalformedProtocolVersion = errors.New("malformed packet: protocol version")
|
||||
ErrMalformedFlags = errors.New("malformed packet: flags")
|
||||
ErrMalformedKeepalive = errors.New("malformed packet: keepalive")
|
||||
ErrMalformedClientID = errors.New("malformed packet: client id")
|
||||
ErrMalformedWillTopic = errors.New("malformed packet: will topic")
|
||||
ErrMalformedWillMessage = errors.New("malformed packet: will message")
|
||||
ErrMalformedUsername = errors.New("malformed packet: username")
|
||||
ErrMalformedPassword = errors.New("malformed packet: password")
|
||||
|
||||
// CONNACK
|
||||
ErrMalformedSessionPresent = errors.New("malformed packet: session present")
|
||||
ErrMalformedReturnCode = errors.New("malformed packet: return code")
|
||||
|
||||
// PUBLISH
|
||||
ErrMalformedTopic = errors.New("malformed packet: topic name")
|
||||
ErrMalformedPacketID = errors.New("malformed packet: packet id")
|
||||
|
||||
// SUBSCRIBE
|
||||
ErrMalformedQoS = errors.New("malformed packet: qos")
|
||||
|
||||
// PACKETS
|
||||
ErrProtocolViolation = errors.New("protocol violation")
|
||||
ErrOffsetBytesOutOfRange = errors.New("offset bytes out of range")
|
||||
ErrOffsetByteOutOfRange = errors.New("offset byte out of range")
|
||||
ErrOffsetBoolOutOfRange = errors.New("offset bool out of range")
|
||||
ErrOffsetUintOutOfRange = errors.New("offset uint out of range")
|
||||
ErrOffsetStrInvalidUTF8 = errors.New("offset string invalid utf8")
|
||||
ErrInvalidFlags = errors.New("invalid flags set for packet")
|
||||
ErrOversizedLengthIndicator = errors.New("protocol violation: oversized length indicator")
|
||||
ErrMissingPacketID = errors.New("missing packet id")
|
||||
ErrSurplusPacketID = errors.New("surplus packet id")
|
||||
)
|
||||
|
||||
// Packet is an MQTT packet. Instead of providing a packet interface and variant
|
||||
// packet structs, this is a single concrete packet type to cover all packet
|
||||
// types, which allows us to take advantage of various compiler optimizations.
|
||||
type Packet struct {
|
||||
FixedHeader FixedHeader
|
||||
AllowClients []string // For use with OnMessage event hook.
|
||||
Topics []string
|
||||
ReturnCodes []byte
|
||||
ProtocolName []byte
|
||||
Qoss []byte
|
||||
Payload []byte
|
||||
Username []byte
|
||||
Password []byte
|
||||
WillMessage []byte
|
||||
ClientIdentifier string
|
||||
TopicName string
|
||||
WillTopic string
|
||||
PacketID uint16
|
||||
Keepalive uint16
|
||||
ReturnCode byte
|
||||
ProtocolVersion byte
|
||||
WillQos byte
|
||||
ReservedBit byte
|
||||
CleanSession bool
|
||||
WillFlag bool
|
||||
WillRetain bool
|
||||
UsernameFlag bool
|
||||
PasswordFlag bool
|
||||
SessionPresent bool
|
||||
}
|
||||
|
||||
// ConnectEncode encodes a connect packet.
|
||||
func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
|
||||
protoName := encodeBytes(pk.ProtocolName)
|
||||
protoVersion := pk.ProtocolVersion
|
||||
flag := encodeBool(pk.CleanSession)<<1 | encodeBool(pk.WillFlag)<<2 | pk.WillQos<<3 | encodeBool(pk.WillRetain)<<5 | encodeBool(pk.PasswordFlag)<<6 | encodeBool(pk.UsernameFlag)<<7
|
||||
keepalive := encodeUint16(pk.Keepalive)
|
||||
clientID := encodeString(pk.ClientIdentifier)
|
||||
|
||||
var willTopic, willFlag, usernameFlag, passwordFlag []byte
|
||||
|
||||
// If will flag is set, add topic and message.
|
||||
if pk.WillFlag {
|
||||
willTopic = encodeString(pk.WillTopic)
|
||||
willFlag = encodeBytes(pk.WillMessage)
|
||||
}
|
||||
|
||||
// If username flag is set, add username.
|
||||
if pk.UsernameFlag {
|
||||
usernameFlag = encodeBytes(pk.Username)
|
||||
}
|
||||
|
||||
// If password flag is set, add password.
|
||||
if pk.PasswordFlag {
|
||||
passwordFlag = encodeBytes(pk.Password)
|
||||
}
|
||||
|
||||
// Get a length for the connect header. This is not super pretty, but it works.
|
||||
pk.FixedHeader.Remaining =
|
||||
len(protoName) + 1 + 1 + len(keepalive) + len(clientID) +
|
||||
len(willTopic) + len(willFlag) +
|
||||
len(usernameFlag) + len(passwordFlag)
|
||||
|
||||
pk.FixedHeader.Encode(buf)
|
||||
|
||||
// Eschew magic for readability.
|
||||
buf.Write(protoName)
|
||||
buf.WriteByte(protoVersion)
|
||||
buf.WriteByte(flag)
|
||||
buf.Write(keepalive)
|
||||
buf.Write(clientID)
|
||||
buf.Write(willTopic)
|
||||
buf.Write(willFlag)
|
||||
buf.Write(usernameFlag)
|
||||
buf.Write(passwordFlag)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectDecode decodes a connect packet.
|
||||
func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
// Unpack protocol name and version.
|
||||
pk.ProtocolName, offset, err = decodeBytes(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolName)
|
||||
}
|
||||
|
||||
pk.ProtocolVersion, offset, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolVersion)
|
||||
}
|
||||
// Unpack flags byte.
|
||||
flags, offset, err := decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedFlags)
|
||||
}
|
||||
pk.ReservedBit = 1 & flags
|
||||
pk.CleanSession = 1&(flags>>1) > 0
|
||||
pk.WillFlag = 1&(flags>>2) > 0
|
||||
pk.WillQos = 3 & (flags >> 3) // this one is not a bool
|
||||
pk.WillRetain = 1&(flags>>5) > 0
|
||||
pk.PasswordFlag = 1&(flags>>6) > 0
|
||||
pk.UsernameFlag = 1&(flags>>7) > 0
|
||||
|
||||
// Get keepalive interval.
|
||||
pk.Keepalive, offset, err = decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedKeepalive)
|
||||
}
|
||||
|
||||
// Get client ID.
|
||||
pk.ClientIdentifier, offset, err = decodeString(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedClientID)
|
||||
}
|
||||
|
||||
// Get Last Will and Testament topic and message if applicable.
|
||||
if pk.WillFlag {
|
||||
pk.WillTopic, offset, err = decodeString(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedWillTopic)
|
||||
}
|
||||
|
||||
pk.WillMessage, offset, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedWillMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// Get username and password if applicable.
|
||||
if pk.UsernameFlag {
|
||||
pk.Username, offset, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedUsername)
|
||||
}
|
||||
}
|
||||
|
||||
if pk.PasswordFlag {
|
||||
pk.Password, _, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// ConnectValidate ensures the connect packet is compliant.
|
||||
func (pk *Packet) ConnectValidate() (b byte, err error) {
|
||||
|
||||
// End if protocol name is bad.
|
||||
if bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) != 0 &&
|
||||
bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) != 0 {
|
||||
return CodeConnectProtocolViolation, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if protocol version is bad.
|
||||
if (bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) == 0 && pk.ProtocolVersion != 3) ||
|
||||
(bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) == 0 && pk.ProtocolVersion != 4) {
|
||||
return CodeConnectBadProtocolVersion, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if reserved bit is not 0.
|
||||
if pk.ReservedBit != 0 {
|
||||
return CodeConnectProtocolViolation, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if ClientID is too long.
|
||||
if len(pk.ClientIdentifier) > 65535 {
|
||||
return CodeConnectProtocolViolation, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if password flag is set without a username.
|
||||
if pk.PasswordFlag && !pk.UsernameFlag {
|
||||
return CodeConnectProtocolViolation, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if Username or Password is too long.
|
||||
if len(pk.Username) > 65535 || len(pk.Password) > 65535 {
|
||||
return CodeConnectProtocolViolation, ErrProtocolViolation
|
||||
}
|
||||
|
||||
// End if client id isn't set and clean session is false.
|
||||
if !pk.CleanSession && len(pk.ClientIdentifier) == 0 {
|
||||
return CodeConnectBadClientID, ErrProtocolViolation
|
||||
}
|
||||
|
||||
return Accepted, nil
|
||||
}
|
||||
|
||||
// ConnackEncode encodes a Connack packet.
|
||||
func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.WriteByte(encodeBool(pk.SessionPresent))
|
||||
buf.WriteByte(pk.ReturnCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnackDecode decodes a Connack packet.
|
||||
func (pk *Packet) ConnackDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
pk.SessionPresent, offset, err = decodeByteBool(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
|
||||
}
|
||||
|
||||
pk.ReturnCode, _, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisconnectEncode encodes a Disconnect packet.
|
||||
func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Encode(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PingreqEncode encodes a Pingreq packet.
|
||||
func (pk *Packet) PingreqEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Encode(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PingrespEncode encodes a Pingresp packet.
|
||||
func (pk *Packet) PingrespEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Encode(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubackEncode encodes a Puback packet.
|
||||
func (pk *Packet) PubackEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(encodeUint16(pk.PacketID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubackDecode decodes a Puback packet.
|
||||
func (pk *Packet) PubackDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubcompEncode encodes a Pubcomp packet.
|
||||
func (pk *Packet) PubcompEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(encodeUint16(pk.PacketID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubcompDecode decodes a Pubcomp packet.
|
||||
func (pk *Packet) PubcompDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishEncode encodes a Publish packet.
|
||||
func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
|
||||
topicName := encodeString(pk.TopicName)
|
||||
var packetID []byte
|
||||
|
||||
// Add PacketID if QOS is set.
|
||||
// [MQTT-2.3.1-5] A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0.
|
||||
if pk.FixedHeader.Qos > 0 {
|
||||
|
||||
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.PacketID == 0 {
|
||||
return ErrMissingPacketID
|
||||
}
|
||||
|
||||
packetID = encodeUint16(pk.PacketID)
|
||||
}
|
||||
|
||||
pk.FixedHeader.Remaining = len(topicName) + len(packetID) + len(pk.Payload)
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(topicName)
|
||||
buf.Write(packetID)
|
||||
buf.Write(pk.Payload)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishDecode extracts the data values from the packet.
|
||||
func (pk *Packet) PublishDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
pk.TopicName, offset, err = decodeString(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
|
||||
}
|
||||
|
||||
// If QOS decode Packet ID.
|
||||
if pk.FixedHeader.Qos > 0 {
|
||||
pk.PacketID, offset, err = decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
}
|
||||
|
||||
pk.Payload = buf[offset:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishCopy creates a new instance of Publish packet bearing the
|
||||
// same payload and destination topic, but with an empty header for
|
||||
// inheriting new QoS flags, etc.
|
||||
func (pk *Packet) PublishCopy() Packet {
|
||||
return Packet{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Publish,
|
||||
Retain: pk.FixedHeader.Retain,
|
||||
},
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
}
|
||||
}
|
||||
|
||||
// PublishValidate validates a publish packet.
|
||||
func (pk *Packet) PublishValidate() (byte, error) {
|
||||
|
||||
// @SPEC [MQTT-2.3.1-1]
|
||||
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
|
||||
return Failed, ErrMissingPacketID
|
||||
}
|
||||
|
||||
// @SPEC [MQTT-2.3.1-5]
|
||||
// A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0.
|
||||
if pk.FixedHeader.Qos == 0 && pk.PacketID > 0 {
|
||||
return Failed, ErrSurplusPacketID
|
||||
}
|
||||
|
||||
return Accepted, nil
|
||||
}
|
||||
|
||||
// PubrecEncode encodes a Pubrec packet.
|
||||
func (pk *Packet) PubrecEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(encodeUint16(pk.PacketID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubrecDecode decodes a Pubrec packet.
|
||||
func (pk *Packet) PubrecDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubrelEncode encodes a Pubrel packet.
|
||||
func (pk *Packet) PubrelEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(encodeUint16(pk.PacketID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// PubrelDecode decodes a Pubrel packet.
|
||||
func (pk *Packet) PubrelDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubackEncode encodes a Suback packet.
|
||||
func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
|
||||
packetID := encodeUint16(pk.PacketID)
|
||||
pk.FixedHeader.Remaining = len(packetID) + len(pk.ReturnCodes) // Set length.
|
||||
pk.FixedHeader.Encode(buf)
|
||||
|
||||
buf.Write(packetID) // Encode Packet ID.
|
||||
buf.Write(pk.ReturnCodes) // Encode granted QOS flags.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubackDecode decodes a Suback packet.
|
||||
func (pk *Packet) SubackDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
// Get Packet ID.
|
||||
pk.PacketID, offset, err = decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
// Get Granted QOS flags.
|
||||
pk.ReturnCodes = buf[offset:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubscribeEncode encodes a Subscribe packet.
|
||||
func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
// Add the Packet ID.
|
||||
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.PacketID == 0 {
|
||||
return ErrMissingPacketID
|
||||
}
|
||||
|
||||
packetID := encodeUint16(pk.PacketID)
|
||||
|
||||
// Count topics lengths and associated QOS flags.
|
||||
var topicsLen int
|
||||
for _, topic := range pk.Topics {
|
||||
topicsLen += len(encodeString(topic)) + 1
|
||||
}
|
||||
|
||||
pk.FixedHeader.Remaining = len(packetID) + topicsLen
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(packetID)
|
||||
|
||||
// Add all provided topic names and associated QOS flags.
|
||||
for i, topic := range pk.Topics {
|
||||
buf.Write(encodeString(topic))
|
||||
buf.WriteByte(pk.Qoss[i])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubscribeDecode decodes a Subscribe packet.
|
||||
func (pk *Packet) SubscribeDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
// Get the Packet ID.
|
||||
pk.PacketID, offset, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
// Keep decoding until there's no space left.
|
||||
for offset < len(buf) {
|
||||
|
||||
// Decode Topic Name.
|
||||
var topic string
|
||||
topic, offset, err = decodeString(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
|
||||
}
|
||||
pk.Topics = append(pk.Topics, topic)
|
||||
|
||||
// Decode QOS flag.
|
||||
var qos byte
|
||||
qos, offset, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedQoS)
|
||||
}
|
||||
|
||||
// Ensure QoS byte is within range.
|
||||
if !(qos >= 0 && qos <= 2) {
|
||||
//if !validateQoS(qos) {
|
||||
return ErrMalformedQoS
|
||||
}
|
||||
|
||||
pk.Qoss = append(pk.Qoss, qos)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubscribeValidate ensures the packet is compliant.
|
||||
func (pk *Packet) SubscribeValidate() (byte, error) {
|
||||
// @SPEC [MQTT-2.3.1-1].
|
||||
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
|
||||
return Failed, ErrMissingPacketID
|
||||
}
|
||||
|
||||
return Accepted, nil
|
||||
}
|
||||
|
||||
// UnsubackEncode encodes an Unsuback packet.
|
||||
func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
|
||||
pk.FixedHeader.Remaining = 2
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(encodeUint16(pk.PacketID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnsubackDecode decodes an Unsuback packet.
|
||||
func (pk *Packet) UnsubackDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnsubscribeEncode encodes an Unsubscribe packet.
|
||||
func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
// Add the Packet ID.
|
||||
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.PacketID == 0 {
|
||||
return ErrMissingPacketID
|
||||
}
|
||||
|
||||
packetID := encodeUint16(pk.PacketID)
|
||||
|
||||
// Count topics lengths.
|
||||
var topicsLen int
|
||||
for _, topic := range pk.Topics {
|
||||
topicsLen += len(encodeString(topic))
|
||||
}
|
||||
|
||||
pk.FixedHeader.Remaining = len(packetID) + topicsLen
|
||||
pk.FixedHeader.Encode(buf)
|
||||
buf.Write(packetID)
|
||||
|
||||
// Add all provided topic names.
|
||||
for _, topic := range pk.Topics {
|
||||
buf.Write(encodeString(topic))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnsubscribeDecode decodes an Unsubscribe packet.
|
||||
func (pk *Packet) UnsubscribeDecode(buf []byte) error {
|
||||
var offset int
|
||||
var err error
|
||||
|
||||
// Get the Packet ID.
|
||||
pk.PacketID, offset, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
// Keep decoding until there's no space left.
|
||||
for offset < len(buf) {
|
||||
var t string
|
||||
t, offset, err = decodeString(buf, offset) // Decode Topic Name.
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
|
||||
}
|
||||
|
||||
if len(t) > 0 {
|
||||
pk.Topics = append(pk.Topics, t)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// UnsubscribeValidate validates an Unsubscribe packet.
|
||||
func (pk *Packet) UnsubscribeValidate() (byte, error) {
|
||||
// @SPEC [MQTT-2.3.1-1].
|
||||
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
|
||||
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
|
||||
return Failed, ErrMissingPacketID
|
||||
}
|
||||
|
||||
return Accepted, nil
|
||||
}
|
||||
|
||||
// FormatID returns the PacketID field as a decimal integer.
|
||||
func (pk *Packet) FormatID() string {
|
||||
return strconv.FormatUint(uint64(pk.PacketID), 10)
|
||||
}
|
1416
server/internal/packets/packets_tables_test.go
Normal file
1416
server/internal/packets/packets_tables_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1091
server/internal/packets/packets_test.go
Normal file
1091
server/internal/packets/packets_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user