summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--codec/msg_codec.go27
-rw-r--r--codec/msg_codec_test.go63
-rw-r--r--core/request_handler.go22
3 files changed, 89 insertions, 23 deletions
diff --git a/codec/msg_codec.go b/codec/msg_codec.go
index 9d3f614..67628a4 100644
--- a/codec/msg_codec.go
+++ b/codec/msg_codec.go
@@ -53,24 +53,32 @@ type VppOtherHeader struct {
}
// EncodeMsg encodes provided `Message` structure into its binary-encoded data representation.
-func (*MsgCodec) EncodeMsg(msg api.Message, msgID uint16) ([]byte, error) {
+func (*MsgCodec) EncodeMsg(msg api.Message, msgID uint16) (data []byte, err error) {
if msg == nil {
return nil, errors.New("nil message passed in")
}
+ // try to recover panic which might possibly occur in struc.Pack call
+ defer func() {
+ if r := recover(); r != nil {
+ var ok bool
+ if err, ok = r.(error); !ok {
+ err = fmt.Errorf("%v", r)
+ }
+ err = fmt.Errorf("panic occurred: %v", err)
+ }
+ }()
+
var header interface{}
// encode message header
switch msg.GetMessageType() {
case api.RequestMessage:
header = &VppRequestHeader{VlMsgID: msgID}
-
case api.ReplyMessage:
header = &VppReplyHeader{VlMsgID: msgID}
-
case api.EventMessage:
header = &VppEventHeader{VlMsgID: msgID}
-
default:
header = &VppOtherHeader{VlMsgID: msgID}
}
@@ -79,13 +87,13 @@ func (*MsgCodec) EncodeMsg(msg api.Message, msgID uint16) ([]byte, error) {
// encode message header
if err := struc.Pack(buf, header); err != nil {
- return nil, fmt.Errorf("unable to encode message header: %v, error %v", header, err)
+ return nil, fmt.Errorf("failed to encode message header: %+v, error: %v", header, err)
}
// encode message content
if reflect.TypeOf(msg).Elem().NumField() > 0 {
if err := struc.Pack(buf, msg); err != nil {
- return nil, fmt.Errorf("unable to encode message data: %v, error %v", header, err)
+ return nil, fmt.Errorf("failed to encode message data: %+v, error: %v", data, err)
}
}
@@ -104,13 +112,10 @@ func (*MsgCodec) DecodeMsg(data []byte, msg api.Message) error {
switch msg.GetMessageType() {
case api.RequestMessage:
header = new(VppRequestHeader)
-
case api.ReplyMessage:
header = new(VppReplyHeader)
-
case api.EventMessage:
header = new(VppEventHeader)
-
default:
header = new(VppOtherHeader)
}
@@ -119,12 +124,12 @@ func (*MsgCodec) DecodeMsg(data []byte, msg api.Message) error {
// decode message header
if err := struc.Unpack(buf, header); err != nil {
- return fmt.Errorf("unable to decode message header: %+v, error %v", data, err)
+ return fmt.Errorf("failed to decode message header: %+v, error: %v", header, err)
}
// decode message content
if err := struc.Unpack(buf, msg); err != nil {
- return fmt.Errorf("unable to decode message data: %+v, error %v", data, err)
+ return fmt.Errorf("failed to decode message data: %+v, error: %v", data, err)
}
return nil
diff --git a/codec/msg_codec_test.go b/codec/msg_codec_test.go
new file mode 100644
index 0000000..cd1240e
--- /dev/null
+++ b/codec/msg_codec_test.go
@@ -0,0 +1,63 @@
+package codec
+
+import (
+ "bytes"
+ "testing"
+
+ "git.fd.io/govpp.git/api"
+)
+
+type MyMsg struct {
+ Index uint16
+ Label []byte `struc:"[16]byte"`
+ Port uint16
+}
+
+func (*MyMsg) GetMessageName() string {
+ return "my_msg"
+}
+func (*MyMsg) GetCrcString() string {
+ return "xxxxx"
+}
+func (*MyMsg) GetMessageType() api.MessageType {
+ return api.OtherMessage
+}
+
+func TestEncode(t *testing.T) {
+ tests := []struct {
+ name string
+ msg api.Message
+ msgID uint16
+ expData []byte
+ }{
+ {name: "basic",
+ msg: &MyMsg{Index: 1, Label: []byte("Abcdef"), Port: 1000},
+ msgID: 100,
+ expData: []byte{0x00, 0x64, 0x00, 0x01, 0x41, 0x62, 0x63, 0x64, 0x65, 0x66, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xE8},
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := &MsgCodec{}
+
+ data, err := c.EncodeMsg(test.msg, test.msgID)
+ if err != nil {
+ t.Fatalf("expected nil error, got: %v", err)
+ }
+ if !bytes.Equal(data, test.expData) {
+ t.Fatalf("expected data: % 0X, got: % 0X", test.expData, data)
+ }
+ })
+ }
+}
+
+func TestEncodePanic(t *testing.T) {
+ c := &MsgCodec{}
+
+ msg := &MyMsg{Index: 1, Label: []byte("thisIsLongerThan16Bytes"), Port: 1000}
+
+ _, err := c.EncodeMsg(msg, 100)
+ if err == nil {
+ t.Fatalf("expected non-nil error, got: %v", err)
+ }
+}
diff --git a/core/request_handler.go b/core/request_handler.go
index c042948..e52e262 100644
--- a/core/request_handler.go
+++ b/core/request_handler.go
@@ -39,7 +39,9 @@ func (c *Connection) watchRequests(ch *Channel) {
c.releaseAPIChannel(ch)
return
}
- c.processRequest(ch, req)
+ if err := c.processRequest(ch, req); err != nil {
+ sendReplyError(ch, req, err)
+ }
}
}
}
@@ -50,39 +52,36 @@ func (c *Connection) processRequest(ch *Channel, req *vppRequest) error {
if atomic.LoadUint32(&c.connected) == 0 {
err := ErrNotConnected
log.Errorf("processing request failed: %v", err)
- sendReplyError(ch, req, err)
return err
}
// retrieve message ID
msgID, err := c.GetMessageID(req.msg)
if err != nil {
- err = fmt.Errorf("unable to retrieve message ID: %v", err)
log.WithFields(logger.Fields{
"msg_name": req.msg.GetMessageName(),
"msg_crc": req.msg.GetCrcString(),
"seq_num": req.seqNum,
- }).Error(err)
- sendReplyError(ch, req, err)
- return err
+ "error": err,
+ }).Errorf("failed to retrieve message ID")
+ return fmt.Errorf("unable to retrieve message ID: %v", err)
}
// encode the message into binary
data, err := c.codec.EncodeMsg(req.msg, msgID)
if err != nil {
- err = fmt.Errorf("unable to encode the messge: %v", err)
log.WithFields(logger.Fields{
"channel": ch.id,
"msg_id": msgID,
"msg_name": req.msg.GetMessageName(),
"seq_num": req.seqNum,
- }).Error(err)
- sendReplyError(ch, req, err)
- return err
+ "error": err,
+ }).Errorf("failed to encode message: %#v", req.msg)
+ return fmt.Errorf("unable to encode the message: %v", err)
}
- // get context
context := packRequestContext(ch.id, req.multi, req.seqNum)
+
if log.Level == logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled
log.WithFields(logger.Fields{
"channel": ch.id,
@@ -104,7 +103,6 @@ func (c *Connection) processRequest(ch *Channel, req *vppRequest) error {
"msg_id": msgID,
"seq_num": req.seqNum,
}).Error(err)
- sendReplyError(ch, req, err)
return err
}