This commit is contained in:
Кобелев Андрей Андреевич 2022-08-15 23:06:20 +03:00
commit 4ea1a3802b
532 changed files with 211522 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
cmd/mqtt
.DS_Store
server/persistence/bolt/testbolt.db
.vscode
cert

31
Dockerfile Normal file
View File

@ -0,0 +1,31 @@
FROM golang:1.18.0-alpine3.15 AS builder
RUN apk update
RUN apk add git
WORKDIR /app
COPY go.mod ./
COPY go.sum ./
RUN go mod download
COPY . ./
RUN go build -o /app/mochi ./cmd
FROM alpine
WORKDIR /
COPY --from=builder /app/mochi .
# tcp
EXPOSE 1883
# websockets
EXPOSE 1882
# dashboard
EXPOSE 8080
ENTRYPOINT [ "/mochi" ]

22
LICENSE.md Normal file
View File

@ -0,0 +1,22 @@
The MIT License (MIT)
Copyright (c) 2019 Jonathan Blake (mochi)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

323
README.md Normal file
View File

@ -0,0 +1,323 @@
<p align="center">
![build status](https://github.com/mochi-co/mqtt/actions/workflows/build.yml/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-co/mqtt)](https://goreportcard.com/report/github.com/mochi-co/mqtt)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-co/mqtt.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mochi-co/mqtt/issues)
</p>
# Mochi MQTT
### A High-performance MQTT server in Go (v3.0 | v3.1.1)
Mochi MQTT is an embeddable high-performance MQTT broker server written in Go, and compliant with the MQTT v3.0 and v3.1.1 specification for the development of IoT and smarthome projects. The server can be used either as a standalone binary or embedded as a library in your own projects. Mochi MQTT message throughput is comparable with everyone's favourites such as Mosquitto, Mosca, and VerneMQ.
> #### 📦 💬 See Github Discussions for discussions about releases
> Ongoing discussion about current and future releases can be found at https://github.com/mochi-co/mqtt/discussions
#### What is MQTT?
MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks. [Learn more](https://mqtt.org/faq)
#### Mochi MQTT Features
- Paho MQTT 3.0 / 3.1.1 compatible.
- Full MQTT Feature-set (QoS, Retained, $SYS)
- Trie-based Subscription model.
- Ring Buffer packet codec.
- TCP, Websocket, (including SSL/TLS) and Dashboard listeners.
- Interfaces for Client Authentication and Topic access control.
- Bolt persistence and storage interfaces (see examples folder).
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
- Basic Event Hooks (`OnMessage`, `onSubscribe`, `onUnsubscribe`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
- ARM32 Compatible.
#### Roadmap
- Please open an issue to request new features or event hooks.
- MQTT v5 compatibility?
#### Using the Broker from Go
Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the `main.go` entrypoint in the `cmd` folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners. A docker image is coming soon.
```
cd cmd
go build -o mqtt && ./mqtt
```
#### Using Docker
A simple Dockerfile is provided for running the `cmd/main.go` Websocket, TCP, and Stats server:
```sh
docker build -t mochi:latest .
docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
```
#### Package Quick Start
``` go
import (
mqtt "github.com/mochi-co/mqtt/server"
)
func main() {
// Create the new MQTT Server.
server := mqtt.NewServer(nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", ":1883")
// Add the listener to the server with default options (nil).
err := server.AddListener(tcp, nil)
if err != nil {
log.Fatal(err)
}
// Start the broker. Serve() is blocking - see examples folder
// for usage ideas.
err = server.Serve()
if err != nil {
log.Fatal(err)
}
}
```
Examples of running the broker with various configurations can be found in the `examples` folder.
#### Network Listeners
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
- `listeners.NewTCP(id, address string)` - A TCP Listener, taking a unique ID and a network address to bind.
- `listeners.NewWebsocket(id, address string)` A Websocket Listener
- `listeners.NewHTTPStats()` An HTTP $SYS info dashboard
##### Configuring Network Listeners
When a listener is added to the server using `server.AddListener`, a `*listeners.Config` may be passed as the second argument.
##### Authentication and ACL
Authentication and ACL may be configured on a per-listener basis by providing an Auth Controller to the listener configuration. Custom Auth Controllers should satisfy the `auth.Controller` interface found in `listeners/auth`. Two default controllers are provided, `auth.Allow`, which allows all traffic, and `auth.Disallow`, which denies all traffic.
```go
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
})
```
> If no auth controller is provided in the listener configuration, the server will default to _Disallowing_ all traffic to prevent unintentional security issues.
##### SSL
SSL may be configured on both the TCP and Websocket listeners by providing a public-private PEM key pair to the listener configuration as `[]byte` slices.
```go
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: publicCertificate,
PrivateKey: privateKey,
},
})
```
> Note the mandatory inclusion of the Auth Controller!
#### Event Hooks
Some basic Event Hooks have been added, allowing you to call your own functions when certain events occur. The execution of the functions are blocking - if necessary, please handle goroutines within the embedding service.
Working examples can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
##### OnConnect
`server.Events.OnConnect` is called when a client successfully connects to the broker. The method receives the connect packet and the id and connection type for the client who connected.
```go
import "github.com/mochi-co/mqtt/server/events"
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
```
##### OnDisconnect
`server.Events.OnDisconnect` is called when a client disconnects to the broker. If the client disconnected abnormally, the reason is indicated in the `err` error parameter.
```go
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
}
```
##### OnSubscribe
`server.Events.OnSubscribe` is called when a client subscribes to a new topic filter.
```go
server.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
fmt.Printf("<< OnSubscribe client subscribed %s: %s %v\n", cl.ID, filter, qos)
}
```
##### OnUnsubscribe
`server.Events.OnUnsubscribe` is called when a client unsubscribes from a topic filter.
```go
server.Events.OnUnsubscribe = func(filter string, cl events.Client) {
fmt.Printf("<< OnUnsubscribe client unsubscribed %s: %s\n", cl.ID, filter)
}
```
##### OnMessage
`server.Events.OnMessage` is called when a Publish packet (message) is received. The method receives the published message and information about the client who published it.
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
##### OnProcessMessage
`server.Events.OnProcessMessage` is called before a publish packet (message) is processed. Specifically, the method callback is triggered after topic and ACL validation has occurred, but before the headers and payload are processed. You can use this if you want to programmatically change the data of the packet, such as setting it to retain, or altering the QoS flag.
If an error is returned, the packet will not be modified. and the existing packet will be used. If this is an unwanted outcome, the `mqtt.ErrRejectPacket` error can be returned from the callback, and the packet will be dropped/ignored, any further processing is abandoned.
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
```go
import "github.com/mochi-co/mqtt/server/events"
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
if string(pk.Payload) == "hello" {
pkx = pk
pkx.Payload = []byte("hello world")
return pkx, nil
}
return pk, nil
}
```
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
##### OnError
`server.Events.OnError` is called when an error is encountered on the server, particularly within the use of a client connection status.
##### OnStorage
`server.Events.OnStorage` is like `onError`, but receives the output of persistent storage methods.
#### Server Options
A few options can be passed to the `mqtt.NewServer(opts *Options)` function in order to override the default broker configuration. Currently these options are:
- BufferSize (default 1024 * 256 bytes) - The default value is sufficient for most messaging sizes, but if you are sending many kilobytes of data (such as images), you should increase this to a value of (n*s) where is the typical size of your message and n is the number of messages you may have backlogged for a client at any given time.
- BufferBlockSize (default 1024 * 8) - The minimum size in which R/W data will be allocated. If you are expecting only tiny or large payloads, you can alter this accordingly.
Any options which is not set or is `0` will use default values.
```go
opts := &mqtt.Options{
BufferSize: 512 * 1024,
BufferBlockSize: 16 * 1024,
}
s := mqtt.NewServer(opts)
```
> See `examples/tcp/main.go` for an example implementation.
#### Direct Publishing
When the broker is being embedded in a larger codebase, it can be useful to be able to publish messages directly to clients without having to implement a loopback TCP connection with an MQTT client. The `Publish` method allows you to inject publish messages directly into a queue to be delivered to any clients with matching topic filters. The `Retain` flag is supported.
```go
// func (s *Server) Publish(topic string, payload []byte, retain bool) error
err := s.Publish("a/b/c", []byte("hello"), false)
if err != nil {
log.Fatal(err)
}
```
A working example can be found in the `examples/events` folder.
#### Data Persistence
Mochi MQTT provides a `persistence.Store` interface for developing and attaching persistent stores to the broker. The default persistence mechanism packaged with the broker is backed by [Bolt](https://github.com/etcd-io/bbolt) and can be enabled by assigning a `*bolt.Store` to the server.
```go
// import "github.com/mochi-co/mqtt/server/persistence/bolt"
err = server.AddStore(bolt.New("mochi.db", nil))
if err != nil {
log.Fatal(err)
}
```
> Persistence is on-demand (not flushed) and will potentially reduce throughput when compared to the standard in-memory store. Only use it if you need to maintain state through restarts.
#### Paho Interoperability Test
You can check the broker against the [Paho Interoperability Test](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) by starting the broker using `examples/paho/main.go`, and then running the test with `python3 client_test.py` from the _interoperability_ folder.
#### Performance at v1.0.0
Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a 13-inch, Early 2015 Macbook Pro (2.7 GHz Intel Core i5). Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better. SEND = Publish throughput, RECV = Subscribe throughput.
> As usual, any performance benchmarks should be taken with a pinch of salt, but are shown to demonstrate typical throughput compared to the other leading MQTT brokers.
**Single Client, 10,000 messages**
_With only 1 client, there is no variation in throughput so the benchmark is reports the same number for high, low, and median._
![1 Client, 10,000 Messages](assets/benchmarkchart_1_10000.png "1 Client, 10,000 Messages")
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=1 -num-messages=10000`
| | Mochi | Mosquitto | EMQX | VerneMQ | Mosca |
| :----------- | --------: | ----------: | -------: | --------: | --------:
| SEND Max | 36505 | 30597 | 27202 | 32782 | 30125 |
| SEND Min | 36505 | 30597 | 27202 | 32782 | 30125 |
| SEND Median | 36505 | 30597 | 27202 |32782 | 30125 |
| RECV Max | 152221 | 59130 | 7879 | 17551 | 9145 |
| RECV Min | 152221 | 59130 | 7879 | 17551 | 9145 |
| RECV Median | 152221 | 59130 | 7879 | 17551 | 9145 |
**10 Clients, 1,000 Messages**
![10 Clients, 1,000 Messages](assets/benchmarkchart_10_1000.png "10 Clients, 1,000 Messages")
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=1000`
| | Mochi | Mosquitto | EMQX | VerneMQ | Mosca |
| :----------- | --------: | ----------: | -------: | --------: | --------:
| SEND Max | 37193 | 15775 | 17455 | 34138 | 36575 |
| SEND Min | 6529 | 6446 | 7714 | 8583 | 7383 |
| SEND Median | 15127 | 7813 | 10305 | 9887 | 8169 |
| RECV Max | 33535 | 3710 | 3022 | 4534 | 9411 |
| RECV Min | 7484 | 2661 | 1689 | 2021 | 2275 |
| RECV Median | 11427 | 3142 | 1831 | 2468 | 4692 |
**10 Clients, 10,000 Messages**
![10 Clients, 10000 Messages](assets/benchmarkchart_10_10000.png "10 Clients, 10000 Messages")
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
| | Mochi | Mosquitto | EMQX | VerneMQ | Mosca |
| :----------- | --------: | ----------: | -------: | --------: | --------:
| SEND Max | 13153 | 13270 | 12229 | 13025 | 38446 |
| SEND Min | 8728 | 8513 | 8193 | 6483 | 3889 |
| SEND Median | 9045 | 9532 | 9252 | 8031 | 9210 |
| RECV Max | 20774 | 5052 | 2093 | 2071 | 43008 |
| RECV Min | 10718 |3995 | 1531 | 1673 | 18764 |
| RECV Median | 16339 | 4607 | 1620 | 1907 | 33524 |
**500 Clients, 100 Messages**
![500 Clients, 100 Messages](assets/benchmarkchart_500_100.png "500 Clients, 100 Messages")
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=500 -num-messages=100`
| | Mochi | Mosquitto | EMQX | VerneMQ | Mosca |
| :----------- | --------: | ----------: | -------: | --------: | --------:
| SEND Max | 70688 | 72686 | 71392 | 75336 | 73192 |
| SEND Min | 1021 | 2577 | 1603 | 8417 | 2344 |
| SEND Median | 49871 | 33076 | 33637 | 35200 | 31312 |
| RECV Max | 116163 | 4215 | 3427 | 5484 | 10100 |
| RECV Min | 1044 | 156 | 56 | 83 | 169 |
| RECV Median | 24398 | 208 | 94 | 413 | 474 |
## Contributions
Contributions and feedback are both welcomed and encouraged! Open an [issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request.

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

79
cmd/main.go Normal file
View File

@ -0,0 +1,79 @@
package main
import (
"crypto/tls"
"flag"
"fmt"
"log"
"os"
"os/signal"
"syscall"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/listeners/auth"
)
func main() {
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
wsAddr := flag.String("ws", ":1882", "network address for Websocket listener")
infoAddr := flag.String("info", ":8080", "network address for web info dashboard listener")
certFile := flag.String("cert", "cert/fullchain.pem", "path the body of a public certificate")
keyFile := flag.String("key", "cert/privkey.pem", "path to body of a private key.")
flag.Parse()
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println("Mochi MQTT Broker initializing...")
fmt.Println("TCP", *tcpAddr)
fmt.Println("Websocket", *wsAddr)
fmt.Println("Dashboard", *infoAddr)
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", *tcpAddr)
cer, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Println(err)
return
}
config := &tls.Config{Certificates: []tls.Certificate{cer}}
if err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLSConfig: config,
}); err != nil {
log.Fatal(err)
}
ws := listeners.NewWebsocket("ws1", *wsAddr)
if err = server.AddListener(ws, &listeners.Config{
Auth: new(auth.Allow),
TLSConfig: config,
}); err != nil {
log.Fatal(err)
}
stats := listeners.NewHTTPStats("stats", *infoAddr)
if err = server.AddListener(stats, nil); err != nil {
log.Fatal(err)
}
go server.Serve()
fmt.Println("Started!")
<-done
fmt.Println("Caught Signal")
server.Close()
fmt.Println("Finished")
}

105
examples/auth/main.go Normal file
View File

@ -0,0 +1,105 @@
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: &Auth{
Users: map[string]string{
"peach": "password1",
"melon": "password2",
"apple": "password3",
},
AllowedTopics: map[string][]string{
// Melon user only has access to melon topics.
// If you were implementing this in the real world, you might ensure
// that any topic prefixed with "melon" is allowed (see ACL func below).
"melon": {"melon/info", "melon/events"},
},
},
})
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
}
// Auth is an example auth provider for the server. In the real world
// you are more likely to replace these fields with database/cache lookups
// to check against an auth list. As the Auth Controller is an interface, it can
// be built however you want, as long as it fulfils the interface signature.
type Auth struct {
Users map[string]string // A map of usernames (key) with passwords (value).
AllowedTopics map[string][]string // A map of usernames and topics
}
// Authenticate returns true if a username and password are acceptable.
func (a *Auth) Authenticate(user, password []byte) bool {
// If the user exists in the auth users map, and the password is correct,
// then they can connect to the server. In the real world, this could be a database
// or cached users lookup.
if pass, ok := a.Users[string(user)]; ok && pass == string(password) {
return true
}
return false
}
// ACL returns true if a user has access permissions to read or write on a topic.
func (a *Auth) ACL(user []byte, topic string, write bool) bool {
// An example ACL - if the user has an entry in the auth allow list, then they are
// subject to ACL restrictions. Only let them use a topic if it's available for their
// user.
if topics, ok := a.AllowedTopics[string(user)]; ok {
for _, t := range topics {
// In the real world you might allow all topics prefixed with a user's username,
// or similar multi-topic filters.
if t == topic {
return true
}
}
return false
}
// Otherwise, allow all topics.
return true
}

View File

@ -0,0 +1,49 @@
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.NewServer(nil)
stats := listeners.NewHTTPStats("stats", ":8080")
err := server.AddListener(stats, nil)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
}

103
examples/events/main.go Normal file
View File

@ -0,0 +1,103 @@
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/events"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/listeners/auth"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
})
if err != nil {
log.Fatal(err)
}
// Start the server
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
// Add OnConnect Event Hook
server.Events.OnConnect = func(cl events.Client, pk events.Packet) {
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
// Add OnDisconnect Event Hook
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
}
// Add OnSubscribe Event Hook
server.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
fmt.Printf("<< OnSubscribe client subscribed %s: %s %v\n", cl.ID, filter, qos)
}
// Add OnUnsubscribe Event Hook
server.Events.OnUnsubscribe = func(filter string, cl events.Client) {
fmt.Printf("<< OnUnsubscribe client unsubscribed %s: %s\n", cl.ID, filter)
}
// Add OnMessage Event Hook
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
pkx = pk
if string(pk.Payload) == "hello" {
pkx.Payload = []byte("hello world")
fmt.Printf("< OnMessage modified message from client %s: %s\n", cl.ID, string(pkx.Payload))
} else {
fmt.Printf("< OnMessage received message from client %s: %s\n", cl.ID, string(pkx.Payload))
}
// Example of using AllowClients to selectively deliver/drop messages.
// Only a client with the id of `allowed-client` will received messages on the topic.
if pkx.TopicName == "a/b/restricted" {
pkx.AllowClients = []string{"allowed-client"} // slice of known client ids
}
return pkx, nil
}
// Demonstration of directly publishing messages to a topic via the
// `server.Publish` method. Subscribe to `direct/publish` using your
// MQTT client to see the messages.
go func() {
for range time.Tick(time.Second * 10) {
server.Publish("direct/publish", []byte("scheduled message"), false)
fmt.Println("> issued direct message to direct/publish")
}
}()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
}

View File

@ -0,0 +1,59 @@
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("PAHO Testing Suite"))
server := mqtt.New()
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(Auth),
})
if err != nil {
log.Fatal(err)
}
go server.Serve()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
}
// Auth is an example auth provider for the server.
type Auth struct{}
// Authenticate returns true if a username and password are acceptable.
// Auth always returns true.
func (a *Auth) Authenticate(user, password []byte) bool {
return true
}
// ACL returns true if a user has access permissions to read or write on a topic.
// ACL is used to deny access to a specific topic to satisfy Test.test_subscribe_failure.
func (a *Auth) ACL(user []byte, topic string, write bool) bool {
return topic != "test/nosubscribe"
}

View File

@ -0,0 +1,60 @@
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/logrusorgru/aurora"
"go.etcd.io/bbolt"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/persistence/bolt"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("Persistence"))
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
})
if err != nil {
log.Fatal(err)
}
err = server.AddStore(bolt.New("mochi-test.db", &bbolt.Options{
Timeout: 500 * time.Millisecond,
}))
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
}

57
examples/tcp/main.go Normal file
View File

@ -0,0 +1,57 @@
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/listeners/auth"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
// An example of configuring various server options...
options := &mqtt.Options{
BufferSize: 0, // Use default values
BufferBlockSize: 0, // Use default values
}
server := mqtt.NewServer(options)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
})
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
}

116
examples/tls/main.go Normal file
View File

@ -0,0 +1,116 @@
package main
import (
"crypto/tls"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/listeners/auth"
)
var (
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
-----END CERTIFICATE-----`)
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL"))
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
// Optionally, if you want clients to authenticate only with certs issued by your CA,
// you might want to use something like this:
// certPool := x509.NewCertPool()
// _ = certPool.AppendCertsFromPEM(caCertPem)
// tlsConfig := &tls.Config{
// ClientCAs: certPool,
// ClientAuth: tls.RequireAndVerifyClientCert,
// }
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err = server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)
}
ws := listeners.NewWebsocket("ws1", ":1882")
err = server.AddListener(ws, &listeners.Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)
}
stats := listeners.NewHTTPStats("stats", ":8080")
err = server.AddListener(stats, &listeners.Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfig,
})
if err != nil {
log.Fatal(err)
}
go server.Serve()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
}

View File

@ -0,0 +1,47 @@
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.NewServer(nil)
ws := listeners.NewWebsocket("ws1", ":1882")
err := server.AddListener(ws, nil)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
}

21
go.mod Normal file
View File

@ -0,0 +1,21 @@
module github.com/mochi-co/mqtt
go 1.18
require (
github.com/asdine/storm v2.1.2+incompatible
github.com/asdine/storm/v3 v3.2.1
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.3.5
github.com/logrusorgru/aurora v2.0.3+incompatible
github.com/rs/xid v1.4.0
github.com/stretchr/testify v1.7.1
go.etcd.io/bbolt v1.3.5
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)

58
go.sum Normal file
View File

@ -0,0 +1,58 @@
github.com/DataDog/zstd v1.4.1 h1:3oxKN3wbHibqx897utPC2LTQU4J+IHWWJO+glkAkpFM=
github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863 h1:BRrxwOZBolJN4gIwvZMJY1tzqBvQgpaZiQRuIDD40jM=
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863/go.mod h1:D0JMgToj/WdxCgd30Kc1UcA9E+WdZoJqeVOuYW7iTBM=
github.com/asdine/storm v2.1.2+incompatible h1:dczuIkyqwY2LrtXPz8ixMrU/OFgZp71kbKTHGrXYt/Q=
github.com/asdine/storm v2.1.2+incompatible/go.mod h1:RarYDc9hq1UPLImuiXK3BIWPJLdIygvV3PsInK0FbVQ=
github.com/asdine/storm/v3 v3.2.1 h1:I5AqhkPK6nBZ/qJXySdI7ot5BlXSZ7qvDY1zAn5ZJac=
github.com/asdine/storm/v3 v3.2.1/go.mod h1:LEpXwGt4pIqrE/XcTvCnZHT5MgZCV6Ub9q7yQzOFWr0=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczGlG91VSDkswnjF5A8=
github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0 h1:QPlSTtPE2k6PZPasQUbzuK3p9JbS+vMXYVto8g/yrsg=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d h1:L/IKR6COd7ubZrs2oTnTi73IhgqJ71c9s80WsQnh0Es=
golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

78
server/events/events.go Normal file
View File

@ -0,0 +1,78 @@
package events
import (
"github.com/mochi-co/mqtt/server/internal/packets"
)
// Events provides callback handlers for different event hooks.
type Events struct {
OnProcessMessage // published message receieved before evaluation.
OnMessage // published message receieved.
OnError // server error.
OnConnect // client connected.
OnDisconnect // client disconnected.
OnSubscribe // topic subscription created.
OnUnsubscribe // topic subscription removed.
}
// Packets is an alias for packets.Packet.
type Packet packets.Packet
// Client contains limited information about a connected client.
type Client struct {
ID string
Remote string
Listener string
Username []byte
CleanSession bool
}
// Clientlike is an interface for Clients and client-like objects that
// are able to describe their client/listener IDs and remote address.
type Clientlike interface {
Info() Client
}
// OnProcessMessage is called when a publish message is received, allowing modification
// of the packet data after ACL checking has occurred but before any data is evaluated
// for processing - e.g. for changing the Retain flag. Note, this hook is ONLY called
// by connected client publishers, it is not triggered when using the direct
// s.Publish method. The function receives the sent message and the
// data of the client who published it, and allows the packet to be modified
// before it is dispatched to subscribers. If no modification is required, return
// the original packet data. If an error occurs, the original packet will
// be dispatched as if the event hook had not been triggered.
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any further
// processing of the packet.
type OnProcessMessage func(Client, Packet) (Packet, error)
// OnMessage function is called when a publish message is received. Note,
// this hook is ONLY called by connected client publishers, it is not triggered when
// using the direct s.Publish method. The function receives the sent message and the
// data of the client who published it, and allows the packet to be modified
// before it is dispatched to subscribers. If no modification is required, return
// the original packet data. If an error occurs, the original packet will
// be dispatched as if the event hook had not been triggered.
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
type OnMessage func(Client, Packet) (Packet, error)
// OnConnect is called when a client successfully connects to the broker.
type OnConnect func(Client, Packet)
// OnDisconnect is called when a client disconnects to the broker. An error value
// is passed to the function if the client disconnected abnormally, otherwise it
// will be nil on a normal disconnect.
type OnDisconnect func(Client, error)
// OnError is called when errors that will not be passed to
// OnDisconnect are handled by the server.
type OnError func(Client, error)
// OnSubscribe is called when a new subscription filter for a client is created.
type OnSubscribe func(filter string, cl Client, qos byte)
// OnUnsubscribe is called when an existing subscription filter for a client is removed.
type OnUnsubscribe func(filter string, cl Client)

View File

@ -0,0 +1,212 @@
package circ
import (
"errors"
"io"
"sync"
"sync/atomic"
)
var (
// DefaultBufferSize is the default size of the buffer in bytes.
DefaultBufferSize int = 1024 * 256
// DefaultBlockSize is the default size per R/W block in bytes.
DefaultBlockSize int = 1024 * 8
// ErrOutOfRange indicates that the index was out of range.
ErrOutOfRange = errors.New("Indexes out of range")
// ErrInsufficientBytes indicates that there were not enough bytes to return.
ErrInsufficientBytes = errors.New("Insufficient bytes to return")
)
// Buffer is a circular buffer for reading and writing messages.
type Buffer struct {
buf []byte // the bytes buffer.
tmp []byte // a temporary buffer.
Mu sync.RWMutex // the buffer needs its own mutex to work properly.
ID string // the identifier of the buffer. This is used in debug output.
head int64 // the current position in the sequence - a forever increasing index.
tail int64 // the committed position in the sequence - a forever increasing index.
rcond *sync.Cond // the sync condition for the buffer reader.
wcond *sync.Cond // the sync condition for the buffer writer.
size int // the size of the buffer.
mask int // a bitmask of the buffer size (size-1).
block int // the size of the R/W block.
done uint32 // indicates that the buffer is closed.
State uint32 // indicates whether the buffer is reading from (1) or writing to (2).
}
// NewBuffer returns a new instance of buffer. You should call NewReader or
// NewWriter instead of this function.
func NewBuffer(size, block int) *Buffer {
if size == 0 {
size = DefaultBufferSize
}
if block == 0 {
block = DefaultBlockSize
}
if size < 2*block {
size = 2 * block
}
return &Buffer{
size: size,
mask: size - 1,
block: block,
buf: make([]byte, size),
rcond: sync.NewCond(new(sync.Mutex)),
wcond: sync.NewCond(new(sync.Mutex)),
}
}
// NewBufferFromSlice returns a new instance of buffer using a
// pre-existing byte slice.
func NewBufferFromSlice(block int, buf []byte) *Buffer {
l := len(buf)
if block == 0 {
block = DefaultBlockSize
}
b := &Buffer{
size: l,
mask: l - 1,
block: block,
buf: buf,
rcond: sync.NewCond(new(sync.Mutex)),
wcond: sync.NewCond(new(sync.Mutex)),
}
return b
}
// GetPos will return the tail and head positions of the buffer.
// This method is for use with testing.
func (b *Buffer) GetPos() (int64, int64) {
return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
}
// SetPos sets the head and tail of the buffer.
func (b *Buffer) SetPos(tail, head int64) {
atomic.StoreInt64(&b.tail, tail)
atomic.StoreInt64(&b.head, head)
}
// Get returns the internal buffer.
func (b *Buffer) Get() []byte {
b.Mu.Lock()
defer b.Mu.Unlock()
return b.buf
}
// Set writes bytes to a range of indexes in the byte buffer.
func (b *Buffer) Set(p []byte, start, end int) error {
b.Mu.Lock()
defer b.Mu.Unlock()
if end > b.size || start > b.size {
return ErrOutOfRange
}
o := 0
for i := start; i < end; i++ {
b.buf[i] = p[o]
o++
}
return nil
}
// Index returns the buffer-relative index of an integer.
func (b *Buffer) Index(i int64) int {
return b.mask & int(i)
}
// awaitEmpty will block until there is at least n bytes between
// the head and the tail (looking forward).
func (b *Buffer) awaitEmpty(n int) error {
// If the head has wrapped behind the tail, and next will overrun tail,
// then wait until tail has moved.
b.rcond.L.Lock()
for !b.checkEmpty(n) {
if atomic.LoadUint32(&b.done) == 1 {
b.rcond.L.Unlock()
return io.EOF
}
b.rcond.Wait()
}
b.rcond.L.Unlock()
return nil
}
// awaitFilled will block until there are at least n bytes between the
// tail and the head (looking forward).
func (b *Buffer) awaitFilled(n int) error {
// Because awaitCapacity prevents the head from overrunning the t
// able on write, we can simply ensure there is enough space
// the forever-incrementing tail and head integers.
b.wcond.L.Lock()
for !b.checkFilled(n) {
if atomic.LoadUint32(&b.done) == 1 {
b.wcond.L.Unlock()
return io.EOF
}
b.wcond.Wait()
}
b.wcond.L.Unlock()
return nil
}
// checkEmpty returns true if there are at least n bytes between the head and
// the tail.
func (b *Buffer) checkEmpty(n int) bool {
head := atomic.LoadInt64(&b.head)
next := head + int64(n)
tail := atomic.LoadInt64(&b.tail)
if next-tail > int64(b.size) {
return false
}
return true
}
// checkFilled returns true if there are at least n bytes between the tail and
// the head.
func (b *Buffer) checkFilled(n int) bool {
if atomic.LoadInt64(&b.tail)+int64(n) <= atomic.LoadInt64(&b.head) {
return true
}
return false
}
// CommitTail moves the tail position of the buffer n bytes.
func (b *Buffer) CommitTail(n int) {
atomic.AddInt64(&b.tail, int64(n))
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
}
// CapDelta returns the difference between the head and tail.
func (b *Buffer) CapDelta() int {
return int(atomic.LoadInt64(&b.head) - atomic.LoadInt64(&b.tail))
}
// Stop signals the buffer to stop processing.
func (b *Buffer) Stop() {
atomic.StoreUint32(&b.done, 1)
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
}

View File

@ -0,0 +1,317 @@
package circ
import (
//"fmt"
"sync/atomic"
"testing"
"time"
"unsafe"
"github.com/stretchr/testify/require"
)
func TestNewBuffer(t *testing.T) {
var size int = 16
var block int = 4
buf := NewBuffer(size, block)
require.NotNil(t, buf.buf)
require.NotNil(t, buf.rcond)
require.NotNil(t, buf.wcond)
require.Equal(t, size, len(buf.buf))
require.Equal(t, size, buf.size)
require.Equal(t, block, buf.block)
}
func TestNewBuffer0Size(t *testing.T) {
buf := NewBuffer(0, 0)
require.NotNil(t, buf.buf)
require.Equal(t, DefaultBufferSize, buf.size)
require.Equal(t, DefaultBlockSize, buf.block)
}
func TestNewBufferUndersize(t *testing.T) {
buf := NewBuffer(DefaultBlockSize+10, DefaultBlockSize)
require.NotNil(t, buf.buf)
require.Equal(t, DefaultBlockSize*2, buf.size)
require.Equal(t, DefaultBlockSize, buf.block)
}
func TestNewBufferFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewBufferFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestNewBufferFromSlice0Size(t *testing.T) {
b := NewBytesPool(256)
buf := NewBufferFromSlice(0, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestAtomicAlignment(t *testing.T) {
var b Buffer
offset := unsafe.Offsetof(b.head)
require.Equalf(t, uintptr(0), offset%8,
"head requires 64-bit alignment for atomic: offset %d", offset)
offset = unsafe.Offsetof(b.tail)
require.Equalf(t, uintptr(0), offset%8,
"tail requires 64-bit alignment for atomic: offset %d", offset)
}
func TestGetPos(t *testing.T) {
buf := NewBuffer(16, 4)
tail, head := buf.GetPos()
require.Equal(t, int64(0), tail)
require.Equal(t, int64(0), head)
atomic.StoreInt64(&buf.tail, 3)
atomic.StoreInt64(&buf.head, 11)
tail, head = buf.GetPos()
require.Equal(t, int64(3), tail)
require.Equal(t, int64(11), head)
}
func TestGet(t *testing.T) {
buf := NewBuffer(16, 4)
require.Equal(t, make([]byte, 16), buf.Get())
buf.buf[0] = 1
buf.buf[15] = 1
require.Equal(t, []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, buf.Get())
}
func TestSetPos(t *testing.T) {
buf := NewBuffer(16, 4)
require.Equal(t, int64(0), atomic.LoadInt64(&buf.tail))
require.Equal(t, int64(0), atomic.LoadInt64(&buf.head))
buf.SetPos(4, 8)
require.Equal(t, int64(4), atomic.LoadInt64(&buf.tail))
require.Equal(t, int64(8), atomic.LoadInt64(&buf.head))
}
func TestSet(t *testing.T) {
buf := NewBuffer(16, 4)
err := buf.Set([]byte{1, 1, 1, 1}, 17, 19)
require.Error(t, err)
err = buf.Set([]byte{1, 1, 1, 1}, 4, 8)
require.NoError(t, err)
require.Equal(t, []byte{0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, buf.buf)
}
func TestIndex(t *testing.T) {
buf := NewBuffer(1024, 4)
require.Equal(t, 512, buf.Index(512))
require.Equal(t, 0, buf.Index(1024))
require.Equal(t, 6, buf.Index(1030))
require.Equal(t, 6, buf.Index(61446))
}
func TestAwaitFilled(t *testing.T) {
tests := []struct {
tail int64
head int64
n int
await int
desc string
}{
{tail: 0, head: 4, n: 4, await: 1, desc: "OK 0, 4"},
{tail: 8, head: 11, n: 4, await: 1, desc: "OK 8, 11"},
{tail: 102, head: 103, n: 4, await: 3, desc: "OK 102, 103"},
}
for i, tt := range tests {
//fmt.Println(i)
buf := NewBuffer(16, 4)
buf.SetPos(tt.tail, tt.head)
o := make(chan error)
go func() {
o <- buf.awaitFilled(4)
}()
time.Sleep(time.Millisecond)
atomic.AddInt64(&buf.head, int64(tt.await))
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc)
}
}
func TestAwaitFilledEnded(t *testing.T) {
buf := NewBuffer(16, 4)
o := make(chan error)
go func() {
o <- buf.awaitFilled(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.Error(t, <-o)
}
func TestAwaitEmptyOK(t *testing.T) {
tests := []struct {
tail int64
head int64
await int
desc string
}{
{tail: 0, head: 0, await: 0, desc: "OK 0, 0"},
{tail: 0, head: 5, await: 0, desc: "OK 0, 5"},
{tail: 0, head: 14, await: 3, desc: "OK wrap 0, 14 "},
{tail: 22, head: 35, await: 2, desc: "OK wrap 0, 14 "},
{tail: 15, head: 17, await: 7, desc: "OK 15,2"},
{tail: 0, head: 10, await: 2, desc: "OK 0, 10"},
{tail: 1, head: 15, await: 4, desc: "OK 2, 14"},
}
for i, tt := range tests {
buf := NewBuffer(16, 4)
buf.SetPos(tt.tail, tt.head)
o := make(chan error)
go func() {
o <- buf.awaitEmpty(4)
}()
time.Sleep(time.Millisecond)
atomic.AddInt64(&buf.tail, int64(tt.await))
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc)
}
}
func TestAwaitEmptyEnded(t *testing.T) {
buf := NewBuffer(16, 4)
buf.SetPos(1, 15)
o := make(chan error)
go func() {
o <- buf.awaitEmpty(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
require.Error(t, <-o)
}
func TestCheckEmpty(t *testing.T) {
buf := NewBuffer(16, 4)
tests := []struct {
head int64
tail int64
want bool
desc string
}{
{tail: 0, head: 0, want: true, desc: "0, 0 true"},
{tail: 3, head: 4, want: true, desc: "4, 3 true"},
{tail: 15, head: 17, want: true, desc: "15, 17(1) true"},
{tail: 1, head: 30, want: false, desc: "1, 30(14) false"},
{tail: 15, head: 30, want: false, desc: "15, 30(14) false; head has caught up to tail"},
}
for i, tt := range tests {
buf.SetPos(tt.tail, tt.head)
require.Equal(t, tt.want, buf.checkEmpty(4), "Mismatched bool wanted [i:%d] %s", i, tt.desc)
}
}
func TestCheckFilled(t *testing.T) {
buf := NewBuffer(16, 4)
tests := []struct {
head int64
tail int64
want bool
desc string
}{
{tail: 0, head: 0, want: false, desc: "0, 0 false"},
{tail: 0, head: 4, want: true, desc: "0, 4 true"},
{tail: 14, head: 16, want: false, desc: "14,16 false"},
{tail: 14, head: 18, want: true, desc: "14,16 true"},
}
for i, tt := range tests {
buf.SetPos(tt.tail, tt.head)
require.Equal(t, tt.want, buf.checkFilled(4), "Mismatched bool wanted [i:%d] %s", i, tt.desc)
}
}
func TestCommitTail(t *testing.T) {
tests := []struct {
tail int64
head int64
n int
next int64
await int
desc string
}{
{tail: 0, head: 5, n: 4, next: 4, await: 0, desc: "OK 0, 4"},
{tail: 0, head: 5, n: 6, next: 6, await: 1, desc: "OK 0, 5"},
}
for i, tt := range tests {
buf := NewBuffer(16, 4)
buf.SetPos(tt.tail, tt.head)
go func() {
buf.CommitTail(tt.n)
}()
time.Sleep(time.Millisecond)
for j := 0; j < tt.await; j++ {
atomic.AddInt64(&buf.head, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
}
require.Equal(t, tt.next, atomic.LoadInt64(&buf.tail), "Next tail mismatch [i:%d] %s", i, tt.desc)
}
}
/*
func TestCommitTailEnded(t *testing.T) {
buf := NewBuffer(16, 4)
o := make(chan error)
go func() {
o <- buf.CommitTail(5)
}()
time.Sleep(time.Millisecond)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.Error(t, <-o)
}
*/
func TestCapDelta(t *testing.T) {
buf := NewBuffer(16, 4)
require.Equal(t, 0, buf.CapDelta())
buf.SetPos(10, 15)
require.Equal(t, 5, buf.CapDelta())
}
func TestStop(t *testing.T) {
buf := NewBuffer(16, 4)
buf.Stop()
require.Equal(t, uint32(1), buf.done)
}

View File

@ -0,0 +1,49 @@
package circ
import (
"sync"
"sync/atomic"
)
// BytesPool is a pool of []byte.
type BytesPool struct {
// int64/uint64 has to the first words in order
// to be 64-aligned on 32-bit architectures.
used int64 // access atomically
pool *sync.Pool
}
// NewBytesPool returns a sync.pool of []byte.
func NewBytesPool(n int) *BytesPool {
if n == 0 {
n = DefaultBufferSize
}
return &BytesPool{
pool: &sync.Pool{
New: func() interface{} {
return make([]byte, n)
},
},
}
}
// Get returns a pooled bytes.Buffer.
func (b *BytesPool) Get() []byte {
atomic.AddInt64(&b.used, 1)
return b.pool.Get().([]byte)
}
// Put puts the byte slice back into the pool.
func (b *BytesPool) Put(x []byte) {
for i := range x {
x[i] = 0
}
b.pool.Put(x)
atomic.AddInt64(&b.used, -1)
}
// InUse returns the number of pool blocks in use.
func (b *BytesPool) InUse() int64 {
return atomic.LoadInt64(&b.used)
}

View File

@ -0,0 +1,49 @@
package circ
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNewBytesPool(t *testing.T) {
bpool := NewBytesPool(256)
require.NotNil(t, bpool.pool)
}
func BenchmarkNewBytesPool(b *testing.B) {
for n := 0; n < b.N; n++ {
NewBytesPool(256)
}
}
func TestNewBytesPoolGet(t *testing.T) {
bpool := NewBytesPool(256)
buf := bpool.Get()
require.Equal(t, make([]byte, 256), buf)
require.Equal(t, int64(1), bpool.InUse())
}
func BenchmarkBytesPoolGet(b *testing.B) {
bpool := NewBytesPool(256)
for n := 0; n < b.N; n++ {
bpool.Get()
}
}
func TestNewBytesPoolPut(t *testing.T) {
bpool := NewBytesPool(256)
buf := bpool.Get()
require.Equal(t, int64(1), bpool.InUse())
bpool.Put(buf)
require.Equal(t, int64(0), bpool.InUse())
}
func BenchmarkBytesPoolPut(b *testing.B) {
bpool := NewBytesPool(256)
buf := bpool.Get()
for n := 0; n < b.N; n++ {
bpool.Put(buf)
}
}

View File

@ -0,0 +1,96 @@
package circ
import (
"io"
"sync/atomic"
)
// Reader is a circular buffer for reading data from an io.Reader.
type Reader struct {
*Buffer
}
// NewReader returns a new Circular Reader.
func NewReader(size, block int) *Reader {
b := NewBuffer(size, block)
b.ID = "\treader"
return &Reader{
b,
}
}
// NewReaderFromSlice returns a new Circular Reader using a pre-existing
// byte slice.
func NewReaderFromSlice(block int, p []byte) *Reader {
b := NewBufferFromSlice(block, p)
b.ID = "\treader"
return &Reader{
b,
}
}
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
// there is sufficient capacity to do so.
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
atomic.StoreUint32(&b.State, 1)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadUint32(&b.done) == 1 {
return total, nil
}
// Wait until there's enough capacity in the buffer before
// trying to read more bytes from the io.Reader.
err := b.awaitEmpty(b.block)
if err != nil {
// b.done is the only error condition for awaitCapacity
// so loop around and return properly.
continue
}
// If the block will overrun the circle end, just fill up
// and collect the rest on the next pass.
start := b.Index(atomic.LoadInt64(&b.head))
end := start + b.block
if end > b.size {
end = b.size
}
// Read into the buffer between the start and end indexes only.
n, err := r.Read(b.buf[start:end])
total += int64(n) // incr total bytes read.
if err != nil {
return total, err
}
// Move the head forward however many bytes were read.
atomic.AddInt64(&b.head, int64(n))
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
}
}
// Read reads n bytes from the buffer, and will block until at n bytes
// exist in the buffer to read.
func (b *Buffer) Read(n int) (p []byte, err error) {
err = b.awaitFilled(n)
if err != nil {
return
}
tail := atomic.LoadInt64(&b.tail)
next := tail + int64(n)
// If the read overruns the buffer, get everything until the end
// and then whatever is left from the start.
if b.Index(tail) > b.Index(next) {
b.tmp = b.buf[b.Index(tail):]
b.tmp = append(b.tmp, b.buf[:b.Index(next)]...)
} else {
b.tmp = b.buf[b.Index(tail):b.Index(next)] // Otherwise, simple tail:next read.
}
return b.tmp, nil
}

View File

@ -0,0 +1,129 @@
package circ
import (
"bytes"
"errors"
"io"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewReader(t *testing.T) {
var size = 16
var block = 4
buf := NewReader(size, block)
require.NotNil(t, buf.buf)
require.Equal(t, size, len(buf.buf))
require.Equal(t, size, buf.size)
require.Equal(t, block, buf.block)
}
func TestNewReaderFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewReaderFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestReadFrom(t *testing.T) {
buf := NewReader(16, 4)
b4 := bytes.Repeat([]byte{'-'}, 4)
br := bytes.NewReader(b4)
_, err := buf.ReadFrom(br)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, bytes.Repeat([]byte{'-'}, 4), buf.buf[:4])
require.Equal(t, int64(4), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, int64(8), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, int64(12), buf.head)
}
func TestReadFromWrap(t *testing.T) {
buf := NewReader(16, 4)
buf.buf = bytes.Repeat([]byte{'-'}, 16)
buf.SetPos(8, 14)
br := bytes.NewReader(bytes.Repeat([]byte{'/'}, 8))
o := make(chan error)
go func() {
_, err := buf.ReadFrom(br)
o <- err
}()
time.Sleep(time.Millisecond * 100)
go func() {
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
}()
<-o
require.Equal(t, []byte{'/', '/', '/', '/', '/', '/', '-', '-', '-', '-', '-', '-', '-', '-', '/', '/'}, buf.Get())
require.Equal(t, int64(22), atomic.LoadInt64(&buf.head))
require.Equal(t, 6, buf.Index(atomic.LoadInt64(&buf.head)))
}
func TestReadOK(t *testing.T) {
buf := NewReader(16, 4)
buf.buf = []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
tests := []struct {
tail int64
head int64
n int
bytes []byte
desc string
}{
{tail: 0, head: 4, n: 4, bytes: []byte{'a', 'b', 'c', 'd'}, desc: "0, 4 OK"},
{tail: 3, head: 15, n: 8, bytes: []byte{'d', 'e', 'f', 'g', 'h', 'i', 'j', 'k'}, desc: "3, 15 OK"},
{tail: 14, head: 15, n: 6, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd'}, desc: "14, 2 wrapped OK"},
}
for i, tt := range tests {
buf.SetPos(tt.tail, tt.head)
o := make(chan []byte)
go func() {
p, _ := buf.Read(tt.n)
o <- p
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.head, buf.head+int64(tt.n))
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
done := <-o
require.Equal(t, tt.bytes, done, "Peeked bytes mismatch [i:%d] %s", i, tt.desc)
}
}
func TestReadEnded(t *testing.T) {
buf := NewBuffer(16, 4)
o := make(chan error)
go func() {
_, err := buf.Read(4)
o <- err
}()
time.Sleep(time.Millisecond)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.Error(t, <-o)
}

View File

@ -0,0 +1,106 @@
package circ
import (
"io"
"log"
"sync/atomic"
)
// Writer is a circular buffer for writing data to an io.Writer.
type Writer struct {
*Buffer
}
// NewWriter returns a pointer to a new Circular Writer.
func NewWriter(size, block int) *Writer {
b := NewBuffer(size, block)
b.ID = "writer"
return &Writer{
b,
}
}
// NewWriterFromSlice returns a new Circular Writer using a pre-existing
// byte slice.
func NewWriterFromSlice(block int, p []byte) *Writer {
b := NewBufferFromSlice(block, p)
b.ID = "writer"
return &Writer{
b,
}
}
// WriteTo writes the contents of the buffer to an io.Writer.
func (b *Writer) WriteTo(w io.Writer) (total int64, err error) {
atomic.StoreUint32(&b.State, 2)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadUint32(&b.done) == 1 && b.CapDelta() == 0 {
return total, io.EOF
}
// Read from the buffer until there is at least 1 byte to write.
err = b.awaitFilled(1)
if err != nil {
return
}
// Get all the bytes between the tail and head, wrapping if necessary.
tail := atomic.LoadInt64(&b.tail)
rTail := b.Index(tail)
rHead := b.Index(atomic.LoadInt64(&b.head))
n := b.CapDelta()
p := make([]byte, 0, n)
if rTail > rHead {
p = append(p, b.buf[rTail:]...)
p = append(p, b.buf[:rHead]...)
} else {
p = append(p, b.buf[rTail:rHead]...)
}
n, err = w.Write(p)
total += int64(n)
if err != nil {
log.Println("error writing to buffer io.Writer;", err)
return
}
// Move the tail forward the bytes written and broadcast change.
atomic.StoreInt64(&b.tail, tail+int64(n))
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
}
}
// Write writes the buffer to the buffer p, returning the number of bytes written.
// The bytes written to the buffer are picked up by WriteTo.
func (b *Writer) Write(p []byte) (total int, err error) {
err = b.awaitEmpty(len(p))
if err != nil {
return
}
total = b.writeBytes(p)
atomic.AddInt64(&b.head, int64(total))
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
return
}
// writeBytes writes bytes to the buffer from the start position, and returns
// the new head position. This function does not wait for capacity and will
// overwrite any existing bytes.
func (b *Writer) writeBytes(p []byte) int {
var o int
var n int
for i := 0; i < len(p); i++ {
o = b.Index(atomic.LoadInt64(&b.head) + int64(i))
b.buf[o] = p[i]
n++
}
return n
}

View File

@ -0,0 +1,155 @@
package circ
import (
"bufio"
"bytes"
"net"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewWriter(t *testing.T) {
var size = 16
var block = 4
buf := NewWriter(size, block)
require.NotNil(t, buf.buf)
require.Equal(t, size, len(buf.buf))
require.Equal(t, size, buf.size)
require.Equal(t, block, buf.block)
}
func TestNewWriterFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewWriterFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestWriteTo(t *testing.T) {
tests := []struct {
tail int64
head int64
bytes []byte
await int
total int
err error
desc string
}{
{tail: 0, head: 5, bytes: []byte{'a', 'b', 'c', 'd', 'e'}, desc: "0,5 OK"},
{tail: 14, head: 21, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd', 'e'}, desc: "14,16(2) OK"},
}
for i, tt := range tests {
bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
buf := NewWriter(16, 4)
buf.Set(bb, 0, 16)
buf.SetPos(tt.tail, tt.head)
var b bytes.Buffer
w := bufio.NewWriter(&b)
nc := make(chan int64)
go func() {
n, _ := buf.WriteTo(w)
nc <- n
}()
time.Sleep(time.Millisecond * 100)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
w.Flush()
require.Equal(t, tt.bytes, b.Bytes(), "Written bytes mismatch [i:%d] %s", i, tt.desc)
}
}
func TestWriteToEndedFirst(t *testing.T) {
buf := NewWriter(16, 4)
buf.done = 1
var b bytes.Buffer
w := bufio.NewWriter(&b)
_, err := buf.WriteTo(w)
require.Error(t, err)
}
func TestWriteToBadWriter(t *testing.T) {
bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
buf := NewWriter(16, 4)
buf.Set(bb, 0, 16)
buf.SetPos(0, 6)
r, w := net.Pipe()
w.Close()
_, err := buf.WriteTo(w)
require.Error(t, err)
r.Close()
}
func TestWrite(t *testing.T) {
tests := []struct {
tail int64
head int64
rHead int64
bytes []byte
want []byte
desc string
}{
{tail: 0, head: 0, rHead: 4, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'a', 'b', 'c', 'd', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, desc: "0>4 OK"},
{tail: 4, head: 14, rHead: 2, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'c', 'd', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 'a', 'b'}, desc: "14>2 OK"},
}
for i, tt := range tests {
buf := NewWriter(16, 4)
buf.SetPos(tt.tail, tt.head)
o := make(chan []interface{})
go func() {
nn, err := buf.Write(tt.bytes)
o <- []interface{}{nn, err}
}()
done := <-o
require.Equal(t, tt.want, buf.buf, "Wanted written mismatch [i:%d] %s", i, tt.desc)
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
}
}
func TestWriteEnded(t *testing.T) {
buf := NewWriter(16, 4)
buf.SetPos(15, 30)
buf.done = 1
_, err := buf.Write([]byte{'a', 'b', 'c', 'd'})
require.Error(t, err)
}
func TestWriteBytes(t *testing.T) {
tests := []struct {
tail int64
head int64
bytes []byte
want []byte
start int
desc string
}{
{tail: 0, head: 0, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'a', 'b', 'c', 'd', 0, 0, 0, 0}, desc: "0,4 OK"},
{tail: 6, head: 6, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'c', 'd', 0, 0, 0, 0, 'a', 'b'}, desc: "6,2 OK wrapped"},
}
for i, tt := range tests {
buf := NewWriter(8, 4)
buf.SetPos(tt.tail, tt.head)
n := buf.writeBytes(tt.bytes)
require.Equal(t, tt.want, buf.buf, "Buffer mistmatch [i:%d] %s", i, tt.desc)
require.Equal(t, len(tt.bytes), n)
}
}

View File

@ -0,0 +1,564 @@
package clients
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
"github.com/mochi-co/mqtt/server/events"
"github.com/mochi-co/mqtt/server/internal/circ"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/internal/topics"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
var (
// defaultKeepalive is the default connection keepalive value in seconds.
defaultKeepalive uint16 = 10
// ErrConnectionClosed is returned when operating on a closed
// connection and/or when no error cause has been given.
ErrConnectionClosed = errors.New("connection not open")
)
// Clients contains a map of the clients known by the broker.
type Clients struct {
sync.RWMutex
internal map[string]*Client // clients known by the broker, keyed on client id.
}
// New returns an instance of Clients.
func New() *Clients {
return &Clients{
internal: make(map[string]*Client),
}
}
// Add adds a new client to the clients map, keyed on client id.
func (cl *Clients) Add(val *Client) {
cl.Lock()
cl.internal[val.ID] = val
cl.Unlock()
}
// Get returns the value of a client if it exists.
func (cl *Clients) Get(id string) (*Client, bool) {
cl.RLock()
val, ok := cl.internal[id]
cl.RUnlock()
return val, ok
}
// Len returns the length of the clients map.
func (cl *Clients) Len() int {
cl.RLock()
val := len(cl.internal)
cl.RUnlock()
return val
}
// Delete removes a client from the internal map.
func (cl *Clients) Delete(id string) {
cl.Lock()
delete(cl.internal, id)
cl.Unlock()
}
// GetByListener returns clients matching a listener id.
func (cl *Clients) GetByListener(id string) []*Client {
clients := make([]*Client, 0, cl.Len())
cl.RLock()
for _, v := range cl.internal {
if v.Listener == id && atomic.LoadUint32(&v.State.Done) == 0 {
clients = append(clients, v)
}
}
cl.RUnlock()
return clients
}
// Client contains information about a client known by the broker.
type Client struct {
State State // the operational state of the client.
LWT LWT // the last will and testament for the client.
Inflight *Inflight // a map of in-flight qos messages.
sync.RWMutex // mutex
Username []byte // the username the client authenticated with.
AC auth.Controller // an auth controller inherited from the listener.
Listener string // the id of the listener the client is connected to.
ID string // the client id.
conn net.Conn // the net.Conn used to establish the connection.
R *circ.Reader // a reader for reading incoming bytes.
W *circ.Writer // a writer for writing outgoing bytes.
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
systemInfo *system.Info // pointers to server system info.
packetID uint32 // the current highest packetID.
keepalive uint16 // the number of seconds the connection can wait.
CleanSession bool // indicates if the client expects a clean-session.
}
// State tracks the state of the client.
type State struct {
started *sync.WaitGroup // tracks the goroutines which have been started.
endedW *sync.WaitGroup // tracks when the writer has ended.
endedR *sync.WaitGroup // tracks when the reader has ended.
Done uint32 // atomic counter which indicates that the client has closed.
endOnce sync.Once // only end once.
stopCause atomic.Value // reason for stopping.
}
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
cl := &Client{
conn: c,
R: r,
W: w,
systemInfo: s,
keepalive: defaultKeepalive,
Inflight: &Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
State: State{
started: new(sync.WaitGroup),
endedW: new(sync.WaitGroup),
endedR: new(sync.WaitGroup),
},
}
cl.refreshDeadline(cl.keepalive)
return cl
}
// NewClientStub returns an instance of Client with basic initializations. This
// method is typically called by the persistence restoration system.
func NewClientStub(s *system.Info) *Client {
return &Client{
Inflight: &Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
State: State{
Done: 1,
},
}
}
// Identify sets the identification values of a client instance.
func (cl *Client) Identify(lid string, pk packets.Packet, ac auth.Controller) {
cl.Listener = lid
cl.AC = ac
cl.ID = pk.ClientIdentifier
if cl.ID == "" {
cl.ID = xid.New().String()
}
cl.R.ID = cl.ID + " READER"
cl.W.ID = cl.ID + " WRITER"
cl.Username = pk.Username
cl.CleanSession = pk.CleanSession
cl.keepalive = pk.Keepalive
if pk.WillFlag {
cl.LWT = LWT{
Topic: pk.WillTopic,
Message: pk.WillMessage,
Qos: pk.WillQos,
Retain: pk.WillRetain,
}
}
cl.refreshDeadline(cl.keepalive)
}
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
if cl.conn != nil {
var expiry time.Time // Nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second)
}
_ = cl.conn.SetDeadline(expiry)
}
}
// Info returns an event-version of a client, containing minimal information.
func (cl *Client) Info() events.Client {
addr := "unknown"
if cl.conn != nil && cl.conn.RemoteAddr() != nil {
addr = cl.conn.RemoteAddr().String()
}
return events.Client{
ID: cl.ID,
Remote: addr,
Username: cl.Username,
CleanSession: cl.CleanSession,
Listener: cl.Listener,
}
}
// NextPacketID returns the next packet id for a client, looping back to 0
// if the maximum ID has been reached.
func (cl *Client) NextPacketID() uint32 {
i := atomic.LoadUint32(&cl.packetID)
if i == uint32(65535) || i == uint32(0) {
atomic.StoreUint32(&cl.packetID, 1)
return 1
}
return atomic.AddUint32(&cl.packetID, 1)
}
// NoteSubscription makes a note of a subscription for the client.
func (cl *Client) NoteSubscription(filter string, qos byte) {
cl.Lock()
cl.Subscriptions[filter] = qos
cl.Unlock()
}
// ForgetSubscription forgests a subscription note for the client.
func (cl *Client) ForgetSubscription(filter string) {
cl.Lock()
delete(cl.Subscriptions, filter)
cl.Unlock()
}
// Start begins the client goroutines reading and writing packets.
func (cl *Client) Start() {
cl.State.started.Add(2)
cl.State.endedW.Add(1)
cl.State.endedR.Add(1)
go func() {
cl.State.started.Done()
_, err := cl.W.WriteTo(cl.conn)
if err != nil {
err = fmt.Errorf("writer: %w", err)
}
cl.State.endedW.Done()
cl.Stop(err)
}()
go func() {
cl.State.started.Done()
_, err := cl.R.ReadFrom(cl.conn)
if err != nil {
err = fmt.Errorf("reader: %w", err)
}
cl.State.endedR.Done()
cl.Stop(err)
}()
cl.State.started.Wait()
}
// ClearBuffers sets the read/write buffers to nil so they can be
// deallocated automatically when no longer in use.
func (cl *Client) ClearBuffers() {
cl.R = nil
cl.W = nil
}
// Stop instructs the client to shut down all processing goroutines and disconnect.
// A cause error may be passed to identfy the reason for stopping.
func (cl *Client) Stop(err error) {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return
}
cl.State.endOnce.Do(func() {
cl.R.Stop()
cl.W.Stop()
cl.State.endedW.Wait()
_ = cl.conn.Close() // omit close error
cl.State.endedR.Wait()
atomic.StoreUint32(&cl.State.Done, 1)
if err == nil {
err = ErrConnectionClosed
}
cl.State.stopCause.Store(err)
})
}
// StopCause returns the reason the client connection was stopped, if any.
func (cl *Client) StopCause() error {
if cl.State.stopCause.Load() == nil {
return nil
}
return cl.State.stopCause.Load().(error)
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
p, err := cl.R.Read(1)
if err != nil {
return err
}
err = fh.Decode(p[0])
if err != nil {
return err
}
// The remaining length value can be up to 5 bytes. Read through each byte
// looking for continue values, and if found increase the read. Otherwise
// decode the bytes that were legit.
buf := make([]byte, 0, 6)
i := 1
n := 2
for ; n < 6; n++ {
p, err = cl.R.Read(n)
if err != nil {
return err
}
buf = append(buf, p[i])
// If it's not a continuation flag, end here.
if p[i] < 128 {
break
}
// If i has reached 4 without a length terminator, return a protocol violation.
i++
if i == 4 {
return packets.ErrOversizedLengthIndicator
}
}
// Calculate and store the remaining length of the packet payload.
rem, _ := binary.Uvarint(buf)
fh.Remaining = int(rem)
// Having successfully read n bytes, commit the tail forward.
cl.R.CommitTail(n)
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
return nil
}
// Read loops forever reading new packets from a client connection until
// an error is encountered (or the connection is closed).
func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error {
for {
if atomic.LoadUint32(&cl.State.Done) == 1 && cl.R.CapDelta() == 0 {
return nil
}
cl.refreshDeadline(cl.keepalive)
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
if err != nil {
return err
}
pk, err := cl.ReadPacket(fh)
if err != nil {
return err
}
err = packetHandler(cl, pk) // Process inbound packet.
if err != nil {
return err
}
}
}
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
pk.FixedHeader = *fh
if pk.FixedHeader.Remaining == 0 {
return
}
p, err := cl.R.Read(pk.FixedHeader.Remaining)
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
px := append([]byte{}, p[:]...)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectDecode(px)
case packets.Connack:
err = pk.ConnackDecode(px)
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
case packets.Pubrec:
err = pk.PubrecDecode(px)
case packets.Pubrel:
err = pk.PubrelDecode(px)
case packets.Pubcomp:
err = pk.PubcompDecode(px)
case packets.Subscribe:
err = pk.SubscribeDecode(px)
case packets.Suback:
err = pk.SubackDecode(px)
case packets.Unsubscribe:
err = pk.UnsubscribeDecode(px)
case packets.Unsuback:
err = pk.UnsubackDecode(px)
case packets.Pingreq:
case packets.Pingresp:
case packets.Disconnect:
default:
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
}
cl.R.CommitTail(pk.FixedHeader.Remaining)
return
}
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return 0, ErrConnectionClosed
}
cl.W.Mu.Lock()
defer cl.W.Mu.Unlock()
buf := new(bytes.Buffer)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectEncode(buf)
case packets.Connack:
err = pk.ConnackEncode(buf)
case packets.Publish:
err = pk.PublishEncode(buf)
if err == nil {
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
}
case packets.Puback:
err = pk.PubackEncode(buf)
case packets.Pubrec:
err = pk.PubrecEncode(buf)
case packets.Pubrel:
err = pk.PubrelEncode(buf)
case packets.Pubcomp:
err = pk.PubcompEncode(buf)
case packets.Subscribe:
err = pk.SubscribeEncode(buf)
case packets.Suback:
err = pk.SubackEncode(buf)
case packets.Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case packets.Unsuback:
err = pk.UnsubackEncode(buf)
case packets.Pingreq:
err = pk.PingreqEncode(buf)
case packets.Pingresp:
err = pk.PingrespEncode(buf)
case packets.Disconnect:
err = pk.DisconnectEncode(buf)
default:
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
}
if err != nil {
return
}
// Write the packet bytes to the client byte buffer.
n, err = cl.W.Write(buf.Bytes())
if err != nil {
return
}
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
cl.refreshDeadline(cl.keepalive)
return
}
// LWT contains the last will and testament details for a client connection.
type LWT struct {
Message []byte // the message that shall be sent when the client disconnects.
Topic string // the topic the will message shall be sent to.
Qos byte // the quality of service desired.
Retain bool // indicates whether the will message should be retained
}
// InflightMessage contains data about a packet which is currently in-flight.
type InflightMessage struct {
Packet packets.Packet // the packet currently in-flight.
Sent int64 // the last time the message was sent (for retries) in unixtime.
Resends int // the number of times the message was attempted to be sent.
}
// Inflight is a map of InflightMessage keyed on packet id.
type Inflight struct {
sync.RWMutex
internal map[uint16]InflightMessage // internal contains the inflight messages.
}
// Set stores the packet of an Inflight message, keyed on message id. Returns
// true if the inflight message was new.
func (i *Inflight) Set(key uint16, in InflightMessage) bool {
i.Lock()
_, ok := i.internal[key]
i.internal[key] = in
i.Unlock()
return !ok
}
// Get returns the value of an in-flight message if it exists.
func (i *Inflight) Get(key uint16) (InflightMessage, bool) {
i.RLock()
val, ok := i.internal[key]
i.RUnlock()
return val, ok
}
// Len returns the size of the in-flight messages map.
func (i *Inflight) Len() int {
i.RLock()
v := len(i.internal)
i.RUnlock()
return v
}
// GetAll returns all the in-flight messages.
func (i *Inflight) GetAll() map[uint16]InflightMessage {
i.RLock()
defer i.RUnlock()
return i.internal
}
// Delete removes an in-flight message from the map. Returns true if the
// message existed.
func (i *Inflight) Delete(key uint16) bool {
i.Lock()
_, ok := i.internal[key]
delete(i.internal, key)
i.Unlock()
return ok
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,116 @@
package packets
import (
"encoding/binary"
"unicode/utf8"
"unsafe"
)
// bytesToString provides a zero-alloc, no-copy byte to string conversion.
// via https://github.com/golang/go/issues/25484#issuecomment-391415660
func bytesToString(bs []byte) string {
return *(*string)(unsafe.Pointer(&bs))
}
// decodeUint16 extracts the value of two bytes from a byte array.
func decodeUint16(buf []byte, offset int) (uint16, int, error) {
if len(buf) < offset+2 {
return 0, 0, ErrOffsetUintOutOfRange
}
return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
}
// decodeString extracts a string from a byte array, beginning at an offset.
func decodeString(buf []byte, offset int) (string, int, error) {
b, n, err := decodeBytes(buf, offset)
if err != nil {
return "", 0, err
}
if !validUTF8(b) {
return "", 0, ErrOffsetStrInvalidUTF8
}
return bytesToString(b), n, nil
}
// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
length, next, err := decodeUint16(buf, offset)
if err != nil {
return make([]byte, 0, 0), 0, err
}
if next+int(length) > len(buf) {
return make([]byte, 0, 0), 0, ErrOffsetBytesOutOfRange
}
// Note: there is no validUTF8() test for []byte payloads
return buf[next : next+int(length)], next + int(length), nil
}
// decodeByte extracts the value of a byte from a byte array.
func decodeByte(buf []byte, offset int) (byte, int, error) {
if len(buf) <= offset {
return 0, 0, ErrOffsetByteOutOfRange
}
return buf[offset], offset + 1, nil
}
// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
func decodeByteBool(buf []byte, offset int) (bool, int, error) {
if len(buf) <= offset {
return false, 0, ErrOffsetBoolOutOfRange
}
return 1&buf[offset] > 0, offset + 1, nil
}
// encodeBool returns a byte instead of a bool.
func encodeBool(b bool) byte {
if b {
return 1
}
return 0
}
// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
func encodeBytes(val []byte) []byte {
// In many circumstances the number of bytes being encoded is small.
// Setting the cap to a low amount allows us to account for those without
// triggering allocation growth on append unless we need to.
buf := make([]byte, 2, 32)
binary.BigEndian.PutUint16(buf, uint16(len(val)))
return append(buf, val...)
}
// encodeUint16 encodes a uint16 value to a byte array.
func encodeUint16(val uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, val)
return buf
}
// encodeString encodes a string to a byte array.
func encodeString(val string) []byte {
// Like encodeBytes, we set the cap to a small number to avoid
// triggering allocation growth on append unless we absolutely need to.
buf := make([]byte, 2, 32)
binary.BigEndian.PutUint16(buf, uint16(len(val)))
return append(buf, []byte(val)...)
}
// validUTF8 checks if the byte array contains valid UTF-8 characters, specifically
// conforming to the MQTT specification requirements.
func validUTF8(b []byte) bool {
// [MQTT-1.4.0-1] The character data in a UTF-8 encoded string MUST be well-formed UTF-8...
if !utf8.Valid(b) {
return false
}
// [MQTT-1.4.0-2] A UTF-8 encoded string MUST NOT include an encoding of the null character U+0000...
// ...
return true
}

View File

@ -0,0 +1,386 @@
package packets
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestBytesToString(t *testing.T) {
b := []byte{'a', 'b', 'c'}
require.Equal(t, "abc", bytesToString(b))
}
func BenchmarkBytesToString(b *testing.B) {
for n := 0; n < b.N; n++ {
bytesToString([]byte{'a', 'b', 'c'})
}
}
func TestDecodeString(t *testing.T) {
expect := []struct {
name string
rawBytes []byte
result string
offset int
shouldFail error
}{
{
offset: 0,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: "a/b/c/d",
},
{
offset: 14,
rawBytes: []byte{
byte(Connect << 4), 17, // Fixed header
0, 6, // Protocol Name - MSB+LSB
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
3, // Protocol Version
0, // Packet Flags
0, 30, // Keepalive
0, 3, // Client ID - MSB+LSB
'h', 'e', 'y', // Client ID "zen"},
},
result: "hey",
},
{
offset: 2,
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
result: "1/2/3/4/a/b/c/d/e/^/@/!",
},
{
offset: 0,
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
result: "x/y/z",
},
{
offset: 0,
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 5,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 9,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
shouldFail: ErrOffsetUintOutOfRange,
},
{
offset: 17,
rawBytes: []byte{
byte(Connect << 4), 0, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Flags
0, 20, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 6, // Will Topic - MSB+LSB
'l',
},
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
shouldFail: ErrOffsetStrInvalidUTF8,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func BenchmarkDecodeString(b *testing.B) {
in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}
for n := 0; n < b.N; n++ {
decodeString(in, 0)
}
}
func TestDecodeBytes(t *testing.T) {
expect := []struct {
rawBytes []byte
result []uint8
next int
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session)
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // ... truncated connect packet, only checking start
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 0,
shouldFail: ErrOffsetBytesOutOfRange,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 8,
shouldFail: ErrOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func BenchmarkDecodeBytes(b *testing.B) {
in := []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}
for n := 0; n < b.N; n++ {
decodeBytes(in, 0)
}
}
func TestDecodeByte(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint8
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
result: uint8(0x00),
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x04),
offset: 1,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x4d),
offset: 2,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x51),
offset: 3,
},
{
rawBytes: []byte{0, 4, 77, 80, 82, 84},
offset: 8,
shouldFail: ErrOffsetByteOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+1, offset)
})
}
}
func BenchmarkDecodeByte(b *testing.B) {
in := []byte{0, 4, 77, 81, 84, 84}
for n := 0; n < b.N; n++ {
decodeByte(in, 0)
}
}
func TestDecodeUint16(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint16
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x07),
offset: 0,
},
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x761),
offset: 1,
},
{
rawBytes: []byte{0, 7, 255, 47},
offset: 8,
shouldFail: ErrOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+2, offset)
})
}
}
func BenchmarkDecodeUint16(b *testing.B) {
in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}
for n := 0; n < b.N; n++ {
decodeUint16(in, 0)
}
}
func TestDecodeByteBool(t *testing.T) {
expect := []struct {
rawBytes []byte
result bool
offset int
shouldFail error
}{
{
rawBytes: []byte{0x00, 0x00},
result: false,
},
{
rawBytes: []byte{0x01, 0x00},
result: true,
},
{
rawBytes: []byte{0x01, 0x00},
offset: 5,
shouldFail: ErrOffsetBoolOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, 1, offset)
})
}
}
func BenchmarkDecodeByteBool(b *testing.B) {
in := []byte{0x00, 0x00}
for n := 0; n < b.N; n++ {
decodeByteBool(in, 0)
}
}
func TestEncodeBool(t *testing.T) {
result := encodeBool(true)
require.Equal(t, byte(1), result, "Incorrect encoded value; not true")
result = encodeBool(false)
require.Equal(t, byte(0), result, "Incorrect encoded value; not false")
// Check failure.
result = encodeBool(false)
require.NotEqual(t, byte(1), result, "Expected failure, incorrect encoded value")
}
func BenchmarkEncodeBool(b *testing.B) {
for n := 0; n < b.N; n++ {
encodeBool(true)
}
}
func TestEncodeBytes(t *testing.T) {
result := encodeBytes([]byte("testing"))
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result, "Incorrect encoded value")
result = encodeBytes([]byte("testing"))
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result, "Expected failure, incorrect encoded value")
}
func BenchmarkEncodeBytes(b *testing.B) {
bb := []byte("testing")
for n := 0; n < b.N; n++ {
encodeBytes(bb)
}
}
func TestEncodeUint16(t *testing.T) {
result := encodeUint16(0)
require.Equal(t, []byte{0x00, 0x00}, result, "Incorrect encoded value, 0")
result = encodeUint16(32767)
require.Equal(t, []byte{0x7f, 0xff}, result, "Incorrect encoded value, 32767")
result = encodeUint16(65535)
require.Equal(t, []byte{0xff, 0xff}, result, "Incorrect encoded value, 65535")
}
func BenchmarkEncodeUint16(b *testing.B) {
for n := 0; n < b.N; n++ {
encodeUint16(32767)
}
}
func TestEncodeString(t *testing.T) {
result := encodeString("testing")
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result, "Incorrect encoded value, testing")
result = encodeString("")
require.Equal(t, []uint8{0x00, 0x00}, result, "Incorrect encoded value, null")
result = encodeString("a")
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result, "Incorrect encoded value, a")
result = encodeString("b")
require.NotEqual(t, []uint8{0x00, 0x00}, result, "Expected failure, incorrect encoded value, b")
}
func BenchmarkEncodeString(b *testing.B) {
for n := 0; n < b.N; n++ {
encodeString("benchmarking")
}
}

View File

@ -0,0 +1,59 @@
package packets
import (
"bytes"
)
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
type FixedHeader struct {
Remaining int // the number of remaining bytes in the payload.
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte // indicates the quality of service expected.
Dup bool // indicates if the packet was already sent at an earlier time.
Retain bool // whether the message should be retained.
}
// Encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, int64(fh.Remaining))
}
// Decode extracts the specification bits from the header byte.
func (fh *FixedHeader) Decode(headerByte byte) error {
fh.Type = headerByte >> 4 // Get the message type from the first 4 bytes.
switch fh.Type {
case Publish:
fh.Dup = (headerByte>>3)&0x01 > 0 // Extract flags. Check if message is duplicate.
fh.Qos = (headerByte >> 1) & 0x03 // Extract QoS flag.
fh.Retain = headerByte&0x01 > 0 // Extract retain flag.
case Pubrel:
fh.Qos = (headerByte >> 1) & 0x03
case Subscribe:
fh.Qos = (headerByte >> 1) & 0x03
case Unsubscribe:
fh.Qos = (headerByte >> 1) & 0x03
default:
if (headerByte>>3)&0x01 > 0 || (headerByte>>1)&0x03 > 0 || headerByte&0x01 > 0 {
return ErrInvalidFlags
}
}
return nil
}
// encodeLength writes length bits for the header.
func encodeLength(buf *bytes.Buffer, length int64) {
for {
digit := byte(length % 128)
length /= 128
if length > 0 {
digit |= 0x80
}
buf.WriteByte(digit)
if length == 0 {
break
}
}
}

View File

@ -0,0 +1,220 @@
package packets
import (
"bytes"
"math"
"testing"
"github.com/stretchr/testify/require"
)
type fixedHeaderTable struct {
rawBytes []byte
header FixedHeader
packetError bool
flagError bool
}
var fixedHeaderExpected = []fixedHeaderTable{
{
rawBytes: []byte{Connect << 4, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Connack << 4, 0x00},
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish << 4, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
// remaining length
{
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
},
{
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
},
{
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
},
{
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
},
{
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
packetError: true,
},
// Invalid flags for packet
{
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
flagError: true,
},
}
func TestFixedHeaderEncode(t *testing.T) {
for i, wanted := range fixedHeaderExpected {
buf := new(bytes.Buffer)
wanted.header.Encode(buf)
if wanted.flagError == false {
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()), "Mismatched fixedheader length [i:%d] %v", i, wanted.rawBytes)
require.EqualValues(t, wanted.rawBytes, buf.Bytes(), "Mismatched byte values [i:%d] %v", i, wanted.rawBytes)
}
}
}
func BenchmarkFixedHeaderEncode(b *testing.B) {
buf := new(bytes.Buffer)
for n := 0; n < b.N; n++ {
fixedHeaderExpected[0].header.Encode(buf)
}
}
func TestFixedHeaderDecode(t *testing.T) {
for i, wanted := range fixedHeaderExpected {
fh := new(FixedHeader)
err := fh.Decode(wanted.rawBytes[0])
if wanted.flagError {
require.Error(t, err, "Expected error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
} else {
require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Type, fh.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Dup, fh.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Qos, fh.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Retain, fh.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes)
}
}
}
func BenchmarkFixedHeaderDecode(b *testing.B) {
fh := new(FixedHeader)
for n := 0; n < b.N; n++ {
err := fh.Decode(fixedHeaderExpected[0].rawBytes[0])
if err != nil {
panic(err)
}
}
}
func TestEncodeLength(t *testing.T) {
tt := []struct {
have int64
want []byte
}{
{
120,
[]byte{0x78},
},
{
math.MaxInt64,
[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
},
}
for i, wanted := range tt {
buf := new(bytes.Buffer)
encodeLength(buf, wanted.have)
require.Equal(t, wanted.want, buf.Bytes(), "Returned bytes should match length [i:%d] %s", i, wanted.have)
}
}
func BenchmarkEncodeLength(b *testing.B) {
buf := new(bytes.Buffer)
for n := 0; n < b.N; n++ {
encodeLength(buf, 120)
}
}

View File

@ -0,0 +1,670 @@
package packets
import (
"bytes"
"errors"
"fmt"
"strconv"
)
// All of the valid packet types and their packet identifier.
const (
Reserved byte = iota
Connect // 1
Connack // 2
Publish // 3
Puback // 4
Pubrec // 5
Pubrel // 6
Pubcomp // 7
Subscribe // 8
Suback // 9
Unsubscribe // 10
Unsuback // 11
Pingreq // 12
Pingresp // 13
Disconnect // 14
Accepted byte = 0x00
Failed byte = 0xFF
CodeConnectBadProtocolVersion byte = 0x01
CodeConnectBadClientID byte = 0x02
CodeConnectServerUnavailable byte = 0x03
CodeConnectBadAuthValues byte = 0x04
CodeConnectNotAuthorised byte = 0x05
CodeConnectNetworkError byte = 0xFE
CodeConnectProtocolViolation byte = 0xFF
ErrSubAckNetworkError byte = 0x80
)
var (
// CONNECT
ErrMalformedProtocolName = errors.New("malformed packet: protocol name")
ErrMalformedProtocolVersion = errors.New("malformed packet: protocol version")
ErrMalformedFlags = errors.New("malformed packet: flags")
ErrMalformedKeepalive = errors.New("malformed packet: keepalive")
ErrMalformedClientID = errors.New("malformed packet: client id")
ErrMalformedWillTopic = errors.New("malformed packet: will topic")
ErrMalformedWillMessage = errors.New("malformed packet: will message")
ErrMalformedUsername = errors.New("malformed packet: username")
ErrMalformedPassword = errors.New("malformed packet: password")
// CONNACK
ErrMalformedSessionPresent = errors.New("malformed packet: session present")
ErrMalformedReturnCode = errors.New("malformed packet: return code")
// PUBLISH
ErrMalformedTopic = errors.New("malformed packet: topic name")
ErrMalformedPacketID = errors.New("malformed packet: packet id")
// SUBSCRIBE
ErrMalformedQoS = errors.New("malformed packet: qos")
// PACKETS
ErrProtocolViolation = errors.New("protocol violation")
ErrOffsetBytesOutOfRange = errors.New("offset bytes out of range")
ErrOffsetByteOutOfRange = errors.New("offset byte out of range")
ErrOffsetBoolOutOfRange = errors.New("offset bool out of range")
ErrOffsetUintOutOfRange = errors.New("offset uint out of range")
ErrOffsetStrInvalidUTF8 = errors.New("offset string invalid utf8")
ErrInvalidFlags = errors.New("invalid flags set for packet")
ErrOversizedLengthIndicator = errors.New("protocol violation: oversized length indicator")
ErrMissingPacketID = errors.New("missing packet id")
ErrSurplusPacketID = errors.New("surplus packet id")
)
// Packet is an MQTT packet. Instead of providing a packet interface and variant
// packet structs, this is a single concrete packet type to cover all packet
// types, which allows us to take advantage of various compiler optimizations.
type Packet struct {
FixedHeader FixedHeader
AllowClients []string // For use with OnMessage event hook.
Topics []string
ReturnCodes []byte
ProtocolName []byte
Qoss []byte
Payload []byte
Username []byte
Password []byte
WillMessage []byte
ClientIdentifier string
TopicName string
WillTopic string
PacketID uint16
Keepalive uint16
ReturnCode byte
ProtocolVersion byte
WillQos byte
ReservedBit byte
CleanSession bool
WillFlag bool
WillRetain bool
UsernameFlag bool
PasswordFlag bool
SessionPresent bool
}
// ConnectEncode encodes a connect packet.
func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
protoName := encodeBytes(pk.ProtocolName)
protoVersion := pk.ProtocolVersion
flag := encodeBool(pk.CleanSession)<<1 | encodeBool(pk.WillFlag)<<2 | pk.WillQos<<3 | encodeBool(pk.WillRetain)<<5 | encodeBool(pk.PasswordFlag)<<6 | encodeBool(pk.UsernameFlag)<<7
keepalive := encodeUint16(pk.Keepalive)
clientID := encodeString(pk.ClientIdentifier)
var willTopic, willFlag, usernameFlag, passwordFlag []byte
// If will flag is set, add topic and message.
if pk.WillFlag {
willTopic = encodeString(pk.WillTopic)
willFlag = encodeBytes(pk.WillMessage)
}
// If username flag is set, add username.
if pk.UsernameFlag {
usernameFlag = encodeBytes(pk.Username)
}
// If password flag is set, add password.
if pk.PasswordFlag {
passwordFlag = encodeBytes(pk.Password)
}
// Get a length for the connect header. This is not super pretty, but it works.
pk.FixedHeader.Remaining =
len(protoName) + 1 + 1 + len(keepalive) + len(clientID) +
len(willTopic) + len(willFlag) +
len(usernameFlag) + len(passwordFlag)
pk.FixedHeader.Encode(buf)
// Eschew magic for readability.
buf.Write(protoName)
buf.WriteByte(protoVersion)
buf.WriteByte(flag)
buf.Write(keepalive)
buf.Write(clientID)
buf.Write(willTopic)
buf.Write(willFlag)
buf.Write(usernameFlag)
buf.Write(passwordFlag)
return nil
}
// ConnectDecode decodes a connect packet.
func (pk *Packet) ConnectDecode(buf []byte) error {
var offset int
var err error
// Unpack protocol name and version.
pk.ProtocolName, offset, err = decodeBytes(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolName)
}
pk.ProtocolVersion, offset, err = decodeByte(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolVersion)
}
// Unpack flags byte.
flags, offset, err := decodeByte(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedFlags)
}
pk.ReservedBit = 1 & flags
pk.CleanSession = 1&(flags>>1) > 0
pk.WillFlag = 1&(flags>>2) > 0
pk.WillQos = 3 & (flags >> 3) // this one is not a bool
pk.WillRetain = 1&(flags>>5) > 0
pk.PasswordFlag = 1&(flags>>6) > 0
pk.UsernameFlag = 1&(flags>>7) > 0
// Get keepalive interval.
pk.Keepalive, offset, err = decodeUint16(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedKeepalive)
}
// Get client ID.
pk.ClientIdentifier, offset, err = decodeString(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedClientID)
}
// Get Last Will and Testament topic and message if applicable.
if pk.WillFlag {
pk.WillTopic, offset, err = decodeString(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedWillTopic)
}
pk.WillMessage, offset, err = decodeBytes(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedWillMessage)
}
}
// Get username and password if applicable.
if pk.UsernameFlag {
pk.Username, offset, err = decodeBytes(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedUsername)
}
}
if pk.PasswordFlag {
pk.Password, _, err = decodeBytes(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
}
}
return nil
}
// ConnectValidate ensures the connect packet is compliant.
func (pk *Packet) ConnectValidate() (b byte, err error) {
// End if protocol name is bad.
if bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) != 0 &&
bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) != 0 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if protocol version is bad.
if (bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) == 0 && pk.ProtocolVersion != 3) ||
(bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) == 0 && pk.ProtocolVersion != 4) {
return CodeConnectBadProtocolVersion, ErrProtocolViolation
}
// End if reserved bit is not 0.
if pk.ReservedBit != 0 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if ClientID is too long.
if len(pk.ClientIdentifier) > 65535 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if password flag is set without a username.
if pk.PasswordFlag && !pk.UsernameFlag {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if Username or Password is too long.
if len(pk.Username) > 65535 || len(pk.Password) > 65535 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if client id isn't set and clean session is false.
if !pk.CleanSession && len(pk.ClientIdentifier) == 0 {
return CodeConnectBadClientID, ErrProtocolViolation
}
return Accepted, nil
}
// ConnackEncode encodes a Connack packet.
func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.WriteByte(encodeBool(pk.SessionPresent))
buf.WriteByte(pk.ReturnCode)
return nil
}
// ConnackDecode decodes a Connack packet.
func (pk *Packet) ConnackDecode(buf []byte) error {
var offset int
var err error
pk.SessionPresent, offset, err = decodeByteBool(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
}
pk.ReturnCode, _, err = decodeByte(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
}
return nil
}
// DisconnectEncode encodes a Disconnect packet.
func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Encode(buf)
return nil
}
// PingreqEncode encodes a Pingreq packet.
func (pk *Packet) PingreqEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Encode(buf)
return nil
}
// PingrespEncode encodes a Pingresp packet.
func (pk *Packet) PingrespEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Encode(buf)
return nil
}
// PubackEncode encodes a Puback packet.
func (pk *Packet) PubackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubackDecode decodes a Puback packet.
func (pk *Packet) PubackDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
// PubcompEncode encodes a Pubcomp packet.
func (pk *Packet) PubcompEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubcompDecode decodes a Pubcomp packet.
func (pk *Packet) PubcompDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
// PublishEncode encodes a Publish packet.
func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
topicName := encodeString(pk.TopicName)
var packetID []byte
// Add PacketID if QOS is set.
// [MQTT-2.3.1-5] A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0.
if pk.FixedHeader.Qos > 0 {
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.PacketID == 0 {
return ErrMissingPacketID
}
packetID = encodeUint16(pk.PacketID)
}
pk.FixedHeader.Remaining = len(topicName) + len(packetID) + len(pk.Payload)
pk.FixedHeader.Encode(buf)
buf.Write(topicName)
buf.Write(packetID)
buf.Write(pk.Payload)
return nil
}
// PublishDecode extracts the data values from the packet.
func (pk *Packet) PublishDecode(buf []byte) error {
var offset int
var err error
pk.TopicName, offset, err = decodeString(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
// If QOS decode Packet ID.
if pk.FixedHeader.Qos > 0 {
pk.PacketID, offset, err = decodeUint16(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
}
pk.Payload = buf[offset:]
return nil
}
// PublishCopy creates a new instance of Publish packet bearing the
// same payload and destination topic, but with an empty header for
// inheriting new QoS flags, etc.
func (pk *Packet) PublishCopy() Packet {
return Packet{
FixedHeader: FixedHeader{
Type: Publish,
Retain: pk.FixedHeader.Retain,
},
TopicName: pk.TopicName,
Payload: pk.Payload,
}
}
// PublishValidate validates a publish packet.
func (pk *Packet) PublishValidate() (byte, error) {
// @SPEC [MQTT-2.3.1-1]
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
return Failed, ErrMissingPacketID
}
// @SPEC [MQTT-2.3.1-5]
// A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0.
if pk.FixedHeader.Qos == 0 && pk.PacketID > 0 {
return Failed, ErrSurplusPacketID
}
return Accepted, nil
}
// PubrecEncode encodes a Pubrec packet.
func (pk *Packet) PubrecEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubrecDecode decodes a Pubrec packet.
func (pk *Packet) PubrecDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
// PubrelEncode encodes a Pubrel packet.
func (pk *Packet) PubrelEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubrelDecode decodes a Pubrel packet.
func (pk *Packet) PubrelDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
// SubackEncode encodes a Suback packet.
func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
packetID := encodeUint16(pk.PacketID)
pk.FixedHeader.Remaining = len(packetID) + len(pk.ReturnCodes) // Set length.
pk.FixedHeader.Encode(buf)
buf.Write(packetID) // Encode Packet ID.
buf.Write(pk.ReturnCodes) // Encode granted QOS flags.
return nil
}
// SubackDecode decodes a Suback packet.
func (pk *Packet) SubackDecode(buf []byte) error {
var offset int
var err error
// Get Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Get Granted QOS flags.
pk.ReturnCodes = buf[offset:]
return nil
}
// SubscribeEncode encodes a Subscribe packet.
func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
// Add the Packet ID.
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.PacketID == 0 {
return ErrMissingPacketID
}
packetID := encodeUint16(pk.PacketID)
// Count topics lengths and associated QOS flags.
var topicsLen int
for _, topic := range pk.Topics {
topicsLen += len(encodeString(topic)) + 1
}
pk.FixedHeader.Remaining = len(packetID) + topicsLen
pk.FixedHeader.Encode(buf)
buf.Write(packetID)
// Add all provided topic names and associated QOS flags.
for i, topic := range pk.Topics {
buf.Write(encodeString(topic))
buf.WriteByte(pk.Qoss[i])
}
return nil
}
// SubscribeDecode decodes a Subscribe packet.
func (pk *Packet) SubscribeDecode(buf []byte) error {
var offset int
var err error
// Get the Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Keep decoding until there's no space left.
for offset < len(buf) {
// Decode Topic Name.
var topic string
topic, offset, err = decodeString(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
pk.Topics = append(pk.Topics, topic)
// Decode QOS flag.
var qos byte
qos, offset, err = decodeByte(buf, offset)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedQoS)
}
// Ensure QoS byte is within range.
if !(qos >= 0 && qos <= 2) {
//if !validateQoS(qos) {
return ErrMalformedQoS
}
pk.Qoss = append(pk.Qoss, qos)
}
return nil
}
// SubscribeValidate ensures the packet is compliant.
func (pk *Packet) SubscribeValidate() (byte, error) {
// @SPEC [MQTT-2.3.1-1].
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
return Failed, ErrMissingPacketID
}
return Accepted, nil
}
// UnsubackEncode encodes an Unsuback packet.
func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// UnsubackDecode decodes an Unsuback packet.
func (pk *Packet) UnsubackDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
// UnsubscribeEncode encodes an Unsubscribe packet.
func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
// Add the Packet ID.
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.PacketID == 0 {
return ErrMissingPacketID
}
packetID := encodeUint16(pk.PacketID)
// Count topics lengths.
var topicsLen int
for _, topic := range pk.Topics {
topicsLen += len(encodeString(topic))
}
pk.FixedHeader.Remaining = len(packetID) + topicsLen
pk.FixedHeader.Encode(buf)
buf.Write(packetID)
// Add all provided topic names.
for _, topic := range pk.Topics {
buf.Write(encodeString(topic))
}
return nil
}
// UnsubscribeDecode decodes an Unsubscribe packet.
func (pk *Packet) UnsubscribeDecode(buf []byte) error {
var offset int
var err error
// Get the Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, 0)
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Keep decoding until there's no space left.
for offset < len(buf) {
var t string
t, offset, err = decodeString(buf, offset) // Decode Topic Name.
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
if len(t) > 0 {
pk.Topics = append(pk.Topics, t)
}
}
return nil
}
// UnsubscribeValidate validates an Unsubscribe packet.
func (pk *Packet) UnsubscribeValidate() (byte, error) {
// @SPEC [MQTT-2.3.1-1].
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
return Failed, ErrMissingPacketID
}
return Accepted, nil
}
// FormatID returns the PacketID field as a decimal integer.
func (pk *Packet) FormatID() string {
return strconv.FormatUint(uint64(pk.PacketID), 10)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,344 @@
package topics
import (
"strings"
"sync"
"github.com/mochi-co/mqtt/server/internal/packets"
)
// Subscriptions is a map of subscriptions keyed on client.
type Subscriptions map[string]byte
// Index is a prefix/trie tree containing topic subscribers and retained messages.
type Index struct {
mu sync.RWMutex // a mutex for locking the whole index.
Root *Leaf // a leaf containing a message and more leaves.
}
// New returns a pointer to a new instance of Index.
func New() *Index {
return &Index{
Root: &Leaf{
Leaves: make(map[string]*Leaf),
Clients: make(map[string]byte),
},
}
}
// RetainMessage saves a message payload to the end of a topic branch. Returns
// 1 if a retained message was added, and -1 if the retained message was removed.
// 0 is returned if sequential empty payloads are received.
func (x *Index) RetainMessage(msg packets.Packet) int64 {
x.mu.Lock()
defer x.mu.Unlock()
n := x.poperate(msg.TopicName)
// If there is a payload, we can store it.
if len(msg.Payload) > 0 {
n.Message = msg
return 1
}
// Otherwise, we are unsetting it.
// If there was a previous retained message, return -1 instead of 0.
var r int64 = 0
if len(n.Message.Payload) > 0 && n.Message.FixedHeader.Retain == true {
r = -1
}
x.unpoperate(msg.TopicName, "", true)
return r
}
// Subscribe creates a subscription filter for a client. Returns true if the
// subscription was new.
func (x *Index) Subscribe(filter, client string, qos byte) bool {
x.mu.Lock()
defer x.mu.Unlock()
n := x.poperate(filter)
_, ok := n.Clients[client]
n.Clients[client] = qos
n.Filter = filter
return !ok
}
// Unsubscribe removes a subscription filter for a client. Returns true if an
// unsubscribe action successful and the subscription existed.
func (x *Index) Unsubscribe(filter, client string) bool {
x.mu.Lock()
defer x.mu.Unlock()
n := x.poperate(filter)
_, ok := n.Clients[client]
return x.unpoperate(filter, client, false) && ok
}
// unpoperate steps backward through a trie sequence and removes any orphaned
// nodes. If a client id is specified, it will unsubscribe a client. If message
// is true, it will delete a retained message.
func (x *Index) unpoperate(filter string, client string, message bool) bool {
var d int // Walk to end leaf.
var particle string
var hasNext = true
e := x.Root
for hasNext {
particle, hasNext = isolateParticle(filter, d)
d++
e, _ = e.Leaves[particle]
// If the topic part doesn't exist in the tree, there's nothing
// left to do.
if e == nil {
return false
}
}
// Step backward removing client and orphaned leaves.
var key string
var orphaned bool
var end = true
for e.Parent != nil {
key = e.Key
// Wipe the client from this leaf if it's the filter end.
if end {
if client != "" {
delete(e.Clients, client)
}
if message {
e.Message = packets.Packet{}
}
end = false
}
// If this leaf is empty, note it as orphaned.
orphaned = len(e.Clients) == 0 && len(e.Leaves) == 0 && !e.Message.FixedHeader.Retain
// Traverse up the branch.
e = e.Parent
// If the leaf we just came from was empty, delete it.
if orphaned {
delete(e.Leaves, key)
}
}
return true
}
// poperate iterates and populates through a topic/filter path, instantiating
// leaves as it goes and returning the final leaf in the branch.
// poperate is a more enjoyable word than iterpop.
func (x *Index) poperate(topic string) *Leaf {
var d int
var particle string
var hasNext = true
n := x.Root
for hasNext {
particle, hasNext = isolateParticle(topic, d)
d++
child, _ := n.Leaves[particle]
if child == nil {
child = &Leaf{
Key: particle,
Parent: n,
Leaves: make(map[string]*Leaf),
Clients: make(map[string]byte),
}
n.Leaves[particle] = child
}
n = child
}
return n
}
// Subscribers returns a map of clients who are subscribed to matching filters.
func (x *Index) Subscribers(topic string) Subscriptions {
x.mu.RLock()
defer x.mu.RUnlock()
return x.Root.scanSubscribers(topic, 0, make(Subscriptions))
}
// Messages returns a slice of retained topic messages which match a filter.
func (x *Index) Messages(filter string) []packets.Packet {
// ReLeaf("messages", x.Root, 0)
x.mu.RLock()
defer x.mu.RUnlock()
return x.Root.scanMessages(filter, 0, make([]packets.Packet, 0, 32))
}
// Leaf is a child node on the tree.
type Leaf struct {
Message packets.Packet // a message which has been retained for a specific topic.
Key string // the key that was used to create the leaf.
Filter string // the path of the topic filter being matched.
Parent *Leaf // a pointer to the parent node for the leaf.
Leaves map[string]*Leaf // a map of child nodes, keyed on particle id.
Clients map[string]byte // a map of client ids subscribed to the topic.
}
// scanSubscribers recursively steps through a branch of leaves finding clients who
// have subscription filters matching a topic, and their highest QoS byte.
func (l *Leaf) scanSubscribers(topic string, d int, clients Subscriptions) Subscriptions {
part, hasNext := isolateParticle(topic, d)
// For either the topic part, a +, or a #, follow the branch.
for _, particle := range []string{part, "+", "#"} {
// Topics beginning with the reserved $ character are restricted from
// being returned for top level wildcards.
if d == 0 && len(part) > 0 && part[0] == '$' && (particle == "+" || particle == "#") {
continue
}
if child, ok := l.Leaves[particle]; ok {
// We're only interested in getting clients from the final
// element in the topic, or those with wildhashes.
if !hasNext || particle == "#" {
// Capture the highest QOS byte for any client with a filter
// matching the topic.
for client, qos := range child.Clients {
if ex, ok := clients[client]; !ok || ex < qos {
clients[client] = qos
}
}
// Make sure we also capture any client who are listening
// to this topic via path/#
if !hasNext {
if extra, ok := child.Leaves["#"]; ok {
for client, qos := range extra.Clients {
if ex, ok := clients[client]; !ok || ex < qos {
clients[client] = qos
}
}
}
}
}
// If this branch has hit a wildhash, just return immediately.
if particle == "#" {
return clients
} else if hasNext {
clients = child.scanSubscribers(topic, d+1, clients)
}
}
}
return clients
}
// scanMessages recursively steps through a branch of leaves finding retained messages
// that match a topic filter. Setting `d` to -1 will enable wildhash mode, and will
// recursively check ALL child leaves in every subsequent branch.
func (l *Leaf) scanMessages(filter string, d int, messages []packets.Packet) []packets.Packet {
// If a wildhash mode has been set, continue recursively checking through all
// child leaves regardless of their particle key.
if d == -1 {
for _, child := range l.Leaves {
if child.Message.FixedHeader.Retain {
messages = append(messages, child.Message)
}
messages = child.scanMessages(filter, -1, messages)
}
return messages
}
// Otherwise, we'll get the particle for d in the filter.
particle, hasNext := isolateParticle(filter, d)
// If there's no more particles after this one, then take the messages from
// these topics.
if !hasNext {
// Wildcards and Wildhashes must be checked first, otherwise they
// may be detected as standard particles, and not act properly.
if particle == "+" || particle == "#" {
// Otherwise, if it's a wildcard or wildhash, get messages from all
// the child leaves. This wildhash captures messages on the actual
// wildhash position, whereas the d == -1 block collects subsequent
// messages further down the branch.
for _, child := range l.Leaves {
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
continue
}
if child.Message.FixedHeader.Retain {
messages = append(messages, child.Message)
}
}
} else if child, ok := l.Leaves[particle]; ok {
if child.Message.FixedHeader.Retain {
messages = append(messages, child.Message)
}
}
} else {
// If it's not the last particle, branch out to the next leaves, scanning
// all available if it's a wildcard, or just one if it's a specific particle.
if particle == "+" {
for _, child := range l.Leaves {
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
continue
}
messages = child.scanMessages(filter, d+1, messages)
}
} else if child, ok := l.Leaves[particle]; ok {
messages = child.scanMessages(filter, d+1, messages)
}
}
// If the particle was a wildhash, scan all the child leaves setting the
// d value to wildhash mode.
if particle == "#" {
for _, child := range l.Leaves {
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
continue
}
messages = child.scanMessages(filter, -1, messages)
}
}
return messages
}
// isolateParticle extracts a particle between d / and d+1 / without allocations.
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
var next, end int
for i := 0; end > -1 && i <= d; i++ {
end = strings.IndexRune(filter, '/')
if d > -1 && i == d && end > -1 {
hasNext = true
particle = filter[next:end]
} else if end > -1 {
hasNext = false
filter = filter[end+1:]
} else {
hasNext = false
particle = filter[next:]
}
}
return
}
// ReLeaf is a dev function for showing the trie leafs.
/*
func ReLeaf(m string, leaf *Leaf, d int) {
for k, v := range leaf.Leaves {
fmt.Println(m, d, strings.Repeat(" ", d), k)
ReLeaf(m, v, d+1)
}
}
*/

View File

@ -0,0 +1,494 @@
package topics
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/server/internal/packets"
)
func TestNew(t *testing.T) {
index := New()
require.NotNil(t, index)
require.NotNil(t, index.Root)
}
func BenchmarkNew(b *testing.B) {
for n := 0; n < b.N; n++ {
New()
}
}
func TestPoperate(t *testing.T) {
index := New()
child := index.poperate("path/to/my/mqtt")
require.Equal(t, "mqtt", child.Key)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
child = index.poperate("a/b/c/d/e")
require.Equal(t, "e", child.Key)
child = index.poperate("a/b/c/c/a")
require.Equal(t, "a", child.Key)
}
func BenchmarkPoperate(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.poperate("path/to/my/mqtt")
}
}
func TestUnpoperate(t *testing.T) {
index := New()
index.Subscribe("path/to/my/mqtt", "client-1", 0)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
index.Subscribe("path/to/another/mqtt", "client-1", 0)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
pk := packets.Packet{TopicName: "path/to/retained/message", Payload: []byte{'h', 'e', 'l', 'l', 'o'}}
index.RetainMessage(pk)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["retained"].Leaves["message"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["retained"].Leaves["message"].Message)
pk2 := packets.Packet{TopicName: "path/to/my/mqtt", Payload: []byte{'s', 'h', 'a', 'r', 'e', 'd'}}
index.RetainMessage(pk2)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
index.unpoperate("path/to/my/mqtt", "", true) // delete retained
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message.FixedHeader.Retain)
index.unpoperate("path/to/my/mqtt", "client-1", false) // unsubscribe client
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
index.unpoperate("path/to/retained/message", "", true) // delete retained
require.NotContains(t, index.Root.Leaves["path"].Leaves["to"].Leaves, "my")
index.unpoperate("path/to/whatever", "client-1", false) // unsubscribe client
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
//require.Empty(t, index.Root.Leaves["path"])
}
func BenchmarkUnpoperate(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.poperate("path/to/my/mqtt")
}
}
func TestRetainMessage(t *testing.T) {
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
TopicName: "path/to/my/mqtt",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
}
pk2 := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
TopicName: "path/to/another/mqtt",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
}
index := New()
q := index.RetainMessage(pk)
require.Equal(t, int64(1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
index.Subscribe("path/to/another/mqtt", "client-1", 0)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients["client-1"])
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
q = index.RetainMessage(pk2)
require.Equal(t, int64(1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
// The same message already exists, but we're not doing a deep-copy check, so it's considered
// to be a new message.
q = index.RetainMessage(pk2)
require.Equal(t, int64(1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
// Delete retained
pk3 := packets.Packet{TopicName: "path/to/another/mqtt", Payload: []byte{}}
q = index.RetainMessage(pk3)
require.Equal(t, int64(-1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
// Second Delete retained
q = index.RetainMessage(pk3)
require.Equal(t, int64(0), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
}
func BenchmarkRetainMessage(b *testing.B) {
index := New()
pk := packets.Packet{TopicName: "path/to/another/mqtt"}
for n := 0; n < b.N; n++ {
index.RetainMessage(pk)
}
}
func TestSubscribeOK(t *testing.T) {
index := New()
q := index.Subscribe("path/to/my/mqtt", "client-1", 0)
require.Equal(t, true, q)
q = index.Subscribe("path/to/my/mqtt", "client-1", 0)
require.Equal(t, false, q)
q = index.Subscribe("path/to/my/mqtt", "client-2", 0)
require.Equal(t, true, q)
q = index.Subscribe("path/to/another/mqtt", "client-1", 0)
require.Equal(t, true, q)
q = index.Subscribe("path/+", "client-2", 0)
require.Equal(t, true, q)
q = index.Subscribe("#", "client-3", 0)
require.Equal(t, true, q)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
require.Equal(t, "path/to/my/mqtt", index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Filter)
require.Equal(t, "mqtt", index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Key)
require.Equal(t, index.Root.Leaves["path"], index.Root.Leaves["path"].Leaves["to"].Parent)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["+"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
}
func BenchmarkSubscribe(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.Subscribe("path/to/mqtt/basic", "client-1", 0)
}
}
func TestUnsubscribeA(t *testing.T) {
index := New()
index.Subscribe("path/to/my/mqtt", "client-1", 0)
index.Subscribe("path/to/+/mqtt", "client-1", 0)
index.Subscribe("path/to/stuff", "client-1", 0)
index.Subscribe("path/to/stuff", "client-2", 0)
index.Subscribe("#", "client-3", 0)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["+"].Leaves["mqtt"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
ok := index.Unsubscribe("path/to/my/mqtt", "client-1")
require.Equal(t, true, ok)
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["+"].Leaves["mqtt"].Clients, "client-1")
ok = index.Unsubscribe("path/to/stuff", "client-1")
require.Equal(t, true, ok)
require.NotContains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "client-1")
require.Equal(t, false, ok)
}
func TestUnsubscribeCascade(t *testing.T) {
index := New()
index.Subscribe("a/b/c", "client-1", 0)
index.Subscribe("a/b/c/e/e", "client-1", 0)
ok := index.Unsubscribe("a/b/c/e/e", "client-1")
require.Equal(t, true, ok)
require.NotEmpty(t, index.Root.Leaves)
require.Contains(t, index.Root.Leaves["a"].Leaves["b"].Leaves["c"].Clients, "client-1")
}
// This benchmark is Unsubscribe-Subscribe
func BenchmarkUnsubscribe(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.Subscribe("path/to/my/mqtt", "client-1", 0)
index.Unsubscribe("path/to/mqtt/basic", "client-1")
}
}
func TestSubscribersFind(t *testing.T) {
tt := []struct {
filter string
topic string
len int
}{
{
filter: "a",
topic: "a",
len: 1,
},
{
filter: "a/",
topic: "a",
len: 0,
},
{
filter: "a/",
topic: "a/",
len: 1,
},
{
filter: "/a",
topic: "/a",
len: 1,
},
{
filter: "path/to/my/mqtt",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "path/to/+/mqtt",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "+/to/+/mqtt",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "#",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "+/+/+/+",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "+/+/+/#",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "zen/#",
topic: "zen",
len: 1,
},
{
filter: "+/+/#",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "path/to/",
topic: "path/to/my/mqtt",
len: 0,
},
{
filter: "#/stuff",
topic: "path/to/my/mqtt",
len: 0,
},
{
filter: "$SYS/#",
topic: "$SYS/info",
len: 1,
},
{
filter: "#",
topic: "$SYS/info",
len: 0,
},
{
filter: "+/info",
topic: "$SYS/info",
len: 0,
},
}
for i, check := range tt {
index := New()
index.Subscribe(check.filter, "client-1", 0)
clients := index.Subscribers(check.topic)
//spew.Dump(clients)
require.Equal(t, check.len, len(clients), "Unexpected clients len at %d %s %s", i, check.filter, check.topic)
}
}
func BenchmarkSubscribers(b *testing.B) {
index := New()
index.Subscribe("path/to/my/mqtt", "client-1", 0)
index.Subscribe("path/to/+/mqtt", "client-1", 0)
index.Subscribe("something/things/stuff/+", "client-1", 0)
index.Subscribe("path/to/stuff", "client-2", 0)
index.Subscribe("#", "client-3", 0)
for n := 0; n < b.N; n++ {
index.Subscribers("path/to/testing/mqtt")
}
}
func TestIsolateParticle(t *testing.T) {
particle, hasNext := isolateParticle("path/to/my/mqtt", 0)
require.Equal(t, "path", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 1)
require.Equal(t, "to", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 2)
require.Equal(t, "my", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 3)
require.Equal(t, "mqtt", particle)
require.Equal(t, false, hasNext)
particle, hasNext = isolateParticle("/path/", 0)
require.Equal(t, "", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("/path/", 1)
require.Equal(t, "path", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("/path/", 2)
require.Equal(t, "", particle)
require.Equal(t, false, hasNext)
particle, hasNext = isolateParticle("a/b/c/+/+", 3)
require.Equal(t, "+", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("a/b/c/+/+", 4)
require.Equal(t, "+", particle)
require.Equal(t, false, hasNext)
}
func BenchmarkIsolateParticle(b *testing.B) {
for n := 0; n < b.N; n++ {
isolateParticle("path/to/my/mqtt", 3)
}
}
func TestMessagesPattern(t *testing.T) {
tt := []struct {
packet packets.Packet
filter string
len int
}{
{
packets.Packet{TopicName: "a/b/c/d", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"a/b/c/d",
1,
},
{
packets.Packet{TopicName: "a/b/c/e", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"a/+/c/+",
2,
},
{
packets.Packet{TopicName: "a/b/d/f", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"+/+/+/+",
3,
},
{
packets.Packet{TopicName: "q/w/e/r/t/y", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"q/w/e/#",
1,
},
{
packets.Packet{TopicName: "q/w/x/r/t/x", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"q/#",
2,
},
{
packets.Packet{TopicName: "asd", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"asd",
1,
},
{
packets.Packet{TopicName: "$SYS/testing", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"#",
8,
},
{
packets.Packet{TopicName: "$SYS/test", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"+/testing",
0,
},
{
packets.Packet{TopicName: "$SYS/info", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"$SYS/info",
1,
},
{
packets.Packet{TopicName: "$SYS/b", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"$SYS/#",
4,
},
{
packets.Packet{TopicName: "asd/fgh/jkl", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"#",
8,
},
{
packets.Packet{TopicName: "stuff/asdadsa/dsfdsafdsadfsa/dsfdsf/sdsadas", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"stuff/#/things", // indexer will ignore trailing /things
1,
},
}
index := New()
for _, check := range tt {
index.RetainMessage(check.packet)
}
for i, check := range tt {
messages := index.Messages(check.filter)
require.Equal(t, check.len, len(messages), "Unexpected messages len at %d %s %s", i, check.filter, check.packet.TopicName)
}
}
func TestMessagesFind(t *testing.T) {
index := New()
index.RetainMessage(packets.Packet{TopicName: "a/a", Payload: []byte{'a'}, FixedHeader: packets.FixedHeader{Retain: true}})
index.RetainMessage(packets.Packet{TopicName: "a/b", Payload: []byte{'b'}, FixedHeader: packets.FixedHeader{Retain: true}})
messages := index.Messages("a/a")
require.Equal(t, 1, len(messages))
messages = index.Messages("a/+")
require.Equal(t, 2, len(messages))
}
func BenchmarkMessages(b *testing.B) {
index := New()
index.RetainMessage(packets.Packet{TopicName: "path/to/my/mqtt"})
index.RetainMessage(packets.Packet{TopicName: "path/to/another/mqtt"})
index.RetainMessage(packets.Packet{TopicName: "path/a/some/mqtt"})
index.RetainMessage(packets.Packet{TopicName: "what/is"})
index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"})
for n := 0; n < b.N; n++ {
index.Messages("path/to/+/mqtt")
}
}

View File

@ -0,0 +1,14 @@
package utils
// InSliceString returns true if a string exists in a slice of strings.
// This temporary and should be replaced with a function from the new
// go slices package in 1.19 when available.
// https://github.com/golang/go/issues/45955
func InSliceString(sl []string, st string) bool {
for _, v := range sl {
if st == v {
return true
}
}
return false
}

View File

@ -0,0 +1,18 @@
package utils
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestInSliceString(t *testing.T) {
sl := []string{"a", "b", "c"}
require.Equal(t, true, InSliceString(sl, "b"))
sl = []string{"a", "a", "a"}
require.Equal(t, true, InSliceString(sl, "a"))
sl = []string{"a", "b", "c"}
require.Equal(t, false, InSliceString(sl, "d"))
}

View File

@ -0,0 +1,12 @@
package auth
// Controller is an interface for authentication controllers.
type Controller interface {
// Authenticate authenticates a user on CONNECT and returns true if a user is
// allowed to join the server.
Authenticate(user, password []byte) bool
// ACL returns true if a user has read or write access to a given topic.
ACL(user []byte, topic string, write bool) bool
}

View File

@ -0,0 +1,31 @@
package auth
// Allow is an auth controller which allows access to all connections and topics.
type Allow struct{}
// Authenticate returns true if a username and password are acceptable. Allow always
// returns true.
func (a *Allow) Authenticate(user, password []byte) bool {
return true
}
// ACL returns true if a user has access permissions to read or write on a topic.
// Allow always returns true.
func (a *Allow) ACL(user []byte, topic string, write bool) bool {
return true
}
// Disallow is an auth controller which disallows access to all connections and topics.
type Disallow struct{}
// Authenticate returns true if a username and password are acceptable. Disallow always
// returns false.
func (d *Disallow) Authenticate(user, password []byte) bool {
return false
}
// ACL returns true if a user has access permissions to read or write on a topic.
// Disallow always returns false.
func (d *Disallow) ACL(user []byte, topic string, write bool) bool {
return false
}

View File

@ -0,0 +1,55 @@
package auth
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAllowAuth(t *testing.T) {
ac := new(Allow)
require.Equal(t, true, ac.Authenticate([]byte("user"), []byte("pass")))
}
func BenchmarkAllowAuth(b *testing.B) {
ac := new(Allow)
for n := 0; n < b.N; n++ {
ac.Authenticate([]byte("user"), []byte("pass"))
}
}
func TestAllowACL(t *testing.T) {
ac := new(Allow)
require.Equal(t, true, ac.ACL([]byte("user"), "topic", true))
}
func BenchmarkAllowACL(b *testing.B) {
ac := new(Allow)
for n := 0; n < b.N; n++ {
ac.ACL([]byte("user"), "pass", true)
}
}
func TestDisallowAuth(t *testing.T) {
ac := new(Disallow)
require.Equal(t, false, ac.Authenticate([]byte("user"), []byte("pass")))
}
func BenchmarkDisallowAuth(b *testing.B) {
ac := new(Disallow)
for n := 0; n < b.N; n++ {
ac.Authenticate([]byte("user"), []byte("pass"))
}
}
func TestDisallowACL(t *testing.T) {
ac := new(Disallow)
require.Equal(t, false, ac.ACL([]byte("user"), "topic", true))
}
func BenchmarkDisallowACL(b *testing.B) {
ac := new(Disallow)
for n := 0; n < b.N; n++ {
ac.ACL([]byte("user"), "pass", true)
}
}

View File

@ -0,0 +1,124 @@
package listeners
import (
"context"
"crypto/tls"
"encoding/json"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
type HTTPStats struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
config *Config // configuration values for the listener.
system *system.Info // pointers to the server data.
listen *http.Server // the http server.
end uint32 // ensure the close methods are only called once.}
}
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
func NewHTTPStats(id, address string) *HTTPStats {
return &HTTPStats{
id: id,
address: address,
config: &Config{
Auth: new(auth.Allow),
},
}
}
// SetConfig sets the configuration values for the listener config.
func (l *HTTPStats) SetConfig(config *Config) {
l.Lock()
if config != nil {
l.config = config
// If a config has been passed without an auth controller,
// it may be a mistake, so disallow all traffic.
if l.config.Auth == nil {
l.config.Auth = new(auth.Disallow)
}
}
l.Unlock()
}
// ID returns the id of the listener.
func (l *HTTPStats) ID() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
// Listen starts listening on the listener's network address.
func (l *HTTPStats) Listen(s *system.Info) error {
l.system = s
mux := http.NewServeMux()
mux.HandleFunc("/", l.jsonHandler)
l.listen = &http.Server{
Addr: l.address,
Handler: mux,
}
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
} else {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil
}
// Serve starts listening for new connections and serving responses.
func (l *HTTPStats) Serve(establish EstablishFunc) {
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *HTTPStats) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
info, err := json.MarshalIndent(l.system, "", "\t")
if err != nil {
io.WriteString(w, err.Error())
return
}
w.Write(info)
}

View File

@ -0,0 +1,198 @@
package listeners
import (
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
"github.com/stretchr/testify/require"
)
func TestNewHTTPStats(t *testing.T) {
l := NewHTTPStats("t1", testPort)
require.Equal(t, "t1", l.id)
require.Equal(t, testPort, l.address)
}
func BenchmarkNewHTTPStats(b *testing.B) {
for n := 0; n < b.N; n++ {
NewHTTPStats("t1", testPort)
}
}
func TestHTTPStatsSetConfig(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
})
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Allow), l.config.Auth)
// Switch to disallow on bad config set.
l.SetConfig(new(Config))
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Disallow), l.config.Auth)
}
func BenchmarkHTTPStatsSetConfig(b *testing.B) {
l := NewHTTPStats("t1", testPort)
for n := 0; n < b.N; n++ {
l.SetConfig(new(Config))
}
}
func TestHTTPStatsID(t *testing.T) {
l := NewHTTPStats("t1", testPort)
require.Equal(t, "t1", l.ID())
}
func BenchmarkHTTPStatsID(b *testing.B) {
l := NewHTTPStats("t1", testPort)
for n := 0; n < b.N; n++ {
l.ID()
}
}
func TestHTTPStatsListen(t *testing.T) {
l := NewHTTPStats("t1", testPort)
err := l.Listen(new(system.Info))
require.NoError(t, err)
require.NotNil(t, l.system)
require.NotNil(t, l.listen)
require.Equal(t, testPort, l.listen.Addr)
l.listen.Close()
}
func TestHTTPStatsListenTLSConfig(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(new(system.Info))
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestHTTPStatsListenTLS(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(new(system.Info))
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestHTTPStatsListenTLSInvalid(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: []byte("abcde"),
PrivateKey: testPrivateKey,
},
})
err := l.Listen(new(system.Info))
require.Error(t, err)
}
func TestHTTPStatsServeAndClose(t *testing.T) {
l := NewHTTPStats("t1", testPort)
err := l.Listen(&system.Info{
Version: "test",
})
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
resp, err := http.Get("http://localhost" + testPort)
require.NoError(t, err)
require.NotNil(t, resp)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
v := new(system.Info)
err = json.Unmarshal(body, v)
require.NoError(t, err)
require.Equal(t, "test", v.Version)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
_, err = http.Get("http://localhost" + testPort)
require.Error(t, err)
<-o
}
func TestHTTPStatsServeTLSAndClose(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(&system.Info{
Version: "test",
})
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func TestHTTPStatsJSONHandler(t *testing.T) {
l := NewHTTPStats("t1", testPort)
err := l.Listen(&system.Info{
Version: "test",
})
require.NoError(t, err)
w := httptest.NewRecorder()
l.jsonHandler(w, nil)
resp := w.Result()
body, _ := ioutil.ReadAll(resp.Body)
v := new(system.Info)
err = json.Unmarshal(body, v)
require.NoError(t, err)
require.Equal(t, "test", v.Version)
}

View File

@ -0,0 +1,152 @@
package listeners
import (
"crypto/tls"
"net"
"sync"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
// Config contains configuration values for a listener.
type Config struct {
// Auth controller containing auth and ACL logic for
// allowing or denying access to the server and topics.
Auth auth.Controller
// TLS certficates and settings for the connection.
//
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
// Please use TLSConfig instead.
TLS *TLS
// TLSConfig is a tls.Config configuration to be used with the listener.
// See examples folder for basic and mutual-tls use.
TLSConfig *tls.Config
}
// TLS contains the TLS certificates and settings for the listener connection.
//
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
// Please use TLSConfig instead.
type TLS struct {
Certificate []byte // the body of a public certificate.
PrivateKey []byte // the body of a private key.
}
// EstablishFunc is a callback function for establishing new clients.
type EstablishFunc func(id string, c net.Conn, ac auth.Controller) error
// CloseFunc is a callback function for closing all listener clients.
type CloseFunc func(id string)
// Listener is an interface for network listeners. A network listener listens
// for incoming client connections and adds them to the server.
type Listener interface {
SetConfig(*Config) // set the listener config.
Listen(s *system.Info) error // open the network address.
Serve(EstablishFunc) // starting actively listening for new connections.
ID() string // return the id of the listener.
Close(CloseFunc) // stop and close the listener.
}
// Listeners contains the network listeners for the broker.
type Listeners struct {
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
internal map[string]Listener // a map of active listeners.
system *system.Info // pointers to system info.
sync.RWMutex
}
// New returns a new instance of Listeners.
func New(s *system.Info) *Listeners {
return &Listeners{
internal: map[string]Listener{},
system: s,
}
}
// Add adds a new listener to the listeners map, keyed on id.
func (l *Listeners) Add(val Listener) {
l.Lock()
l.internal[val.ID()] = val
l.Unlock()
}
// Get returns the value of a listener if it exists.
func (l *Listeners) Get(id string) (Listener, bool) {
l.RLock()
val, ok := l.internal[id]
l.RUnlock()
return val, ok
}
// Len returns the length of the listeners map.
func (l *Listeners) Len() int {
l.RLock()
val := len(l.internal)
l.RUnlock()
return val
}
// Delete removes a listener from the internal map.
func (l *Listeners) Delete(id string) {
l.Lock()
delete(l.internal, id)
l.Unlock()
}
// Serve starts a listener serving from the internal map.
func (l *Listeners) Serve(id string, establisher EstablishFunc) {
l.RLock()
listener := l.internal[id]
l.RUnlock()
go func(e EstablishFunc) {
defer l.wg.Done()
l.wg.Add(1)
listener.Serve(e)
}(establisher)
}
// ServeAll starts all listeners serving from the internal map.
func (l *Listeners) ServeAll(establisher EstablishFunc) {
l.RLock()
i := 0
ids := make([]string, len(l.internal))
for id := range l.internal {
ids[i] = id
i++
}
l.RUnlock()
for _, id := range ids {
l.Serve(id, establisher)
}
}
// Close stops a listener from the internal map.
func (l *Listeners) Close(id string, closer CloseFunc) {
l.RLock()
listener := l.internal[id]
l.RUnlock()
listener.Close(closer)
}
// CloseAll iterates and closes all registered listeners.
func (l *Listeners) CloseAll(closer CloseFunc) {
l.RLock()
i := 0
ids := make([]string, len(l.internal))
for id := range l.internal {
ids[i] = id
i++
}
l.RUnlock()
for _, id := range ids {
l.Close(id, closer)
}
l.wg.Wait()
}

View File

@ -0,0 +1,245 @@
package listeners
import (
"crypto/tls"
"log"
"testing"
"time"
"github.com/stretchr/testify/require"
)
var (
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
-----END CERTIFICATE-----`)
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
tlsConfigBasic *tls.Config
)
func init() {
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfigBasic = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
func TestNew(t *testing.T) {
l := New(nil)
require.NotNil(t, l.internal)
}
func BenchmarkNewListeners(b *testing.B) {
for n := 0; n < b.N; n++ {
New(nil)
}
}
func TestAddListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
require.Contains(t, l.internal, "t1")
}
func BenchmarkAddListener(b *testing.B) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
for n := 0; n < b.N; n++ {
l.Add(mocked)
}
}
func TestGetListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
g, ok := l.Get("t1")
require.Equal(t, true, ok)
require.Equal(t, g.ID(), "t1")
}
func BenchmarkGetListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Get("t1")
}
}
func TestLenListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
require.Equal(t, 2, l.Len())
}
func BenchmarkLenListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Len()
}
}
func TestDeleteListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
require.Contains(t, l.internal, "t1")
l.Delete("t1")
_, ok := l.Get("t1")
require.Equal(t, false, ok)
require.Nil(t, l.internal["t1"])
}
func BenchmarkDeleteListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Delete("t1")
}
}
func TestServeListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing())
l.Close("t1", MockCloser)
require.Equal(t, false, l.internal["t1"].(*MockListener).IsServing())
}
func BenchmarkServeListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Serve("t1", MockEstablisher)
}
}
func TestServeAllListeners(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
l.Add(NewMockListener("t3", ":1882"))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing())
require.Equal(t, true, l.internal["t2"].(*MockListener).IsServing())
require.Equal(t, true, l.internal["t3"].(*MockListener).IsServing())
l.Close("t1", MockCloser)
l.Close("t2", MockCloser)
l.Close("t3", MockCloser)
require.Equal(t, false, l.internal["t1"].(*MockListener).IsServing())
require.Equal(t, false, l.internal["t2"].(*MockListener).IsServing())
require.Equal(t, false, l.internal["t3"].(*MockListener).IsServing())
}
func BenchmarkServeAllListeners(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1883"))
l.Add(NewMockListener("t3", ":1884"))
for n := 0; n < b.N; n++ {
l.ServeAll(MockEstablisher)
}
}
func TestCloseListener(t *testing.T) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
l.Add(mocked)
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
var closed bool
l.Close("t1", func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func BenchmarkCloseListener(b *testing.B) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
l.Add(mocked)
l.Serve("t1", MockEstablisher)
for n := 0; n < b.N; n++ {
l.internal["t1"].(*MockListener).done = make(chan bool)
l.Close("t1", MockCloser)
}
}
func TestCloseAllListeners(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
l.Add(NewMockListener("t3", ":1882"))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing())
require.Equal(t, true, l.internal["t2"].(*MockListener).IsServing())
require.Equal(t, true, l.internal["t3"].(*MockListener).IsServing())
closed := make(map[string]bool)
l.CloseAll(func(id string) {
closed[id] = true
})
require.Contains(t, closed, "t1")
require.Contains(t, closed, "t2")
require.Contains(t, closed, "t3")
require.Equal(t, true, closed["t1"])
require.Equal(t, true, closed["t2"])
require.Equal(t, true, closed["t3"])
}
func BenchmarkCloseAllListeners(b *testing.B) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
l.Add(mocked)
l.Serve("t1", MockEstablisher)
for n := 0; n < b.N; n++ {
l.internal["t1"].(*MockListener).done = make(chan bool)
l.Close("t1", MockCloser)
}
}

100
server/listeners/mock.go Normal file
View File

@ -0,0 +1,100 @@
package listeners
import (
"fmt"
"net"
"sync"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
// MockCloser is a function signature which can be used in testing.
func MockCloser(id string) {}
// MockEstablisher is a function signature which can be used in testing.
func MockEstablisher(id string, c net.Conn, ac auth.Controller) error {
return nil
}
// MockListener is a mock listener for establishing client connections.
type MockListener struct {
sync.RWMutex
id string // the id of the listener.
address string // the network address the listener binds to.
Config *Config // configuration for the listener.
done chan bool // indicate the listener is done.
Serving bool // indicate the listener is serving.
Listening bool // indiciate the listener is listening.
ErrListen bool // throw an error on listen.
}
// NewMockListener returns a new instance of MockListener
func NewMockListener(id, address string) *MockListener {
return &MockListener{
id: id,
address: address,
done: make(chan bool),
}
}
// Serve serves the mock listener.
func (l *MockListener) Serve(establisher EstablishFunc) {
l.Lock()
l.Serving = true
l.Unlock()
for range l.done {
return
}
}
// Listen begins listening for incoming traffic.
func (l *MockListener) Listen(s *system.Info) error {
if l.ErrListen {
return fmt.Errorf("listen failure")
}
l.Lock()
l.Listening = true
l.Unlock()
return nil
}
// SetConfig sets the configuration values of the mock listener.
func (l *MockListener) SetConfig(config *Config) {
l.Lock()
l.Config = config
l.Unlock()
}
// ID returns the id of the mock listener.
func (l *MockListener) ID() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
// Close closes the mock listener.
func (l *MockListener) Close(closer CloseFunc) {
l.Lock()
defer l.Unlock()
l.Serving = false
closer(l.id)
close(l.done)
}
// IsServing indicates whether the mock listener is serving.
func (l *MockListener) IsServing() bool {
l.Lock()
defer l.Unlock()
return l.Serving
}
// IsListening indicates whether the mock listener is listening.
func (l *MockListener) IsListening() bool {
l.Lock()
defer l.Unlock()
return l.Listening
}

View File

@ -0,0 +1,89 @@
package listeners
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/server/listeners/auth"
)
func TestMockEstablisher(t *testing.T) {
_, w := net.Pipe()
err := MockEstablisher("t1", w, new(auth.Allow))
require.NoError(t, err)
w.Close()
}
func TestNewMockListener(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, "t1", mocked.id)
require.Equal(t, ":1882", mocked.address)
}
func TestNewMockListenerListen(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, "t1", mocked.id)
require.Equal(t, ":1882", mocked.address)
require.Equal(t, false, mocked.IsListening())
err := mocked.Listen(nil)
require.NoError(t, err)
require.Equal(t, true, mocked.IsListening())
}
func TestNewMockListenerListenFailure(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
mocked.ErrListen = true
err := mocked.Listen(nil)
require.Error(t, err)
}
func TestMockListenerServe(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, false, mocked.IsServing())
o := make(chan bool)
go func(o chan bool) {
mocked.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond) // easy non-channel wait for start of serving
require.Equal(t, true, mocked.IsServing())
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
mocked.Listen(nil)
}
func TestMockListenerSetConfig(t *testing.T) {
mocked := NewMockListener("t1", ":1883")
mocked.SetConfig(new(Config))
require.NotNil(t, mocked.Config)
}
func TestMockListenerClose(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func TestNewMockListenerIsListening(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, false, mocked.IsListening())
}
func TestNewMockListenerIsServing(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, false, mocked.IsServing())
}

127
server/listeners/tcp.go Normal file
View File

@ -0,0 +1,127 @@
package listeners
import (
"crypto/tls"
"net"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
// TCP is a listener for establishing client connections on basic TCP protocol.
type TCP struct {
sync.RWMutex
id string // the internal id of the listener.
protocol string // the TCP protocol to use.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
config *Config // configuration values for the listener.
end uint32 // ensure the close methods are only called once.
}
// NewTCP initialises and returns a new TCP listener, listening on an address.
func NewTCP(id, address string) *TCP {
return &TCP{
id: id,
protocol: "tcp",
address: address,
config: &Config{ // default configuration.
Auth: new(auth.Allow),
TLS: new(TLS),
},
}
}
// SetConfig sets the configuration values for the listener config.
func (l *TCP) SetConfig(config *Config) {
l.Lock()
if config != nil {
l.config = config
// If a config has been passed without an auth controller,
// it may be a mistake, so disallow all traffic.
if l.config.Auth == nil {
l.config.Auth = new(auth.Disallow)
}
}
l.Unlock()
}
// ID returns the id of the listener.
func (l *TCP) ID() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
// Listen starts listening on the listener's network address.
func (l *TCP) Listen(s *system.Info) error {
var err error
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
var cert tls.Certificate
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
Certificates: []tls.Certificate{cert},
})
} else if l.config.TLSConfig != nil {
l.listen, err = tls.Listen(l.protocol, l.address, l.config.TLSConfig)
} else {
l.listen, err = net.Listen(l.protocol, l.address)
}
if err != nil {
return err
}
return nil
}
// Serve starts waiting for new TCP connections, and calls the establish
// connection callback for any received.
func (l *TCP) Serve(establish EstablishFunc) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
_ = establish(l.id, conn, l.config.Auth)
}()
}
}
}
// Close closes the listener and any client connections.
func (l *TCP) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

View File

@ -0,0 +1,228 @@
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/stretchr/testify/require"
)
const (
testPort = ":22222"
)
func TestNewTCP(t *testing.T) {
l := NewTCP("t1", testPort)
require.Equal(t, "t1", l.id)
require.Equal(t, testPort, l.address)
}
func BenchmarkNewTCP(b *testing.B) {
for n := 0; n < b.N; n++ {
NewTCP("t1", testPort)
}
}
func TestTCPSetConfig(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
})
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Allow), l.config.Auth)
// Switch to disallow on bad config set.
l.SetConfig(new(Config))
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Disallow), l.config.Auth)
}
func BenchmarkTCPSetConfig(b *testing.B) {
l := NewTCP("t1", testPort)
for n := 0; n < b.N; n++ {
l.SetConfig(new(Config))
}
}
func TestTCPID(t *testing.T) {
l := NewTCP("t1", testPort)
require.Equal(t, "t1", l.ID())
}
func BenchmarkTCPID(b *testing.B) {
l := NewTCP("t1", testPort)
for n := 0; n < b.N; n++ {
l.ID()
}
}
func TestTCPListen(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
l2 := NewTCP("t2", testPort)
err = l2.Listen(nil)
require.Error(t, err)
l.listen.Close()
}
func TestTCPListenTLSConfig(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(nil)
require.NoError(t, err)
l.listen.Close()
}
func TestTCPListenTLS(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.NoError(t, err)
l.listen.Close()
}
func TestTCPListenTLSInvalid(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: []byte("abcde"),
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.Error(t, err)
}
func TestTCPServeAndClose(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
}
func TestTCPServeTLSAndClose(t *testing.T) {
l := NewTCP("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
}
func TestTCPCloseError(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
l.listen.Close()
l.Close(MockCloser)
<-o
}
func TestTCPServeEnd(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
l.Close(MockCloser)
l.Serve(func(id string, c net.Conn, ac auth.Controller) error {
return nil
})
}
func TestTCPEstablishThenError(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn, ac auth.Controller) error {
established <- true
return errors.New("testing") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial(l.protocol, l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}
func TestTCPEstablishButEnding(t *testing.T) {
l := NewTCP("t1", testPort)
err := l.Listen(nil)
require.NoError(t, err)
l.end = 1
o := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn, ac auth.Controller) error {
return nil
})
o <- true
}()
net.Dial(l.protocol, l.listen.Addr().String())
time.Sleep(time.Millisecond)
l.Close(MockCloser)
<-o
}

View File

@ -0,0 +1,178 @@
package listeners
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
var (
// ErrInvalidMessage indicates that a message payload was not valid.
ErrInvalidMessage = errors.New("message type not binary")
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
// websocket compliant connection.
wsUpgrader = &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
CheckOrigin: func(r *http.Request) bool { return true },
}
)
// Websocket is a listener for establishing websocket connections.
type Websocket struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
config *Config // configuration values for the listener.
listen *http.Server // an http server for serving websocket connections.
establish EstablishFunc // the server's establish connection handler.
end uint32 // ensure the close methods are only called once.
}
// wsConn is a websocket connection which satisfies the net.Conn interface.
// Inspired by
type wsConn struct {
net.Conn
c *websocket.Conn
}
// Read reads the next span of bytes from the websocket connection and returns
// the number of bytes read.
func (ws *wsConn) Read(p []byte) (n int, err error) {
op, r, err := ws.c.NextReader()
if err != nil {
return
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return
}
return r.Read(p)
}
// Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (n int, err error) {
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return
}
return len(p), nil
}
// Close signals the underlying websocket conn to close.
func (ws *wsConn) Close() error {
return ws.Conn.Close()
}
// NewWebsocket initialises and returns a new Websocket listener, listening on an address.
func NewWebsocket(id, address string) *Websocket {
return &Websocket{
id: id,
address: address,
config: &Config{
Auth: new(auth.Allow),
TLS: new(TLS),
},
}
}
// SetConfig sets the configuration values for the listener config.
func (l *Websocket) SetConfig(config *Config) {
l.Lock()
if config != nil {
l.config = config
// If a config has been passed without an auth controller,
// it may be a mistake, so disallow all traffic.
if l.config.Auth == nil {
l.config.Auth = new(auth.Disallow)
}
}
l.Unlock()
}
// ID returns the id of the listener.
func (l *Websocket) ID() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
// Listen starts listening on the listener's network address.
func (l *Websocket) Listen(s *system.Info) error {
mux := http.NewServeMux()
mux.HandleFunc("/", l.handler)
l.listen = &http.Server{
Addr: l.address,
Handler: mux,
}
// The following logic is deprecated in favour of passing through the tls.Config
// value directly, however it remains in order to provide backwards compatibility.
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
} else {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil
}
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
c, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
l.establish(l.id, &wsConn{c.UnderlyingConn(), c}, l.config.Auth)
}
// Serve starts waiting for new Websocket connections, and calls the connection
// establishment callback for any received.
func (l *Websocket) Serve(establish EstablishFunc) {
l.establish = establish
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *Websocket) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}

View File

@ -0,0 +1,181 @@
package listeners
import (
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/stretchr/testify/require"
)
func TestWsConnClose(t *testing.T) {
r, _ := net.Pipe()
ws := &wsConn{r, new(websocket.Conn)}
err := ws.Close()
require.NoError(t, err)
}
func TestNewWebsocket(t *testing.T) {
l := NewWebsocket("t1", testPort)
require.Equal(t, "t1", l.id)
require.Equal(t, testPort, l.address)
}
func BenchmarkNewWebsocket(b *testing.B) {
for n := 0; n < b.N; n++ {
NewWebsocket("t1", testPort)
}
}
func TestWebsocketSetConfig(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
})
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Allow), l.config.Auth)
// Switch to disallow on bad config set.
l.SetConfig(new(Config))
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Disallow), l.config.Auth)
}
func BenchmarkWebsocketSetConfig(b *testing.B) {
l := NewWebsocket("t1", testPort)
for n := 0; n < b.N; n++ {
l.SetConfig(new(Config))
}
}
func TestWebsocketID(t *testing.T) {
l := NewWebsocket("t1", testPort)
require.Equal(t, "t1", l.ID())
}
func BenchmarkWebsocketID(b *testing.B) {
l := NewWebsocket("t1", testPort)
for n := 0; n < b.N; n++ {
l.ID()
}
}
func TestWebsocketListen(t *testing.T) {
l := NewWebsocket("t1", testPort)
require.Nil(t, l.listen)
err := l.Listen(nil)
require.NoError(t, err)
require.NotNil(t, l.listen)
}
func TestWebsocketListenTLSConfig(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLSConfig: tlsConfigBasic,
})
err := l.Listen(nil)
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestWebsocketListenTLS(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestWebsocketListenTLSInvalid(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: []byte("abcde"),
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.Error(t, err)
}
func TestWebsocketServeAndClose(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.Listen(nil)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
}
func TestWebsocketServeTLSAndClose(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func TestWebsocketUpgrade(t *testing.T) {
l := NewWebsocket("t1", testPort)
l.Listen(nil)
e := make(chan bool)
l.establish = func(id string, c net.Conn, ac auth.Controller) error {
e <- true
return nil
}
s := httptest.NewServer(http.HandlerFunc(l.handler))
u := "ws" + strings.TrimPrefix(s.URL, "http")
ws, _, err := websocket.DefaultDialer.Dial(u, nil)
require.NoError(t, err)
require.Equal(t, true, <-e)
s.Close()
ws.Close()
}

View File

@ -0,0 +1,266 @@
package bolt
import (
"fmt"
"time"
sgob "github.com/asdine/storm/codec/gob"
"github.com/asdine/storm/v3"
"go.etcd.io/bbolt"
"github.com/mochi-co/mqtt/server/persistence"
)
const (
// defaultPath is the default file to use to store the data.
defaultPath = "mochi.db"
// defaultTimeout is the default timeout of the file lock.
defaultTimeout = 250 * time.Millisecond
)
var (
// ErrDBNotOpen indicates the bolt db file is not open for reading.
ErrDBNotOpen = fmt.Errorf("boltdb not opened")
)
// Store is a backend for writing and reading to bolt persistent storage.
type Store struct {
path string // the path on which to store the db file.
opts *bbolt.Options // options for configuring the boltdb instance.
db *storm.DB // the boltdb instance.
}
// New returns a configured instance of the boltdb store.
func New(path string, opts *bbolt.Options) *Store {
if path == "" || path == "." {
path = defaultPath
}
if opts == nil {
opts = &bbolt.Options{
Timeout: defaultTimeout,
}
}
return &Store{
path: path,
opts: opts,
}
}
// Open opens the boltdb instance.
func (s *Store) Open() error {
var err error
s.db, err = storm.Open(s.path, storm.BoltOptions(0600, s.opts), storm.Codec(sgob.Codec))
if err != nil {
return err
}
return nil
}
// Close closes the boltdb instance.
func (s *Store) Close() {
s.db.Close()
}
// WriteServerInfo writes the server info to the boltdb instance.
func (s *Store) WriteServerInfo(v persistence.ServerInfo) error {
if s.db == nil {
return ErrDBNotOpen
}
err := s.db.Save(&v)
if err != nil {
return err
}
return nil
}
// WriteSubscription writes a single subscription to the boltdb instance.
func (s *Store) WriteSubscription(v persistence.Subscription) error {
if s.db == nil {
return ErrDBNotOpen
}
err := s.db.Save(&v)
if err != nil {
return err
}
return nil
}
// WriteInflight writes a single inflight message to the boltdb instance.
func (s *Store) WriteInflight(v persistence.Message) error {
if s.db == nil {
return ErrDBNotOpen
}
err := s.db.Save(&v)
if err != nil {
return err
}
return nil
}
// WriteRetained writes a single retained message to the boltdb instance.
func (s *Store) WriteRetained(v persistence.Message) error {
if s.db == nil {
return ErrDBNotOpen
}
err := s.db.Save(&v)
if err != nil {
return err
}
return nil
}
// WriteClient writes a single client to the boltdb instance.
func (s *Store) WriteClient(v persistence.Client) error {
if s.db == nil {
return ErrDBNotOpen
}
err := s.db.Save(&v)
if err != nil {
return err
}
return nil
}
// DeleteSubscription deletes a subscription from the boltdb instance.
func (s *Store) DeleteSubscription(id string) error {
if s.db == nil {
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Subscription{
ID: id,
})
if err != nil {
return err
}
return nil
}
// DeleteClient deletes a client from the boltdb instance.
func (s *Store) DeleteClient(id string) error {
if s.db == nil {
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Client{
ID: id,
})
if err != nil {
return err
}
return nil
}
// DeleteInflight deletes an inflight message from the boltdb instance.
func (s *Store) DeleteInflight(id string) error {
if s.db == nil {
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Message{
ID: id,
})
if err != nil {
return err
}
return nil
}
// DeleteRetained deletes a retained message from the boltdb instance.
func (s *Store) DeleteRetained(id string) error {
if s.db == nil {
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Message{
ID: id,
})
if err != nil {
return err
}
return nil
}
// ReadSubscriptions loads all the subscriptions from the boltdb instance.
func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) {
if s.db == nil {
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KSubscription, &v)
if err != nil && err != storm.ErrNotFound {
return
}
return v, nil
}
// ReadClients loads all the clients from the boltdb instance.
func (s *Store) ReadClients() (v []persistence.Client, err error) {
if s.db == nil {
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KClient, &v)
if err != nil && err != storm.ErrNotFound {
return
}
return v, nil
}
// ReadInflight loads all the inflight messages from the boltdb instance.
func (s *Store) ReadInflight() (v []persistence.Message, err error) {
if s.db == nil {
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KInflight, &v)
if err != nil && err != storm.ErrNotFound {
return
}
return v, nil
}
// ReadRetained loads all the retained messages from the boltdb instance.
func (s *Store) ReadRetained() (v []persistence.Message, err error) {
if s.db == nil {
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KRetained, &v)
if err != nil && err != storm.ErrNotFound {
return
}
return v, nil
}
//ReadServerInfo loads the server info from the boltdb instance.
func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
if s.db == nil {
return v, ErrDBNotOpen
}
err = s.db.One("ID", persistence.KServerInfo, &v)
if err != nil && err != storm.ErrNotFound {
return
}
return v, nil
}

View File

@ -0,0 +1,486 @@
package bolt
import (
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.etcd.io/bbolt"
"github.com/mochi-co/mqtt/server/persistence"
"github.com/mochi-co/mqtt/server/system"
)
const tmpPath = "testbolt.db"
func teardown(s *Store, t *testing.T) {
s.Close()
err := os.Remove(tmpPath)
require.NoError(t, err)
}
func TestSatsifies(t *testing.T) {
var x persistence.Store
x = New(tmpPath, &bbolt.Options{
Timeout: 500 * time.Millisecond,
})
require.NotNil(t, x)
}
func TestNew(t *testing.T) {
s := New(tmpPath, &bbolt.Options{
Timeout: 500 * time.Millisecond,
})
require.NotNil(t, s)
require.Equal(t, tmpPath, s.path)
require.Equal(t, 500*time.Millisecond, s.opts.Timeout)
}
func TestNewNoPath(t *testing.T) {
s := New("", nil)
require.NotNil(t, s)
require.Equal(t, defaultPath, s.path)
}
func TestNewNoOpts(t *testing.T) {
s := New("", nil)
require.NotNil(t, s)
require.Equal(t, defaultTimeout, s.opts.Timeout)
}
func TestOpen(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
defer teardown(s, t)
require.NotNil(t, s.db)
}
func TestOpenFailure(t *testing.T) {
s := New("..", nil)
err := s.Open()
require.Error(t, err)
}
func TestWriteAndRetrieveServerInfo(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
defer teardown(s, t)
v := system.Info{
Version: "test",
Started: 100,
}
err = s.WriteServerInfo(persistence.ServerInfo{
Info: v,
ID: persistence.KServerInfo,
})
require.NoError(t, err)
r, err := s.ReadServerInfo()
require.NoError(t, err)
require.NotNil(t, r)
require.Equal(t, v.Version, r.Version)
require.Equal(t, v.Started, r.Started)
}
func TestWriteServerInfoNoDB(t *testing.T) {
s := New(tmpPath, nil)
err := s.WriteServerInfo(persistence.ServerInfo{})
require.Error(t, err)
}
func TestWriteServerInfoFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
err = os.Remove(tmpPath)
require.NoError(t, err)
err = s.WriteServerInfo(persistence.ServerInfo{})
require.Error(t, err)
}
func TestReadServerInfoNoDB(t *testing.T) {
s := New(tmpPath, nil)
_, err := s.ReadServerInfo()
require.Error(t, err)
}
func TestReadServerInfoFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
_, err = s.ReadServerInfo()
require.Error(t, err)
}
func TestWriteRetrieveDeleteSubscription(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
defer teardown(s, t)
v := persistence.Subscription{
ID: "test:a/b/c",
Client: "test",
Filter: "a/b/c",
QoS: 1,
T: persistence.KSubscription,
}
err = s.WriteSubscription(v)
require.NoError(t, err)
v2 := persistence.Subscription{
ID: "test:d/e/f",
Client: "test",
Filter: "d/e/f",
QoS: 2,
T: persistence.KSubscription,
}
err = s.WriteSubscription(v2)
require.NoError(t, err)
subs, err := s.ReadSubscriptions()
require.NoError(t, err)
require.Equal(t, persistence.KSubscription, subs[0].T)
require.Equal(t, 2, len(subs))
err = s.DeleteSubscription("test:d/e/f")
require.NoError(t, err)
subs, err = s.ReadSubscriptions()
require.NoError(t, err)
require.Equal(t, 1, len(subs))
}
func TestWriteSubscriptionNoDB(t *testing.T) {
s := New(tmpPath, nil)
err := s.WriteSubscription(persistence.Subscription{})
require.Error(t, err)
}
func TestWriteSubscriptionFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
err = s.WriteSubscription(persistence.Subscription{})
require.Error(t, err)
}
func TestReadSubscriptionNoDB(t *testing.T) {
s := New(tmpPath, nil)
_, err := s.ReadSubscriptions()
require.Error(t, err)
}
func TestReadSubscriptionFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
_, err = s.ReadSubscriptions()
require.Error(t, err)
}
func TestWriteRetrieveDeleteInflight(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
defer teardown(s, t)
v := persistence.Message{
ID: "client1_if_0",
T: persistence.KInflight,
PacketID: 0,
TopicName: "a/b/c",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
Sent: 100,
Resends: 0,
}
err = s.WriteInflight(v)
require.NoError(t, err)
v2 := persistence.Message{
ID: "client1_if_100",
T: persistence.KInflight,
PacketID: 100,
TopicName: "d/e/f",
Payload: []byte{'y', 'e', 's'},
Sent: 200,
Resends: 1,
}
err = s.WriteInflight(v2)
require.NoError(t, err)
msgs, err := s.ReadInflight()
require.NoError(t, err)
require.Equal(t, persistence.KInflight, msgs[0].T)
require.Equal(t, 2, len(msgs))
err = s.DeleteInflight("client1_if_100")
require.NoError(t, err)
msgs, err = s.ReadInflight()
require.NoError(t, err)
require.Equal(t, 1, len(msgs))
}
func TestWriteInflightNoDB(t *testing.T) {
s := New(tmpPath, nil)
err := s.WriteInflight(persistence.Message{})
require.Error(t, err)
}
func TestWriteInflightFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
err = s.WriteInflight(persistence.Message{})
require.Error(t, err)
}
func TestReadInflightNoDB(t *testing.T) {
s := New(tmpPath, nil)
_, err := s.ReadInflight()
require.Error(t, err)
}
func TestReadInflightFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
_, err = s.ReadInflight()
require.Error(t, err)
}
func TestWriteRetrieveDeleteRetained(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
defer teardown(s, t)
v := persistence.Message{
ID: "client1_ret_200",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{
Retain: true,
},
PacketID: 200,
TopicName: "a/b/c",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
Sent: 100,
Resends: 0,
}
err = s.WriteRetained(v)
require.NoError(t, err)
v2 := persistence.Message{
ID: "client1_ret_300",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{
Retain: true,
},
PacketID: 100,
TopicName: "d/e/f",
Payload: []byte{'y', 'e', 's'},
Sent: 200,
Resends: 1,
}
err = s.WriteRetained(v2)
require.NoError(t, err)
msgs, err := s.ReadRetained()
require.NoError(t, err)
require.Equal(t, persistence.KRetained, msgs[0].T)
require.Equal(t, true, msgs[0].FixedHeader.Retain)
require.Equal(t, 2, len(msgs))
err = s.DeleteRetained("client1_ret_300")
require.NoError(t, err)
msgs, err = s.ReadRetained()
require.NoError(t, err)
require.Equal(t, 1, len(msgs))
}
func TestWriteRetainedNoDB(t *testing.T) {
s := New(tmpPath, nil)
err := s.WriteRetained(persistence.Message{})
require.Error(t, err)
}
func TestWriteRetainedFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
err = os.Remove(tmpPath)
require.NoError(t, err)
err = s.WriteRetained(persistence.Message{})
require.Error(t, err)
}
func TestReadRetainedNoDB(t *testing.T) {
s := New(tmpPath, nil)
_, err := s.ReadRetained()
require.Error(t, err)
}
func TestReadRetainedFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
_, err = s.ReadRetained()
require.Error(t, err)
}
func TestWriteRetrieveDeleteClients(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
defer teardown(s, t)
v := persistence.Client{
ID: "cl_client1",
ClientID: "client1",
T: persistence.KClient,
Listener: "tcp1",
Username: []byte{'m', 'o', 'c', 'h', 'i'},
LWT: persistence.LWT{
Topic: "a/b/c",
Message: []byte{'h', 'e', 'l', 'l', 'o'},
Qos: 1,
Retain: true,
},
}
err = s.WriteClient(v)
require.NoError(t, err)
clients, err := s.ReadClients()
require.NoError(t, err)
require.Equal(t, []byte{'m', 'o', 'c', 'h', 'i'}, clients[0].Username)
require.Equal(t, "a/b/c", clients[0].LWT.Topic)
v2 := persistence.Client{
ID: "cl_client2",
ClientID: "client2",
T: persistence.KClient,
Listener: "tcp1",
}
err = s.WriteClient(v2)
require.NoError(t, err)
clients, err = s.ReadClients()
require.NoError(t, err)
require.Equal(t, persistence.KClient, clients[0].T)
require.Equal(t, 2, len(clients))
err = s.DeleteClient("cl_client2")
require.NoError(t, err)
clients, err = s.ReadClients()
require.NoError(t, err)
require.Equal(t, 1, len(clients))
}
func TestWriteClientNoDB(t *testing.T) {
s := New(tmpPath, nil)
err := s.WriteClient(persistence.Client{})
require.Error(t, err)
}
func TestWriteClientFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
err = s.WriteClient(persistence.Client{})
require.Error(t, err)
}
func TestReadClientNoDB(t *testing.T) {
s := New(tmpPath, nil)
_, err := s.ReadClients()
require.Error(t, err)
}
func TestReadClientFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
_, err = s.ReadClients()
require.Error(t, err)
}
func TestDeleteSubscriptionNoDB(t *testing.T) {
s := New(tmpPath, nil)
err := s.DeleteSubscription("a")
require.Error(t, err)
}
func TestDeleteSubscriptionFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
err = s.DeleteSubscription("a")
require.Error(t, err)
}
func TestDeleteClientNoDB(t *testing.T) {
s := New(tmpPath, nil)
err := s.DeleteClient("a")
require.Error(t, err)
}
func TestDeleteClientFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
err = s.DeleteClient("a")
require.Error(t, err)
}
func TestDeleteInflightNoDB(t *testing.T) {
s := New(tmpPath, nil)
err := s.DeleteInflight("a")
require.Error(t, err)
}
func TestDeleteInflightFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
err = s.DeleteInflight("a")
require.Error(t, err)
}
func TestDeleteRetainedNoDB(t *testing.T) {
s := New(tmpPath, nil)
err := s.DeleteRetained("a")
require.Error(t, err)
}
func TestDeleteRetainedFail(t *testing.T) {
s := New(tmpPath, nil)
err := s.Open()
require.NoError(t, err)
s.Close()
err = s.DeleteRetained("a")
require.Error(t, err)
}

View File

@ -0,0 +1,291 @@
package persistence
import (
"errors"
"github.com/mochi-co/mqtt/server/system"
)
const (
// KSubscription is the key for subscription data.
KSubscription = "sub"
// KServerInfo is the key for server info data.
KServerInfo = "srv"
// KRetained is the key for retained messages data.
KRetained = "ret"
// KInflight is the key for inflight messages data.
KInflight = "ifm"
// KClient is the key for client data.
KClient = "cl"
)
// Store is an interface which details a persistent storage connector.
type Store interface {
Open() error
Close()
WriteSubscription(v Subscription) error
WriteClient(v Client) error
WriteInflight(v Message) error
WriteServerInfo(v ServerInfo) error
WriteRetained(v Message) error
DeleteSubscription(id string) error
DeleteClient(id string) error
DeleteInflight(id string) error
DeleteRetained(id string) error
ReadSubscriptions() (v []Subscription, err error)
ReadInflight() (v []Message, err error)
ReadRetained() (v []Message, err error)
ReadClients() (v []Client, err error)
ReadServerInfo() (v ServerInfo, err error)
}
// ServerInfo contains information and statistics about the server.
type ServerInfo struct {
system.Info // embed the system info struct.
ID string // the storage key.
}
// Subscription contains the details of a topic filter subscription.
type Subscription struct {
ID string // the storage key.
T string // the type of the stored data.
Client string // the id of the client who the subscription belongs to.
Filter string // the topic filter being subscribed to.
QoS byte // the desired QoS byte.
}
// Message contains the details of a retained or inflight message.
type Message struct {
Payload []byte // the message payload (if retained).
FixedHeader FixedHeader // the header properties of the message.
T string // the type of the stored data.
ID string // the storage key.
Client string // the id of the client who sent the message (if inflight).
TopicName string // the topic the message was sent to (if retained).
Sent int64 // the last time the message was sent (for retries) in unixtime (if inflight).
Resends int // the number of times the message was attempted to be sent (if inflight).
PacketID uint16 // the unique id of the packet (if inflight).
}
// FixedHeader contains the fixed header properties of a message.
type FixedHeader struct {
Remaining int // the number of remaining bytes in the payload.
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte // indicates the quality of service expected.
Dup bool // indicates if the packet was already sent at an earlier time.
Retain bool // whether the message should be retained.
}
// Client contains client data that can be persistently stored.
type Client struct {
LWT LWT // the last-will-and-testament message for the client.
Username []byte // the username the client authenticated with.
ID string // the storage key.
ClientID string // the id of the client.
T string // the type of the stored data.
Listener string // the last known listener id for the client
}
// LWT contains details about a clients LWT payload.
type LWT struct {
Message []byte // the message that shall be sent when the client disconnects.
Topic string // the topic the will message shall be sent to.
Qos byte // the quality of service desired.
Retain bool // indicates whether the will message should be retained
}
// MockStore is a mock storage backend for testing.
type MockStore struct {
Fail map[string]bool // issue errors for different methods.
FailOpen bool // error on open.
Closed bool // indicate mock store is closed.
Opened bool // indicate mock store is open.
}
// Open opens the storage instance.
func (s *MockStore) Open() error {
if s.FailOpen {
return errors.New("test")
}
s.Opened = true
return nil
}
// Close closes the storage instance.
func (s *MockStore) Close() {
s.Closed = true
}
// WriteSubscription writes a single subscription to the storage instance.
func (s *MockStore) WriteSubscription(v Subscription) error {
if _, ok := s.Fail["write_subs"]; ok {
return errors.New("test")
}
return nil
}
// WriteClient writes a single client to the storage instance.
func (s *MockStore) WriteClient(v Client) error {
if _, ok := s.Fail["write_clients"]; ok {
return errors.New("test")
}
return nil
}
// WriteInFlight writes a single InFlight message to the storage instance.
func (s *MockStore) WriteInflight(v Message) error {
if _, ok := s.Fail["write_inflight"]; ok {
return errors.New("test")
}
return nil
}
// WriteRetained writes a single retained message to the storage instance.
func (s *MockStore) WriteRetained(v Message) error {
if _, ok := s.Fail["write_retained"]; ok {
return errors.New("test")
}
return nil
}
// WriteServerInfo writes server info to the storage instance.
func (s *MockStore) WriteServerInfo(v ServerInfo) error {
if _, ok := s.Fail["write_info"]; ok {
return errors.New("test")
}
return nil
}
// DeleteSubscription deletes a subscription from the persistent store.
func (s *MockStore) DeleteSubscription(id string) error {
if _, ok := s.Fail["delete_subs"]; ok {
return errors.New("test")
}
return nil
}
// DeleteClient deletes a client from the persistent store.
func (s *MockStore) DeleteClient(id string) error {
if _, ok := s.Fail["delete_clients"]; ok {
return errors.New("test")
}
return nil
}
// DeleteInflight deletes an inflight message from the persistent store.
func (s *MockStore) DeleteInflight(id string) error {
if _, ok := s.Fail["delete_inflight"]; ok {
return errors.New("test")
}
return nil
}
// DeleteRetained deletes a retained message from the persistent store.
func (s *MockStore) DeleteRetained(id string) error {
if _, ok := s.Fail["delete_retained"]; ok {
return errors.New("test")
}
return nil
}
// ReadSubscriptions loads the subscriptions from the storage instance.
func (s *MockStore) ReadSubscriptions() (v []Subscription, err error) {
if _, ok := s.Fail["read_subs"]; ok {
return v, errors.New("test_subs")
}
return []Subscription{
{
ID: "test:a/b/c",
Client: "test",
Filter: "a/b/c",
QoS: 1,
T: KSubscription,
},
}, nil
}
// ReadClients loads the clients from the storage instance.
func (s *MockStore) ReadClients() (v []Client, err error) {
if _, ok := s.Fail["read_clients"]; ok {
return v, errors.New("test_clients")
}
return []Client{
{
ID: "cl_client1",
ClientID: "client1",
T: KClient,
Listener: "tcp1",
},
}, nil
}
// ReadInflight loads the inflight messages from the storage instance.
func (s *MockStore) ReadInflight() (v []Message, err error) {
if _, ok := s.Fail["read_inflight"]; ok {
return v, errors.New("test_inflight")
}
return []Message{
{
ID: "client1_if_100",
T: KInflight,
Client: "client1",
PacketID: 100,
TopicName: "d/e/f",
Payload: []byte{'y', 'e', 's'},
Sent: 200,
Resends: 1,
},
}, nil
}
// ReadRetained loads the retained messages from the storage instance.
func (s *MockStore) ReadRetained() (v []Message, err error) {
if _, ok := s.Fail["read_retained"]; ok {
return v, errors.New("test_retained")
}
return []Message{
{
ID: "client1_ret_200",
T: KRetained,
FixedHeader: FixedHeader{
Retain: true,
},
PacketID: 200,
TopicName: "a/b/c",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
Sent: 100,
Resends: 0,
},
}, nil
}
//ReadServerInfo loads the server info from the storage instance.
func (s *MockStore) ReadServerInfo() (v ServerInfo, err error) {
if _, ok := s.Fail["read_info"]; ok {
return v, errors.New("test_info")
}
return ServerInfo{
system.Info{
Version: "test",
Started: 100,
},
KServerInfo,
}, nil
}

View File

@ -0,0 +1,251 @@
package persistence
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMockStoreOpen(t *testing.T) {
s := new(MockStore)
err := s.Open()
require.NoError(t, err)
require.Equal(t, true, s.Opened)
}
func TestMockStoreOpenFail(t *testing.T) {
s := new(MockStore)
s.FailOpen = true
err := s.Open()
require.Error(t, err)
}
func TestMockStoreClose(t *testing.T) {
s := new(MockStore)
s.Close()
require.Equal(t, true, s.Closed)
}
func TestMockStoreWriteSubscription(t *testing.T) {
s := new(MockStore)
err := s.WriteSubscription(Subscription{})
require.NoError(t, err)
}
func TestMockStoreWriteSubscriptionFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"write_subs": true,
},
}
err := s.WriteSubscription(Subscription{})
require.Error(t, err)
}
func TestMockStoreWriteClient(t *testing.T) {
s := new(MockStore)
err := s.WriteClient(Client{})
require.NoError(t, err)
}
func TestMockStoreWriteClientFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"write_clients": true,
},
}
err := s.WriteClient(Client{})
require.Error(t, err)
}
func TestMockStoreWriteInflight(t *testing.T) {
s := new(MockStore)
err := s.WriteInflight(Message{})
require.NoError(t, err)
}
func TestMockStoreWriteInflightFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"write_inflight": true,
},
}
err := s.WriteInflight(Message{})
require.Error(t, err)
}
func TestMockStoreWriteRetained(t *testing.T) {
s := new(MockStore)
err := s.WriteRetained(Message{})
require.NoError(t, err)
}
func TestMockStoreWriteRetainedFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"write_retained": true,
},
}
err := s.WriteRetained(Message{})
require.Error(t, err)
}
func TestMockStoreWriteServerInfo(t *testing.T) {
s := new(MockStore)
err := s.WriteServerInfo(ServerInfo{})
require.NoError(t, err)
}
func TestMockStoreWriteServerInfoFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"write_info": true,
},
}
err := s.WriteServerInfo(ServerInfo{})
require.Error(t, err)
}
func TestMockStoreDeleteSubscription(t *testing.T) {
s := new(MockStore)
err := s.DeleteSubscription("client1:d/e/f")
require.NoError(t, err)
}
func TestMockStoreDeleteSubscriptionFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"delete_subs": true,
},
}
err := s.DeleteSubscription("client1:a/b/c")
require.Error(t, err)
}
func TestMockStoreDeleteClient(t *testing.T) {
s := new(MockStore)
err := s.DeleteClient("client1")
require.NoError(t, err)
}
func TestMockStoreDeleteClientFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"delete_clients": true,
},
}
err := s.DeleteClient("client1")
require.Error(t, err)
}
func TestMockStoreDeleteInflight(t *testing.T) {
s := new(MockStore)
err := s.DeleteInflight("client1-if-100")
require.NoError(t, err)
}
func TestMockStoreDeleteInflightFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"delete_inflight": true,
},
}
err := s.DeleteInflight("client1-if-100")
require.Error(t, err)
}
func TestMockStoreDeleteRetained(t *testing.T) {
s := new(MockStore)
err := s.DeleteRetained("client1-ret-100")
require.NoError(t, err)
}
func TestMockStoreDeleteRetainedFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"delete_retained": true,
},
}
err := s.DeleteRetained("client1-ret-100")
require.Error(t, err)
}
func TestMockStorReadServerInfo(t *testing.T) {
s := new(MockStore)
_, err := s.ReadServerInfo()
require.NoError(t, err)
}
func TestMockStorReadServerInfoFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"read_info": true,
},
}
_, err := s.ReadServerInfo()
require.Error(t, err)
}
func TestMockStoreReadSubscriptions(t *testing.T) {
s := new(MockStore)
_, err := s.ReadSubscriptions()
require.NoError(t, err)
}
func TestMockStoreReadSubscriptionsFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"read_subs": true,
},
}
_, err := s.ReadSubscriptions()
require.Error(t, err)
}
func TestMockStoreReadClients(t *testing.T) {
s := new(MockStore)
_, err := s.ReadClients()
require.NoError(t, err)
}
func TestMockStoreReadClientsFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"read_clients": true,
},
}
_, err := s.ReadClients()
require.Error(t, err)
}
func TestMockStoreReadInflight(t *testing.T) {
s := new(MockStore)
_, err := s.ReadInflight()
require.NoError(t, err)
}
func TestMockStoreReadInflightFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"read_inflight": true,
},
}
_, err := s.ReadInflight()
require.Error(t, err)
}
func TestMockStoreReadRetained(t *testing.T) {
s := new(MockStore)
_, err := s.ReadRetained()
require.NoError(t, err)
}
func TestMockStoreReadRetainedFail(t *testing.T) {
s := &MockStore{
Fail: map[string]bool{
"read_retained": true,
},
}
_, err := s.ReadRetained()
require.Error(t, err)
}

1066
server/server.go Normal file

File diff suppressed because it is too large Load Diff

2706
server/server_test.go Normal file

File diff suppressed because it is too large Load Diff

24
server/system/system.go Normal file
View File

@ -0,0 +1,24 @@
package system
// Info contains atomic counters and values for various server statistics
// commonly found in $SYS topics.
type Info struct {
Version string `json:"version"` // the current version of the server.
Started int64 `json:"started"` // the time the server started in unix seconds.
Uptime int64 `json:"uptime"` // the number of seconds the server has been online.
BytesRecv int64 `json:"bytes_recv"` // the total number of bytes received in all packets.
BytesSent int64 `json:"bytes_sent"` // the total number of bytes sent to clients.
ClientsConnected int64 `json:"clients_connected"` // the number of currently connected clients.
ClientsDisconnected int64 `json:"clients_disconnected"` // the number of disconnected non-cleansession clients.
ClientsMax int64 `json:"clients_max"` // the maximum number of clients that have been concurrently connected.
ClientsTotal int64 `json:"clients_total"` // the sum of all clients, connected and disconnected.
ConnectionsTotal int64 `json:"connections_total"` // the sum number of clients which have ever connected.
MessagesRecv int64 `json:"messages_recv"` // the total number of packets received.
MessagesSent int64 `json:"messages_sent"` // the total number of packets sent.
PublishDropped int64 `json:"publish_dropped"` // the number of in-flight publish messages which were dropped.
PublishRecv int64 `json:"publish_recv"` // the total number of received publish packets.
PublishSent int64 `json:"publish_sent"` // the total number of sent publish packets.
Retained int64 `json:"retained"` // the number of messages currently retained.
Inflight int64 `json:"inflight"` // the number of messages currently in-flight.
Subscriptions int64 `json:"subscriptions"` // the total number of filter subscriptions.
}

View File

@ -0,0 +1,21 @@
package system
import (
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
func TestInfoAlignment(t *testing.T) {
typ := reflect.TypeOf(Info{})
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
switch f.Type.Kind() {
case reflect.Int64, reflect.Uint64:
require.Equalf(t, uintptr(0), f.Offset%8,
"%s requires 64-bit alignment for atomic: offset %d",
f.Name, f.Offset)
}
}
}

21
vendor/github.com/asdine/storm/LICENSE generated vendored Normal file
View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) [2017] [Asdine El Hrychy]
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

35
vendor/github.com/asdine/storm/codec/gob/gob.go generated vendored Normal file
View File

@ -0,0 +1,35 @@
// Package gob contains a codec to encode and decode entities in Gob format
package gob
import (
"bytes"
"encoding/gob"
)
const name = "gob"
// Codec serializing objects using the gob package.
// See https://golang.org/pkg/encoding/gob/
var Codec = new(gobCodec)
type gobCodec int
func (c gobCodec) Marshal(v interface{}) ([]byte, error) {
var b bytes.Buffer
enc := gob.NewEncoder(&b)
err := enc.Encode(v)
if err != nil {
return nil, err
}
return b.Bytes(), nil
}
func (c gobCodec) Unmarshal(b []byte, v interface{}) error {
r := bytes.NewReader(b)
dec := gob.NewDecoder(r)
return dec.Decode(v)
}
func (c gobCodec) Name() string {
return name
}

32
vendor/github.com/asdine/storm/v3/.gitignore generated vendored Normal file
View File

@ -0,0 +1,32 @@
# IDE
.idea/
.vscode/
*.iml
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
# Golang vendor folder
/vendor/

19
vendor/github.com/asdine/storm/v3/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,19 @@
language: go
before_install:
- go get github.com/stretchr/testify
env: GO111MODULE=on
go:
- "1.13.x"
- "1.14.x"
- tip
matrix:
allow_failures:
- go: tip
script:
- go mod vendor
- go test -mod vendor -race -v ./...

21
vendor/github.com/asdine/storm/v3/LICENSE generated vendored Normal file
View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) [2017] [Asdine El Hrychy]
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

643
vendor/github.com/asdine/storm/v3/README.md generated vendored Normal file
View File

@ -0,0 +1,643 @@
# Storm
[![Build Status](https://travis-ci.org/asdine/storm.svg)](https://travis-ci.org/asdine/storm)
[![GoDoc](https://godoc.org/github.com/asdine/storm?status.svg)](https://godoc.org/github.com/asdine/storm)
Storm is a simple and powerful toolkit for [BoltDB](https://github.com/coreos/bbolt). Basically, Storm provides indexes, a wide range of methods to store and fetch data, an advanced query system, and much more.
In addition to the examples below, see also the [examples in the GoDoc](https://godoc.org/github.com/asdine/storm#pkg-examples).
_For extended queries and support for [Badger](https://github.com/dgraph-io/badger), see also [Genji](https://github.com/asdine/genji)_
## Table of Contents
- [Getting Started](#getting-started)
- [Import Storm](#import-storm)
- [Open a database](#open-a-database)
- [Simple CRUD system](#simple-crud-system)
- [Declare your structures](#declare-your-structures)
- [Save your object](#save-your-object)
- [Auto Increment](#auto-increment)
- [Simple queries](#simple-queries)
- [Fetch one object](#fetch-one-object)
- [Fetch multiple objects](#fetch-multiple-objects)
- [Fetch all objects](#fetch-all-objects)
- [Fetch all objects sorted by index](#fetch-all-objects-sorted-by-index)
- [Fetch a range of objects](#fetch-a-range-of-objects)
- [Fetch objects by prefix](#fetch-objects-by-prefix)
- [Skip, Limit and Reverse](#skip-limit-and-reverse)
- [Delete an object](#delete-an-object)
- [Update an object](#update-an-object)
- [Initialize buckets and indexes before saving an object](#initialize-buckets-and-indexes-before-saving-an-object)
- [Drop a bucket](#drop-a-bucket)
- [Re-index a bucket](#re-index-a-bucket)
- [Advanced queries](#advanced-queries)
- [Transactions](#transactions)
- [Options](#options)
- [BoltOptions](#boltoptions)
- [MarshalUnmarshaler](#marshalunmarshaler)
- [Provided Codecs](#provided-codecs)
- [Use existing Bolt connection](#use-existing-bolt-connection)
- [Batch mode](#batch-mode)
- [Nodes and nested buckets](#nodes-and-nested-buckets)
- [Node options](#node-options)
- [Simple Key/Value store](#simple-keyvalue-store)
- [BoltDB](#boltdb)
- [License](#license)
- [Credits](#credits)
## Getting Started
```bash
GO111MODULE=on go get -u github.com/asdine/storm/v3
```
## Import Storm
```go
import "github.com/asdine/storm/v3"
```
## Open a database
Quick way of opening a database
```go
db, err := storm.Open("my.db")
defer db.Close()
```
`Open` can receive multiple options to customize the way it behaves. See [Options](#options) below
## Simple CRUD system
### Declare your structures
```go
type User struct {
ID int // primary key
Group string `storm:"index"` // this field will be indexed
Email string `storm:"unique"` // this field will be indexed with a unique constraint
Name string // this field will not be indexed
Age int `storm:"index"`
}
```
The primary key can be of any type as long as it is not a zero value. Storm will search for the tag `id`, if not present Storm will search for a field named `ID`.
```go
type User struct {
ThePrimaryKey string `storm:"id"`// primary key
Group string `storm:"index"` // this field will be indexed
Email string `storm:"unique"` // this field will be indexed with a unique constraint
Name string // this field will not be indexed
}
```
Storm handles tags in nested structures with the `inline` tag
```go
type Base struct {
Ident bson.ObjectId `storm:"id"`
}
type User struct {
Base `storm:"inline"`
Group string `storm:"index"`
Email string `storm:"unique"`
Name string
CreatedAt time.Time `storm:"index"`
}
```
### Save your object
```go
user := User{
ID: 10,
Group: "staff",
Email: "john@provider.com",
Name: "John",
Age: 21,
CreatedAt: time.Now(),
}
err := db.Save(&user)
// err == nil
user.ID++
err = db.Save(&user)
// err == storm.ErrAlreadyExists
```
That's it.
`Save` creates or updates all the required indexes and buckets, checks the unique constraints and saves the object to the store.
#### Auto Increment
Storm can auto increment integer values so you don't have to worry about that when saving your objects. Also, the new value is automatically inserted in your field.
```go
type Product struct {
Pk int `storm:"id,increment"` // primary key with auto increment
Name string
IntegerField uint64 `storm:"increment"`
IndexedIntegerField uint32 `storm:"index,increment"`
UniqueIntegerField int16 `storm:"unique,increment=100"` // the starting value can be set
}
p := Product{Name: "Vaccum Cleaner"}
fmt.Println(p.Pk)
fmt.Println(p.IntegerField)
fmt.Println(p.IndexedIntegerField)
fmt.Println(p.UniqueIntegerField)
// 0
// 0
// 0
// 0
_ = db.Save(&p)
fmt.Println(p.Pk)
fmt.Println(p.IntegerField)
fmt.Println(p.IndexedIntegerField)
fmt.Println(p.UniqueIntegerField)
// 1
// 1
// 1
// 100
```
### Simple queries
Any object can be fetched, indexed or not. Storm uses indexes when available, otherwise it uses the [query system](#advanced-queries).
#### Fetch one object
```go
var user User
err := db.One("Email", "john@provider.com", &user)
// err == nil
err = db.One("Name", "John", &user)
// err == nil
err = db.One("Name", "Jack", &user)
// err == storm.ErrNotFound
```
#### Fetch multiple objects
```go
var users []User
err := db.Find("Group", "staff", &users)
```
#### Fetch all objects
```go
var users []User
err := db.All(&users)
```
#### Fetch all objects sorted by index
```go
var users []User
err := db.AllByIndex("CreatedAt", &users)
```
#### Fetch a range of objects
```go
var users []User
err := db.Range("Age", 10, 21, &users)
```
#### Fetch objects by prefix
```go
var users []User
err := db.Prefix("Name", "Jo", &users)
```
#### Skip, Limit and Reverse
```go
var users []User
err := db.Find("Group", "staff", &users, storm.Skip(10))
err = db.Find("Group", "staff", &users, storm.Limit(10))
err = db.Find("Group", "staff", &users, storm.Reverse())
err = db.Find("Group", "staff", &users, storm.Limit(10), storm.Skip(10), storm.Reverse())
err = db.All(&users, storm.Limit(10), storm.Skip(10), storm.Reverse())
err = db.AllByIndex("CreatedAt", &users, storm.Limit(10), storm.Skip(10), storm.Reverse())
err = db.Range("Age", 10, 21, &users, storm.Limit(10), storm.Skip(10), storm.Reverse())
```
#### Delete an object
```go
err := db.DeleteStruct(&user)
```
#### Update an object
```go
// Update multiple fields
err := db.Update(&User{ID: 10, Name: "Jack", Age: 45})
// Update a single field
err := db.UpdateField(&User{ID: 10}, "Age", 0)
```
#### Initialize buckets and indexes before saving an object
```go
err := db.Init(&User{})
```
Useful when starting your application
#### Drop a bucket
Using the struct
```go
err := db.Drop(&User)
```
Using the bucket name
```go
err := db.Drop("User")
```
#### Re-index a bucket
```go
err := db.ReIndex(&User{})
```
Useful when the structure has changed
### Advanced queries
For more complex queries, you can use the `Select` method.
`Select` takes any number of [`Matcher`](https://godoc.org/github.com/asdine/storm/q#Matcher) from the [`q`](https://godoc.org/github.com/asdine/storm/q) package.
Here are some common Matchers:
```go
// Equality
q.Eq("Name", John)
// Strictly greater than
q.Gt("Age", 7)
// Lesser than or equal to
q.Lte("Age", 77)
// Regex with name that starts with the letter D
q.Re("Name", "^D")
// In the given slice of values
q.In("Group", []string{"Staff", "Admin"})
// Comparing fields
q.EqF("FieldName", "SecondFieldName")
q.LtF("FieldName", "SecondFieldName")
q.GtF("FieldName", "SecondFieldName")
q.LteF("FieldName", "SecondFieldName")
q.GteF("FieldName", "SecondFieldName")
```
Matchers can also be combined with `And`, `Or` and `Not`:
```go
// Match if all match
q.And(
q.Gt("Age", 7),
q.Re("Name", "^D")
)
// Match if one matches
q.Or(
q.Re("Name", "^A"),
q.Not(
q.Re("Name", "^B")
),
q.Re("Name", "^C"),
q.In("Group", []string{"Staff", "Admin"}),
q.And(
q.StrictEq("Password", []byte(password)),
q.Eq("Registered", true)
)
)
```
You can find the complete list in the [documentation](https://godoc.org/github.com/asdine/storm/q#Matcher).
`Select` takes any number of matchers and wraps them into a `q.And()` so it's not necessary to specify it. It returns a [`Query`](https://godoc.org/github.com/asdine/storm#Query) type.
```go
query := db.Select(q.Gte("Age", 7), q.Lte("Age", 77))
```
The `Query` type contains methods to filter and order the records.
```go
// Limit
query = query.Limit(10)
// Skip
query = query.Skip(20)
// Calls can also be chained
query = query.Limit(10).Skip(20).OrderBy("Age").Reverse()
```
But also to specify how to fetch them.
```go
var users []User
err = query.Find(&users)
var user User
err = query.First(&user)
```
Examples with `Select`:
```go
// Find all users with an ID between 10 and 100
err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Find(&users)
// Nested matchers
err = db.Select(q.Or(
q.Gt("ID", 50),
q.Lt("Age", 21),
q.And(
q.Eq("Group", "admin"),
q.Gte("Age", 21),
),
)).Find(&users)
query := db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name")
// Find multiple records
err = query.Find(&users)
// or
err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name").Find(&users)
// Find first record
err = query.First(&user)
// or
err = db.Select(q.Gte("ID", 10), q.Lte("ID", 100)).Limit(10).Skip(5).Reverse().OrderBy("Age", "Name").First(&user)
// Delete all matching records
err = query.Delete(new(User))
// Fetching records one by one (useful when the bucket contains a lot of records)
query = db.Select(q.Gte("ID", 10),q.Lte("ID", 100)).OrderBy("Age", "Name")
err = query.Each(new(User), func(record interface{}) error) {
u := record.(*User)
...
return nil
}
```
See the [documentation](https://godoc.org/github.com/asdine/storm#Query) for a complete list of methods.
### Transactions
```go
tx, err := db.Begin(true)
if err != nil {
return err
}
defer tx.Rollback()
accountA.Amount -= 100
accountB.Amount += 100
err = tx.Save(accountA)
if err != nil {
return err
}
err = tx.Save(accountB)
if err != nil {
return err
}
return tx.Commit()
```
### Options
Storm options are functions that can be passed when constructing you Storm instance. You can pass it any number of options.
#### BoltOptions
By default, Storm opens a database with the mode `0600` and a timeout of one second.
You can change this behavior by using `BoltOptions`
```go
db, err := storm.Open("my.db", storm.BoltOptions(0600, &bolt.Options{Timeout: 1 * time.Second}))
```
#### MarshalUnmarshaler
To store the data in BoltDB, Storm marshals it in JSON by default. If you wish to change this behavior you can pass a codec that implements [`codec.MarshalUnmarshaler`](https://godoc.org/github.com/asdine/storm/codec#MarshalUnmarshaler) via the [`storm.Codec`](https://godoc.org/github.com/asdine/storm#Codec) option:
```go
db := storm.Open("my.db", storm.Codec(myCodec))
```
##### Provided Codecs
You can easily implement your own `MarshalUnmarshaler`, but Storm comes with built-in support for [JSON](https://godoc.org/github.com/asdine/storm/codec/json) (default), [GOB](https://godoc.org/github.com/asdine/storm/codec/gob), [Sereal](https://godoc.org/github.com/asdine/storm/codec/sereal), [Protocol Buffers](https://godoc.org/github.com/asdine/storm/codec/protobuf) and [MessagePack](https://godoc.org/github.com/asdine/storm/codec/msgpack).
These can be used by importing the relevant package and use that codec to configure Storm. The example below shows all variants (without proper error handling):
```go
import (
"github.com/asdine/storm/v3"
"github.com/asdine/storm/v3/codec/gob"
"github.com/asdine/storm/v3/codec/json"
"github.com/asdine/storm/v3/codec/sereal"
"github.com/asdine/storm/v3/codec/protobuf"
"github.com/asdine/storm/v3/codec/msgpack"
)
var gobDb, _ = storm.Open("gob.db", storm.Codec(gob.Codec))
var jsonDb, _ = storm.Open("json.db", storm.Codec(json.Codec))
var serealDb, _ = storm.Open("sereal.db", storm.Codec(sereal.Codec))
var protobufDb, _ = storm.Open("protobuf.db", storm.Codec(protobuf.Codec))
var msgpackDb, _ = storm.Open("msgpack.db", storm.Codec(msgpack.Codec))
```
**Tip**: Adding Storm tags to generated Protobuf files can be tricky. A good solution is to use [this tool](https://github.com/favadi/protoc-go-inject-tag) to inject the tags during the compilation.
#### Use existing Bolt connection
You can use an existing connection and pass it to Storm
```go
bDB, _ := bolt.Open(filepath.Join(dir, "bolt.db"), 0600, &bolt.Options{Timeout: 10 * time.Second})
db := storm.Open("my.db", storm.UseDB(bDB))
```
#### Batch mode
Batch mode can be enabled to speed up concurrent writes (see [Batch read-write transactions](https://github.com/coreos/bbolt#batch-read-write-transactions))
```go
db := storm.Open("my.db", storm.Batch())
```
## Nodes and nested buckets
Storm takes advantage of BoltDB nested buckets feature by using `storm.Node`.
A `storm.Node` is the underlying object used by `storm.DB` to manipulate a bucket.
To create a nested bucket and use the same API as `storm.DB`, you can use the `DB.From` method.
```go
repo := db.From("repo")
err := repo.Save(&Issue{
Title: "I want more features",
Author: user.ID,
})
err = repo.Save(newRelease("0.10"))
var issues []Issue
err = repo.Find("Author", user.ID, &issues)
var release Release
err = repo.One("Tag", "0.10", &release)
```
You can also chain the nodes to create a hierarchy
```go
chars := db.From("characters")
heroes := chars.From("heroes")
enemies := chars.From("enemies")
items := db.From("items")
potions := items.From("consumables").From("medicine").From("potions")
```
You can even pass the entire hierarchy as arguments to `From`:
```go
privateNotes := db.From("notes", "private")
workNotes := db.From("notes", "work")
```
### Node options
A Node can also be configured. Activating an option on a Node creates a copy, so a Node is always thread-safe.
```go
n := db.From("my-node")
```
Give a bolt.Tx transaction to the Node
```go
n = n.WithTransaction(tx)
```
Enable batch mode
```go
n = n.WithBatch(true)
```
Use a Codec
```go
n = n.WithCodec(gob.Codec)
```
## Simple Key/Value store
Storm can be used as a simple, robust, key/value store that can store anything.
The key and the value can be of any type as long as the key is not a zero value.
Saving data :
```go
db.Set("logs", time.Now(), "I'm eating my breakfast man")
db.Set("sessions", bson.NewObjectId(), &someUser)
db.Set("weird storage", "754-3010", map[string]interface{}{
"hair": "blonde",
"likes": []string{"cheese", "star wars"},
})
```
Fetching data :
```go
user := User{}
db.Get("sessions", someObjectId, &user)
var details map[string]interface{}
db.Get("weird storage", "754-3010", &details)
db.Get("sessions", someObjectId, &details)
```
Deleting data :
```go
db.Delete("sessions", someObjectId)
db.Delete("weird storage", "754-3010")
```
You can find other useful methods in the [documentation](https://godoc.org/github.com/asdine/storm#KeyValueStore).
## BoltDB
BoltDB is still easily accessible and can be used as usual
```go
db.Bolt.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte("my bucket"))
val := bucket.Get([]byte("any id"))
fmt.Println(string(val))
return nil
})
```
A transaction can be also be passed to Storm
```go
db.Bolt.Update(func(tx *bolt.Tx) error {
...
dbx := db.WithTransaction(tx)
err = dbx.Save(&user)
...
return nil
})
```
## License
MIT
## Credits
- [Asdine El Hrychy](https://github.com/asdine)
- [Bjørn Erik Pedersen](https://github.com/bep)

47
vendor/github.com/asdine/storm/v3/bucket.go generated vendored Normal file
View File

@ -0,0 +1,47 @@
package storm
import bolt "go.etcd.io/bbolt"
// CreateBucketIfNotExists creates the bucket below the current node if it doesn't
// already exist.
func (n *node) CreateBucketIfNotExists(tx *bolt.Tx, bucket string) (*bolt.Bucket, error) {
var b *bolt.Bucket
var err error
bucketNames := append(n.rootBucket, bucket)
for _, bucketName := range bucketNames {
if b != nil {
if b, err = b.CreateBucketIfNotExists([]byte(bucketName)); err != nil {
return nil, err
}
} else {
if b, err = tx.CreateBucketIfNotExists([]byte(bucketName)); err != nil {
return nil, err
}
}
}
return b, nil
}
// GetBucket returns the given bucket below the current node.
func (n *node) GetBucket(tx *bolt.Tx, children ...string) *bolt.Bucket {
var b *bolt.Bucket
bucketNames := append(n.rootBucket, children...)
for _, bucketName := range bucketNames {
if b != nil {
if b = b.Bucket([]byte(bucketName)); b == nil {
return nil
}
} else {
if b = tx.Bucket([]byte(bucketName)); b == nil {
return nil
}
}
}
return b
}

1
vendor/github.com/asdine/storm/v3/codec/.gitignore generated vendored Normal file
View File

@ -0,0 +1 @@
*.db

11
vendor/github.com/asdine/storm/v3/codec/codec.go generated vendored Normal file
View File

@ -0,0 +1,11 @@
// Package codec contains sub-packages with different codecs that can be used
// to encode and decode entities in Storm.
package codec
// MarshalUnmarshaler represents a codec used to marshal and unmarshal entities.
type MarshalUnmarshaler interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(b []byte, v interface{}) error
// name of this codec
Name() string
}

25
vendor/github.com/asdine/storm/v3/codec/json/json.go generated vendored Normal file
View File

@ -0,0 +1,25 @@
// Package json contains a codec to encode and decode entities in JSON format
package json
import (
"encoding/json"
)
const name = "json"
// Codec that encodes to and decodes from JSON.
var Codec = new(jsonCodec)
type jsonCodec int
func (j jsonCodec) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
func (j jsonCodec) Unmarshal(b []byte, v interface{}) error {
return json.Unmarshal(b, v)
}
func (j jsonCodec) Name() string {
return name
}

51
vendor/github.com/asdine/storm/v3/errors.go generated vendored Normal file
View File

@ -0,0 +1,51 @@
package storm
import "errors"
// Errors
var (
// ErrNoID is returned when no ID field or id tag is found in the struct.
ErrNoID = errors.New("missing struct tag id or ID field")
// ErrZeroID is returned when the ID field is a zero value.
ErrZeroID = errors.New("id field must not be a zero value")
// ErrBadType is returned when a method receives an unexpected value type.
ErrBadType = errors.New("provided data must be a struct or a pointer to struct")
// ErrAlreadyExists is returned uses when trying to set an existing value on a field that has a unique index.
ErrAlreadyExists = errors.New("already exists")
// ErrNilParam is returned when the specified param is expected to be not nil.
ErrNilParam = errors.New("param must not be nil")
// ErrUnknownTag is returned when an unexpected tag is specified.
ErrUnknownTag = errors.New("unknown tag")
// ErrIdxNotFound is returned when the specified index is not found.
ErrIdxNotFound = errors.New("index not found")
// ErrSlicePtrNeeded is returned when an unexpected value is given, instead of a pointer to slice.
ErrSlicePtrNeeded = errors.New("provided target must be a pointer to slice")
// ErrStructPtrNeeded is returned when an unexpected value is given, instead of a pointer to struct.
ErrStructPtrNeeded = errors.New("provided target must be a pointer to struct")
// ErrPtrNeeded is returned when an unexpected value is given, instead of a pointer.
ErrPtrNeeded = errors.New("provided target must be a pointer to a valid variable")
// ErrNoName is returned when the specified struct has no name.
ErrNoName = errors.New("provided target must have a name")
// ErrNotFound is returned when the specified record is not saved in the bucket.
ErrNotFound = errors.New("not found")
// ErrNotInTransaction is returned when trying to rollback or commit when not in transaction.
ErrNotInTransaction = errors.New("not in transaction")
// ErrIncompatibleValue is returned when trying to set a value with a different type than the chosen field
ErrIncompatibleValue = errors.New("incompatible value")
// ErrDifferentCodec is returned when using a codec different than the first codec used with the bucket.
ErrDifferentCodec = errors.New("the selected codec is incompatible with this bucket")
)

226
vendor/github.com/asdine/storm/v3/extract.go generated vendored Normal file
View File

@ -0,0 +1,226 @@
package storm
import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/asdine/storm/v3/index"
bolt "go.etcd.io/bbolt"
)
// Storm tags
const (
tagID = "id"
tagIdx = "index"
tagUniqueIdx = "unique"
tagInline = "inline"
tagIncrement = "increment"
indexPrefix = "__storm_index_"
)
type fieldConfig struct {
Name string
Index string
IsZero bool
IsID bool
Increment bool
IncrementStart int64
IsInteger bool
Value *reflect.Value
ForceUpdate bool
}
// structConfig is a structure gathering all the relevant informations about a model
type structConfig struct {
Name string
Fields map[string]*fieldConfig
ID *fieldConfig
}
func extract(s *reflect.Value, mi ...*structConfig) (*structConfig, error) {
if s.Kind() == reflect.Ptr {
e := s.Elem()
s = &e
}
if s.Kind() != reflect.Struct {
return nil, ErrBadType
}
typ := s.Type()
var child bool
var m *structConfig
if len(mi) > 0 {
m = mi[0]
child = true
} else {
m = &structConfig{}
m.Fields = make(map[string]*fieldConfig)
}
if m.Name == "" {
m.Name = typ.Name()
}
numFields := s.NumField()
for i := 0; i < numFields; i++ {
field := typ.Field(i)
value := s.Field(i)
if field.PkgPath != "" {
continue
}
err := extractField(&value, &field, m, child)
if err != nil {
return nil, err
}
}
if child {
return m, nil
}
if m.ID == nil {
return nil, ErrNoID
}
if m.Name == "" {
return nil, ErrNoName
}
return m, nil
}
func extractField(value *reflect.Value, field *reflect.StructField, m *structConfig, isChild bool) error {
var f *fieldConfig
var err error
tag := field.Tag.Get("storm")
if tag != "" {
f = &fieldConfig{
Name: field.Name,
IsZero: isZero(value),
IsInteger: isInteger(value),
Value: value,
IncrementStart: 1,
}
tags := strings.Split(tag, ",")
for _, tag := range tags {
switch tag {
case "id":
f.IsID = true
f.Index = tagUniqueIdx
case tagUniqueIdx, tagIdx:
f.Index = tag
case tagInline:
if value.Kind() == reflect.Ptr {
e := value.Elem()
value = &e
}
if value.Kind() == reflect.Struct {
a := value.Addr()
_, err := extract(&a, m)
if err != nil {
return err
}
}
// we don't need to save this field
return nil
default:
if strings.HasPrefix(tag, tagIncrement) {
f.Increment = true
parts := strings.Split(tag, "=")
if parts[0] != tagIncrement {
return ErrUnknownTag
}
if len(parts) > 1 {
f.IncrementStart, err = strconv.ParseInt(parts[1], 0, 64)
if err != nil {
return err
}
}
} else {
return ErrUnknownTag
}
}
}
if _, ok := m.Fields[f.Name]; !ok || !isChild {
m.Fields[f.Name] = f
}
}
if m.ID == nil && f != nil && f.IsID {
m.ID = f
}
// the field is named ID and no ID field has been detected before
if m.ID == nil && field.Name == "ID" {
if f == nil {
f = &fieldConfig{
Index: tagUniqueIdx,
Name: field.Name,
IsZero: isZero(value),
IsInteger: isInteger(value),
IsID: true,
Value: value,
IncrementStart: 1,
}
m.Fields[field.Name] = f
}
m.ID = f
}
return nil
}
func extractSingleField(ref *reflect.Value, fieldName string) (*structConfig, error) {
var cfg structConfig
cfg.Fields = make(map[string]*fieldConfig)
f, ok := ref.Type().FieldByName(fieldName)
if !ok || f.PkgPath != "" {
return nil, fmt.Errorf("field %s not found", fieldName)
}
v := ref.FieldByName(fieldName)
err := extractField(&v, &f, &cfg, false)
if err != nil {
return nil, err
}
return &cfg, nil
}
func getIndex(bucket *bolt.Bucket, idxKind string, fieldName string) (index.Index, error) {
var idx index.Index
var err error
switch idxKind {
case tagUniqueIdx:
idx, err = index.NewUniqueIndex(bucket, []byte(indexPrefix+fieldName))
case tagIdx:
idx, err = index.NewListIndex(bucket, []byte(indexPrefix+fieldName))
default:
err = ErrIdxNotFound
}
return idx, err
}
func isZero(v *reflect.Value) bool {
zero := reflect.Zero(v.Type()).Interface()
current := v.Interface()
return reflect.DeepEqual(current, zero)
}
func isInteger(v *reflect.Value) bool {
kind := v.Kind()
return v != nil && kind >= reflect.Int && kind <= reflect.Uint64
}

499
vendor/github.com/asdine/storm/v3/finder.go generated vendored Normal file
View File

@ -0,0 +1,499 @@
package storm
import (
"fmt"
"reflect"
"github.com/asdine/storm/v3/index"
"github.com/asdine/storm/v3/q"
bolt "go.etcd.io/bbolt"
)
// A Finder can fetch types from BoltDB.
type Finder interface {
// One returns one record by the specified index
One(fieldName string, value interface{}, to interface{}) error
// Find returns one or more records by the specified index
Find(fieldName string, value interface{}, to interface{}, options ...func(q *index.Options)) error
// AllByIndex gets all the records of a bucket that are indexed in the specified index
AllByIndex(fieldName string, to interface{}, options ...func(*index.Options)) error
// All gets all the records of a bucket.
// If there are no records it returns no error and the 'to' parameter is set to an empty slice.
All(to interface{}, options ...func(*index.Options)) error
// Select a list of records that match a list of matchers. Doesn't use indexes.
Select(matchers ...q.Matcher) Query
// Range returns one or more records by the specified index within the specified range
Range(fieldName string, min, max, to interface{}, options ...func(*index.Options)) error
// Prefix returns one or more records whose given field starts with the specified prefix.
Prefix(fieldName string, prefix string, to interface{}, options ...func(*index.Options)) error
// Count counts all the records of a bucket
Count(data interface{}) (int, error)
}
// One returns one record by the specified index
func (n *node) One(fieldName string, value interface{}, to interface{}) error {
sink, err := newFirstSink(n, to)
if err != nil {
return err
}
bucketName := sink.bucketName()
if bucketName == "" {
return ErrNoName
}
if fieldName == "" {
return ErrNotFound
}
ref := reflect.Indirect(sink.ref)
cfg, err := extractSingleField(&ref, fieldName)
if err != nil {
return err
}
field, ok := cfg.Fields[fieldName]
if !ok || (!field.IsID && field.Index == "") {
query := newQuery(n, q.StrictEq(fieldName, value))
query.Limit(1)
if n.tx != nil {
err = query.query(n.tx, sink)
} else {
err = n.s.Bolt.View(func(tx *bolt.Tx) error {
return query.query(tx, sink)
})
}
if err != nil {
return err
}
return sink.flush()
}
val, err := toBytes(value, n.codec)
if err != nil {
return err
}
return n.readTx(func(tx *bolt.Tx) error {
return n.one(tx, bucketName, fieldName, cfg, to, val, field.IsID)
})
}
func (n *node) one(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig, to interface{}, val []byte, skipIndex bool) error {
bucket := n.GetBucket(tx, bucketName)
if bucket == nil {
return ErrNotFound
}
var id []byte
if !skipIndex {
idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName)
if err != nil {
if err == index.ErrNotFound {
return ErrNotFound
}
return err
}
id = idx.Get(val)
} else {
id = val
}
if id == nil {
return ErrNotFound
}
raw := bucket.Get(id)
if raw == nil {
return ErrNotFound
}
return n.codec.Unmarshal(raw, to)
}
// Find returns one or more records by the specified index
func (n *node) Find(fieldName string, value interface{}, to interface{}, options ...func(q *index.Options)) error {
sink, err := newListSink(n, to)
if err != nil {
return err
}
bucketName := sink.bucketName()
if bucketName == "" {
return ErrNoName
}
ref := reflect.Indirect(reflect.New(sink.elemType))
cfg, err := extractSingleField(&ref, fieldName)
if err != nil {
return err
}
opts := index.NewOptions()
for _, fn := range options {
fn(opts)
}
field, ok := cfg.Fields[fieldName]
if !ok || (!field.IsID && (field.Index == "" || value == nil)) {
query := newQuery(n, q.Eq(fieldName, value))
query.Skip(opts.Skip).Limit(opts.Limit)
if opts.Reverse {
query.Reverse()
}
err = n.readTx(func(tx *bolt.Tx) error {
return query.query(tx, sink)
})
if err != nil {
return err
}
return sink.flush()
}
val, err := toBytes(value, n.codec)
if err != nil {
return err
}
return n.readTx(func(tx *bolt.Tx) error {
return n.find(tx, bucketName, fieldName, cfg, sink, val, opts)
})
}
func (n *node) find(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig, sink *listSink, val []byte, opts *index.Options) error {
bucket := n.GetBucket(tx, bucketName)
if bucket == nil {
return ErrNotFound
}
idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName)
if err != nil {
return err
}
list, err := idx.All(val, opts)
if err != nil {
if err == index.ErrNotFound {
return ErrNotFound
}
return err
}
sink.results = reflect.MakeSlice(reflect.Indirect(sink.ref).Type(), len(list), len(list))
sorter := newSorter(n, sink)
for i := range list {
raw := bucket.Get(list[i])
if raw == nil {
return ErrNotFound
}
if _, err := sorter.filter(nil, bucket, list[i], raw); err != nil {
return err
}
}
return sorter.flush()
}
// AllByIndex gets all the records of a bucket that are indexed in the specified index
func (n *node) AllByIndex(fieldName string, to interface{}, options ...func(*index.Options)) error {
if fieldName == "" {
return n.All(to, options...)
}
ref := reflect.ValueOf(to)
if ref.Kind() != reflect.Ptr || ref.Elem().Kind() != reflect.Slice {
return ErrSlicePtrNeeded
}
typ := reflect.Indirect(ref).Type().Elem()
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
newElem := reflect.New(typ)
cfg, err := extract(&newElem)
if err != nil {
return err
}
if cfg.ID.Name == fieldName {
return n.All(to, options...)
}
opts := index.NewOptions()
for _, fn := range options {
fn(opts)
}
return n.readTx(func(tx *bolt.Tx) error {
return n.allByIndex(tx, fieldName, cfg, &ref, opts)
})
}
func (n *node) allByIndex(tx *bolt.Tx, fieldName string, cfg *structConfig, ref *reflect.Value, opts *index.Options) error {
bucket := n.GetBucket(tx, cfg.Name)
if bucket == nil {
return ErrNotFound
}
fieldCfg, ok := cfg.Fields[fieldName]
if !ok {
return ErrNotFound
}
idx, err := getIndex(bucket, fieldCfg.Index, fieldName)
if err != nil {
return err
}
list, err := idx.AllRecords(opts)
if err != nil {
if err == index.ErrNotFound {
return ErrNotFound
}
return err
}
results := reflect.MakeSlice(reflect.Indirect(*ref).Type(), len(list), len(list))
for i := range list {
raw := bucket.Get(list[i])
if raw == nil {
return ErrNotFound
}
err = n.codec.Unmarshal(raw, results.Index(i).Addr().Interface())
if err != nil {
return err
}
}
reflect.Indirect(*ref).Set(results)
return nil
}
// All gets all the records of a bucket.
// If there are no records it returns no error and the 'to' parameter is set to an empty slice.
func (n *node) All(to interface{}, options ...func(*index.Options)) error {
opts := index.NewOptions()
for _, fn := range options {
fn(opts)
}
query := newQuery(n, nil).Limit(opts.Limit).Skip(opts.Skip)
if opts.Reverse {
query.Reverse()
}
err := query.Find(to)
if err != nil && err != ErrNotFound {
return err
}
if err == ErrNotFound {
ref := reflect.ValueOf(to)
results := reflect.MakeSlice(reflect.Indirect(ref).Type(), 0, 0)
reflect.Indirect(ref).Set(results)
}
return nil
}
// Range returns one or more records by the specified index within the specified range
func (n *node) Range(fieldName string, min, max, to interface{}, options ...func(*index.Options)) error {
sink, err := newListSink(n, to)
if err != nil {
return err
}
bucketName := sink.bucketName()
if bucketName == "" {
return ErrNoName
}
ref := reflect.Indirect(reflect.New(sink.elemType))
cfg, err := extractSingleField(&ref, fieldName)
if err != nil {
return err
}
opts := index.NewOptions()
for _, fn := range options {
fn(opts)
}
field, ok := cfg.Fields[fieldName]
if !ok || (!field.IsID && field.Index == "") {
query := newQuery(n, q.And(q.Gte(fieldName, min), q.Lte(fieldName, max)))
query.Skip(opts.Skip).Limit(opts.Limit)
if opts.Reverse {
query.Reverse()
}
err = n.readTx(func(tx *bolt.Tx) error {
return query.query(tx, sink)
})
if err != nil {
return err
}
return sink.flush()
}
mn, err := toBytes(min, n.codec)
if err != nil {
return err
}
mx, err := toBytes(max, n.codec)
if err != nil {
return err
}
return n.readTx(func(tx *bolt.Tx) error {
return n.rnge(tx, bucketName, fieldName, cfg, sink, mn, mx, opts)
})
}
func (n *node) rnge(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig, sink *listSink, min, max []byte, opts *index.Options) error {
bucket := n.GetBucket(tx, bucketName)
if bucket == nil {
reflect.Indirect(sink.ref).SetLen(0)
return nil
}
idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName)
if err != nil {
return err
}
list, err := idx.Range(min, max, opts)
if err != nil {
return err
}
sink.results = reflect.MakeSlice(reflect.Indirect(sink.ref).Type(), len(list), len(list))
sorter := newSorter(n, sink)
for i := range list {
raw := bucket.Get(list[i])
if raw == nil {
return ErrNotFound
}
if _, err := sorter.filter(nil, bucket, list[i], raw); err != nil {
return err
}
}
return sorter.flush()
}
// Prefix returns one or more records whose given field starts with the specified prefix.
func (n *node) Prefix(fieldName string, prefix string, to interface{}, options ...func(*index.Options)) error {
sink, err := newListSink(n, to)
if err != nil {
return err
}
bucketName := sink.bucketName()
if bucketName == "" {
return ErrNoName
}
ref := reflect.Indirect(reflect.New(sink.elemType))
cfg, err := extractSingleField(&ref, fieldName)
if err != nil {
return err
}
opts := index.NewOptions()
for _, fn := range options {
fn(opts)
}
field, ok := cfg.Fields[fieldName]
if !ok || (!field.IsID && field.Index == "") {
query := newQuery(n, q.Re(fieldName, fmt.Sprintf("^%s", prefix)))
query.Skip(opts.Skip).Limit(opts.Limit)
if opts.Reverse {
query.Reverse()
}
err = n.readTx(func(tx *bolt.Tx) error {
return query.query(tx, sink)
})
if err != nil {
return err
}
return sink.flush()
}
prfx, err := toBytes(prefix, n.codec)
if err != nil {
return err
}
return n.readTx(func(tx *bolt.Tx) error {
return n.prefix(tx, bucketName, fieldName, cfg, sink, prfx, opts)
})
}
func (n *node) prefix(tx *bolt.Tx, bucketName, fieldName string, cfg *structConfig, sink *listSink, prefix []byte, opts *index.Options) error {
bucket := n.GetBucket(tx, bucketName)
if bucket == nil {
reflect.Indirect(sink.ref).SetLen(0)
return nil
}
idx, err := getIndex(bucket, cfg.Fields[fieldName].Index, fieldName)
if err != nil {
return err
}
list, err := idx.Prefix(prefix, opts)
if err != nil {
return err
}
sink.results = reflect.MakeSlice(reflect.Indirect(sink.ref).Type(), len(list), len(list))
sorter := newSorter(n, sink)
for i := range list {
raw := bucket.Get(list[i])
if raw == nil {
return ErrNotFound
}
if _, err := sorter.filter(nil, bucket, list[i], raw); err != nil {
return err
}
}
return sorter.flush()
}
// Count counts all the records of a bucket
func (n *node) Count(data interface{}) (int, error) {
return n.Select().Count(data)
}

14
vendor/github.com/asdine/storm/v3/index/errors.go generated vendored Normal file
View File

@ -0,0 +1,14 @@
package index
import "errors"
var (
// ErrNotFound is returned when the specified record is not saved in the bucket.
ErrNotFound = errors.New("not found")
// ErrAlreadyExists is returned uses when trying to set an existing value on a field that has a unique index.
ErrAlreadyExists = errors.New("already exists")
// ErrNilParam is returned when the specified param is expected to be not nil.
ErrNilParam = errors.New("param must not be nil")
)

14
vendor/github.com/asdine/storm/v3/index/indexes.go generated vendored Normal file
View File

@ -0,0 +1,14 @@
// Package index contains Index engines used to store values and their corresponding IDs
package index
// Index interface
type Index interface {
Add(value []byte, targetID []byte) error
Remove(value []byte) error
RemoveID(id []byte) error
Get(value []byte) []byte
All(value []byte, opts *Options) ([][]byte, error)
AllRecords(opts *Options) ([][]byte, error)
Range(min []byte, max []byte, opts *Options) ([][]byte, error)
Prefix(prefix []byte, opts *Options) ([][]byte, error)
}

283
vendor/github.com/asdine/storm/v3/index/list.go generated vendored Normal file
View File

@ -0,0 +1,283 @@
package index
import (
"bytes"
"github.com/asdine/storm/v3/internal"
bolt "go.etcd.io/bbolt"
)
// NewListIndex loads a ListIndex
func NewListIndex(parent *bolt.Bucket, indexName []byte) (*ListIndex, error) {
var err error
b := parent.Bucket(indexName)
if b == nil {
if !parent.Writable() {
return nil, ErrNotFound
}
b, err = parent.CreateBucket(indexName)
if err != nil {
return nil, err
}
}
ids, err := NewUniqueIndex(b, []byte("storm__ids"))
if err != nil {
return nil, err
}
return &ListIndex{
IndexBucket: b,
Parent: parent,
IDs: ids,
}, nil
}
// ListIndex is an index that references values and the corresponding IDs.
type ListIndex struct {
Parent *bolt.Bucket
IndexBucket *bolt.Bucket
IDs *UniqueIndex
}
// Add a value to the list index
func (idx *ListIndex) Add(newValue []byte, targetID []byte) error {
if newValue == nil || len(newValue) == 0 {
return ErrNilParam
}
if targetID == nil || len(targetID) == 0 {
return ErrNilParam
}
key := idx.IDs.Get(targetID)
if key != nil {
err := idx.IndexBucket.Delete(key)
if err != nil {
return err
}
err = idx.IDs.Remove(targetID)
if err != nil {
return err
}
key = key[:0]
}
key = append(key, newValue...)
key = append(key, '_')
key = append(key, '_')
key = append(key, targetID...)
err := idx.IDs.Add(targetID, key)
if err != nil {
return err
}
return idx.IndexBucket.Put(key, targetID)
}
// Remove a value from the unique index
func (idx *ListIndex) Remove(value []byte) error {
var err error
var keys [][]byte
c := idx.IndexBucket.Cursor()
prefix := generatePrefix(value)
for k, _ := c.Seek(prefix); bytes.HasPrefix(k, prefix); k, _ = c.Next() {
keys = append(keys, k)
}
for _, k := range keys {
err = idx.IndexBucket.Delete(k)
if err != nil {
return err
}
}
return idx.IDs.RemoveID(value)
}
// RemoveID removes an ID from the list index
func (idx *ListIndex) RemoveID(targetID []byte) error {
value := idx.IDs.Get(targetID)
if value == nil {
return nil
}
err := idx.IndexBucket.Delete(value)
if err != nil {
return err
}
return idx.IDs.Remove(targetID)
}
// Get the first ID corresponding to the given value
func (idx *ListIndex) Get(value []byte) []byte {
c := idx.IndexBucket.Cursor()
prefix := generatePrefix(value)
for k, id := c.Seek(prefix); bytes.HasPrefix(k, prefix); k, id = c.Next() {
return id
}
return nil
}
// All the IDs corresponding to the given value
func (idx *ListIndex) All(value []byte, opts *Options) ([][]byte, error) {
var list [][]byte
c := idx.IndexBucket.Cursor()
cur := internal.Cursor{C: c, Reverse: opts != nil && opts.Reverse}
prefix := generatePrefix(value)
k, id := c.Seek(prefix)
if cur.Reverse {
var count int
kc := k
idc := id
for ; kc != nil && bytes.HasPrefix(kc, prefix); kc, idc = c.Next() {
count++
k, id = kc, idc
}
if kc != nil {
k, id = c.Prev()
}
list = make([][]byte, 0, count)
}
for ; bytes.HasPrefix(k, prefix); k, id = cur.Next() {
if opts != nil && opts.Skip > 0 {
opts.Skip--
continue
}
if opts != nil && opts.Limit == 0 {
break
}
if opts != nil && opts.Limit > 0 {
opts.Limit--
}
list = append(list, id)
}
return list, nil
}
// AllRecords returns all the IDs of this index
func (idx *ListIndex) AllRecords(opts *Options) ([][]byte, error) {
var list [][]byte
c := internal.Cursor{C: idx.IndexBucket.Cursor(), Reverse: opts != nil && opts.Reverse}
for k, id := c.First(); k != nil; k, id = c.Next() {
if id == nil || bytes.Equal(k, []byte("storm__ids")) {
continue
}
if opts != nil && opts.Skip > 0 {
opts.Skip--
continue
}
if opts != nil && opts.Limit == 0 {
break
}
if opts != nil && opts.Limit > 0 {
opts.Limit--
}
list = append(list, id)
}
return list, nil
}
// Range returns the ids corresponding to the given range of values
func (idx *ListIndex) Range(min []byte, max []byte, opts *Options) ([][]byte, error) {
var list [][]byte
c := internal.RangeCursor{
C: idx.IndexBucket.Cursor(),
Reverse: opts != nil && opts.Reverse,
Min: min,
Max: max,
CompareFn: func(val, limit []byte) int {
pos := bytes.LastIndex(val, []byte("__"))
return bytes.Compare(val[:pos], limit)
},
}
for k, id := c.First(); c.Continue(k); k, id = c.Next() {
if id == nil || bytes.Equal(k, []byte("storm__ids")) {
continue
}
if opts != nil && opts.Skip > 0 {
opts.Skip--
continue
}
if opts != nil && opts.Limit == 0 {
break
}
if opts != nil && opts.Limit > 0 {
opts.Limit--
}
list = append(list, id)
}
return list, nil
}
// Prefix returns the ids whose values have the given prefix.
func (idx *ListIndex) Prefix(prefix []byte, opts *Options) ([][]byte, error) {
var list [][]byte
c := internal.PrefixCursor{
C: idx.IndexBucket.Cursor(),
Reverse: opts != nil && opts.Reverse,
Prefix: prefix,
}
for k, id := c.First(); k != nil && c.Continue(k); k, id = c.Next() {
if id == nil || bytes.Equal(k, []byte("storm__ids")) {
continue
}
if opts != nil && opts.Skip > 0 {
opts.Skip--
continue
}
if opts != nil && opts.Limit == 0 {
break
}
if opts != nil && opts.Limit > 0 {
opts.Limit--
}
list = append(list, id)
}
return list, nil
}
func generatePrefix(value []byte) []byte {
prefix := make([]byte, len(value)+2)
var i int
for i = range value {
prefix[i] = value[i]
}
prefix[i+1] = '_'
prefix[i+2] = '_'
return prefix
}

15
vendor/github.com/asdine/storm/v3/index/options.go generated vendored Normal file
View File

@ -0,0 +1,15 @@
package index
// NewOptions creates initialized Options
func NewOptions() *Options {
return &Options{
Limit: -1,
}
}
// Options are used to customize queries
type Options struct {
Limit int
Skip int
Reverse bool
}

183
vendor/github.com/asdine/storm/v3/index/unique.go generated vendored Normal file
View File

@ -0,0 +1,183 @@
package index
import (
"bytes"
"github.com/asdine/storm/v3/internal"
bolt "go.etcd.io/bbolt"
)
// NewUniqueIndex loads a UniqueIndex
func NewUniqueIndex(parent *bolt.Bucket, indexName []byte) (*UniqueIndex, error) {
var err error
b := parent.Bucket(indexName)
if b == nil {
if !parent.Writable() {
return nil, ErrNotFound
}
b, err = parent.CreateBucket(indexName)
if err != nil {
return nil, err
}
}
return &UniqueIndex{
IndexBucket: b,
Parent: parent,
}, nil
}
// UniqueIndex is an index that references unique values and the corresponding ID.
type UniqueIndex struct {
Parent *bolt.Bucket
IndexBucket *bolt.Bucket
}
// Add a value to the unique index
func (idx *UniqueIndex) Add(value []byte, targetID []byte) error {
if value == nil || len(value) == 0 {
return ErrNilParam
}
if targetID == nil || len(targetID) == 0 {
return ErrNilParam
}
exists := idx.IndexBucket.Get(value)
if exists != nil {
if bytes.Equal(exists, targetID) {
return nil
}
return ErrAlreadyExists
}
return idx.IndexBucket.Put(value, targetID)
}
// Remove a value from the unique index
func (idx *UniqueIndex) Remove(value []byte) error {
return idx.IndexBucket.Delete(value)
}
// RemoveID removes an ID from the unique index
func (idx *UniqueIndex) RemoveID(id []byte) error {
c := idx.IndexBucket.Cursor()
for val, ident := c.First(); val != nil; val, ident = c.Next() {
if bytes.Equal(ident, id) {
return idx.Remove(val)
}
}
return nil
}
// Get the id corresponding to the given value
func (idx *UniqueIndex) Get(value []byte) []byte {
return idx.IndexBucket.Get(value)
}
// All returns all the ids corresponding to the given value
func (idx *UniqueIndex) All(value []byte, opts *Options) ([][]byte, error) {
id := idx.IndexBucket.Get(value)
if id != nil {
return [][]byte{id}, nil
}
return nil, nil
}
// AllRecords returns all the IDs of this index
func (idx *UniqueIndex) AllRecords(opts *Options) ([][]byte, error) {
var list [][]byte
c := internal.Cursor{C: idx.IndexBucket.Cursor(), Reverse: opts != nil && opts.Reverse}
for val, ident := c.First(); val != nil; val, ident = c.Next() {
if opts != nil && opts.Skip > 0 {
opts.Skip--
continue
}
if opts != nil && opts.Limit == 0 {
break
}
if opts != nil && opts.Limit > 0 {
opts.Limit--
}
list = append(list, ident)
}
return list, nil
}
// Range returns the ids corresponding to the given range of values
func (idx *UniqueIndex) Range(min []byte, max []byte, opts *Options) ([][]byte, error) {
var list [][]byte
c := internal.RangeCursor{
C: idx.IndexBucket.Cursor(),
Reverse: opts != nil && opts.Reverse,
Min: min,
Max: max,
CompareFn: func(val, limit []byte) int {
return bytes.Compare(val, limit)
},
}
for val, ident := c.First(); val != nil && c.Continue(val); val, ident = c.Next() {
if opts != nil && opts.Skip > 0 {
opts.Skip--
continue
}
if opts != nil && opts.Limit == 0 {
break
}
if opts != nil && opts.Limit > 0 {
opts.Limit--
}
list = append(list, ident)
}
return list, nil
}
// Prefix returns the ids whose values have the given prefix.
func (idx *UniqueIndex) Prefix(prefix []byte, opts *Options) ([][]byte, error) {
var list [][]byte
c := internal.PrefixCursor{
C: idx.IndexBucket.Cursor(),
Reverse: opts != nil && opts.Reverse,
Prefix: prefix,
}
for val, ident := c.First(); val != nil && c.Continue(val); val, ident = c.Next() {
if opts != nil && opts.Skip > 0 {
opts.Skip--
continue
}
if opts != nil && opts.Limit == 0 {
break
}
if opts != nil && opts.Limit > 0 {
opts.Limit--
}
list = append(list, ident)
}
return list, nil
}
// first returns the first ID of this index
func (idx *UniqueIndex) first() []byte {
c := idx.IndexBucket.Cursor()
for val, ident := c.First(); val != nil; val, ident = c.Next() {
return ident
}
return nil
}

121
vendor/github.com/asdine/storm/v3/internal/boltdb.go generated vendored Normal file
View File

@ -0,0 +1,121 @@
package internal
import (
"bytes"
bolt "go.etcd.io/bbolt"
)
// Cursor that can be reversed
type Cursor struct {
C *bolt.Cursor
Reverse bool
}
// First element
func (c *Cursor) First() ([]byte, []byte) {
if c.Reverse {
return c.C.Last()
}
return c.C.First()
}
// Next element
func (c *Cursor) Next() ([]byte, []byte) {
if c.Reverse {
return c.C.Prev()
}
return c.C.Next()
}
// RangeCursor that can be reversed
type RangeCursor struct {
C *bolt.Cursor
Reverse bool
Min []byte
Max []byte
CompareFn func([]byte, []byte) int
}
// First element
func (c *RangeCursor) First() ([]byte, []byte) {
if c.Reverse {
k, v := c.C.Seek(c.Max)
// If Seek doesn't find a key it goes to the next.
// If so, we need to get the previous one to avoid
// including bigger values. #218
if !bytes.HasPrefix(k, c.Max) && k != nil {
k, v = c.C.Prev()
}
return k, v
}
return c.C.Seek(c.Min)
}
// Next element
func (c *RangeCursor) Next() ([]byte, []byte) {
if c.Reverse {
return c.C.Prev()
}
return c.C.Next()
}
// Continue tells if the loop needs to continue
func (c *RangeCursor) Continue(val []byte) bool {
if c.Reverse {
return val != nil && c.CompareFn(val, c.Min) >= 0
}
return val != nil && c.CompareFn(val, c.Max) <= 0
}
// PrefixCursor that can be reversed
type PrefixCursor struct {
C *bolt.Cursor
Reverse bool
Prefix []byte
}
// First element
func (c *PrefixCursor) First() ([]byte, []byte) {
var k, v []byte
for k, v = c.C.First(); k != nil && !bytes.HasPrefix(k, c.Prefix); k, v = c.C.Next() {
}
if k == nil {
return nil, nil
}
if c.Reverse {
kc, vc := k, v
for ; kc != nil && bytes.HasPrefix(kc, c.Prefix); kc, vc = c.C.Next() {
k, v = kc, vc
}
if kc != nil {
k, v = c.C.Prev()
}
}
return k, v
}
// Next element
func (c *PrefixCursor) Next() ([]byte, []byte) {
if c.Reverse {
return c.C.Prev()
}
return c.C.Next()
}
// Continue tells if the loop needs to continue
func (c *PrefixCursor) Continue(val []byte) bool {
return val != nil && bytes.HasPrefix(val, c.Prefix)
}

170
vendor/github.com/asdine/storm/v3/kv.go generated vendored Normal file
View File

@ -0,0 +1,170 @@
package storm
import (
"reflect"
bolt "go.etcd.io/bbolt"
)
// KeyValueStore can store and fetch values by key
type KeyValueStore interface {
// Get a value from a bucket
Get(bucketName string, key interface{}, to interface{}) error
// Set a key/value pair into a bucket
Set(bucketName string, key interface{}, value interface{}) error
// Delete deletes a key from a bucket
Delete(bucketName string, key interface{}) error
// GetBytes gets a raw value from a bucket.
GetBytes(bucketName string, key interface{}) ([]byte, error)
// SetBytes sets a raw value into a bucket.
SetBytes(bucketName string, key interface{}, value []byte) error
// KeyExists reports the presence of a key in a bucket.
KeyExists(bucketName string, key interface{}) (bool, error)
}
// GetBytes gets a raw value from a bucket.
func (n *node) GetBytes(bucketName string, key interface{}) ([]byte, error) {
id, err := toBytes(key, n.codec)
if err != nil {
return nil, err
}
var val []byte
return val, n.readTx(func(tx *bolt.Tx) error {
raw, err := n.getBytes(tx, bucketName, id)
if err != nil {
return err
}
val = make([]byte, len(raw))
copy(val, raw)
return nil
})
}
// GetBytes gets a raw value from a bucket.
func (n *node) getBytes(tx *bolt.Tx, bucketName string, id []byte) ([]byte, error) {
bucket := n.GetBucket(tx, bucketName)
if bucket == nil {
return nil, ErrNotFound
}
raw := bucket.Get(id)
if raw == nil {
return nil, ErrNotFound
}
return raw, nil
}
// SetBytes sets a raw value into a bucket.
func (n *node) SetBytes(bucketName string, key interface{}, value []byte) error {
if key == nil {
return ErrNilParam
}
id, err := toBytes(key, n.codec)
if err != nil {
return err
}
return n.readWriteTx(func(tx *bolt.Tx) error {
return n.setBytes(tx, bucketName, id, value)
})
}
func (n *node) setBytes(tx *bolt.Tx, bucketName string, id, data []byte) error {
bucket, err := n.CreateBucketIfNotExists(tx, bucketName)
if err != nil {
return err
}
// save node configuration in the bucket
_, err = newMeta(bucket, n)
if err != nil {
return err
}
return bucket.Put(id, data)
}
// Get a value from a bucket
func (n *node) Get(bucketName string, key interface{}, to interface{}) error {
ref := reflect.ValueOf(to)
if !ref.IsValid() || ref.Kind() != reflect.Ptr {
return ErrPtrNeeded
}
id, err := toBytes(key, n.codec)
if err != nil {
return err
}
return n.readTx(func(tx *bolt.Tx) error {
raw, err := n.getBytes(tx, bucketName, id)
if err != nil {
return err
}
return n.codec.Unmarshal(raw, to)
})
}
// Set a key/value pair into a bucket
func (n *node) Set(bucketName string, key interface{}, value interface{}) error {
var data []byte
var err error
if value != nil {
data, err = n.codec.Marshal(value)
if err != nil {
return err
}
}
return n.SetBytes(bucketName, key, data)
}
// Delete deletes a key from a bucket
func (n *node) Delete(bucketName string, key interface{}) error {
id, err := toBytes(key, n.codec)
if err != nil {
return err
}
return n.readWriteTx(func(tx *bolt.Tx) error {
return n.delete(tx, bucketName, id)
})
}
func (n *node) delete(tx *bolt.Tx, bucketName string, id []byte) error {
bucket := n.GetBucket(tx, bucketName)
if bucket == nil {
return ErrNotFound
}
return bucket.Delete(id)
}
// KeyExists reports the presence of a key in a bucket.
func (n *node) KeyExists(bucketName string, key interface{}) (bool, error) {
id, err := toBytes(key, n.codec)
if err != nil {
return false, err
}
var exists bool
return exists, n.readTx(func(tx *bolt.Tx) error {
bucket := n.GetBucket(tx, bucketName)
if bucket == nil {
return ErrNotFound
}
v := bucket.Get(id)
if v != nil {
exists = true
}
return nil
})
}

69
vendor/github.com/asdine/storm/v3/metadata.go generated vendored Normal file
View File

@ -0,0 +1,69 @@
package storm
import (
"reflect"
bolt "go.etcd.io/bbolt"
)
const (
metaCodec = "codec"
)
func newMeta(b *bolt.Bucket, n Node) (*meta, error) {
m := b.Bucket([]byte(metadataBucket))
if m != nil {
name := m.Get([]byte(metaCodec))
if string(name) != n.Codec().Name() {
return nil, ErrDifferentCodec
}
return &meta{
node: n,
bucket: m,
}, nil
}
m, err := b.CreateBucket([]byte(metadataBucket))
if err != nil {
return nil, err
}
m.Put([]byte(metaCodec), []byte(n.Codec().Name()))
return &meta{
node: n,
bucket: m,
}, nil
}
type meta struct {
node Node
bucket *bolt.Bucket
}
func (m *meta) increment(field *fieldConfig) error {
var err error
counter := field.IncrementStart
raw := m.bucket.Get([]byte(field.Name + "counter"))
if raw != nil {
counter, err = numberfromb(raw)
if err != nil {
return err
}
counter++
}
raw, err = numbertob(counter)
if err != nil {
return err
}
err = m.bucket.Put([]byte(field.Name+"counter"), raw)
if err != nil {
return err
}
field.Value.Set(reflect.ValueOf(counter).Convert(field.Value.Type()))
field.IsZero = false
return nil
}

126
vendor/github.com/asdine/storm/v3/node.go generated vendored Normal file
View File

@ -0,0 +1,126 @@
package storm
import (
"github.com/asdine/storm/v3/codec"
bolt "go.etcd.io/bbolt"
)
// A Node in Storm represents the API to a BoltDB bucket.
type Node interface {
Tx
TypeStore
KeyValueStore
BucketScanner
// From returns a new Storm node with a new bucket root below the current.
// All DB operations on the new node will be executed relative to this bucket.
From(addend ...string) Node
// Bucket returns the bucket name as a slice from the root.
// In the normal, simple case this will be empty.
Bucket() []string
// GetBucket returns the given bucket below the current node.
GetBucket(tx *bolt.Tx, children ...string) *bolt.Bucket
// CreateBucketIfNotExists creates the bucket below the current node if it doesn't
// already exist.
CreateBucketIfNotExists(tx *bolt.Tx, bucket string) (*bolt.Bucket, error)
// WithTransaction returns a New Storm node that will use the given transaction.
WithTransaction(tx *bolt.Tx) Node
// Begin starts a new transaction.
Begin(writable bool) (Node, error)
// Codec used by this instance of Storm
Codec() codec.MarshalUnmarshaler
// WithCodec returns a New Storm Node that will use the given Codec.
WithCodec(codec codec.MarshalUnmarshaler) Node
// WithBatch returns a new Storm Node with the batch mode enabled.
WithBatch(enabled bool) Node
}
// A Node in Storm represents the API to a BoltDB bucket.
type node struct {
s *DB
// The root bucket. In the normal, simple case this will be empty.
rootBucket []string
// Transaction object. Nil if not in transaction
tx *bolt.Tx
// Codec of this node
codec codec.MarshalUnmarshaler
// Enable batch mode for read-write transaction, instead of update mode
batchMode bool
}
// From returns a new Storm Node with a new bucket root below the current.
// All DB operations on the new node will be executed relative to this bucket.
func (n node) From(addend ...string) Node {
n.rootBucket = append(n.rootBucket, addend...)
return &n
}
// WithTransaction returns a new Storm Node that will use the given transaction.
func (n node) WithTransaction(tx *bolt.Tx) Node {
n.tx = tx
return &n
}
// WithCodec returns a new Storm Node that will use the given Codec.
func (n node) WithCodec(codec codec.MarshalUnmarshaler) Node {
n.codec = codec
return &n
}
// WithBatch returns a new Storm Node with the batch mode enabled.
func (n node) WithBatch(enabled bool) Node {
n.batchMode = enabled
return &n
}
// Bucket returns the bucket name as a slice from the root.
// In the normal, simple case this will be empty.
func (n *node) Bucket() []string {
return n.rootBucket
}
// Codec returns the EncodeDecoder used by this instance of Storm
func (n *node) Codec() codec.MarshalUnmarshaler {
return n.codec
}
// Detects if already in transaction or runs a read write transaction.
// Uses batch mode if enabled.
func (n *node) readWriteTx(fn func(tx *bolt.Tx) error) error {
if n.tx != nil {
return fn(n.tx)
}
if n.batchMode {
return n.s.Bolt.Batch(func(tx *bolt.Tx) error {
return fn(tx)
})
}
return n.s.Bolt.Update(func(tx *bolt.Tx) error {
return fn(tx)
})
}
// Detects if already in transaction or runs a read transaction.
func (n *node) readTx(fn func(tx *bolt.Tx) error) error {
if n.tx != nil {
return fn(n.tx)
}
return n.s.Bolt.View(func(tx *bolt.Tx) error {
return fn(tx)
})
}

97
vendor/github.com/asdine/storm/v3/options.go generated vendored Normal file
View File

@ -0,0 +1,97 @@
package storm
import (
"os"
"github.com/asdine/storm/v3/codec"
"github.com/asdine/storm/v3/index"
bolt "go.etcd.io/bbolt"
)
// BoltOptions used to pass options to BoltDB.
func BoltOptions(mode os.FileMode, options *bolt.Options) func(*Options) error {
return func(opts *Options) error {
opts.boltMode = mode
opts.boltOptions = options
return nil
}
}
// Codec used to set a custom encoder and decoder. The default is JSON.
func Codec(c codec.MarshalUnmarshaler) func(*Options) error {
return func(opts *Options) error {
opts.codec = c
return nil
}
}
// Batch enables the use of batch instead of update for read-write transactions.
func Batch() func(*Options) error {
return func(opts *Options) error {
opts.batchMode = true
return nil
}
}
// Root used to set the root bucket. See also the From method.
func Root(root ...string) func(*Options) error {
return func(opts *Options) error {
opts.rootBucket = root
return nil
}
}
// UseDB allows Storm to use an existing open Bolt.DB.
// Warning: storm.DB.Close() will close the bolt.DB instance.
func UseDB(b *bolt.DB) func(*Options) error {
return func(opts *Options) error {
opts.path = b.Path()
opts.bolt = b
return nil
}
}
// Limit sets the maximum number of records to return
func Limit(limit int) func(*index.Options) {
return func(opts *index.Options) {
opts.Limit = limit
}
}
// Skip sets the number of records to skip
func Skip(offset int) func(*index.Options) {
return func(opts *index.Options) {
opts.Skip = offset
}
}
// Reverse will return the results in descending order
func Reverse() func(*index.Options) {
return func(opts *index.Options) {
opts.Reverse = true
}
}
// Options are used to customize the way Storm opens a database.
type Options struct {
// Handles encoding and decoding of objects
codec codec.MarshalUnmarshaler
// Bolt file mode
boltMode os.FileMode
// Bolt options
boltOptions *bolt.Options
// Enable batch mode for read-write transaction, instead of update mode
batchMode bool
// The root bucket name
rootBucket []string
// Path of the database file
path string
// Bolt is still easily accessible
bolt *bolt.DB
}

122
vendor/github.com/asdine/storm/v3/q/compare.go generated vendored Normal file
View File

@ -0,0 +1,122 @@
package q
import (
"go/constant"
"go/token"
"reflect"
"strconv"
)
func compare(a, b interface{}, tok token.Token) bool {
vala := reflect.ValueOf(a)
valb := reflect.ValueOf(b)
ak := vala.Kind()
bk := valb.Kind()
switch {
// comparing nil values
case (ak == reflect.Ptr || ak == reflect.Slice || ak == reflect.Interface || ak == reflect.Invalid) &&
(bk == reflect.Ptr || ak == reflect.Slice || bk == reflect.Interface || bk == reflect.Invalid) &&
(!vala.IsValid() || vala.IsNil()) && (!valb.IsValid() || valb.IsNil()):
return true
case ak >= reflect.Int && ak <= reflect.Int64:
if bk >= reflect.Int && bk <= reflect.Int64 {
return constant.Compare(constant.MakeInt64(vala.Int()), tok, constant.MakeInt64(valb.Int()))
}
if bk >= reflect.Uint && bk <= reflect.Uint64 {
return constant.Compare(constant.MakeInt64(vala.Int()), tok, constant.MakeInt64(int64(valb.Uint())))
}
if bk == reflect.Float32 || bk == reflect.Float64 {
return constant.Compare(constant.MakeFloat64(float64(vala.Int())), tok, constant.MakeFloat64(valb.Float()))
}
if bk == reflect.String {
bla, err := strconv.ParseFloat(valb.String(), 64)
if err != nil {
return false
}
return constant.Compare(constant.MakeFloat64(float64(vala.Int())), tok, constant.MakeFloat64(bla))
}
case ak >= reflect.Uint && ak <= reflect.Uint64:
if bk >= reflect.Uint && bk <= reflect.Uint64 {
return constant.Compare(constant.MakeUint64(vala.Uint()), tok, constant.MakeUint64(valb.Uint()))
}
if bk >= reflect.Int && bk <= reflect.Int64 {
return constant.Compare(constant.MakeUint64(vala.Uint()), tok, constant.MakeUint64(uint64(valb.Int())))
}
if bk == reflect.Float32 || bk == reflect.Float64 {
return constant.Compare(constant.MakeFloat64(float64(vala.Uint())), tok, constant.MakeFloat64(valb.Float()))
}
if bk == reflect.String {
bla, err := strconv.ParseFloat(valb.String(), 64)
if err != nil {
return false
}
return constant.Compare(constant.MakeFloat64(float64(vala.Uint())), tok, constant.MakeFloat64(bla))
}
case ak == reflect.Float32 || ak == reflect.Float64:
if bk == reflect.Float32 || bk == reflect.Float64 {
return constant.Compare(constant.MakeFloat64(vala.Float()), tok, constant.MakeFloat64(valb.Float()))
}
if bk >= reflect.Int && bk <= reflect.Int64 {
return constant.Compare(constant.MakeFloat64(vala.Float()), tok, constant.MakeFloat64(float64(valb.Int())))
}
if bk >= reflect.Uint && bk <= reflect.Uint64 {
return constant.Compare(constant.MakeFloat64(vala.Float()), tok, constant.MakeFloat64(float64(valb.Uint())))
}
if bk == reflect.String {
bla, err := strconv.ParseFloat(valb.String(), 64)
if err != nil {
return false
}
return constant.Compare(constant.MakeFloat64(vala.Float()), tok, constant.MakeFloat64(bla))
}
case ak == reflect.String:
if bk == reflect.String {
return constant.Compare(constant.MakeString(vala.String()), tok, constant.MakeString(valb.String()))
}
}
typea, typeb := reflect.TypeOf(a), reflect.TypeOf(b)
if typea != nil && (typea.String() == "time.Time" || typea.String() == "*time.Time") &&
typeb != nil && (typeb.String() == "time.Time" || typeb.String() == "*time.Time") {
if typea.String() == "*time.Time" && vala.IsNil() {
return true
}
if typeb.String() == "*time.Time" {
if valb.IsNil() {
return true
}
valb = valb.Elem()
}
var x, y int64
x = 1
if vala.MethodByName("Equal").Call([]reflect.Value{valb})[0].Bool() {
y = 1
} else if vala.MethodByName("Before").Call([]reflect.Value{valb})[0].Bool() {
y = 2
}
return constant.Compare(constant.MakeInt64(x), tok, constant.MakeInt64(y))
}
if tok == token.EQL {
return reflect.DeepEqual(a, b)
}
return false
}

67
vendor/github.com/asdine/storm/v3/q/fieldmatcher.go generated vendored Normal file
View File

@ -0,0 +1,67 @@
package q
import (
"errors"
"go/token"
"reflect"
)
// ErrUnknownField is returned when an unknown field is passed.
var ErrUnknownField = errors.New("unknown field")
type fieldMatcherDelegate struct {
FieldMatcher
Field string
}
// NewFieldMatcher creates a Matcher for a given field.
func NewFieldMatcher(field string, fm FieldMatcher) Matcher {
return fieldMatcherDelegate{Field: field, FieldMatcher: fm}
}
// FieldMatcher can be used in NewFieldMatcher as a simple way to create the
// most common Matcher: A Matcher that evaluates one field's value.
// For more complex scenarios, implement the Matcher interface directly.
type FieldMatcher interface {
MatchField(v interface{}) (bool, error)
}
func (r fieldMatcherDelegate) Match(i interface{}) (bool, error) {
v := reflect.Indirect(reflect.ValueOf(i))
return r.MatchValue(&v)
}
func (r fieldMatcherDelegate) MatchValue(v *reflect.Value) (bool, error) {
field := v.FieldByName(r.Field)
if !field.IsValid() {
return false, ErrUnknownField
}
return r.MatchField(field.Interface())
}
// NewField2FieldMatcher creates a Matcher for a given field1 and field2.
func NewField2FieldMatcher(field1, field2 string, tok token.Token) Matcher {
return field2fieldMatcherDelegate{Field1: field1, Field2: field2, Tok: tok}
}
type field2fieldMatcherDelegate struct {
Field1, Field2 string
Tok token.Token
}
func (r field2fieldMatcherDelegate) Match(i interface{}) (bool, error) {
v := reflect.Indirect(reflect.ValueOf(i))
return r.MatchValue(&v)
}
func (r field2fieldMatcherDelegate) MatchValue(v *reflect.Value) (bool, error) {
field1 := v.FieldByName(r.Field1)
if !field1.IsValid() {
return false, ErrUnknownField
}
field2 := v.FieldByName(r.Field2)
if !field2.IsValid() {
return false, ErrUnknownField
}
return compare(field1.Interface(), field2.Interface(), r.Tok), nil
}

51
vendor/github.com/asdine/storm/v3/q/regexp.go generated vendored Normal file
View File

@ -0,0 +1,51 @@
package q
import (
"fmt"
"regexp"
"sync"
)
// Re creates a regexp matcher. It checks if the given field matches the given regexp.
// Note that this only supports fields of type string or []byte.
func Re(field string, re string) Matcher {
regexpCache.RLock()
if r, ok := regexpCache.m[re]; ok {
regexpCache.RUnlock()
return NewFieldMatcher(field, &regexpMatcher{r: r})
}
regexpCache.RUnlock()
regexpCache.Lock()
r, err := regexp.Compile(re)
if err == nil {
regexpCache.m[re] = r
}
regexpCache.Unlock()
return NewFieldMatcher(field, &regexpMatcher{r: r, err: err})
}
var regexpCache = struct {
sync.RWMutex
m map[string]*regexp.Regexp
}{m: make(map[string]*regexp.Regexp)}
type regexpMatcher struct {
r *regexp.Regexp
err error
}
func (r *regexpMatcher) MatchField(v interface{}) (bool, error) {
if r.err != nil {
return false, r.err
}
switch fieldValue := v.(type) {
case string:
return r.r.MatchString(fieldValue), nil
case []byte:
return r.r.Match(fieldValue), nil
default:
return false, fmt.Errorf("Only string and []byte supported for regexp matcher, got %T", fieldValue)
}
}

247
vendor/github.com/asdine/storm/v3/q/tree.go generated vendored Normal file
View File

@ -0,0 +1,247 @@
// Package q contains a list of Matchers used to compare struct fields with values
package q
import (
"go/token"
"reflect"
)
// A Matcher is used to test against a record to see if it matches.
type Matcher interface {
// Match is used to test the criteria against a structure.
Match(interface{}) (bool, error)
}
// A ValueMatcher is used to test against a reflect.Value.
type ValueMatcher interface {
// MatchValue tests if the given reflect.Value matches.
// It is useful when the reflect.Value of an object already exists.
MatchValue(*reflect.Value) (bool, error)
}
type cmp struct {
value interface{}
token token.Token
}
func (c *cmp) MatchField(v interface{}) (bool, error) {
return compare(v, c.value, c.token), nil
}
type trueMatcher struct{}
func (*trueMatcher) Match(i interface{}) (bool, error) {
return true, nil
}
func (*trueMatcher) MatchValue(v *reflect.Value) (bool, error) {
return true, nil
}
type or struct {
children []Matcher
}
func (c *or) Match(i interface{}) (bool, error) {
v := reflect.Indirect(reflect.ValueOf(i))
return c.MatchValue(&v)
}
func (c *or) MatchValue(v *reflect.Value) (bool, error) {
for _, matcher := range c.children {
if vm, ok := matcher.(ValueMatcher); ok {
ok, err := vm.MatchValue(v)
if err != nil {
return false, err
}
if ok {
return true, nil
}
continue
}
ok, err := matcher.Match(v.Interface())
if err != nil {
return false, err
}
if ok {
return true, nil
}
}
return false, nil
}
type and struct {
children []Matcher
}
func (c *and) Match(i interface{}) (bool, error) {
v := reflect.Indirect(reflect.ValueOf(i))
return c.MatchValue(&v)
}
func (c *and) MatchValue(v *reflect.Value) (bool, error) {
for _, matcher := range c.children {
if vm, ok := matcher.(ValueMatcher); ok {
ok, err := vm.MatchValue(v)
if err != nil {
return false, err
}
if !ok {
return false, nil
}
continue
}
ok, err := matcher.Match(v.Interface())
if err != nil {
return false, err
}
if !ok {
return false, nil
}
}
return true, nil
}
type strictEq struct {
field string
value interface{}
}
func (s *strictEq) MatchField(v interface{}) (bool, error) {
return reflect.DeepEqual(v, s.value), nil
}
type in struct {
list interface{}
}
func (i *in) MatchField(v interface{}) (bool, error) {
ref := reflect.ValueOf(i.list)
if ref.Kind() != reflect.Slice {
return false, nil
}
c := cmp{
token: token.EQL,
}
for i := 0; i < ref.Len(); i++ {
c.value = ref.Index(i).Interface()
ok, err := c.MatchField(v)
if err != nil {
return false, err
}
if ok {
return true, nil
}
}
return false, nil
}
type not struct {
children []Matcher
}
func (n *not) Match(i interface{}) (bool, error) {
v := reflect.Indirect(reflect.ValueOf(i))
return n.MatchValue(&v)
}
func (n *not) MatchValue(v *reflect.Value) (bool, error) {
var err error
for _, matcher := range n.children {
vm, ok := matcher.(ValueMatcher)
if ok {
ok, err = vm.MatchValue(v)
} else {
ok, err = matcher.Match(v.Interface())
}
if err != nil {
return false, err
}
if ok {
return false, nil
}
}
return true, nil
}
// Eq matcher, checks if the given field is equal to the given value
func Eq(field string, v interface{}) Matcher {
return NewFieldMatcher(field, &cmp{value: v, token: token.EQL})
}
// EqF matcher, checks if the given field is equal to the given field
func EqF(field1, field2 string) Matcher {
return NewField2FieldMatcher(field1, field2, token.EQL)
}
// StrictEq matcher, checks if the given field is deeply equal to the given value
func StrictEq(field string, v interface{}) Matcher {
return NewFieldMatcher(field, &strictEq{value: v})
}
// Gt matcher, checks if the given field is greater than the given value
func Gt(field string, v interface{}) Matcher {
return NewFieldMatcher(field, &cmp{value: v, token: token.GTR})
}
// GtF matcher, checks if the given field is greater than the given field
func GtF(field1, field2 string) Matcher {
return NewField2FieldMatcher(field1, field2, token.GTR)
}
// Gte matcher, checks if the given field is greater than or equal to the given value
func Gte(field string, v interface{}) Matcher {
return NewFieldMatcher(field, &cmp{value: v, token: token.GEQ})
}
// GteF matcher, checks if the given field is greater than or equal to the given field
func GteF(field1, field2 string) Matcher {
return NewField2FieldMatcher(field1, field2, token.GEQ)
}
// Lt matcher, checks if the given field is lesser than the given value
func Lt(field string, v interface{}) Matcher {
return NewFieldMatcher(field, &cmp{value: v, token: token.LSS})
}
// LtF matcher, checks if the given field is lesser than the given field
func LtF(field1, field2 string) Matcher {
return NewField2FieldMatcher(field1, field2, token.LSS)
}
// Lte matcher, checks if the given field is lesser than or equal to the given value
func Lte(field string, v interface{}) Matcher {
return NewFieldMatcher(field, &cmp{value: v, token: token.LEQ})
}
// LteF matcher, checks if the given field is lesser than or equal to the given field
func LteF(field1, field2 string) Matcher {
return NewField2FieldMatcher(field1, field2, token.LEQ)
}
// In matcher, checks if the given field matches one of the value of the given slice.
// v must be a slice.
func In(field string, v interface{}) Matcher {
return NewFieldMatcher(field, &in{list: v})
}
// True matcher, always returns true
func True() Matcher { return &trueMatcher{} }
// Or matcher, checks if at least one of the given matchers matches the record
func Or(matchers ...Matcher) Matcher { return &or{children: matchers} }
// And matcher, checks if all of the given matchers matches the record
func And(matchers ...Matcher) Matcher { return &and{children: matchers} }
// Not matcher, checks if all of the given matchers return false
func Not(matchers ...Matcher) Matcher { return &not{children: matchers} }

219
vendor/github.com/asdine/storm/v3/query.go generated vendored Normal file
View File

@ -0,0 +1,219 @@
package storm
import (
"github.com/asdine/storm/v3/internal"
"github.com/asdine/storm/v3/q"
bolt "go.etcd.io/bbolt"
)
// Select a list of records that match a list of matchers. Doesn't use indexes.
func (n *node) Select(matchers ...q.Matcher) Query {
tree := q.And(matchers...)
return newQuery(n, tree)
}
// Query is the low level query engine used by Storm. It allows to operate searches through an entire bucket.
type Query interface {
// Skip matching records by the given number
Skip(int) Query
// Limit the results by the given number
Limit(int) Query
// Order by the given fields, in descending precedence, left-to-right.
OrderBy(...string) Query
// Reverse the order of the results
Reverse() Query
// Bucket specifies the bucket name
Bucket(string) Query
// Find a list of matching records
Find(interface{}) error
// First gets the first matching record
First(interface{}) error
// Delete all matching records
Delete(interface{}) error
// Count all the matching records
Count(interface{}) (int, error)
// Returns all the records without decoding them
Raw() ([][]byte, error)
// Execute the given function for each raw element
RawEach(func([]byte, []byte) error) error
// Execute the given function for each element
Each(interface{}, func(interface{}) error) error
}
func newQuery(n *node, tree q.Matcher) *query {
return &query{
skip: 0,
limit: -1,
node: n,
tree: tree,
}
}
type query struct {
limit int
skip int
reverse bool
tree q.Matcher
node *node
bucket string
orderBy []string
}
func (q *query) Skip(nb int) Query {
q.skip = nb
return q
}
func (q *query) Limit(nb int) Query {
q.limit = nb
return q
}
func (q *query) OrderBy(field ...string) Query {
q.orderBy = field
return q
}
func (q *query) Reverse() Query {
q.reverse = true
return q
}
func (q *query) Bucket(bucketName string) Query {
q.bucket = bucketName
return q
}
func (q *query) Find(to interface{}) error {
sink, err := newListSink(q.node, to)
if err != nil {
return err
}
return q.runQuery(sink)
}
func (q *query) First(to interface{}) error {
sink, err := newFirstSink(q.node, to)
if err != nil {
return err
}
q.limit = 1
return q.runQuery(sink)
}
func (q *query) Delete(kind interface{}) error {
sink, err := newDeleteSink(q.node, kind)
if err != nil {
return err
}
return q.runQuery(sink)
}
func (q *query) Count(kind interface{}) (int, error) {
sink, err := newCountSink(q.node, kind)
if err != nil {
return 0, err
}
err = q.runQuery(sink)
if err != nil {
return 0, err
}
return sink.counter, nil
}
func (q *query) Raw() ([][]byte, error) {
sink := newRawSink()
err := q.runQuery(sink)
if err != nil {
return nil, err
}
return sink.results, nil
}
func (q *query) RawEach(fn func([]byte, []byte) error) error {
sink := newRawSink()
sink.execFn = fn
return q.runQuery(sink)
}
func (q *query) Each(kind interface{}, fn func(interface{}) error) error {
sink, err := newEachSink(kind)
if err != nil {
return err
}
sink.execFn = fn
return q.runQuery(sink)
}
func (q *query) runQuery(sink sink) error {
if q.node.tx != nil {
return q.query(q.node.tx, sink)
}
if sink.readOnly() {
return q.node.s.Bolt.View(func(tx *bolt.Tx) error {
return q.query(tx, sink)
})
}
return q.node.s.Bolt.Update(func(tx *bolt.Tx) error {
return q.query(tx, sink)
})
}
func (q *query) query(tx *bolt.Tx, sink sink) error {
bucketName := q.bucket
if bucketName == "" {
bucketName = sink.bucketName()
}
bucket := q.node.GetBucket(tx, bucketName)
if q.limit == 0 {
return sink.flush()
}
sorter := newSorter(q.node, sink)
sorter.orderBy = q.orderBy
sorter.reverse = q.reverse
sorter.skip = q.skip
sorter.limit = q.limit
if bucket != nil {
c := internal.Cursor{C: bucket.Cursor(), Reverse: q.reverse}
for k, v := c.First(); k != nil; k, v = c.Next() {
if v == nil {
continue
}
stop, err := sorter.filter(q.tree, bucket, k, v)
if err != nil {
return err
}
if stop {
break
}
}
}
return sorter.flush()
}

105
vendor/github.com/asdine/storm/v3/scan.go generated vendored Normal file
View File

@ -0,0 +1,105 @@
package storm
import (
"bytes"
bolt "go.etcd.io/bbolt"
)
// A BucketScanner scans a Node for a list of buckets
type BucketScanner interface {
// PrefixScan scans the root buckets for keys matching the given prefix.
PrefixScan(prefix string) []Node
// PrefixScan scans the buckets in this node for keys matching the given prefix.
RangeScan(min, max string) []Node
}
// PrefixScan scans the buckets in this node for keys matching the given prefix.
func (n *node) PrefixScan(prefix string) []Node {
if n.tx != nil {
return n.prefixScan(n.tx, prefix)
}
var nodes []Node
n.readTx(func(tx *bolt.Tx) error {
nodes = n.prefixScan(tx, prefix)
return nil
})
return nodes
}
func (n *node) prefixScan(tx *bolt.Tx, prefix string) []Node {
var (
prefixBytes = []byte(prefix)
nodes []Node
c = n.cursor(tx)
)
if c == nil {
return nil
}
for k, v := c.Seek(prefixBytes); k != nil && bytes.HasPrefix(k, prefixBytes); k, v = c.Next() {
if v != nil {
continue
}
nodes = append(nodes, n.From(string(k)))
}
return nodes
}
// RangeScan scans the buckets in this node over a range such as a sortable time range.
func (n *node) RangeScan(min, max string) []Node {
if n.tx != nil {
return n.rangeScan(n.tx, min, max)
}
var nodes []Node
n.readTx(func(tx *bolt.Tx) error {
nodes = n.rangeScan(tx, min, max)
return nil
})
return nodes
}
func (n *node) rangeScan(tx *bolt.Tx, min, max string) []Node {
var (
minBytes = []byte(min)
maxBytes = []byte(max)
nodes []Node
c = n.cursor(tx)
)
for k, v := c.Seek(minBytes); k != nil && bytes.Compare(k, maxBytes) <= 0; k, v = c.Next() {
if v != nil {
continue
}
nodes = append(nodes, n.From(string(k)))
}
return nodes
}
func (n *node) cursor(tx *bolt.Tx) *bolt.Cursor {
var c *bolt.Cursor
if len(n.rootBucket) > 0 {
b := n.GetBucket(tx)
if b == nil {
return nil
}
c = b.Cursor()
} else {
c = tx.Cursor()
}
return c
}

620
vendor/github.com/asdine/storm/v3/sink.go generated vendored Normal file
View File

@ -0,0 +1,620 @@
package storm
import (
"reflect"
"sort"
"time"
"github.com/asdine/storm/v3/index"
"github.com/asdine/storm/v3/q"
bolt "go.etcd.io/bbolt"
)
type item struct {
value *reflect.Value
bucket *bolt.Bucket
k []byte
v []byte
}
func newSorter(n Node, snk sink) *sorter {
return &sorter{
node: n,
sink: snk,
skip: 0,
limit: -1,
list: make([]*item, 0),
err: make(chan error),
done: make(chan struct{}),
}
}
type sorter struct {
node Node
sink sink
list []*item
skip int
limit int
orderBy []string
reverse bool
err chan error
done chan struct{}
}
func (s *sorter) filter(tree q.Matcher, bucket *bolt.Bucket, k, v []byte) (bool, error) {
itm := &item{
bucket: bucket,
k: k,
v: v,
}
rsink, ok := s.sink.(reflectSink)
if !ok {
return s.add(itm)
}
newElem := rsink.elem()
if err := s.node.Codec().Unmarshal(v, newElem.Interface()); err != nil {
return false, err
}
itm.value = &newElem
if tree != nil {
ok, err := tree.Match(newElem.Interface())
if err != nil {
return false, err
}
if !ok {
return false, nil
}
}
if len(s.orderBy) == 0 {
return s.add(itm)
}
if _, ok := s.sink.(sliceSink); ok {
// add directly to sink, we'll apply skip/limits after sorting
return false, s.sink.add(itm)
}
s.list = append(s.list, itm)
return false, nil
}
func (s *sorter) add(itm *item) (stop bool, err error) {
if s.limit == 0 {
return true, nil
}
if s.skip > 0 {
s.skip--
return false, nil
}
if s.limit > 0 {
s.limit--
}
err = s.sink.add(itm)
return s.limit == 0, err
}
func (s *sorter) compareValue(left reflect.Value, right reflect.Value) int {
if !left.IsValid() || !right.IsValid() {
if left.IsValid() {
return 1
}
return -1
}
switch left.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
l, r := left.Int(), right.Int()
if l < r {
return -1
}
if l > r {
return 1
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
l, r := left.Uint(), right.Uint()
if l < r {
return -1
}
if l > r {
return 1
}
case reflect.Float32, reflect.Float64:
l, r := left.Float(), right.Float()
if l < r {
return -1
}
if l > r {
return 1
}
case reflect.String:
l, r := left.String(), right.String()
if l < r {
return -1
}
if l > r {
return 1
}
case reflect.Struct:
if lt, lok := left.Interface().(time.Time); lok {
if rt, rok := right.Interface().(time.Time); rok {
if lok && rok {
if lt.Before(rt) {
return -1
} else {
return 1
}
}
}
}
default:
rawLeft, err := toBytes(left.Interface(), s.node.Codec())
if err != nil {
return -1
}
rawRight, err := toBytes(right.Interface(), s.node.Codec())
if err != nil {
return 1
}
l, r := string(rawLeft), string(rawRight)
if l < r {
return -1
}
if l > r {
return 1
}
}
return 0
}
func (s *sorter) less(leftElem reflect.Value, rightElem reflect.Value) bool {
for _, orderBy := range s.orderBy {
leftField := reflect.Indirect(leftElem).FieldByName(orderBy)
if !leftField.IsValid() {
s.err <- ErrNotFound
return false
}
rightField := reflect.Indirect(rightElem).FieldByName(orderBy)
if !rightField.IsValid() {
s.err <- ErrNotFound
return false
}
direction := 1
if s.reverse {
direction = -1
}
switch s.compareValue(leftField, rightField) * direction {
case -1:
return true
case 1:
return false
default:
continue
}
}
return false
}
func (s *sorter) flush() error {
if len(s.orderBy) == 0 {
return s.sink.flush()
}
go func() {
sort.Sort(s)
close(s.err)
}()
err := <-s.err
close(s.done)
if err != nil {
return err
}
if ssink, ok := s.sink.(sliceSink); ok {
if !ssink.slice().IsValid() {
return s.sink.flush()
}
if s.skip >= ssink.slice().Len() {
ssink.reset()
return s.sink.flush()
}
leftBound := s.skip
if leftBound < 0 {
leftBound = 0
}
limit := s.limit
if s.limit < 0 {
limit = 0
}
rightBound := leftBound + limit
if rightBound > ssink.slice().Len() || rightBound == leftBound {
rightBound = ssink.slice().Len()
}
ssink.setSlice(ssink.slice().Slice(leftBound, rightBound))
return s.sink.flush()
}
for _, itm := range s.list {
if itm == nil {
break
}
stop, err := s.add(itm)
if err != nil {
return err
}
if stop {
break
}
}
return s.sink.flush()
}
func (s *sorter) Len() int {
// skip if we encountered an earlier error
select {
case <-s.done:
return 0
default:
}
if ssink, ok := s.sink.(sliceSink); ok {
return ssink.slice().Len()
}
return len(s.list)
}
func (s *sorter) Less(i, j int) bool {
// skip if we encountered an earlier error
select {
case <-s.done:
return false
default:
}
if ssink, ok := s.sink.(sliceSink); ok {
return s.less(ssink.slice().Index(i), ssink.slice().Index(j))
}
return s.less(*s.list[i].value, *s.list[j].value)
}
type sink interface {
bucketName() string
flush() error
add(*item) error
readOnly() bool
}
type reflectSink interface {
elem() reflect.Value
}
type sliceSink interface {
slice() reflect.Value
setSlice(reflect.Value)
reset()
}
func newListSink(node Node, to interface{}) (*listSink, error) {
ref := reflect.ValueOf(to)
if ref.Kind() != reflect.Ptr || reflect.Indirect(ref).Kind() != reflect.Slice {
return nil, ErrSlicePtrNeeded
}
sliceType := reflect.Indirect(ref).Type()
elemType := sliceType.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Name() == "" {
return nil, ErrNoName
}
return &listSink{
node: node,
ref: ref,
isPtr: sliceType.Elem().Kind() == reflect.Ptr,
elemType: elemType,
name: elemType.Name(),
results: reflect.MakeSlice(reflect.Indirect(ref).Type(), 0, 0),
}, nil
}
type listSink struct {
node Node
ref reflect.Value
results reflect.Value
elemType reflect.Type
name string
isPtr bool
idx int
}
func (l *listSink) slice() reflect.Value {
return l.results
}
func (l *listSink) setSlice(s reflect.Value) {
l.results = s
}
func (l *listSink) reset() {
l.results = reflect.MakeSlice(reflect.Indirect(l.ref).Type(), 0, 0)
}
func (l *listSink) elem() reflect.Value {
if l.results.IsValid() && l.idx < l.results.Len() {
return l.results.Index(l.idx).Addr()
}
return reflect.New(l.elemType)
}
func (l *listSink) bucketName() string {
return l.name
}
func (l *listSink) add(i *item) error {
if l.idx == l.results.Len() {
if l.isPtr {
l.results = reflect.Append(l.results, *i.value)
} else {
l.results = reflect.Append(l.results, reflect.Indirect(*i.value))
}
}
l.idx++
return nil
}
func (l *listSink) flush() error {
if l.results.IsValid() && l.results.Len() > 0 {
reflect.Indirect(l.ref).Set(l.results)
return nil
}
return ErrNotFound
}
func (l *listSink) readOnly() bool {
return true
}
func newFirstSink(node Node, to interface{}) (*firstSink, error) {
ref := reflect.ValueOf(to)
if !ref.IsValid() || ref.Kind() != reflect.Ptr || ref.Elem().Kind() != reflect.Struct {
return nil, ErrStructPtrNeeded
}
return &firstSink{
node: node,
ref: ref,
}, nil
}
type firstSink struct {
node Node
ref reflect.Value
found bool
}
func (f *firstSink) elem() reflect.Value {
return reflect.New(reflect.Indirect(f.ref).Type())
}
func (f *firstSink) bucketName() string {
return reflect.Indirect(f.ref).Type().Name()
}
func (f *firstSink) add(i *item) error {
reflect.Indirect(f.ref).Set(i.value.Elem())
f.found = true
return nil
}
func (f *firstSink) flush() error {
if !f.found {
return ErrNotFound
}
return nil
}
func (f *firstSink) readOnly() bool {
return true
}
func newDeleteSink(node Node, kind interface{}) (*deleteSink, error) {
ref := reflect.ValueOf(kind)
if !ref.IsValid() || ref.Kind() != reflect.Ptr || ref.Elem().Kind() != reflect.Struct {
return nil, ErrStructPtrNeeded
}
return &deleteSink{
node: node,
ref: ref,
}, nil
}
type deleteSink struct {
node Node
ref reflect.Value
removed int
}
func (d *deleteSink) elem() reflect.Value {
return reflect.New(reflect.Indirect(d.ref).Type())
}
func (d *deleteSink) bucketName() string {
return reflect.Indirect(d.ref).Type().Name()
}
func (d *deleteSink) add(i *item) error {
info, err := extract(&d.ref)
if err != nil {
return err
}
for fieldName, fieldCfg := range info.Fields {
if fieldCfg.Index == "" {
continue
}
idx, err := getIndex(i.bucket, fieldCfg.Index, fieldName)
if err != nil {
return err
}
err = idx.RemoveID(i.k)
if err != nil {
if err == index.ErrNotFound {
return ErrNotFound
}
return err
}
}
d.removed++
return i.bucket.Delete(i.k)
}
func (d *deleteSink) flush() error {
if d.removed == 0 {
return ErrNotFound
}
return nil
}
func (d *deleteSink) readOnly() bool {
return false
}
func newCountSink(node Node, kind interface{}) (*countSink, error) {
ref := reflect.ValueOf(kind)
if !ref.IsValid() || ref.Kind() != reflect.Ptr || ref.Elem().Kind() != reflect.Struct {
return nil, ErrStructPtrNeeded
}
return &countSink{
node: node,
ref: ref,
}, nil
}
type countSink struct {
node Node
ref reflect.Value
counter int
}
func (c *countSink) elem() reflect.Value {
return reflect.New(reflect.Indirect(c.ref).Type())
}
func (c *countSink) bucketName() string {
return reflect.Indirect(c.ref).Type().Name()
}
func (c *countSink) add(i *item) error {
c.counter++
return nil
}
func (c *countSink) flush() error {
return nil
}
func (c *countSink) readOnly() bool {
return true
}
func newRawSink() *rawSink {
return &rawSink{}
}
type rawSink struct {
results [][]byte
execFn func([]byte, []byte) error
}
func (r *rawSink) add(i *item) error {
if r.execFn != nil {
err := r.execFn(i.k, i.v)
if err != nil {
return err
}
} else {
r.results = append(r.results, i.v)
}
return nil
}
func (r *rawSink) bucketName() string {
return ""
}
func (r *rawSink) flush() error {
return nil
}
func (r *rawSink) readOnly() bool {
return true
}
func newEachSink(to interface{}) (*eachSink, error) {
ref := reflect.ValueOf(to)
if !ref.IsValid() || ref.Kind() != reflect.Ptr || ref.Elem().Kind() != reflect.Struct {
return nil, ErrStructPtrNeeded
}
return &eachSink{
ref: ref,
}, nil
}
type eachSink struct {
ref reflect.Value
execFn func(interface{}) error
}
func (e *eachSink) elem() reflect.Value {
return reflect.New(reflect.Indirect(e.ref).Type())
}
func (e *eachSink) bucketName() string {
return reflect.Indirect(e.ref).Type().Name()
}
func (e *eachSink) add(i *item) error {
return e.execFn(i.value.Interface())
}
func (e *eachSink) flush() error {
return nil
}
func (e *eachSink) readOnly() bool {
return true
}

22
vendor/github.com/asdine/storm/v3/sink_sorter_swap.go generated vendored Normal file
View File

@ -0,0 +1,22 @@
// +build !go1.8
package storm
import "reflect"
func (s *sorter) Swap(i, j int) {
// skip if we encountered an earlier error
select {
case <-s.done:
return
default:
}
if ssink, ok := s.sink.(sliceSink); ok {
x, y := ssink.slice().Index(i).Interface(), ssink.slice().Index(j).Interface()
ssink.slice().Index(i).Set(reflect.ValueOf(y))
ssink.slice().Index(j).Set(reflect.ValueOf(x))
} else {
s.list[i], s.list[j] = s.list[j], s.list[i]
}
}

View File

@ -0,0 +1,20 @@
// +build go1.8
package storm
import "reflect"
func (s *sorter) Swap(i, j int) {
// skip if we encountered an earlier error
select {
case <-s.done:
return
default:
}
if ssink, ok := s.sink.(sliceSink); ok {
reflect.Swapper(ssink.slice().Interface())(i, j)
} else {
s.list[i], s.list[j] = s.list[j], s.list[i]
}
}

425
vendor/github.com/asdine/storm/v3/store.go generated vendored Normal file
View File

@ -0,0 +1,425 @@
package storm
import (
"bytes"
"reflect"
"github.com/asdine/storm/v3/index"
"github.com/asdine/storm/v3/q"
bolt "go.etcd.io/bbolt"
)
// TypeStore stores user defined types in BoltDB.
type TypeStore interface {
Finder
// Init creates the indexes and buckets for a given structure
Init(data interface{}) error
// ReIndex rebuilds all the indexes of a bucket
ReIndex(data interface{}) error
// Save a structure
Save(data interface{}) error
// Update a structure
Update(data interface{}) error
// UpdateField updates a single field
UpdateField(data interface{}, fieldName string, value interface{}) error
// Drop a bucket
Drop(data interface{}) error
// DeleteStruct deletes a structure from the associated bucket
DeleteStruct(data interface{}) error
}
// Init creates the indexes and buckets for a given structure
func (n *node) Init(data interface{}) error {
v := reflect.ValueOf(data)
cfg, err := extract(&v)
if err != nil {
return err
}
return n.readWriteTx(func(tx *bolt.Tx) error {
return n.init(tx, cfg)
})
}
func (n *node) init(tx *bolt.Tx, cfg *structConfig) error {
bucket, err := n.CreateBucketIfNotExists(tx, cfg.Name)
if err != nil {
return err
}
// save node configuration in the bucket
_, err = newMeta(bucket, n)
if err != nil {
return err
}
for fieldName, fieldCfg := range cfg.Fields {
if fieldCfg.Index == "" {
continue
}
switch fieldCfg.Index {
case tagUniqueIdx:
_, err = index.NewUniqueIndex(bucket, []byte(indexPrefix+fieldName))
case tagIdx:
_, err = index.NewListIndex(bucket, []byte(indexPrefix+fieldName))
default:
err = ErrIdxNotFound
}
if err != nil {
return err
}
}
return nil
}
func (n *node) ReIndex(data interface{}) error {
ref := reflect.ValueOf(data)
if !ref.IsValid() || ref.Kind() != reflect.Ptr || ref.Elem().Kind() != reflect.Struct {
return ErrStructPtrNeeded
}
cfg, err := extract(&ref)
if err != nil {
return err
}
return n.readWriteTx(func(tx *bolt.Tx) error {
return n.reIndex(tx, data, cfg)
})
}
func (n *node) reIndex(tx *bolt.Tx, data interface{}, cfg *structConfig) error {
root := n.WithTransaction(tx)
nodes := root.From(cfg.Name).PrefixScan(indexPrefix)
bucket := root.GetBucket(tx, cfg.Name)
if bucket == nil {
return ErrNotFound
}
for _, node := range nodes {
buckets := node.Bucket()
name := buckets[len(buckets)-1]
err := bucket.DeleteBucket([]byte(name))
if err != nil {
return err
}
}
total, err := root.Count(data)
if err != nil {
return err
}
for i := 0; i < total; i++ {
err = root.Select(q.True()).Skip(i).First(data)
if err != nil {
return err
}
err = root.Update(data)
if err != nil {
return err
}
}
return nil
}
// Save a structure
func (n *node) Save(data interface{}) error {
ref := reflect.ValueOf(data)
if !ref.IsValid() || ref.Kind() != reflect.Ptr || ref.Elem().Kind() != reflect.Struct {
return ErrStructPtrNeeded
}
cfg, err := extract(&ref)
if err != nil {
return err
}
if cfg.ID.IsZero {
if !cfg.ID.IsInteger || !cfg.ID.Increment {
return ErrZeroID
}
}
return n.readWriteTx(func(tx *bolt.Tx) error {
return n.save(tx, cfg, data, false)
})
}
func (n *node) save(tx *bolt.Tx, cfg *structConfig, data interface{}, update bool) error {
bucket, err := n.CreateBucketIfNotExists(tx, cfg.Name)
if err != nil {
return err
}
// save node configuration in the bucket
meta, err := newMeta(bucket, n)
if err != nil {
return err
}
if cfg.ID.IsZero {
err = meta.increment(cfg.ID)
if err != nil {
return err
}
}
id, err := toBytes(cfg.ID.Value.Interface(), n.codec)
if err != nil {
return err
}
for fieldName, fieldCfg := range cfg.Fields {
if !update && !fieldCfg.IsID && fieldCfg.Increment && fieldCfg.IsInteger && fieldCfg.IsZero {
err = meta.increment(fieldCfg)
if err != nil {
return err
}
}
if fieldCfg.Index == "" {
continue
}
idx, err := getIndex(bucket, fieldCfg.Index, fieldName)
if err != nil {
return err
}
if update && fieldCfg.IsZero && !fieldCfg.ForceUpdate {
continue
}
if fieldCfg.IsZero {
err = idx.RemoveID(id)
if err != nil {
return err
}
continue
}
value, err := toBytes(fieldCfg.Value.Interface(), n.codec)
if err != nil {
return err
}
var found bool
idsSaved, err := idx.All(value, nil)
if err != nil {
return err
}
for _, idSaved := range idsSaved {
if bytes.Compare(idSaved, id) == 0 {
found = true
break
}
}
if found {
continue
}
err = idx.RemoveID(id)
if err != nil {
return err
}
err = idx.Add(value, id)
if err != nil {
if err == index.ErrAlreadyExists {
return ErrAlreadyExists
}
return err
}
}
raw, err := n.codec.Marshal(data)
if err != nil {
return err
}
return bucket.Put(id, raw)
}
// Update a structure
func (n *node) Update(data interface{}) error {
return n.update(data, func(ref *reflect.Value, current *reflect.Value, cfg *structConfig) error {
numfield := ref.NumField()
for i := 0; i < numfield; i++ {
f := ref.Field(i)
if ref.Type().Field(i).PkgPath != "" {
continue
}
zero := reflect.Zero(f.Type()).Interface()
actual := f.Interface()
if !reflect.DeepEqual(actual, zero) {
cf := current.Field(i)
cf.Set(f)
idxInfo, ok := cfg.Fields[ref.Type().Field(i).Name]
if ok {
idxInfo.Value = &cf
}
}
}
return nil
})
}
// UpdateField updates a single field
func (n *node) UpdateField(data interface{}, fieldName string, value interface{}) error {
return n.update(data, func(ref *reflect.Value, current *reflect.Value, cfg *structConfig) error {
f := current.FieldByName(fieldName)
if !f.IsValid() {
return ErrNotFound
}
tf, _ := current.Type().FieldByName(fieldName)
if tf.PkgPath != "" {
return ErrNotFound
}
v := reflect.ValueOf(value)
if v.Kind() != f.Kind() {
return ErrIncompatibleValue
}
f.Set(v)
idxInfo, ok := cfg.Fields[fieldName]
if ok {
idxInfo.Value = &f
idxInfo.IsZero = isZero(idxInfo.Value)
idxInfo.ForceUpdate = true
}
return nil
})
}
func (n *node) update(data interface{}, fn func(*reflect.Value, *reflect.Value, *structConfig) error) error {
ref := reflect.ValueOf(data)
if !ref.IsValid() || ref.Kind() != reflect.Ptr || ref.Elem().Kind() != reflect.Struct {
return ErrStructPtrNeeded
}
cfg, err := extract(&ref)
if err != nil {
return err
}
if cfg.ID.IsZero {
return ErrNoID
}
current := reflect.New(reflect.Indirect(ref).Type())
return n.readWriteTx(func(tx *bolt.Tx) error {
err = n.WithTransaction(tx).One(cfg.ID.Name, cfg.ID.Value.Interface(), current.Interface())
if err != nil {
return err
}
ref := reflect.ValueOf(data).Elem()
cref := current.Elem()
err = fn(&ref, &cref, cfg)
if err != nil {
return err
}
return n.save(tx, cfg, current.Interface(), true)
})
}
// Drop a bucket
func (n *node) Drop(data interface{}) error {
var bucketName string
v := reflect.ValueOf(data)
if v.Kind() != reflect.String {
info, err := extract(&v)
if err != nil {
return err
}
bucketName = info.Name
} else {
bucketName = v.Interface().(string)
}
return n.readWriteTx(func(tx *bolt.Tx) error {
return n.drop(tx, bucketName)
})
}
func (n *node) drop(tx *bolt.Tx, bucketName string) error {
bucket := n.GetBucket(tx)
if bucket == nil {
return tx.DeleteBucket([]byte(bucketName))
}
return bucket.DeleteBucket([]byte(bucketName))
}
// DeleteStruct deletes a structure from the associated bucket
func (n *node) DeleteStruct(data interface{}) error {
ref := reflect.ValueOf(data)
if !ref.IsValid() || ref.Kind() != reflect.Ptr || ref.Elem().Kind() != reflect.Struct {
return ErrStructPtrNeeded
}
cfg, err := extract(&ref)
if err != nil {
return err
}
id, err := toBytes(cfg.ID.Value.Interface(), n.codec)
if err != nil {
return err
}
return n.readWriteTx(func(tx *bolt.Tx) error {
return n.deleteStruct(tx, cfg, id)
})
}
func (n *node) deleteStruct(tx *bolt.Tx, cfg *structConfig, id []byte) error {
bucket := n.GetBucket(tx, cfg.Name)
if bucket == nil {
return ErrNotFound
}
for fieldName, fieldCfg := range cfg.Fields {
if fieldCfg.Index == "" {
continue
}
idx, err := getIndex(bucket, fieldCfg.Index, fieldName)
if err != nil {
return err
}
err = idx.RemoveID(id)
if err != nil {
if err == index.ErrNotFound {
return ErrNotFound
}
return err
}
}
raw := bucket.Get(id)
if raw == nil {
return ErrNotFound
}
return bucket.Delete(id)
}

142
vendor/github.com/asdine/storm/v3/storm.go generated vendored Normal file
View File

@ -0,0 +1,142 @@
package storm
import (
"bytes"
"encoding/binary"
"time"
"github.com/asdine/storm/v3/codec"
"github.com/asdine/storm/v3/codec/json"
bolt "go.etcd.io/bbolt"
)
const (
dbinfo = "__storm_db"
metadataBucket = "__storm_metadata"
)
// Defaults to json
var defaultCodec = json.Codec
// Open opens a database at the given path with optional Storm options.
func Open(path string, stormOptions ...func(*Options) error) (*DB, error) {
var err error
var opts Options
for _, option := range stormOptions {
if err = option(&opts); err != nil {
return nil, err
}
}
s := DB{
Bolt: opts.bolt,
}
n := node{
s: &s,
codec: opts.codec,
batchMode: opts.batchMode,
rootBucket: opts.rootBucket,
}
if n.codec == nil {
n.codec = defaultCodec
}
if opts.boltMode == 0 {
opts.boltMode = 0600
}
if opts.boltOptions == nil {
opts.boltOptions = &bolt.Options{Timeout: 1 * time.Second}
}
s.Node = &n
// skip if UseDB option is used
if s.Bolt == nil {
s.Bolt, err = bolt.Open(path, opts.boltMode, opts.boltOptions)
if err != nil {
return nil, err
}
}
err = s.checkVersion()
if err != nil {
return nil, err
}
return &s, nil
}
// DB is the wrapper around BoltDB. It contains an instance of BoltDB and uses it to perform all the
// needed operations
type DB struct {
// The root node that points to the root bucket.
Node
// Bolt is still easily accessible
Bolt *bolt.DB
}
// Close the database
func (s *DB) Close() error {
return s.Bolt.Close()
}
func (s *DB) checkVersion() error {
var v string
err := s.Get(dbinfo, "version", &v)
if err != nil && err != ErrNotFound {
return err
}
// for now, we only set the current version if it doesn't exist.
// v1 and v2 database files are compatible.
if v == "" {
return s.Set(dbinfo, "version", Version)
}
return nil
}
// toBytes turns an interface into a slice of bytes
func toBytes(key interface{}, codec codec.MarshalUnmarshaler) ([]byte, error) {
if key == nil {
return nil, nil
}
switch t := key.(type) {
case []byte:
return t, nil
case string:
return []byte(t), nil
case int:
return numbertob(int64(t))
case uint:
return numbertob(uint64(t))
case int8, int16, int32, int64, uint8, uint16, uint32, uint64:
return numbertob(t)
default:
return codec.Marshal(key)
}
}
func numbertob(v interface{}) ([]byte, error) {
var buf bytes.Buffer
err := binary.Write(&buf, binary.BigEndian, v)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func numberfromb(raw []byte) (int64, error) {
r := bytes.NewReader(raw)
var to int64
err := binary.Read(r, binary.BigEndian, &to)
if err != nil {
return 0, err
}
return to, nil
}

52
vendor/github.com/asdine/storm/v3/transaction.go generated vendored Normal file
View File

@ -0,0 +1,52 @@
package storm
import bolt "go.etcd.io/bbolt"
// Tx is a transaction.
type Tx interface {
// Commit writes all changes to disk.
Commit() error
// Rollback closes the transaction and ignores all previous updates.
Rollback() error
}
// Begin starts a new transaction.
func (n node) Begin(writable bool) (Node, error) {
var err error
n.tx, err = n.s.Bolt.Begin(writable)
if err != nil {
return nil, err
}
return &n, nil
}
// Rollback closes the transaction and ignores all previous updates.
func (n *node) Rollback() error {
if n.tx == nil {
return ErrNotInTransaction
}
err := n.tx.Rollback()
if err == bolt.ErrTxClosed {
return ErrNotInTransaction
}
return err
}
// Commit writes all changes to disk.
func (n *node) Commit() error {
if n.tx == nil {
return ErrNotInTransaction
}
err := n.tx.Commit()
if err == bolt.ErrTxClosed {
return ErrNotInTransaction
}
return err
}

4
vendor/github.com/asdine/storm/v3/version.go generated vendored Normal file
View File

@ -0,0 +1,4 @@
package storm
// Version of Storm
const Version = "2.0.0"

15
vendor/github.com/davecgh/go-spew/LICENSE generated vendored Normal file
View File

@ -0,0 +1,15 @@
ISC License
Copyright (c) 2012-2016 Dave Collins <dave@davec.name>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

145
vendor/github.com/davecgh/go-spew/spew/bypass.go generated vendored Normal file
View File

@ -0,0 +1,145 @@
// Copyright (c) 2015-2016 Dave Collins <dave@davec.name>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// NOTE: Due to the following build constraints, this file will only be compiled
// when the code is not running on Google App Engine, compiled by GopherJS, and
// "-tags safe" is not added to the go build command line. The "disableunsafe"
// tag is deprecated and thus should not be used.
// Go versions prior to 1.4 are disabled because they use a different layout
// for interfaces which make the implementation of unsafeReflectValue more complex.
// +build !js,!appengine,!safe,!disableunsafe,go1.4
package spew
import (
"reflect"
"unsafe"
)
const (
// UnsafeDisabled is a build-time constant which specifies whether or
// not access to the unsafe package is available.
UnsafeDisabled = false
// ptrSize is the size of a pointer on the current arch.
ptrSize = unsafe.Sizeof((*byte)(nil))
)
type flag uintptr
var (
// flagRO indicates whether the value field of a reflect.Value
// is read-only.
flagRO flag
// flagAddr indicates whether the address of the reflect.Value's
// value may be taken.
flagAddr flag
)
// flagKindMask holds the bits that make up the kind
// part of the flags field. In all the supported versions,
// it is in the lower 5 bits.
const flagKindMask = flag(0x1f)
// Different versions of Go have used different
// bit layouts for the flags type. This table
// records the known combinations.
var okFlags = []struct {
ro, addr flag
}{{
// From Go 1.4 to 1.5
ro: 1 << 5,
addr: 1 << 7,
}, {
// Up to Go tip.
ro: 1<<5 | 1<<6,
addr: 1 << 8,
}}
var flagValOffset = func() uintptr {
field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
if !ok {
panic("reflect.Value has no flag field")
}
return field.Offset
}()
// flagField returns a pointer to the flag field of a reflect.Value.
func flagField(v *reflect.Value) *flag {
return (*flag)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + flagValOffset))
}
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses
// the typical safety restrictions preventing access to unaddressable and
// unexported data. It works by digging the raw pointer to the underlying
// value out of the protected value and generating a new unprotected (unsafe)
// reflect.Value to it.
//
// This allows us to check for implementations of the Stringer and error
// interfaces to be used for pretty printing ordinarily unaddressable and
// inaccessible values such as unexported struct fields.
func unsafeReflectValue(v reflect.Value) reflect.Value {
if !v.IsValid() || (v.CanInterface() && v.CanAddr()) {
return v
}
flagFieldPtr := flagField(&v)
*flagFieldPtr &^= flagRO
*flagFieldPtr |= flagAddr
return v
}
// Sanity checks against future reflect package changes
// to the type or semantics of the Value.flag field.
func init() {
field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag")
if !ok {
panic("reflect.Value has no flag field")
}
if field.Type.Kind() != reflect.TypeOf(flag(0)).Kind() {
panic("reflect.Value flag field has changed kind")
}
type t0 int
var t struct {
A t0
// t0 will have flagEmbedRO set.
t0
// a will have flagStickyRO set
a t0
}
vA := reflect.ValueOf(t).FieldByName("A")
va := reflect.ValueOf(t).FieldByName("a")
vt0 := reflect.ValueOf(t).FieldByName("t0")
// Infer flagRO from the difference between the flags
// for the (otherwise identical) fields in t.
flagPublic := *flagField(&vA)
flagWithRO := *flagField(&va) | *flagField(&vt0)
flagRO = flagPublic ^ flagWithRO
// Infer flagAddr from the difference between a value
// taken from a pointer and not.
vPtrA := reflect.ValueOf(&t).Elem().FieldByName("A")
flagNoPtr := *flagField(&vA)
flagPtr := *flagField(&vPtrA)
flagAddr = flagNoPtr ^ flagPtr
// Check that the inferred flags tally with one of the known versions.
for _, f := range okFlags {
if flagRO == f.ro && flagAddr == f.addr {
return
}
}
panic("reflect.Value read-only flag has changed semantics")
}

Some files were not shown because too many files have changed in this diff Show More