Skip to content

Commit 5f22f6b

Browse files
committed
Add websocket connection mutex to avoid concurrent writes.
Fixes #191.
1 parent 2097398 commit 5f22f6b

File tree

2 files changed

+41
-26
lines changed

2 files changed

+41
-26
lines changed

internal/backend/basicstation/backend.go

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func NewBackend(conf config.Config) (*Backend, error) {
8484
scheme: "ws",
8585

8686
gateways: gateways{
87-
gateways: make(map[lorawan.EUI64]gateway),
87+
gateways: make(map[lorawan.EUI64]*connection),
8888
},
8989

9090
caCert: conf.Backend.BasicStation.CACert,
@@ -312,11 +312,11 @@ func (b *Backend) Stop() error {
312312
return b.ln.Close()
313313
}
314314

315-
func (b *Backend) handleRouterInfo(r *http.Request, c *websocket.Conn) {
315+
func (b *Backend) handleRouterInfo(r *http.Request, conn *connection) {
316316
websocketReceiveCounter("router_info").Inc()
317317
var req structs.RouterInfoRequest
318318

319-
if err := c.ReadJSON(&req); err != nil {
319+
if err := conn.conn.ReadJSON(&req); err != nil {
320320
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
321321
log.WithError(err).Error("backend/basicstation: read message error")
322322
}
@@ -345,8 +345,11 @@ func (b *Backend) handleRouterInfo(r *http.Request, c *websocket.Conn) {
345345
return
346346
}
347347

348-
c.SetWriteDeadline(time.Now().Add(b.writeTimeout))
349-
if err := c.WriteMessage(websocket.TextMessage, bb); err != nil {
348+
conn.Lock()
349+
defer conn.Unlock()
350+
351+
conn.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
352+
if err := conn.conn.WriteMessage(websocket.TextMessage, bb); err != nil {
350353
log.WithError(err).Error("backend/basicstation: websocket send message error")
351354
return
352355
}
@@ -358,7 +361,7 @@ func (b *Backend) handleRouterInfo(r *http.Request, c *websocket.Conn) {
358361
}).Info("backend/basicstation: router-info request received")
359362
}
360363

361-
func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) {
364+
func (b *Backend) handleGateway(r *http.Request, conn *connection) {
362365
// get the gateway id from the url
363366
urlParts := strings.Split(r.URL.Path, "/")
364367
if len(urlParts) < 2 {
@@ -391,7 +394,7 @@ func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) {
391394
}
392395

393396
// set the gateway connection
394-
if err := b.gateways.set(gatewayID, gateway{conn: c}); err != nil {
397+
if err := b.gateways.set(gatewayID, conn); err != nil {
395398
log.WithError(err).WithField("gateway_id", gatewayID).Error("backend/basicstation: set gateway error")
396399
}
397400
log.WithFields(log.Fields{
@@ -466,7 +469,7 @@ func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) {
466469

467470
// receive data
468471
for {
469-
mt, msg, err := c.ReadMessage()
472+
mt, msg, err := conn.conn.ReadMessage()
470473
if err != nil {
471474
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
472475
log.WithField("gateway_id", gatewayID).WithError(err).Error("backend/basicstation: read message error")
@@ -475,7 +478,7 @@ func (b *Backend) handleGateway(r *http.Request, c *websocket.Conn) {
475478
}
476479

477480
// reset the read deadline as the Basic Station doesn't respond to PONG messages (yet)
478-
c.SetReadDeadline(time.Now().Add(b.readTimeout))
481+
conn.conn.SetReadDeadline(time.Now().Add(b.readTimeout))
479482

480483
if mt == websocket.BinaryMessage {
481484
log.WithFields(log.Fields{
@@ -768,11 +771,14 @@ func (b *Backend) handleTimeSync(gatewayID lorawan.EUI64, v structs.TimeSyncRequ
768771
}
769772

770773
func (b *Backend) sendToGateway(gatewayID lorawan.EUI64, v interface{}) error {
771-
gw, err := b.gateways.get(gatewayID)
774+
conn, err := b.gateways.get(gatewayID)
772775
if err != nil {
773776
return errors.Wrap(err, "get gateway error")
774777
}
775778

779+
conn.Lock()
780+
defer conn.Unlock()
781+
776782
bb, err := json.Marshal(v)
777783
if err != nil {
778784
return errors.Wrap(err, "marshal json error")
@@ -783,29 +789,32 @@ func (b *Backend) sendToGateway(gatewayID lorawan.EUI64, v interface{}) error {
783789
"message": string(bb),
784790
}).Debug("sending message to gateway")
785791

786-
gw.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
787-
if err := gw.conn.WriteMessage(websocket.TextMessage, bb); err != nil {
792+
conn.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
793+
if err := conn.conn.WriteMessage(websocket.TextMessage, bb); err != nil {
788794
return errors.Wrap(err, "send message to gateway error")
789795
}
790796

791797
return nil
792798
}
793799

794800
func (b *Backend) sendRawToGateway(gatewayID lorawan.EUI64, messageType int, data []byte) error {
795-
gw, err := b.gateways.get(gatewayID)
801+
conn, err := b.gateways.get(gatewayID)
796802
if err != nil {
797803
return errors.Wrap(err, "get gateway error")
798804
}
799805

800-
gw.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
801-
if err := gw.conn.WriteMessage(messageType, data); err != nil {
806+
conn.Lock()
807+
defer conn.Unlock()
808+
809+
conn.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
810+
if err := conn.conn.WriteMessage(messageType, data); err != nil {
802811
return errors.Wrap(err, "send message to gateway error")
803812
}
804813

805814
return nil
806815
}
807816

808-
func (b *Backend) websocketWrap(handler func(*http.Request, *websocket.Conn), w http.ResponseWriter, r *http.Request) {
817+
func (b *Backend) websocketWrap(handler func(*http.Request, *connection), w http.ResponseWriter, r *http.Request) {
809818
conn, err := upgrader.Upgrade(w, r, nil)
810819
if err != nil {
811820
log.WithError(err).Error("backend/basicstation: websocket upgrade error")
@@ -824,23 +833,29 @@ func (b *Backend) websocketWrap(handler func(*http.Request, *websocket.Conn), w
824833
defer ticker.Stop()
825834
done := make(chan struct{})
826835

836+
// Wrap the conn inside a gateway struct, so that we can lock it when writing
837+
// data.
838+
c := connection{conn: conn}
839+
827840
go func() {
828841
for {
829842
select {
830843
case <-ticker.C:
844+
c.Lock()
831845
websocketPingPongCounter("ping").Inc()
832-
conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
846+
c.conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
833847
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
834848
log.WithError(err).Error("backend/basicstation: send ping message error")
835-
conn.Close()
849+
c.conn.Close()
836850
}
851+
c.Unlock()
837852
case <-done:
838853
return
839854
}
840855
}
841856
}()
842857

843-
handler(r, conn)
858+
handler(r, &c)
844859
done <- struct{}{}
845860
}
846861

internal/backend/basicstation/gateway.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@ var (
1414
errGatewayDoesNotExist = errors.New("gateway does not exist")
1515
)
1616

17-
type gateway struct {
18-
conn *websocket.Conn
19-
configVersion string
17+
type connection struct {
18+
sync.Mutex
19+
conn *websocket.Conn
2020
}
2121

2222
type gateways struct {
2323
sync.RWMutex
24-
gateways map[lorawan.EUI64]gateway
24+
gateways map[lorawan.EUI64]*connection
2525

2626
subscribeEventFunc func(events.Subscribe)
2727
}
2828

29-
func (g *gateways) get(id lorawan.EUI64) (gateway, error) {
29+
func (g *gateways) get(id lorawan.EUI64) (*connection, error) {
3030
g.RLock()
3131
defer g.RUnlock()
3232

@@ -37,11 +37,11 @@ func (g *gateways) get(id lorawan.EUI64) (gateway, error) {
3737
return gw, nil
3838
}
3939

40-
func (g *gateways) set(id lorawan.EUI64, gw gateway) error {
40+
func (g *gateways) set(id lorawan.EUI64, c *connection) error {
4141
g.Lock()
4242
defer g.Unlock()
4343

44-
g.gateways[id] = gw
44+
g.gateways[id] = c
4545

4646
if g.subscribeEventFunc != nil {
4747
g.subscribeEventFunc(events.Subscribe{Subscribe: true, GatewayID: id})

0 commit comments

Comments
 (0)