diff options
-rw-r--r-- | adapter/mock/mock_adapter.go | 9 | ||||
-rw-r--r-- | api/api.go | 6 | ||||
-rw-r--r-- | api/api_test.go | 94 | ||||
-rw-r--r-- | core/msg_codec.go | 2 | ||||
-rw-r--r-- | core/request_handler.go | 31 |
5 files changed, 114 insertions, 28 deletions
diff --git a/adapter/mock/mock_adapter.go b/adapter/mock/mock_adapter.go index 0b2e8d5..dab51a6 100644 --- a/adapter/mock/mock_adapter.go +++ b/adapter/mock/mock_adapter.go @@ -30,10 +30,10 @@ import ( "github.com/lunixbochs/struc" ) -type ReplyMode int +type replyMode int const ( - _ ReplyMode = 0 + _ replyMode = 0 useRepliesQueue = 1 // use replies in the queue useReplyHandlers = 2 // use reply handler ) @@ -51,7 +51,7 @@ type VppAdapter struct { replies []api.Message // FIFO queue of messages replyHandlers []ReplyHandler // callbacks that are able to calculate mock responses repliesLock sync.Mutex // mutex for the queue - mode ReplyMode // mode in which the mock operates + mode replyMode // mode in which the mock operates } // defaultReply is a default reply message that mock adapter returns for a request. @@ -250,8 +250,7 @@ func (a *VppAdapter) SendMsg(clientID uint32, data []byte) error { for i, reply := range a.replies { if i > 0 && reply.GetMessageName() == "control_ping_reply" { // hack - do not send control_ping_reply immediately, leave it for the the next callback - a.replies = []api.Message{} - a.replies = append(a.replies, reply) + a.replies = a.replies[i:] return nil } msgID, _ := a.GetMsgID(reply.GetMessageName(), reply.GetCrcString()) @@ -233,15 +233,15 @@ func (ch *Channel) receiveReplyInternal(msg Message) (LastReplyReceived bool, er return false, err } if vppReply.MessageID != expMsgID { - err = fmt.Errorf("invalid message ID %d, expected %d "+ - "(also check if multiple goroutines are not sharing one GoVPP channel)", vppReply.MessageID, expMsgID) + err = fmt.Errorf("received invalid message ID, expected %d (%s), but got %d (check if multiple goroutines are not sharing single GoVPP channel)", + expMsgID, msg.GetMessageName(), vppReply.MessageID) return false, err } // decode the message err = ch.MsgDecoder.DecodeMsg(vppReply.Data, msg) case <-time.After(ch.replyTimeout): - err = fmt.Errorf("no reply received within the timeout period %ds", ch.replyTimeout/time.Second) + err = fmt.Errorf("no reply received within the timeout period %s", ch.replyTimeout) } return } diff --git a/api/api_test.go b/api/api_test.go index 3e77f48..9af6e71 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -49,6 +49,8 @@ func setupTest(t *testing.T) *testCtx { ctx.ch, err = ctx.conn.NewAPIChannel() Expect(err).ShouldNot(HaveOccurred()) + ctx.ch.SetReplyTimeout(time.Millisecond) + return ctx } @@ -197,10 +199,9 @@ func TestMultiRequestReplySwInterfaceTapDump(t *testing.T) { // mock reply for i := 1; i <= 10; i++ { - byteName := []byte("dev-name-test") ctx.mockVpp.MockReply(&tap.SwInterfaceTapDetails{ SwIfIndex: uint32(i), - DevName: byteName, + DevName: []byte("dev-name-test"), }) } ctx.mockVpp.MockReply(&vpe.ControlPingReply{}) @@ -327,8 +328,6 @@ func TestSetReplyTimeout(t *testing.T) { ctx := setupTest(t) defer ctx.teardownTest() - ctx.ch.SetReplyTimeout(time.Millisecond) - // first one request should work ctx.mockVpp.MockReply(&vpe.ControlPingReply{}) err := ctx.ch.SendRequest(&vpe.ControlPing{}).ReceiveReply(&vpe.ControlPingReply{}) @@ -340,6 +339,47 @@ func TestSetReplyTimeout(t *testing.T) { Expect(err.Error()).To(ContainSubstring("timeout")) } +func TestSetReplyTimeoutMultiRequest(t *testing.T) { + ctx := setupTest(t) + defer ctx.teardownTest() + + for i := 1; i <= 3; i++ { + ctx.mockVpp.MockReply(&interfaces.SwInterfaceDetails{ + SwIfIndex: uint32(i), + InterfaceName: []byte("if-name-test"), + }) + } + ctx.mockVpp.MockReply(&vpe.ControlPingReply{}) + + cnt := 0 + sendMultiRequest := func() error { + reqCtx := ctx.ch.SendMultiRequest(&interfaces.SwInterfaceDump{}) + for { + msg := &interfaces.SwInterfaceDetails{} + stop, err := reqCtx.ReceiveReply(msg) + if stop { + break // break out of the loop + } + if err != nil { + return err + } + cnt++ + } + return nil + } + + // first one request should work + err := sendMultiRequest() + Expect(err).ShouldNot(HaveOccurred()) + + // no other reply ready - expect timeout + err = sendMultiRequest() + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("timeout")) + + Expect(cnt).To(BeEquivalentTo(3)) +} + func TestReceiveReplyNegative(t *testing.T) { ctx := setupTest(t) defer ctx.teardownTest() @@ -362,3 +402,49 @@ func TestReceiveReplyNegative(t *testing.T) { Expect(err).Should(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("invalid request context")) } + +func TestMultiRequestDouble(t *testing.T) { + ctx := setupTest(t) + defer ctx.teardownTest() + + // mock reply + for i := 1; i <= 3; i++ { + ctx.mockVpp.MockReply(&interfaces.SwInterfaceDetails{ + SwIfIndex: uint32(i), + InterfaceName: []byte("if-name-test"), + }) + } + ctx.mockVpp.MockReply(&vpe.ControlPingReply{}) + for i := 1; i <= 3; i++ { + ctx.mockVpp.MockReply(&interfaces.SwInterfaceDetails{ + SwIfIndex: uint32(i), + InterfaceName: []byte("if-name-test"), + }) + } + ctx.mockVpp.MockReply(&vpe.ControlPingReply{}) + + cnt := 0 + sendMultiRequest := func() error { + reqCtx := ctx.ch.SendMultiRequest(&interfaces.SwInterfaceDump{}) + for { + msg := &interfaces.SwInterfaceDetails{} + stop, err := reqCtx.ReceiveReply(msg) + if stop { + break // break out of the loop + } + if err != nil { + return err + } + cnt++ + } + return nil + } + + err := sendMultiRequest() + Expect(err).ShouldNot(HaveOccurred()) + + err = sendMultiRequest() + Expect(err).ShouldNot(HaveOccurred()) + + Expect(cnt).To(BeEquivalentTo(6)) +} diff --git a/core/msg_codec.go b/core/msg_codec.go index 77fb9a9..e32916b 100644 --- a/core/msg_codec.go +++ b/core/msg_codec.go @@ -20,8 +20,8 @@ import ( "fmt" "reflect" - logger "github.com/sirupsen/logrus" "github.com/lunixbochs/struc" + logger "github.com/sirupsen/logrus" "git.fd.io/govpp.git/api" ) diff --git a/core/request_handler.go b/core/request_handler.go index dc02ee7..4a62754 100644 --- a/core/request_handler.go +++ b/core/request_handler.go @@ -48,34 +48,34 @@ func (c *Connection) watchRequests(ch *api.Channel, chMeta *channelMetadata) { func (c *Connection) processRequest(ch *api.Channel, chMeta *channelMetadata, req *api.VppRequest) error { // check whether we are connected to VPP if atomic.LoadUint32(&c.connected) == 0 { - error := errors.New("not connected to VPP, ignoring the request") - log.Error(error) - sendReply(ch, &api.VppReply{Error: error}) - return error + err := errors.New("not connected to VPP, ignoring the request") + log.Error(err) + sendReply(ch, &api.VppReply{Error: err}) + return err } // retrieve message ID msgID, err := c.GetMessageID(req.Message) if err != nil { - error := fmt.Errorf("unable to retrieve message ID: %v", err) + err = fmt.Errorf("unable to retrieve message ID: %v", err) log.WithFields(logger.Fields{ "msg_name": req.Message.GetMessageName(), "msg_crc": req.Message.GetCrcString(), - }).Error(error) - sendReply(ch, &api.VppReply{Error: error}) - return error + }).Error(err) + sendReply(ch, &api.VppReply{Error: err}) + return err } // encode the message into binary data, err := c.codec.EncodeMsg(req.Message, msgID) if err != nil { - error := fmt.Errorf("unable to encode the messge: %v", err) + err = fmt.Errorf("unable to encode the messge: %v", err) log.WithFields(logger.Fields{ "context": chMeta.id, "msg_id": msgID, - }).Error(error) - sendReply(ch, &api.VppReply{Error: error}) - return error + }).Error(err) + sendReply(ch, &api.VppReply{Error: err}) + return err } if log.Level == logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled @@ -83,6 +83,7 @@ func (c *Connection) processRequest(ch *api.Channel, chMeta *channelMetadata, re "context": chMeta.id, "msg_id": msgID, "msg_size": len(data), + "msg_name": req.Message.GetMessageName(), }).Debug("Sending a message to VPP.") } @@ -199,12 +200,12 @@ func (c *Connection) messageNameToID(msgName string, msgCrc string) (uint16, err // get the ID using VPP API id, err := c.vpp.GetMsgID(msgName, msgCrc) if err != nil { - error := fmt.Errorf("unable to retrieve message ID: %v", err) + err = fmt.Errorf("unable to retrieve message ID: %v", err) log.WithFields(logger.Fields{ "msg_name": msgName, "msg_crc": msgCrc, - }).Error(error) - return id, error + }).Error(err) + return id, err } c.msgIDsLock.Lock() |