aboutsummaryrefslogtreecommitdiffstats
path: root/core
diff options
context:
space:
mode:
authorLukas Vogel <vogel@anapaya.net>2022-01-31 11:33:37 +0100
committerLukas Vogel <vogel@anapaya.net>2022-01-31 11:53:39 +0100
commitb9aa34d5f77ac5969ad273bfd187ce5f9594b25e (patch)
treefd87cf2331ce8008128e1f29c565e9381f2fb440 /core
parent000215c229d6df2c1a68b50847d8c7abf3842ce5 (diff)
connection: prevent channel ID overlap
When creating a new channel and the channel ID wraps around, make sure to not re-use a channel ID that is still in use. Re-using the channel ID usually means that the connection health check will stop working and other things might break as well. Also rename maxChannelID to nextChannelID and use a lock to guard access instead of using an atomic. The lock does anyway need to be acquired because to put the entry in the map. This commit was inspired by the following PR on Github: https://github.com/FDio/govpp/pull/14. Change-Id: I8c1a4ca63a53d07a6482b6047a3005065168c0b4 Signed-off-by: Lukas Vogel <vogel@anapaya.net>
Diffstat (limited to 'core')
-rw-r--r--core/channel.go21
-rw-r--r--core/connection.go11
-rw-r--r--core/stream.go6
3 files changed, 27 insertions, 11 deletions
diff --git a/core/channel.go b/core/channel.go
index 1086c36..112c14e 100644
--- a/core/channel.go
+++ b/core/channel.go
@@ -19,7 +19,6 @@ import (
"fmt"
"reflect"
"strings"
- "sync/atomic"
"time"
"github.com/sirupsen/logrus"
@@ -110,11 +109,9 @@ type Channel struct {
receiveReplyTimeout time.Duration // maximum time that we wait for receiver to consume reply
}
-func (c *Connection) newChannel(reqChanBufSize, replyChanBufSize int) *Channel {
+func (c *Connection) newChannel(reqChanBufSize, replyChanBufSize int) (*Channel, error) {
// create new channel
- chID := uint16(atomic.AddUint32(&c.maxChannelID, 1) & 0x7fff)
channel := &Channel{
- id: chID,
conn: c,
msgCodec: c.codec,
msgIdentifier: c,
@@ -126,10 +123,22 @@ func (c *Connection) newChannel(reqChanBufSize, replyChanBufSize int) *Channel {
// store API channel within the client
c.channelsLock.Lock()
- c.channels[chID] = channel
+ if len(c.channels) >= 0x7fff {
+ return nil, errors.New("all channel IDs are used")
+ }
+ for {
+ c.nextChannelID++
+ chID := c.nextChannelID & 0x7fff
+ _, ok := c.channels[chID]
+ if !ok {
+ channel.id = chID
+ c.channels[chID] = channel
+ break
+ }
+ }
c.channelsLock.Unlock()
- return channel
+ return channel, nil
}
func (ch *Channel) GetID() uint16 {
diff --git a/core/connection.go b/core/connection.go
index 442eb51..1bfcae5 100644
--- a/core/connection.go
+++ b/core/connection.go
@@ -109,9 +109,9 @@ type Connection struct {
msgIDs map[string]uint16 // map of message IDs indexed by message name + CRC
msgMapByPath map[string]map[uint16]api.Message // map of messages indexed by message ID which are indexed by path
- maxChannelID uint32 // maximum used channel ID (the real limit is 2^15, 32-bit is used for atomic operations)
- channelsLock sync.RWMutex // lock for the channels map
- channels map[uint16]*Channel // map of all API channels indexed by the channel ID
+ channelsLock sync.RWMutex // lock for the channels map and the channel ID
+ nextChannelID uint16 // next potential channel ID (the real limit is 2^15)
+ channels map[uint16]*Channel // map of all API channels indexed by the channel ID
subscriptionsLock sync.RWMutex // lock for the subscriptions map
subscriptions map[uint16][]*subscriptionCtx // map od all notification subscriptions indexed by message ID
@@ -248,7 +248,10 @@ func (c *Connection) newAPIChannel(reqChanBufSize, replyChanBufSize int) (*Chann
return nil, errors.New("nil connection passed in")
}
- channel := c.newChannel(reqChanBufSize, replyChanBufSize)
+ channel, err := c.newChannel(reqChanBufSize, replyChanBufSize)
+ if err != nil {
+ return nil, err
+ }
// start watching on the request channel
go c.watchRequests(channel)
diff --git a/core/stream.go b/core/stream.go
index 2f639b0..67236f1 100644
--- a/core/stream.go
+++ b/core/stream.go
@@ -56,7 +56,11 @@ func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption)
option(s)
}
- s.channel = c.newChannel(s.requestSize, s.replySize)
+ ch, err := c.newChannel(s.requestSize, s.replySize)
+ if err != nil {
+ return nil, err
+ }
+ s.channel = ch
s.channel.SetReplyTimeout(s.replyTimeout)
// Channel.watchRequests are not started here intentionally, because