diff options
-rw-r--r-- | codec/msg_codec.go | 27 | ||||
-rw-r--r-- | codec/msg_codec_test.go | 63 | ||||
-rw-r--r-- | core/request_handler.go | 22 |
3 files changed, 89 insertions, 23 deletions
diff --git a/codec/msg_codec.go b/codec/msg_codec.go index 9d3f614..67628a4 100644 --- a/codec/msg_codec.go +++ b/codec/msg_codec.go @@ -53,24 +53,32 @@ type VppOtherHeader struct { } // EncodeMsg encodes provided `Message` structure into its binary-encoded data representation. -func (*MsgCodec) EncodeMsg(msg api.Message, msgID uint16) ([]byte, error) { +func (*MsgCodec) EncodeMsg(msg api.Message, msgID uint16) (data []byte, err error) { if msg == nil { return nil, errors.New("nil message passed in") } + // try to recover panic which might possibly occur in struc.Pack call + defer func() { + if r := recover(); r != nil { + var ok bool + if err, ok = r.(error); !ok { + err = fmt.Errorf("%v", r) + } + err = fmt.Errorf("panic occurred: %v", err) + } + }() + var header interface{} // encode message header switch msg.GetMessageType() { case api.RequestMessage: header = &VppRequestHeader{VlMsgID: msgID} - case api.ReplyMessage: header = &VppReplyHeader{VlMsgID: msgID} - case api.EventMessage: header = &VppEventHeader{VlMsgID: msgID} - default: header = &VppOtherHeader{VlMsgID: msgID} } @@ -79,13 +87,13 @@ func (*MsgCodec) EncodeMsg(msg api.Message, msgID uint16) ([]byte, error) { // encode message header if err := struc.Pack(buf, header); err != nil { - return nil, fmt.Errorf("unable to encode message header: %v, error %v", header, err) + return nil, fmt.Errorf("failed to encode message header: %+v, error: %v", header, err) } // encode message content if reflect.TypeOf(msg).Elem().NumField() > 0 { if err := struc.Pack(buf, msg); err != nil { - return nil, fmt.Errorf("unable to encode message data: %v, error %v", header, err) + return nil, fmt.Errorf("failed to encode message data: %+v, error: %v", data, err) } } @@ -104,13 +112,10 @@ func (*MsgCodec) DecodeMsg(data []byte, msg api.Message) error { switch msg.GetMessageType() { case api.RequestMessage: header = new(VppRequestHeader) - case api.ReplyMessage: header = new(VppReplyHeader) - case api.EventMessage: header = new(VppEventHeader) - default: header = new(VppOtherHeader) } @@ -119,12 +124,12 @@ func (*MsgCodec) DecodeMsg(data []byte, msg api.Message) error { // decode message header if err := struc.Unpack(buf, header); err != nil { - return fmt.Errorf("unable to decode message header: %+v, error %v", data, err) + return fmt.Errorf("failed to decode message header: %+v, error: %v", header, err) } // decode message content if err := struc.Unpack(buf, msg); err != nil { - return fmt.Errorf("unable to decode message data: %+v, error %v", data, err) + return fmt.Errorf("failed to decode message data: %+v, error: %v", data, err) } return nil diff --git a/codec/msg_codec_test.go b/codec/msg_codec_test.go new file mode 100644 index 0000000..cd1240e --- /dev/null +++ b/codec/msg_codec_test.go @@ -0,0 +1,63 @@ +package codec + +import ( + "bytes" + "testing" + + "git.fd.io/govpp.git/api" +) + +type MyMsg struct { + Index uint16 + Label []byte `struc:"[16]byte"` + Port uint16 +} + +func (*MyMsg) GetMessageName() string { + return "my_msg" +} +func (*MyMsg) GetCrcString() string { + return "xxxxx" +} +func (*MyMsg) GetMessageType() api.MessageType { + return api.OtherMessage +} + +func TestEncode(t *testing.T) { + tests := []struct { + name string + msg api.Message + msgID uint16 + expData []byte + }{ + {name: "basic", + msg: &MyMsg{Index: 1, Label: []byte("Abcdef"), Port: 1000}, + msgID: 100, + expData: []byte{0x00, 0x64, 0x00, 0x01, 0x41, 0x62, 0x63, 0x64, 0x65, 0x66, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xE8}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := &MsgCodec{} + + data, err := c.EncodeMsg(test.msg, test.msgID) + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } + if !bytes.Equal(data, test.expData) { + t.Fatalf("expected data: % 0X, got: % 0X", test.expData, data) + } + }) + } +} + +func TestEncodePanic(t *testing.T) { + c := &MsgCodec{} + + msg := &MyMsg{Index: 1, Label: []byte("thisIsLongerThan16Bytes"), Port: 1000} + + _, err := c.EncodeMsg(msg, 100) + if err == nil { + t.Fatalf("expected non-nil error, got: %v", err) + } +} diff --git a/core/request_handler.go b/core/request_handler.go index c042948..e52e262 100644 --- a/core/request_handler.go +++ b/core/request_handler.go @@ -39,7 +39,9 @@ func (c *Connection) watchRequests(ch *Channel) { c.releaseAPIChannel(ch) return } - c.processRequest(ch, req) + if err := c.processRequest(ch, req); err != nil { + sendReplyError(ch, req, err) + } } } } @@ -50,39 +52,36 @@ func (c *Connection) processRequest(ch *Channel, req *vppRequest) error { if atomic.LoadUint32(&c.connected) == 0 { err := ErrNotConnected log.Errorf("processing request failed: %v", err) - sendReplyError(ch, req, err) return err } // retrieve message ID msgID, err := c.GetMessageID(req.msg) if err != nil { - err = fmt.Errorf("unable to retrieve message ID: %v", err) log.WithFields(logger.Fields{ "msg_name": req.msg.GetMessageName(), "msg_crc": req.msg.GetCrcString(), "seq_num": req.seqNum, - }).Error(err) - sendReplyError(ch, req, err) - return err + "error": err, + }).Errorf("failed to retrieve message ID") + return fmt.Errorf("unable to retrieve message ID: %v", err) } // encode the message into binary data, err := c.codec.EncodeMsg(req.msg, msgID) if err != nil { - err = fmt.Errorf("unable to encode the messge: %v", err) log.WithFields(logger.Fields{ "channel": ch.id, "msg_id": msgID, "msg_name": req.msg.GetMessageName(), "seq_num": req.seqNum, - }).Error(err) - sendReplyError(ch, req, err) - return err + "error": err, + }).Errorf("failed to encode message: %#v", req.msg) + return fmt.Errorf("unable to encode the message: %v", err) } - // get context context := packRequestContext(ch.id, req.multi, req.seqNum) + if log.Level == logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled log.WithFields(logger.Fields{ "channel": ch.id, @@ -104,7 +103,6 @@ func (c *Connection) processRequest(ch *Channel, req *vppRequest) error { "msg_id": msgID, "seq_num": req.seqNum, }).Error(err) - sendReplyError(ch, req, err) return err } |