aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--adapter/mock/mock_vpp_adapter.go28
-rw-r--r--api/binapi.go22
-rw-r--r--cmd/vpp-proxy/main.go6
-rw-r--r--core/channel.go7
-rw-r--r--core/connection.go84
-rw-r--r--core/request_handler.go22
-rw-r--r--core/stream.go12
-rw-r--r--proxy/server.go42
8 files changed, 137 insertions, 86 deletions
diff --git a/adapter/mock/mock_vpp_adapter.go b/adapter/mock/mock_vpp_adapter.go
index f79bb8b..90195e7 100644
--- a/adapter/mock/mock_vpp_adapter.go
+++ b/adapter/mock/mock_vpp_adapter.go
@@ -44,7 +44,7 @@ type VppAdapter struct {
access sync.RWMutex
msgNameToIds map[string]uint16
msgIDsToName map[uint16]string
- binAPITypes map[string]reflect.Type
+ binAPITypes map[string]map[string]reflect.Type
repliesLock sync.Mutex // mutex for the queue
replies []reply // FIFO queue of messages
@@ -126,7 +126,7 @@ func NewVppAdapter() *VppAdapter {
msgIDSeq: 1000,
msgIDsToName: make(map[uint16]string),
msgNameToIds: make(map[string]uint16),
- binAPITypes: make(map[string]reflect.Type),
+ binAPITypes: make(map[string]map[string]reflect.Type),
}
a.registerBinAPITypes()
return a
@@ -165,19 +165,25 @@ func (a *VppAdapter) GetMsgNameByID(msgID uint16) (string, bool) {
func (a *VppAdapter) registerBinAPITypes() {
a.access.Lock()
defer a.access.Unlock()
- for _, msg := range api.GetRegisteredMessages() {
- a.binAPITypes[msg.GetMessageName()] = reflect.TypeOf(msg).Elem()
+ for pkg, msgs := range api.GetRegisteredMessages() {
+ msgMap := make(map[string]reflect.Type)
+ for _, msg := range msgs {
+ msgMap[msg.GetMessageName()] = reflect.TypeOf(msg).Elem()
+ }
+ a.binAPITypes[pkg] = msgMap
}
}
// ReplyTypeFor returns reply message type for given request message name.
-func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16, bool) {
+func (a *VppAdapter) ReplyTypeFor(pkg, requestMsgName string) (reflect.Type, uint16, bool) {
replyName, foundName := binapi.ReplyNameFor(requestMsgName)
if foundName {
- if reply, found := a.binAPITypes[replyName]; found {
- msgID, err := a.GetMsgID(replyName, "")
- if err == nil {
- return reply, msgID, found
+ if messages, found := a.binAPITypes[pkg]; found {
+ if reply, found := messages[replyName]; found {
+ msgID, err := a.GetMsgID(replyName, "")
+ if err == nil {
+ return reply, msgID, found
+ }
}
}
}
@@ -186,8 +192,8 @@ func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16,
}
// ReplyFor returns reply message for given request message name.
-func (a *VppAdapter) ReplyFor(requestMsgName string) (api.Message, uint16, bool) {
- replType, msgID, foundReplType := a.ReplyTypeFor(requestMsgName)
+func (a *VppAdapter) ReplyFor(pkg, requestMsgName string) (api.Message, uint16, bool) {
+ replType, msgID, foundReplType := a.ReplyTypeFor(pkg, requestMsgName)
if foundReplType {
msgVal := reflect.New(replType)
if msg, ok := msgVal.Interface().(api.Message); ok {
diff --git a/api/binapi.go b/api/binapi.go
index cb4ab85..1b07a7e 100644
--- a/api/binapi.go
+++ b/api/binapi.go
@@ -15,7 +15,7 @@
package api
import (
- "fmt"
+ "path"
"reflect"
)
@@ -59,27 +59,27 @@ type DataType interface {
}
var (
- registeredMessageTypes = make(map[reflect.Type]string)
- registeredMessages = make(map[string]Message)
+ registeredMessages = make(map[string]map[string]Message)
+ registeredMessageTypes = make(map[string]map[reflect.Type]string)
)
// RegisterMessage is called from generated code to register message.
func RegisterMessage(x Message, name string) {
- typ := reflect.TypeOf(x)
- namecrc := x.GetMessageName() + "_" + x.GetCrcString()
- if _, ok := registeredMessageTypes[typ]; ok {
- panic(fmt.Errorf("govpp: message type %v already registered as %s (%s)", typ, name, namecrc))
+ binapiPath := path.Dir(reflect.TypeOf(x).Elem().PkgPath())
+ if _, ok := registeredMessages[binapiPath]; !ok {
+ registeredMessages[binapiPath] = make(map[string]Message)
+ registeredMessageTypes[binapiPath] = make(map[reflect.Type]string)
}
- registeredMessages[namecrc] = x
- registeredMessageTypes[typ] = name
+ registeredMessages[binapiPath][x.GetMessageName()+"_"+x.GetCrcString()] = x
+ registeredMessageTypes[binapiPath][reflect.TypeOf(x)] = name
}
// GetRegisteredMessages returns list of all registered messages.
-func GetRegisteredMessages() map[string]Message {
+func GetRegisteredMessages() map[string]map[string]Message {
return registeredMessages
}
// GetRegisteredMessageTypes returns list of all registered message types.
-func GetRegisteredMessageTypes() map[reflect.Type]string {
+func GetRegisteredMessageTypes() map[string]map[reflect.Type]string {
return registeredMessageTypes
}
diff --git a/cmd/vpp-proxy/main.go b/cmd/vpp-proxy/main.go
index d1af5df..3c85bcf 100644
--- a/cmd/vpp-proxy/main.go
+++ b/cmd/vpp-proxy/main.go
@@ -35,8 +35,10 @@ var (
)
func init() {
- for _, msg := range api.GetRegisteredMessages() {
- gob.Register(msg)
+ for _, msgList := range api.GetRegisteredMessages() {
+ for _, msg := range msgList {
+ gob.Register(msg)
+ }
}
}
diff --git a/core/channel.go b/core/channel.go
index 28d0710..fbb3e59 100644
--- a/core/channel.go
+++ b/core/channel.go
@@ -45,8 +45,10 @@ type MessageCodec interface {
type MessageIdentifier interface {
// GetMessageID returns message identifier of given API message.
GetMessageID(msg api.Message) (uint16, error)
+ // GetMessagePath returns path for the given message
+ GetMessagePath(msg api.Message) string
// LookupByID looks up message name and crc by ID
- LookupByID(msgID uint16) (api.Message, error)
+ LookupByID(path string, msgID uint16) (api.Message, error)
}
// vppRequest is a request that will be sent to VPP.
@@ -329,7 +331,8 @@ func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Messa
if reply.msgID != expMsgID {
var msgNameCrc string
- if replyMsg, err := ch.msgIdentifier.LookupByID(reply.msgID); err != nil {
+ pkgPath := ch.msgIdentifier.GetMessagePath(msg)
+ if replyMsg, err := ch.msgIdentifier.LookupByID(pkgPath, reply.msgID); err != nil {
msgNameCrc = err.Error()
} else {
msgNameCrc = getMsgNameWithCrc(replyMsg)
diff --git a/core/connection.go b/core/connection.go
index 0f54f38..f3ff964 100644
--- a/core/connection.go
+++ b/core/connection.go
@@ -17,6 +17,7 @@ package core
import (
"errors"
"fmt"
+ "path"
"reflect"
"sync"
"sync/atomic"
@@ -103,9 +104,9 @@ type Connection struct {
connChan chan ConnectionEvent // connection status events are sent to this channel
- codec MessageCodec // message codec
- msgIDs map[string]uint16 // map of message IDs indexed by message name + CRC
- msgMap map[uint16]api.Message // map of messages indexed by message ID
+ codec MessageCodec // message codec
+ msgIDs map[string]uint16 // map of message IDs indexed by message name + CRC
+ msgMapByPath map[string]map[uint16]api.Message // map of messages indexed by message ID which are indexed by path
maxChannelID uint32 // maximum used channel ID (the real limit is 2^15, 32-bit is used for atomic operations)
channelsLock sync.RWMutex // lock for the channels map
@@ -139,7 +140,7 @@ func newConnection(binapi adapter.VppAPI, attempts int, interval time.Duration)
connChan: make(chan ConnectionEvent, NotificationChanBufSize),
codec: codec.DefaultCodec,
msgIDs: make(map[string]uint16),
- msgMap: make(map[uint16]api.Message),
+ msgMapByPath: make(map[string]map[uint16]api.Message),
channels: make(map[uint16]*Channel),
subscriptions: make(map[uint16][]*subscriptionCtx),
msgControlPing: msgControlPing,
@@ -400,69 +401,74 @@ func (c *Connection) GetMessageID(msg api.Message) (uint16, error) {
if c == nil {
return 0, errors.New("nil connection passed in")
}
-
- if msgID, ok := c.msgIDs[getMsgNameWithCrc(msg)]; ok {
- return msgID, nil
- }
-
+ pkgPath := c.GetMessagePath(msg)
msgID, err := c.vppClient.GetMsgID(msg.GetMessageName(), msg.GetCrcString())
if err != nil {
return 0, err
}
-
+ if pathMsgs, pathOk := c.msgMapByPath[pkgPath]; !pathOk {
+ c.msgMapByPath[pkgPath] = make(map[uint16]api.Message)
+ c.msgMapByPath[pkgPath][msgID] = msg
+ } else if _, msgOk := pathMsgs[msgID]; !msgOk {
+ c.msgMapByPath[pkgPath][msgID] = msg
+ }
+ if _, ok := c.msgIDs[getMsgNameWithCrc(msg)]; ok {
+ return msgID, nil
+ }
c.msgIDs[getMsgNameWithCrc(msg)] = msgID
- c.msgMap[msgID] = msg
-
return msgID, nil
}
// LookupByID looks up message name and crc by ID.
-func (c *Connection) LookupByID(msgID uint16) (api.Message, error) {
+func (c *Connection) LookupByID(path string, msgID uint16) (api.Message, error) {
if c == nil {
return nil, errors.New("nil connection passed in")
}
-
- if msg, ok := c.msgMap[msgID]; ok {
+ if msg, ok := c.msgMapByPath[path][msgID]; ok {
return msg, nil
}
+ return nil, fmt.Errorf("unknown message ID %d for path '%s'", msgID, path)
+}
- return nil, fmt.Errorf("unknown message ID: %d", msgID)
+// GetMessagePath returns path for the given message
+func (c *Connection) GetMessagePath(msg api.Message) string {
+ return path.Dir(reflect.TypeOf(msg).Elem().PkgPath())
}
// retrieveMessageIDs retrieves IDs for all registered messages and stores them in map
func (c *Connection) retrieveMessageIDs() (err error) {
t := time.Now()
- msgs := api.GetRegisteredMessages()
+ msgsByPath := api.GetRegisteredMessages()
var n int
- for name, msg := range msgs {
- typ := reflect.TypeOf(msg).Elem()
- path := fmt.Sprintf("%s.%s", typ.PkgPath(), typ.Name())
+ for pkgPath, msgs := range msgsByPath {
+ for _, msg := range msgs {
+ msgID, err := c.GetMessageID(msg)
+ if err != nil {
+ if debugMsgIDs {
+ log.Debugf("retrieving message ID for %s.%s failed: %v",
+ pkgPath, msg.GetMessageName(), err)
+ }
+ continue
+ }
+ n++
+
+ if c.pingReqID == 0 && msg.GetMessageName() == c.msgControlPing.GetMessageName() {
+ c.pingReqID = msgID
+ c.msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
+ } else if c.pingReplyID == 0 && msg.GetMessageName() == c.msgControlPingReply.GetMessageName() {
+ c.pingReplyID = msgID
+ c.msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
+ }
- msgID, err := c.GetMessageID(msg)
- if err != nil {
if debugMsgIDs {
- log.Debugf("retrieving message ID for %s failed: %v", path, err)
+ log.Debugf("message %q (%s) has ID: %d", msg.GetMessageName(), getMsgNameWithCrc(msg), msgID)
}
- continue
- }
- n++
-
- if c.pingReqID == 0 && msg.GetMessageName() == c.msgControlPing.GetMessageName() {
- c.pingReqID = msgID
- c.msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
- } else if c.pingReplyID == 0 && msg.GetMessageName() == c.msgControlPingReply.GetMessageName() {
- c.pingReplyID = msgID
- c.msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
- }
-
- if debugMsgIDs {
- log.Debugf("message %q (%s) has ID: %d", name, getMsgNameWithCrc(msg), msgID)
}
+ log.WithField("took", time.Since(t)).
+ Debugf("retrieved IDs for %d messages (registered %d) from path %s", n, len(msgs), pkgPath)
}
- log.WithField("took", time.Since(t)).
- Debugf("retrieved IDs for %d messages (registered %d)", n, len(msgs))
return nil
}
diff --git a/core/request_handler.go b/core/request_handler.go
index fc704cb..f9d972a 100644
--- a/core/request_handler.go
+++ b/core/request_handler.go
@@ -210,9 +210,9 @@ func (c *Connection) msgCallback(msgID uint16, data []byte) {
return
}
- msg, ok := c.msgMap[msgID]
- if !ok {
- log.Warnf("Unknown message received, ID: %d", msgID)
+ msg, err := c.getMessageByID(msgID)
+ if err != nil {
+ log.Warnln(err)
return
}
@@ -419,3 +419,19 @@ func compareSeqNumbers(seqNum1, seqNum2 uint16) int {
}
return 1
}
+
+// Returns first message from any package where the message ID matches
+// Note: the msg is further used only for its MessageType which is not
+// affected by the message's package
+func (c *Connection) getMessageByID(msgID uint16) (msg api.Message, err error) {
+ var ok bool
+ for _, msgs := range c.msgMapByPath {
+ if msg, ok = msgs[msgID]; ok {
+ break
+ }
+ }
+ if !ok {
+ return nil, fmt.Errorf("unknown message received, ID: %d", msgID)
+ }
+ return msg, nil
+}
diff --git a/core/stream.go b/core/stream.go
index abe9d55..3d417f1 100644
--- a/core/stream.go
+++ b/core/stream.go
@@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"reflect"
+ "sync"
"sync/atomic"
"time"
@@ -34,6 +35,9 @@ type Stream struct {
requestSize int
replySize int
replyTimeout time.Duration
+ // per-request context
+ pkgPath string
+ sync.Mutex
}
func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) {
@@ -109,6 +113,9 @@ func (s *Stream) SendMsg(msg api.Message) error {
if err := s.conn.processRequest(s.channel, req); err != nil {
return err
}
+ s.Lock()
+ s.pkgPath = s.conn.GetMessagePath(msg)
+ s.Unlock()
return nil
}
@@ -118,7 +125,10 @@ func (s *Stream) RecvMsg() (api.Message, error) {
return nil, err
}
// resolve message type
- msg, err := s.channel.msgIdentifier.LookupByID(reply.msgID)
+ s.Lock()
+ path := s.pkgPath
+ s.Unlock()
+ msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID)
if err != nil {
return nil, err
}
diff --git a/proxy/server.go b/proxy/server.go
index 21d8e1b..e395468 100644
--- a/proxy/server.go
+++ b/proxy/server.go
@@ -226,8 +226,8 @@ type BinapiCompatibilityRequest struct {
}
type BinapiCompatibilityResponse struct {
- CompatibleMsgs []string
- IncompatibleMsgs []string
+ CompatibleMsgs map[string][]string
+ IncompatibleMsgs map[string][]string
}
// BinapiRPC is a RPC server for proxying client request to api.Channel.
@@ -379,25 +379,33 @@ func (s *BinapiRPC) Compatibility(req BinapiCompatibilityRequest, resp *BinapiCo
}
defer ch.Close()
- resp.CompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs))
- resp.IncompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs))
+ resp.CompatibleMsgs = make(map[string][]string)
+ resp.IncompatibleMsgs = make(map[string][]string)
- for _, msg := range req.MsgNameCrcs {
- val, ok := api.GetRegisteredMessages()[msg]
- if !ok {
- resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg)
- continue
+ for path, messages := range api.GetRegisteredMessages() {
+ if resp.IncompatibleMsgs[path] == nil {
+ resp.IncompatibleMsgs[path] = make([]string, 0, len(req.MsgNameCrcs))
}
-
- if err = ch.CheckCompatiblity(val); err != nil {
- resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg)
- } else {
- resp.CompatibleMsgs = append(resp.CompatibleMsgs, msg)
+ if resp.CompatibleMsgs[path] == nil {
+ resp.CompatibleMsgs[path] = make([]string, 0, len(req.MsgNameCrcs))
+ }
+ for _, msg := range req.MsgNameCrcs {
+ val, ok := messages[msg]
+ if !ok {
+ resp.IncompatibleMsgs[path] = append(resp.IncompatibleMsgs[path], msg)
+ continue
+ }
+ if err = ch.CheckCompatiblity(val); err != nil {
+ resp.IncompatibleMsgs[path] = append(resp.IncompatibleMsgs[path], msg)
+ } else {
+ resp.CompatibleMsgs[path] = append(resp.CompatibleMsgs[path], msg)
+ }
}
}
-
- if len(resp.IncompatibleMsgs) > 0 {
- return fmt.Errorf("compatibility check failed for messages: %v", resp.IncompatibleMsgs)
+ for _, messages := range resp.IncompatibleMsgs {
+ if len(messages) > 0 {
+ return fmt.Errorf("compatibility check failed for messages: %v", resp.IncompatibleMsgs)
+ }
}
return nil