summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVladimir Lavor <vlavor@cisco.com>2020-12-01 13:57:29 +0100
committerVladimir Lavor <vlavor@cisco.com>2020-12-03 10:14:12 +0100
commitbcf3fbd21aa22d1546bc85ffb887ae5ba557808e (patch)
tree60668ecb3d0721bf33cfa1b37736a494f675ec4b
parent2b743eede78b6fed115421716888f80088edefdb (diff)
Fixed incorrect message error in the stream API
The message package is passed to the stream object and used to evaluate correct reply message type Change-Id: I2c9844d6447d024af1693205efd5721e2f89f22d Signed-off-by: Vladimir Lavor <vlavor@cisco.com>
-rw-r--r--adapter/mock/mock_vpp_adapter.go28
-rw-r--r--api/binapi.go22
-rw-r--r--cmd/vpp-proxy/main.go6
-rw-r--r--core/channel.go7
-rw-r--r--core/connection.go84
-rw-r--r--core/request_handler.go22
-rw-r--r--core/stream.go12
-rw-r--r--proxy/server.go42
8 files changed, 137 insertions, 86 deletions
diff --git a/adapter/mock/mock_vpp_adapter.go b/adapter/mock/mock_vpp_adapter.go
index f79bb8b..90195e7 100644
--- a/adapter/mock/mock_vpp_adapter.go
+++ b/adapter/mock/mock_vpp_adapter.go
@@ -44,7 +44,7 @@ type VppAdapter struct {
access sync.RWMutex
msgNameToIds map[string]uint16
msgIDsToName map[uint16]string
- binAPITypes map[string]reflect.Type
+ binAPITypes map[string]map[string]reflect.Type
repliesLock sync.Mutex // mutex for the queue
replies []reply // FIFO queue of messages
@@ -126,7 +126,7 @@ func NewVppAdapter() *VppAdapter {
msgIDSeq: 1000,
msgIDsToName: make(map[uint16]string),
msgNameToIds: make(map[string]uint16),
- binAPITypes: make(map[string]reflect.Type),
+ binAPITypes: make(map[string]map[string]reflect.Type),
}
a.registerBinAPITypes()
return a
@@ -165,19 +165,25 @@ func (a *VppAdapter) GetMsgNameByID(msgID uint16) (string, bool) {
func (a *VppAdapter) registerBinAPITypes() {
a.access.Lock()
defer a.access.Unlock()
- for _, msg := range api.GetRegisteredMessages() {
- a.binAPITypes[msg.GetMessageName()] = reflect.TypeOf(msg).Elem()
+ for pkg, msgs := range api.GetRegisteredMessages() {
+ msgMap := make(map[string]reflect.Type)
+ for _, msg := range msgs {
+ msgMap[msg.GetMessageName()] = reflect.TypeOf(msg).Elem()
+ }
+ a.binAPITypes[pkg] = msgMap
}
}
// ReplyTypeFor returns reply message type for given request message name.
-func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16, bool) {
+func (a *VppAdapter) ReplyTypeFor(pkg, requestMsgName string) (reflect.Type, uint16, bool) {
replyName, foundName := binapi.ReplyNameFor(requestMsgName)
if foundName {
- if reply, found := a.binAPITypes[replyName]; found {
- msgID, err := a.GetMsgID(replyName, "")
- if err == nil {
- return reply, msgID, found
+ if messages, found := a.binAPITypes[pkg]; found {
+ if reply, found := messages[replyName]; found {
+ msgID, err := a.GetMsgID(replyName, "")
+ if err == nil {
+ return reply, msgID, found
+ }
}
}
}
@@ -186,8 +192,8 @@ func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16,
}
// ReplyFor returns reply message for given request message name.
-func (a *VppAdapter) ReplyFor(requestMsgName string) (api.Message, uint16, bool) {
- replType, msgID, foundReplType := a.ReplyTypeFor(requestMsgName)
+func (a *VppAdapter) ReplyFor(pkg, requestMsgName string) (api.Message, uint16, bool) {
+ replType, msgID, foundReplType := a.ReplyTypeFor(pkg, requestMsgName)
if foundReplType {
msgVal := reflect.New(replType)
if msg, ok := msgVal.Interface().(api.Message); ok {
diff --git a/api/binapi.go b/api/binapi.go
index cb4ab85..1b07a7e 100644
--- a/api/binapi.go
+++ b/api/binapi.go
@@ -15,7 +15,7 @@
package api
import (
- "fmt"
+ "path"
"reflect"
)
@@ -59,27 +59,27 @@ type DataType interface {
}
var (
- registeredMessageTypes = make(map[reflect.Type]string)
- registeredMessages = make(map[string]Message)
+ registeredMessages = make(map[string]map[string]Message)
+ registeredMessageTypes = make(map[string]map[reflect.Type]string)
)
// RegisterMessage is called from generated code to register message.
func RegisterMessage(x Message, name string) {
- typ := reflect.TypeOf(x)
- namecrc := x.GetMessageName() + "_" + x.GetCrcString()
- if _, ok := registeredMessageTypes[typ]; ok {
- panic(fmt.Errorf("govpp: message type %v already registered as %s (%s)", typ, name, namecrc))
+ binapiPath := path.Dir(reflect.TypeOf(x).Elem().PkgPath())
+ if _, ok := registeredMessages[binapiPath]; !ok {
+ registeredMessages[binapiPath] = make(map[string]Message)
+ registeredMessageTypes[binapiPath] = make(map[reflect.Type]string)
}
- registeredMessages[namecrc] = x
- registeredMessageTypes[typ] = name
+ registeredMessages[binapiPath][x.GetMessageName()+"_"+x.GetCrcString()] = x
+ registeredMessageTypes[binapiPath][reflect.TypeOf(x)] = name
}
// GetRegisteredMessages returns list of all registered messages.
-func GetRegisteredMessages() map[string]Message {
+func GetRegisteredMessages() map[string]map[string]Message {
return registeredMessages
}
// GetRegisteredMessageTypes returns list of all registered message types.
-func GetRegisteredMessageTypes() map[reflect.Type]string {
+func GetRegisteredMessageTypes() map[string]map[reflect.Type]string {
return registeredMessageTypes
}
diff --git a/cmd/vpp-proxy/main.go b/cmd/vpp-proxy/main.go
index d1af5df..3c85bcf 100644
--- a/cmd/vpp-proxy/main.go
+++ b/cmd/vpp-proxy/main.go
@@ -35,8 +35,10 @@ var (
)
func init() {
- for _, msg := range api.GetRegisteredMessages() {
- gob.Register(msg)
+ for _, msgList := range api.GetRegisteredMessages() {
+ for _, msg := range msgList {
+ gob.Register(msg)
+ }
}
}
diff --git a/core/channel.go b/core/channel.go
index 28d0710..fbb3e59 100644
--- a/core/channel.go
+++ b/core/channel.go
@@ -45,8 +45,10 @@ type MessageCodec interface {
type MessageIdentifier interface {
// GetMessageID returns message identifier of given API message.
GetMessageID(msg api.Message) (uint16, error)
+ // GetMessagePath returns path for the given message
+ GetMessagePath(msg api.Message) string
// LookupByID looks up message name and crc by ID
- LookupByID(msgID uint16) (api.Message, error)
+ LookupByID(path string, msgID uint16) (api.Message, error)
}
// vppRequest is a request that will be sent to VPP.
@@ -329,7 +331,8 @@ func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Messa
if reply.msgID != expMsgID {
var msgNameCrc string
- if replyMsg, err := ch.msgIdentifier.LookupByID(reply.msgID); err != nil {
+ pkgPath := ch.msgIdentifier.GetMessagePath(msg)
+ if replyMsg, err := ch.msgIdentifier.LookupByID(pkgPath, reply.msgID); err != nil {
msgNameCrc = err.Error()
} else {
msgNameCrc = getMsgNameWithCrc(replyMsg)
diff --git a/core/connection.go b/core/connection.go
index 0f54f38..f3ff964 100644
--- a/core/connection.go
+++ b/core/connection.go
@@ -17,6 +17,7 @@ package core
import (
"errors"
"fmt"
+ "path"
"reflect"
"sync"
"sync/atomic"
@@ -103,9 +104,9 @@ type Connection struct {
connChan chan ConnectionEvent // connection status events are sent to this channel
- codec MessageCodec // message codec
- msgIDs map[string]uint16 // map of message IDs indexed by message name + CRC
- msgMap map[uint16]api.Message // map of messages indexed by message ID
+ codec MessageCodec // message codec
+ msgIDs map[string]uint16 // map of message IDs indexed by message name + CRC
+ msgMapByPath map[string]map[uint16]api.Message // map of messages indexed by message ID which are indexed by path
maxChannelID uint32 // maximum used channel ID (the real limit is 2^15, 32-bit is used for atomic operations)
channelsLock sync.RWMutex // lock for the channels map
@@ -139,7 +140,7 @@ func newConnection(binapi adapter.VppAPI, attempts int, interval time.Duration)
connChan: make(chan ConnectionEvent, NotificationChanBufSize),
codec: codec.DefaultCodec,
msgIDs: make(map[string]uint16),
- msgMap: make(map[uint16]api.Message),
+ msgMapByPath: make(map[string]map[uint16]api.Message),
channels: make(map[uint16]*Channel),
subscriptions: make(map[uint16][]*subscriptionCtx),
msgControlPing: msgControlPing,
@@ -400,69 +401,74 @@ func (c *Connection) GetMessageID(msg api.Message) (uint16, error) {
if c == nil {
return 0, errors.New("nil connection passed in")
}
-
- if msgID, ok := c.msgIDs[getMsgNameWithCrc(msg)]; ok {
- return msgID, nil
- }
-
+ pkgPath := c.GetMessagePath(msg)
msgID, err := c.vppClient.GetMsgID(msg.GetMessageName(), msg.GetCrcString())
if err != nil {
return 0, err
}
-
+ if pathMsgs, pathOk := c.msgMapByPath[pkgPath]; !pathOk {
+ c.msgMapByPath[pkgPath] = make(map[uint16]api.Message)
+ c.msgMapByPath[pkgPath][msgID] = msg
+ } else if _, msgOk := pathMsgs[msgID]; !msgOk {
+ c.msgMapByPath[pkgPath][msgID] = msg
+ }
+ if _, ok := c.msgIDs[getMsgNameWithCrc(msg)]; ok {
+ return msgID, nil
+ }
c.msgIDs[getMsgNameWithCrc(msg)] = msgID
- c.msgMap[msgID] = msg
-
return msgID, nil
}
// LookupByID looks up message name and crc by ID.
-func (c *Connection) LookupByID(msgID uint16) (api.Message, error) {
+func (c *Connection) LookupByID(path string, msgID uint16) (api.Message, error) {
if c == nil {
return nil, errors.New("nil connection passed in")
}
-
- if msg, ok := c.msgMap[msgID]; ok {
+ if msg, ok := c.msgMapByPath[path][msgID]; ok {
return msg, nil
}
+ return nil, fmt.Errorf("unknown message ID %d for path '%s'", msgID, path)
+}
- return nil, fmt.Errorf("unknown message ID: %d", msgID)
+// GetMessagePath returns path for the given message
+func (c *Connection) GetMessagePath(msg api.Message) string {
+ return path.Dir(reflect.TypeOf(msg).Elem().PkgPath())
}
// retrieveMessageIDs retrieves IDs for all registered messages and stores them in map
func (c *Connection) retrieveMessageIDs() (err error) {
t := time.Now()
- msgs := api.GetRegisteredMessages()
+ msgsByPath := api.GetRegisteredMessages()
var n int
- for name, msg := range msgs {
- typ := reflect.TypeOf(msg).Elem()
- path := fmt.Sprintf("%s.%s", typ.PkgPath(), typ.Name())
+ for pkgPath, msgs := range msgsByPath {
+ for _, msg := range msgs {
+ msgID, err := c.GetMessageID(msg)
+ if err != nil {
+ if debugMsgIDs {
+ log.Debugf("retrieving message ID for %s.%s failed: %v",
+ pkgPath, msg.GetMessageName(), err)
+ }
+ continue
+ }
+ n++
+
+ if c.pingReqID == 0 && msg.GetMessageName() == c.msgControlPing.GetMessageName() {
+ c.pingReqID = msgID
+ c.msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
+ } else if c.pingReplyID == 0 && msg.GetMessageName() == c.msgControlPingReply.GetMessageName() {
+ c.pingReplyID = msgID
+ c.msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
+ }
- msgID, err := c.GetMessageID(msg)
- if err != nil {
if debugMsgIDs {
- log.Debugf("retrieving message ID for %s failed: %v", path, err)
+ log.Debugf("message %q (%s) has ID: %d", msg.GetMessageName(), getMsgNameWithCrc(msg), msgID)
}
- continue
- }
- n++
-
- if c.pingReqID == 0 && msg.GetMessageName() == c.msgControlPing.GetMessageName() {
- c.pingReqID = msgID
- c.msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
- } else if c.pingReplyID == 0 && msg.GetMessageName() == c.msgControlPingReply.GetMessageName() {
- c.pingReplyID = msgID
- c.msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
- }
-
- if debugMsgIDs {
- log.Debugf("message %q (%s) has ID: %d", name, getMsgNameWithCrc(msg), msgID)
}
+ log.WithField("took", time.Since(t)).
+ Debugf("retrieved IDs for %d messages (registered %d) from path %s", n, len(msgs), pkgPath)
}
- log.WithField("took", time.Since(t)).
- Debugf("retrieved IDs for %d messages (registered %d)", n, len(msgs))
return nil
}
diff --git a/core/request_handler.go b/core/request_handler.go
index fc704cb..f9d972a 100644
--- a/core/request_handler.go
+++ b/core/request_handler.go
@@ -210,9 +210,9 @@ func (c *Connection) msgCallback(msgID uint16, data []byte) {
return
}
- msg, ok := c.msgMap[msgID]
- if !ok {
- log.Warnf("Unknown message received, ID: %d", msgID)
+ msg, err := c.getMessageByID(msgID)
+ if err != nil {
+ log.Warnln(err)
return
}
@@ -419,3 +419,19 @@ func compareSeqNumbers(seqNum1, seqNum2 uint16) int {
}
return 1
}
+
+// Returns first message from any package where the message ID matches
+// Note: the msg is further used only for its MessageType which is not
+// affected by the message's package
+func (c *Connection) getMessageByID(msgID uint16) (msg api.Message, err error) {
+ var ok bool
+ for _, msgs := range c.msgMapByPath {
+ if msg, ok = msgs[msgID]; ok {
+ break
+ }
+ }
+ if !ok {
+ return nil, fmt.Errorf("unknown message received, ID: %d", msgID)
+ }
+ return msg, nil
+}
diff --git a/core/stream.go b/core/stream.go
index abe9d55..3d417f1 100644
--- a/core/stream.go
+++ b/core/stream.go
@@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"reflect"
+ "sync"
"sync/atomic"
"time"
@@ -34,6 +35,9 @@ type Stream struct {
requestSize int
replySize int
replyTimeout time.Duration
+ // per-request context
+ pkgPath string
+ sync.Mutex
}
func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) {
@@ -109,6 +113,9 @@ func (s *Stream) SendMsg(msg api.Message) error {
if err := s.conn.processRequest(s.channel, req); err != nil {
return err
}
+ s.Lock()
+ s.pkgPath = s.conn.GetMessagePath(msg)
+ s.Unlock()
return nil
}
@@ -118,7 +125,10 @@ func (s *Stream) RecvMsg() (api.Message, error) {
return nil, err
}
// resolve message type
- msg, err := s.channel.msgIdentifier.LookupByID(reply.msgID)
+ s.Lock()
+ path := s.pkgPath
+ s.Unlock()
+ msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID)
if err != nil {
return nil, err
}
diff --git a/proxy/server.go b/proxy/server.go
index 21d8e1b..e395468 100644
--- a/proxy/server.go
+++ b/proxy/server.go
@@ -226,8 +226,8 @@ type BinapiCompatibilityRequest struct {
}
type BinapiCompatibilityResponse struct {
- CompatibleMsgs []string
- IncompatibleMsgs []string
+ CompatibleMsgs map[string][]string
+ IncompatibleMsgs map[string][]string
}
// BinapiRPC is a RPC server for proxying client request to api.Channel.
@@ -379,25 +379,33 @@ func (s *BinapiRPC) Compatibility(req BinapiCompatibilityRequest, resp *BinapiCo
}
defer ch.Close()
- resp.CompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs))
- resp.IncompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs))
+ resp.CompatibleMsgs = make(map[string][]string)
+ resp.IncompatibleMsgs = make(map[string][]string)
- for _, msg := range req.MsgNameCrcs {
- val, ok := api.GetRegisteredMessages()[msg]
- if !ok {
- resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg)
- continue
+ for path, messages := range api.GetRegisteredMessages() {
+ if resp.IncompatibleMsgs[path] == nil {
+ resp.IncompatibleMsgs[path] = make([]string, 0, len(req.MsgNameCrcs))
}
-
- if err = ch.CheckCompatiblity(val); err != nil {
- resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg)
- } else {
- resp.CompatibleMsgs = append(resp.CompatibleMsgs, msg)
+ if resp.CompatibleMsgs[path] == nil {
+ resp.CompatibleMsgs[path] = make([]string, 0, len(req.MsgNameCrcs))
+ }
+ for _, msg := range req.MsgNameCrcs {
+ val, ok := messages[msg]
+ if !ok {
+ resp.IncompatibleMsgs[path] = append(resp.IncompatibleMsgs[path], msg)
+ continue
+ }
+ if err = ch.CheckCompatiblity(val); err != nil {
+ resp.IncompatibleMsgs[path] = append(resp.IncompatibleMsgs[path], msg)
+ } else {
+ resp.CompatibleMsgs[path] = append(resp.CompatibleMsgs[path], msg)
+ }
}
}
-
- if len(resp.IncompatibleMsgs) > 0 {
- return fmt.Errorf("compatibility check failed for messages: %v", resp.IncompatibleMsgs)
+ for _, messages := range resp.IncompatibleMsgs {
+ if len(messages) > 0 {
+ return fmt.Errorf("compatibility check failed for messages: %v", resp.IncompatibleMsgs)
+ }
}
return nil