565 lines
14 KiB
Go
565 lines
14 KiB
Go
package clients
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/rs/xid"
|
|
|
|
"github.com/mochi-co/mqtt/server/events"
|
|
"github.com/mochi-co/mqtt/server/internal/circ"
|
|
"github.com/mochi-co/mqtt/server/internal/packets"
|
|
"github.com/mochi-co/mqtt/server/internal/topics"
|
|
"github.com/mochi-co/mqtt/server/listeners/auth"
|
|
"github.com/mochi-co/mqtt/server/system"
|
|
)
|
|
|
|
var (
|
|
// defaultKeepalive is the default connection keepalive value in seconds.
|
|
defaultKeepalive uint16 = 10
|
|
|
|
// ErrConnectionClosed is returned when operating on a closed
|
|
// connection and/or when no error cause has been given.
|
|
ErrConnectionClosed = errors.New("connection not open")
|
|
)
|
|
|
|
// Clients contains a map of the clients known by the broker.
|
|
type Clients struct {
|
|
sync.RWMutex
|
|
internal map[string]*Client // clients known by the broker, keyed on client id.
|
|
}
|
|
|
|
// New returns an instance of Clients.
|
|
func New() *Clients {
|
|
return &Clients{
|
|
internal: make(map[string]*Client),
|
|
}
|
|
}
|
|
|
|
// Add adds a new client to the clients map, keyed on client id.
|
|
func (cl *Clients) Add(val *Client) {
|
|
cl.Lock()
|
|
cl.internal[val.ID] = val
|
|
cl.Unlock()
|
|
}
|
|
|
|
// Get returns the value of a client if it exists.
|
|
func (cl *Clients) Get(id string) (*Client, bool) {
|
|
cl.RLock()
|
|
val, ok := cl.internal[id]
|
|
cl.RUnlock()
|
|
return val, ok
|
|
}
|
|
|
|
// Len returns the length of the clients map.
|
|
func (cl *Clients) Len() int {
|
|
cl.RLock()
|
|
val := len(cl.internal)
|
|
cl.RUnlock()
|
|
return val
|
|
}
|
|
|
|
// Delete removes a client from the internal map.
|
|
func (cl *Clients) Delete(id string) {
|
|
cl.Lock()
|
|
delete(cl.internal, id)
|
|
cl.Unlock()
|
|
}
|
|
|
|
// GetByListener returns clients matching a listener id.
|
|
func (cl *Clients) GetByListener(id string) []*Client {
|
|
clients := make([]*Client, 0, cl.Len())
|
|
cl.RLock()
|
|
for _, v := range cl.internal {
|
|
if v.Listener == id && atomic.LoadUint32(&v.State.Done) == 0 {
|
|
clients = append(clients, v)
|
|
}
|
|
}
|
|
cl.RUnlock()
|
|
return clients
|
|
}
|
|
|
|
// Client contains information about a client known by the broker.
|
|
type Client struct {
|
|
State State // the operational state of the client.
|
|
LWT LWT // the last will and testament for the client.
|
|
Inflight *Inflight // a map of in-flight qos messages.
|
|
sync.RWMutex // mutex
|
|
Username []byte // the username the client authenticated with.
|
|
AC auth.Controller // an auth controller inherited from the listener.
|
|
Listener string // the id of the listener the client is connected to.
|
|
ID string // the client id.
|
|
conn net.Conn // the net.Conn used to establish the connection.
|
|
R *circ.Reader // a reader for reading incoming bytes.
|
|
W *circ.Writer // a writer for writing outgoing bytes.
|
|
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
|
|
systemInfo *system.Info // pointers to server system info.
|
|
packetID uint32 // the current highest packetID.
|
|
keepalive uint16 // the number of seconds the connection can wait.
|
|
CleanSession bool // indicates if the client expects a clean-session.
|
|
}
|
|
|
|
// State tracks the state of the client.
|
|
type State struct {
|
|
started *sync.WaitGroup // tracks the goroutines which have been started.
|
|
endedW *sync.WaitGroup // tracks when the writer has ended.
|
|
endedR *sync.WaitGroup // tracks when the reader has ended.
|
|
Done uint32 // atomic counter which indicates that the client has closed.
|
|
endOnce sync.Once // only end once.
|
|
stopCause atomic.Value // reason for stopping.
|
|
}
|
|
|
|
// NewClient returns a new instance of Client.
|
|
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
|
|
cl := &Client{
|
|
conn: c,
|
|
R: r,
|
|
W: w,
|
|
systemInfo: s,
|
|
keepalive: defaultKeepalive,
|
|
Inflight: &Inflight{
|
|
internal: make(map[uint16]InflightMessage),
|
|
},
|
|
Subscriptions: make(map[string]byte),
|
|
State: State{
|
|
started: new(sync.WaitGroup),
|
|
endedW: new(sync.WaitGroup),
|
|
endedR: new(sync.WaitGroup),
|
|
},
|
|
}
|
|
|
|
cl.refreshDeadline(cl.keepalive)
|
|
|
|
return cl
|
|
}
|
|
|
|
// NewClientStub returns an instance of Client with basic initializations. This
|
|
// method is typically called by the persistence restoration system.
|
|
func NewClientStub(s *system.Info) *Client {
|
|
return &Client{
|
|
Inflight: &Inflight{
|
|
internal: make(map[uint16]InflightMessage),
|
|
},
|
|
Subscriptions: make(map[string]byte),
|
|
State: State{
|
|
Done: 1,
|
|
},
|
|
}
|
|
}
|
|
|
|
// Identify sets the identification values of a client instance.
|
|
func (cl *Client) Identify(lid string, pk packets.Packet, ac auth.Controller) {
|
|
cl.Listener = lid
|
|
cl.AC = ac
|
|
|
|
cl.ID = pk.ClientIdentifier
|
|
if cl.ID == "" {
|
|
cl.ID = xid.New().String()
|
|
}
|
|
|
|
cl.R.ID = cl.ID + " READER"
|
|
cl.W.ID = cl.ID + " WRITER"
|
|
|
|
cl.Username = pk.Username
|
|
cl.CleanSession = pk.CleanSession
|
|
cl.keepalive = pk.Keepalive
|
|
|
|
if pk.WillFlag {
|
|
cl.LWT = LWT{
|
|
Topic: pk.WillTopic,
|
|
Message: pk.WillMessage,
|
|
Qos: pk.WillQos,
|
|
Retain: pk.WillRetain,
|
|
}
|
|
}
|
|
|
|
cl.refreshDeadline(cl.keepalive)
|
|
}
|
|
|
|
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
|
func (cl *Client) refreshDeadline(keepalive uint16) {
|
|
if cl.conn != nil {
|
|
var expiry time.Time // Nil time can be used to disable deadline if keepalive = 0
|
|
if keepalive > 0 {
|
|
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second)
|
|
}
|
|
_ = cl.conn.SetDeadline(expiry)
|
|
}
|
|
}
|
|
|
|
// Info returns an event-version of a client, containing minimal information.
|
|
func (cl *Client) Info() events.Client {
|
|
addr := "unknown"
|
|
if cl.conn != nil && cl.conn.RemoteAddr() != nil {
|
|
addr = cl.conn.RemoteAddr().String()
|
|
}
|
|
return events.Client{
|
|
ID: cl.ID,
|
|
Remote: addr,
|
|
Username: cl.Username,
|
|
CleanSession: cl.CleanSession,
|
|
Listener: cl.Listener,
|
|
}
|
|
}
|
|
|
|
// NextPacketID returns the next packet id for a client, looping back to 0
|
|
// if the maximum ID has been reached.
|
|
func (cl *Client) NextPacketID() uint32 {
|
|
i := atomic.LoadUint32(&cl.packetID)
|
|
if i == uint32(65535) || i == uint32(0) {
|
|
atomic.StoreUint32(&cl.packetID, 1)
|
|
return 1
|
|
}
|
|
|
|
return atomic.AddUint32(&cl.packetID, 1)
|
|
}
|
|
|
|
// NoteSubscription makes a note of a subscription for the client.
|
|
func (cl *Client) NoteSubscription(filter string, qos byte) {
|
|
cl.Lock()
|
|
cl.Subscriptions[filter] = qos
|
|
cl.Unlock()
|
|
}
|
|
|
|
// ForgetSubscription forgests a subscription note for the client.
|
|
func (cl *Client) ForgetSubscription(filter string) {
|
|
cl.Lock()
|
|
delete(cl.Subscriptions, filter)
|
|
cl.Unlock()
|
|
}
|
|
|
|
// Start begins the client goroutines reading and writing packets.
|
|
func (cl *Client) Start() {
|
|
cl.State.started.Add(2)
|
|
cl.State.endedW.Add(1)
|
|
cl.State.endedR.Add(1)
|
|
|
|
go func() {
|
|
cl.State.started.Done()
|
|
_, err := cl.W.WriteTo(cl.conn)
|
|
if err != nil {
|
|
err = fmt.Errorf("writer: %w", err)
|
|
}
|
|
cl.State.endedW.Done()
|
|
cl.Stop(err)
|
|
}()
|
|
|
|
go func() {
|
|
cl.State.started.Done()
|
|
_, err := cl.R.ReadFrom(cl.conn)
|
|
if err != nil {
|
|
err = fmt.Errorf("reader: %w", err)
|
|
}
|
|
cl.State.endedR.Done()
|
|
cl.Stop(err)
|
|
}()
|
|
|
|
cl.State.started.Wait()
|
|
}
|
|
|
|
// ClearBuffers sets the read/write buffers to nil so they can be
|
|
// deallocated automatically when no longer in use.
|
|
func (cl *Client) ClearBuffers() {
|
|
cl.R = nil
|
|
cl.W = nil
|
|
}
|
|
|
|
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
|
// A cause error may be passed to identfy the reason for stopping.
|
|
func (cl *Client) Stop(err error) {
|
|
if atomic.LoadUint32(&cl.State.Done) == 1 {
|
|
return
|
|
}
|
|
|
|
cl.State.endOnce.Do(func() {
|
|
cl.R.Stop()
|
|
cl.W.Stop()
|
|
|
|
cl.State.endedW.Wait()
|
|
|
|
_ = cl.conn.Close() // omit close error
|
|
|
|
cl.State.endedR.Wait()
|
|
atomic.StoreUint32(&cl.State.Done, 1)
|
|
|
|
if err == nil {
|
|
err = ErrConnectionClosed
|
|
}
|
|
cl.State.stopCause.Store(err)
|
|
})
|
|
}
|
|
|
|
// StopCause returns the reason the client connection was stopped, if any.
|
|
func (cl *Client) StopCause() error {
|
|
if cl.State.stopCause.Load() == nil {
|
|
return nil
|
|
}
|
|
return cl.State.stopCause.Load().(error)
|
|
}
|
|
|
|
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
|
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
|
p, err := cl.R.Read(1)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = fh.Decode(p[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// The remaining length value can be up to 5 bytes. Read through each byte
|
|
// looking for continue values, and if found increase the read. Otherwise
|
|
// decode the bytes that were legit.
|
|
buf := make([]byte, 0, 6)
|
|
i := 1
|
|
n := 2
|
|
for ; n < 6; n++ {
|
|
p, err = cl.R.Read(n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
buf = append(buf, p[i])
|
|
|
|
// If it's not a continuation flag, end here.
|
|
if p[i] < 128 {
|
|
break
|
|
}
|
|
|
|
// If i has reached 4 without a length terminator, return a protocol violation.
|
|
i++
|
|
if i == 4 {
|
|
return packets.ErrOversizedLengthIndicator
|
|
}
|
|
}
|
|
|
|
// Calculate and store the remaining length of the packet payload.
|
|
rem, _ := binary.Uvarint(buf)
|
|
fh.Remaining = int(rem)
|
|
|
|
// Having successfully read n bytes, commit the tail forward.
|
|
cl.R.CommitTail(n)
|
|
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
|
|
|
|
return nil
|
|
}
|
|
|
|
// Read loops forever reading new packets from a client connection until
|
|
// an error is encountered (or the connection is closed).
|
|
func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error {
|
|
for {
|
|
if atomic.LoadUint32(&cl.State.Done) == 1 && cl.R.CapDelta() == 0 {
|
|
return nil
|
|
}
|
|
|
|
cl.refreshDeadline(cl.keepalive)
|
|
fh := new(packets.FixedHeader)
|
|
err := cl.ReadFixedHeader(fh)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pk, err := cl.ReadPacket(fh)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = packetHandler(cl, pk) // Process inbound packet.
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// ReadPacket reads the remaining buffer into an MQTT packet.
|
|
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
|
|
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
|
|
|
|
pk.FixedHeader = *fh
|
|
if pk.FixedHeader.Remaining == 0 {
|
|
return
|
|
}
|
|
|
|
p, err := cl.R.Read(pk.FixedHeader.Remaining)
|
|
if err != nil {
|
|
return pk, err
|
|
}
|
|
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
|
|
|
|
// Decode the remaining packet values using a fresh copy of the bytes,
|
|
// otherwise the next packet will change the data of this one.
|
|
px := append([]byte{}, p[:]...)
|
|
|
|
switch pk.FixedHeader.Type {
|
|
case packets.Connect:
|
|
err = pk.ConnectDecode(px)
|
|
case packets.Connack:
|
|
err = pk.ConnackDecode(px)
|
|
case packets.Publish:
|
|
err = pk.PublishDecode(px)
|
|
if err == nil {
|
|
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
|
|
}
|
|
case packets.Puback:
|
|
err = pk.PubackDecode(px)
|
|
case packets.Pubrec:
|
|
err = pk.PubrecDecode(px)
|
|
case packets.Pubrel:
|
|
err = pk.PubrelDecode(px)
|
|
case packets.Pubcomp:
|
|
err = pk.PubcompDecode(px)
|
|
case packets.Subscribe:
|
|
err = pk.SubscribeDecode(px)
|
|
case packets.Suback:
|
|
err = pk.SubackDecode(px)
|
|
case packets.Unsubscribe:
|
|
err = pk.UnsubscribeDecode(px)
|
|
case packets.Unsuback:
|
|
err = pk.UnsubackDecode(px)
|
|
case packets.Pingreq:
|
|
case packets.Pingresp:
|
|
case packets.Disconnect:
|
|
default:
|
|
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
|
|
}
|
|
|
|
cl.R.CommitTail(pk.FixedHeader.Remaining)
|
|
|
|
return
|
|
}
|
|
|
|
// WritePacket encodes and writes a packet to the client.
|
|
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
|
if atomic.LoadUint32(&cl.State.Done) == 1 {
|
|
return 0, ErrConnectionClosed
|
|
}
|
|
|
|
cl.W.Mu.Lock()
|
|
defer cl.W.Mu.Unlock()
|
|
|
|
buf := new(bytes.Buffer)
|
|
switch pk.FixedHeader.Type {
|
|
case packets.Connect:
|
|
err = pk.ConnectEncode(buf)
|
|
case packets.Connack:
|
|
err = pk.ConnackEncode(buf)
|
|
case packets.Publish:
|
|
err = pk.PublishEncode(buf)
|
|
if err == nil {
|
|
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
|
|
}
|
|
case packets.Puback:
|
|
err = pk.PubackEncode(buf)
|
|
case packets.Pubrec:
|
|
err = pk.PubrecEncode(buf)
|
|
case packets.Pubrel:
|
|
err = pk.PubrelEncode(buf)
|
|
case packets.Pubcomp:
|
|
err = pk.PubcompEncode(buf)
|
|
case packets.Subscribe:
|
|
err = pk.SubscribeEncode(buf)
|
|
case packets.Suback:
|
|
err = pk.SubackEncode(buf)
|
|
case packets.Unsubscribe:
|
|
err = pk.UnsubscribeEncode(buf)
|
|
case packets.Unsuback:
|
|
err = pk.UnsubackEncode(buf)
|
|
case packets.Pingreq:
|
|
err = pk.PingreqEncode(buf)
|
|
case packets.Pingresp:
|
|
err = pk.PingrespEncode(buf)
|
|
case packets.Disconnect:
|
|
err = pk.DisconnectEncode(buf)
|
|
default:
|
|
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Write the packet bytes to the client byte buffer.
|
|
n, err = cl.W.Write(buf.Bytes())
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
|
|
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
|
|
|
|
cl.refreshDeadline(cl.keepalive)
|
|
|
|
return
|
|
}
|
|
|
|
// LWT contains the last will and testament details for a client connection.
|
|
type LWT struct {
|
|
Message []byte // the message that shall be sent when the client disconnects.
|
|
Topic string // the topic the will message shall be sent to.
|
|
Qos byte // the quality of service desired.
|
|
Retain bool // indicates whether the will message should be retained
|
|
}
|
|
|
|
// InflightMessage contains data about a packet which is currently in-flight.
|
|
type InflightMessage struct {
|
|
Packet packets.Packet // the packet currently in-flight.
|
|
Sent int64 // the last time the message was sent (for retries) in unixtime.
|
|
Resends int // the number of times the message was attempted to be sent.
|
|
}
|
|
|
|
// Inflight is a map of InflightMessage keyed on packet id.
|
|
type Inflight struct {
|
|
sync.RWMutex
|
|
internal map[uint16]InflightMessage // internal contains the inflight messages.
|
|
}
|
|
|
|
// Set stores the packet of an Inflight message, keyed on message id. Returns
|
|
// true if the inflight message was new.
|
|
func (i *Inflight) Set(key uint16, in InflightMessage) bool {
|
|
i.Lock()
|
|
_, ok := i.internal[key]
|
|
i.internal[key] = in
|
|
i.Unlock()
|
|
return !ok
|
|
}
|
|
|
|
// Get returns the value of an in-flight message if it exists.
|
|
func (i *Inflight) Get(key uint16) (InflightMessage, bool) {
|
|
i.RLock()
|
|
val, ok := i.internal[key]
|
|
i.RUnlock()
|
|
return val, ok
|
|
}
|
|
|
|
// Len returns the size of the in-flight messages map.
|
|
func (i *Inflight) Len() int {
|
|
i.RLock()
|
|
v := len(i.internal)
|
|
i.RUnlock()
|
|
return v
|
|
}
|
|
|
|
// GetAll returns all the in-flight messages.
|
|
func (i *Inflight) GetAll() map[uint16]InflightMessage {
|
|
i.RLock()
|
|
defer i.RUnlock()
|
|
return i.internal
|
|
}
|
|
|
|
// Delete removes an in-flight message from the map. Returns true if the
|
|
// message existed.
|
|
func (i *Inflight) Delete(key uint16) bool {
|
|
i.Lock()
|
|
_, ok := i.internal[key]
|
|
delete(i.internal, key)
|
|
i.Unlock()
|
|
return ok
|
|
}
|