diff options
Diffstat (limited to 'core')
-rw-r--r-- | core/channel.go | 60 | ||||
-rw-r--r-- | core/channel_test.go | 12 | ||||
-rw-r--r-- | core/connection.go | 258 | ||||
-rw-r--r-- | core/connection_test.go | 111 | ||||
-rw-r--r-- | core/control_ping.go | 2 | ||||
-rw-r--r-- | core/request_handler.go | 88 | ||||
-rw-r--r-- | core/stats.go | 198 | ||||
-rw-r--r-- | core/stream.go | 118 | ||||
-rw-r--r-- | core/trace.go | 70 | ||||
-rw-r--r-- | core/trace_test.go | 265 |
10 files changed, 879 insertions, 303 deletions
diff --git a/core/channel.go b/core/channel.go index 1b5e77e..eef59d0 100644 --- a/core/channel.go +++ b/core/channel.go @@ -23,8 +23,8 @@ import ( "github.com/sirupsen/logrus" - "git.fd.io/govpp.git/adapter" - "git.fd.io/govpp.git/api" + "go.fd.io/govpp/adapter" + "go.fd.io/govpp/api" ) var ( @@ -37,16 +37,18 @@ type MessageCodec interface { EncodeMsg(msg api.Message, msgID uint16) ([]byte, error) // DecodeMsg decodes binary-encoded data of a message into provided Message structure. DecodeMsg(data []byte, msg api.Message) error - // DecodeMsgContext decodes context from message data. - DecodeMsgContext(data []byte, msg api.Message) (context uint32, err error) + // DecodeMsgContext decodes context from message data and type. + DecodeMsgContext(data []byte, msgType api.MessageType) (context uint32, err error) } // MessageIdentifier provides identification of generated API messages. type MessageIdentifier interface { // GetMessageID returns message identifier of given API message. GetMessageID(msg api.Message) (uint16, error) + // GetMessagePath returns path for the given message + GetMessagePath(msg api.Message) string // LookupByID looks up message name and crc by ID - LookupByID(msgID uint16) (api.Message, error) + LookupByID(path string, msgID uint16) (api.Message, error) } // vppRequest is a request that will be sent to VPP. @@ -107,17 +109,36 @@ type Channel struct { receiveReplyTimeout time.Duration // maximum time that we wait for receiver to consume reply } -func newChannel(id uint16, conn *Connection, codec MessageCodec, identifier MessageIdentifier, reqSize, replySize int) *Channel { - return &Channel{ - id: id, - conn: conn, - msgCodec: codec, - msgIdentifier: identifier, - reqChan: make(chan *vppRequest, reqSize), - replyChan: make(chan *vppReply, replySize), +func (c *Connection) newChannel(reqChanBufSize, replyChanBufSize int) (*Channel, error) { + // create new channel + channel := &Channel{ + conn: c, + msgCodec: c.codec, + msgIdentifier: c, + reqChan: make(chan *vppRequest, reqChanBufSize), + replyChan: make(chan *vppReply, replyChanBufSize), replyTimeout: DefaultReplyTimeout, receiveReplyTimeout: ReplyChannelTimeout, } + + // store API channel within the client + c.channelsLock.Lock() + if len(c.channels) >= 0x7fff { + return nil, errors.New("all channel IDs are used") + } + for { + c.nextChannelID++ + chID := c.nextChannelID & 0x7fff + _, ok := c.channels[chID] + if !ok { + channel.id = chID + c.channels[chID] = channel + break + } + } + c.channelsLock.Unlock() + + return channel, nil } func (ch *Channel) GetID() uint16 { @@ -329,7 +350,8 @@ func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Messa if reply.msgID != expMsgID { var msgNameCrc string - if replyMsg, err := ch.msgIdentifier.LookupByID(reply.msgID); err != nil { + pkgPath := ch.msgIdentifier.GetMessagePath(msg) + if replyMsg, err := ch.msgIdentifier.LookupByID(pkgPath, reply.msgID); err != nil { msgNameCrc = err.Error() } else { msgNameCrc = getMsgNameWithCrc(replyMsg) @@ -350,7 +372,15 @@ func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Messa if strings.HasSuffix(msg.GetMessageName(), "_reply") { // TODO: use categories for messages to avoid checking message name if f := reflect.Indirect(reflect.ValueOf(msg)).FieldByName("Retval"); f.IsValid() { - retval := int32(f.Int()) + var retval int32 + switch f.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + retval = int32(f.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + retval = int32(f.Uint()) + default: + logrus.Warnf("invalid kind (%v) for Retval field of message %v", f.Kind(), msg.GetMessageName()) + } err = api.RetvalToVPPApiError(retval) } } diff --git a/core/channel_test.go b/core/channel_test.go index fa3e58d..6b12d68 100644 --- a/core/channel_test.go +++ b/core/channel_test.go @@ -20,12 +20,12 @@ import ( . "github.com/onsi/gomega" - "git.fd.io/govpp.git/adapter/mock" - "git.fd.io/govpp.git/api" - interfaces "git.fd.io/govpp.git/binapi/interface" - "git.fd.io/govpp.git/binapi/interface_types" - "git.fd.io/govpp.git/binapi/memif" - "git.fd.io/govpp.git/binapi/vpe" + "go.fd.io/govpp/adapter/mock" + "go.fd.io/govpp/api" + interfaces "go.fd.io/govpp/binapi/interface" + "go.fd.io/govpp/binapi/interface_types" + "go.fd.io/govpp/binapi/memif" + "go.fd.io/govpp/binapi/vpe" ) type testCtx struct { diff --git a/core/connection.go b/core/connection.go index 53a9acf..2c05333 100644 --- a/core/connection.go +++ b/core/connection.go @@ -17,6 +17,7 @@ package core import ( "errors" "fmt" + "path" "reflect" "sync" "sync/atomic" @@ -24,9 +25,9 @@ import ( logger "github.com/sirupsen/logrus" - "git.fd.io/govpp.git/adapter" - "git.fd.io/govpp.git/api" - "git.fd.io/govpp.git/codec" + "go.fd.io/govpp/adapter" + "go.fd.io/govpp/api" + "go.fd.io/govpp/codec" ) const ( @@ -42,8 +43,8 @@ var ( var ( HealthCheckProbeInterval = time.Second // default health check probe interval - HealthCheckReplyTimeout = time.Millisecond * 100 // timeout for reply to a health check probe - HealthCheckThreshold = 1 // number of failed health checks until the error is reported + HealthCheckReplyTimeout = time.Millisecond * 250 // timeout for reply to a health check probe + HealthCheckThreshold = 2 // number of failed health checks until the error is reported DefaultReplyTimeout = time.Second // default timeout for replies from VPP ) @@ -101,15 +102,16 @@ type Connection struct { vppConnected uint32 // non-zero if the adapter is connected to VPP - connChan chan ConnectionEvent // connection status events are sent to this channel + connChan chan ConnectionEvent // connection status events are sent to this channel + healthCheckDone chan struct{} // used to terminate health check loop - codec MessageCodec // message codec - msgIDs map[string]uint16 // map of message IDs indexed by message name + CRC - msgMap map[uint16]api.Message // map of messages indexed by message ID + codec MessageCodec // message codec + msgIDs map[string]uint16 // map of message IDs indexed by message name + CRC + msgMapByPath map[string]map[uint16]api.Message // map of messages indexed by message ID which are indexed by path - maxChannelID uint32 // maximum used channel ID (the real limit is 2^15, 32-bit is used for atomic operations) - channelsLock sync.RWMutex // lock for the channels map - channels map[uint16]*Channel // map of all API channels indexed by the channel ID + channelsLock sync.RWMutex // lock for the channels map and the channel ID + nextChannelID uint16 // next potential channel ID (the real limit is 2^15) + channels map[uint16]*Channel // map of all API channels indexed by the channel ID subscriptionsLock sync.RWMutex // lock for the subscriptions map subscriptions map[uint16][]*subscriptionCtx // map od all notification subscriptions indexed by message ID @@ -122,6 +124,8 @@ type Connection struct { msgControlPing api.Message msgControlPingReply api.Message + + apiTrace *trace // API tracer (disabled by default) } func newConnection(binapi adapter.VppAPI, attempts int, interval time.Duration) *Connection { @@ -137,13 +141,18 @@ func newConnection(binapi adapter.VppAPI, attempts int, interval time.Duration) maxAttempts: attempts, recInterval: interval, connChan: make(chan ConnectionEvent, NotificationChanBufSize), + healthCheckDone: make(chan struct{}), codec: codec.DefaultCodec, msgIDs: make(map[string]uint16), - msgMap: make(map[uint16]api.Message), + msgMapByPath: make(map[string]map[uint16]api.Message), channels: make(map[uint16]*Channel), subscriptions: make(map[uint16][]*subscriptionCtx), msgControlPing: msgControlPing, msgControlPingReply: msgControlPingReply, + apiTrace: &trace{ + list: make([]*api.Record, 0), + mux: &sync.Mutex{}, + }, } binapi.SetMsgCallback(c.msgCallback) return c @@ -207,13 +216,18 @@ func (c *Connection) Disconnect() { return } if c.vppClient != nil { - c.disconnectVPP() + c.disconnectVPP(true) } } -// disconnectVPP disconnects from VPP in case it is connected. -func (c *Connection) disconnectVPP() { +// disconnectVPP disconnects from VPP in case it is connected. terminate tells +// that disconnectVPP() was called from Close(), so healthCheckLoop() can be +// terminated. +func (c *Connection) disconnectVPP(terminate bool) { if atomic.CompareAndSwapUint32(&c.vppConnected, 1, 0) { + if terminate { + close(c.healthCheckDone) + } log.Debug("Disconnecting from VPP..") if err := c.vppClient.Disconnect(); err != nil { @@ -238,14 +252,10 @@ func (c *Connection) newAPIChannel(reqChanBufSize, replyChanBufSize int) (*Chann return nil, errors.New("nil connection passed in") } - // create new channel - chID := uint16(atomic.AddUint32(&c.maxChannelID, 1) & 0x7fff) - channel := newChannel(chID, c, c.codec, c, reqChanBufSize, replyChanBufSize) - - // store API channel within the client - c.channelsLock.Lock() - c.channels[chID] = channel - c.channelsLock.Unlock() + channel, err := c.newChannel(reqChanBufSize, replyChanBufSize) + if err != nil { + return nil, err + } // start watching on the request channel go c.watchRequests(channel) @@ -302,6 +312,7 @@ func (c *Connection) healthCheckLoop() { log.Error("Failed to create health check API channel, health check will be disabled:", err) return } + defer ch.Close() var ( sinceLastReply time.Duration @@ -309,73 +320,74 @@ func (c *Connection) healthCheckLoop() { ) // send health check probes until an error or timeout occurs - for { - // sleep until next health check probe period - time.Sleep(HealthCheckProbeInterval) + probeInterval := time.NewTicker(HealthCheckProbeInterval) + defer probeInterval.Stop() - if atomic.LoadUint32(&c.vppConnected) == 0 { - // Disconnect has been called in the meantime, return the healthcheck - reconnect loop +HealthCheck: + for { + select { + case <-c.healthCheckDone: + // Terminate the health check loop on connection disconnect log.Debug("Disconnected on request, exiting health check loop.") return - } - - // try draining probe replies from previous request before sending next one - select { - case <-ch.replyChan: - log.Debug("drained old probe reply from reply channel") - default: - } + case <-probeInterval.C: + // try draining probe replies from previous request before sending next one + select { + case <-ch.replyChan: + log.Debug("drained old probe reply from reply channel") + default: + } - // send the control ping request - ch.reqChan <- &vppRequest{msg: c.msgControlPing} + // send the control ping request + ch.reqChan <- &vppRequest{msg: c.msgControlPing} - for { - // expect response within timeout period - select { - case vppReply := <-ch.replyChan: - err = vppReply.err + for { + // expect response within timeout period + select { + case vppReply := <-ch.replyChan: + err = vppReply.err - case <-time.After(HealthCheckReplyTimeout): - err = ErrProbeTimeout + case <-time.After(HealthCheckReplyTimeout): + err = ErrProbeTimeout - // check if time since last reply from any other - // channel is less than health check reply timeout - c.lastReplyLock.Lock() - sinceLastReply = time.Since(c.lastReply) - c.lastReplyLock.Unlock() + // check if time since last reply from any other + // channel is less than health check reply timeout + c.lastReplyLock.Lock() + sinceLastReply = time.Since(c.lastReply) + c.lastReplyLock.Unlock() - if sinceLastReply < HealthCheckReplyTimeout { - log.Warnf("VPP health check probe timing out, but some request on other channel was received %v ago, continue waiting!", sinceLastReply) - continue + if sinceLastReply < HealthCheckReplyTimeout { + log.Warnf("VPP health check probe timing out, but some request on other channel was received %v ago, continue waiting!", sinceLastReply) + continue + } } + break } - break - } - if err == ErrProbeTimeout { - failedChecks++ - log.Warnf("VPP health check probe timed out after %v (%d. timeout)", HealthCheckReplyTimeout, failedChecks) - if failedChecks > HealthCheckThreshold { - // in case of exceeded failed check threshold, assume VPP unresponsive - log.Errorf("VPP does not responding, the health check exceeded threshold for timeouts (>%d)", HealthCheckThreshold) - c.sendConnEvent(ConnectionEvent{Timestamp: time.Now(), State: NotResponding}) - break + if err == ErrProbeTimeout { + failedChecks++ + log.Warnf("VPP health check probe timed out after %v (%d. timeout)", HealthCheckReplyTimeout, failedChecks) + if failedChecks > HealthCheckThreshold { + // in case of exceeded failed check threshold, assume VPP unresponsive + log.Errorf("VPP does not responding, the health check exceeded threshold for timeouts (>%d)", HealthCheckThreshold) + c.sendConnEvent(ConnectionEvent{Timestamp: time.Now(), State: NotResponding}) + break HealthCheck + } + } else if err != nil { + // in case of error, assume VPP disconnected + log.Errorf("VPP health check probe failed: %v", err) + c.sendConnEvent(ConnectionEvent{Timestamp: time.Now(), State: Disconnected, Error: err}) + break HealthCheck + } else if failedChecks > 0 { + // in case of success after failed checks, clear failed check counter + failedChecks = 0 + log.Infof("VPP health check probe OK") } - } else if err != nil { - // in case of error, assume VPP disconnected - log.Errorf("VPP health check probe failed: %v", err) - c.sendConnEvent(ConnectionEvent{Timestamp: time.Now(), State: Disconnected, Error: err}) - break - } else if failedChecks > 0 { - // in case of success after failed checks, clear failed check counter - failedChecks = 0 - log.Infof("VPP health check probe OK") } } // cleanup - ch.Close() - c.disconnectVPP() + c.disconnectVPP(false) // we are now disconnected, start connect loop c.connectLoop() @@ -400,69 +412,74 @@ func (c *Connection) GetMessageID(msg api.Message) (uint16, error) { if c == nil { return 0, errors.New("nil connection passed in") } - - if msgID, ok := c.msgIDs[getMsgNameWithCrc(msg)]; ok { - return msgID, nil - } - + pkgPath := c.GetMessagePath(msg) msgID, err := c.vppClient.GetMsgID(msg.GetMessageName(), msg.GetCrcString()) if err != nil { return 0, err } - + if pathMsgs, pathOk := c.msgMapByPath[pkgPath]; !pathOk { + c.msgMapByPath[pkgPath] = make(map[uint16]api.Message) + c.msgMapByPath[pkgPath][msgID] = msg + } else if _, msgOk := pathMsgs[msgID]; !msgOk { + c.msgMapByPath[pkgPath][msgID] = msg + } + if _, ok := c.msgIDs[getMsgNameWithCrc(msg)]; ok { + return msgID, nil + } c.msgIDs[getMsgNameWithCrc(msg)] = msgID - c.msgMap[msgID] = msg - return msgID, nil } // LookupByID looks up message name and crc by ID. -func (c *Connection) LookupByID(msgID uint16) (api.Message, error) { +func (c *Connection) LookupByID(path string, msgID uint16) (api.Message, error) { if c == nil { return nil, errors.New("nil connection passed in") } - - if msg, ok := c.msgMap[msgID]; ok { + if msg, ok := c.msgMapByPath[path][msgID]; ok { return msg, nil } + return nil, fmt.Errorf("unknown message ID %d for path '%s'", msgID, path) +} - return nil, fmt.Errorf("unknown message ID: %d", msgID) +// GetMessagePath returns path for the given message +func (c *Connection) GetMessagePath(msg api.Message) string { + return path.Dir(reflect.TypeOf(msg).Elem().PkgPath()) } // retrieveMessageIDs retrieves IDs for all registered messages and stores them in map func (c *Connection) retrieveMessageIDs() (err error) { t := time.Now() - msgs := api.GetRegisteredMessages() + msgsByPath := api.GetRegisteredMessages() var n int - for name, msg := range msgs { - typ := reflect.TypeOf(msg).Elem() - path := fmt.Sprintf("%s.%s", typ.PkgPath(), typ.Name()) + for pkgPath, msgs := range msgsByPath { + for _, msg := range msgs { + msgID, err := c.GetMessageID(msg) + if err != nil { + if debugMsgIDs { + log.Debugf("retrieving message ID for %s.%s failed: %v", + pkgPath, msg.GetMessageName(), err) + } + continue + } + n++ + + if c.pingReqID == 0 && msg.GetMessageName() == c.msgControlPing.GetMessageName() { + c.pingReqID = msgID + c.msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) + } else if c.pingReplyID == 0 && msg.GetMessageName() == c.msgControlPingReply.GetMessageName() { + c.pingReplyID = msgID + c.msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) + } - msgID, err := c.GetMessageID(msg) - if err != nil { if debugMsgIDs { - log.Debugf("retrieving message ID for %s failed: %v", path, err) + log.Debugf("message %q (%s) has ID: %d", msg.GetMessageName(), getMsgNameWithCrc(msg), msgID) } - continue - } - n++ - - if c.pingReqID == 0 && msg.GetMessageName() == c.msgControlPing.GetMessageName() { - c.pingReqID = msgID - c.msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) - } else if c.pingReplyID == 0 && msg.GetMessageName() == c.msgControlPingReply.GetMessageName() { - c.pingReplyID = msgID - c.msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) - } - - if debugMsgIDs { - log.Debugf("message %q (%s) has ID: %d", name, getMsgNameWithCrc(msg), msgID) } + log.WithField("took", time.Since(t)). + Debugf("retrieved IDs for %d messages (registered %d) from path %s", n, len(msgs), pkgPath) } - log.WithField("took", time.Since(t)). - Debugf("retrieved IDs for %d messages (registered %d)", n, len(msgs)) return nil } @@ -474,3 +491,24 @@ func (c *Connection) sendConnEvent(event ConnectionEvent) { log.Warn("Connection state channel is full, discarding value.") } } + +// Trace gives access to the API trace interface +func (c *Connection) Trace() api.Trace { + return c.apiTrace +} + +// trace records api message +func (c *Connection) trace(msg api.Message, chId uint16, t time.Time, isReceived bool) { + if atomic.LoadInt32(&c.apiTrace.isEnabled) == 0 { + return + } + entry := &api.Record{ + Message: msg, + Timestamp: t, + IsReceived: isReceived, + ChannelID: chId, + } + c.apiTrace.mux.Lock() + c.apiTrace.list = append(c.apiTrace.list, entry) + c.apiTrace.mux.Unlock() +} diff --git a/core/connection_test.go b/core/connection_test.go index 230eea5..fe2f191 100644 --- a/core/connection_test.go +++ b/core/connection_test.go @@ -16,17 +16,18 @@ package core_test import ( "testing" + "time" . "github.com/onsi/gomega" - "git.fd.io/govpp.git/adapter/mock" - "git.fd.io/govpp.git/api" - "git.fd.io/govpp.git/binapi/ethernet_types" - interfaces "git.fd.io/govpp.git/binapi/interface" - "git.fd.io/govpp.git/binapi/interface_types" - "git.fd.io/govpp.git/binapi/vpe" - "git.fd.io/govpp.git/codec" - "git.fd.io/govpp.git/core" + "go.fd.io/govpp/adapter/mock" + "go.fd.io/govpp/api" + "go.fd.io/govpp/binapi/ethernet_types" + interfaces "go.fd.io/govpp/binapi/interface" + "go.fd.io/govpp/binapi/interface_types" + memclnt "go.fd.io/govpp/binapi/memclnt" + "go.fd.io/govpp/codec" + "go.fd.io/govpp/core" ) type testCtx struct { @@ -91,6 +92,28 @@ func TestAsyncConnection(t *testing.T) { Expect(ev.State).Should(BeEquivalentTo(core.Connected)) } +func TestAsyncConnectionProcessesVppTimeout(t *testing.T) { + ctx := setupTest(t, false) + defer ctx.teardownTest() + + ctx.conn.Disconnect() + conn, statusChan, err := core.AsyncConnect(ctx.mockVpp, core.DefaultMaxReconnectAttempts, core.DefaultReconnectInterval) + ctx.conn = conn + + Expect(err).ShouldNot(HaveOccurred()) + Expect(conn).ShouldNot(BeNil()) + + ev := <-statusChan + Expect(ev.State).Should(BeEquivalentTo(core.Connected)) + + // make control ping reply fail so that connection.healthCheckLoop() + // initiates reconnection. + ctx.mockVpp.MockReply(&memclnt.ControlPingReply{ + Retval: -1, + }) + time.Sleep(time.Duration(1+core.HealthCheckThreshold) * (core.HealthCheckInterval + 2*core.HealthCheckReplyTimeout)) +} + func TestCodec(t *testing.T) { RegisterTestingT(t) @@ -107,11 +130,11 @@ func TestCodec(t *testing.T) { Expect(msg1.MacAddress).To(BeEquivalentTo(ethernet_types.MacAddress{1, 2, 3, 4, 5, 6})) // reply - data, err = msgCodec.EncodeMsg(&vpe.ControlPingReply{Retval: 55}, 22) + data, err = msgCodec.EncodeMsg(&memclnt.ControlPingReply{Retval: 55}, 22) Expect(err).ShouldNot(HaveOccurred()) Expect(data).ShouldNot(BeEmpty()) - msg2 := &vpe.ControlPingReply{} + msg2 := &memclnt.ControlPingReply{} err = msgCodec.DecodeMsg(data, msg2) Expect(err).ShouldNot(HaveOccurred()) Expect(msg2.Retval).To(BeEquivalentTo(55)) @@ -134,7 +157,7 @@ func TestCodecNegative(t *testing.T) { Expect(err.Error()).To(ContainSubstring("nil message")) // nil data for decoding - err = msgCodec.DecodeMsg(nil, &vpe.ControlPingReply{}) + err = msgCodec.DecodeMsg(nil, &memclnt.ControlPingReply{}) Expect(err).Should(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("panic")) } @@ -145,13 +168,13 @@ func TestSimpleRequestsWithSequenceNumbers(t *testing.T) { var reqCtx []api.RequestCtx for i := 0; i < 10; i++ { - ctx.mockVpp.MockReply(&vpe.ControlPingReply{}) - req := &vpe.ControlPing{} + ctx.mockVpp.MockReply(&memclnt.ControlPingReply{}) + req := &memclnt.ControlPing{} reqCtx = append(reqCtx, ctx.ch.SendRequest(req)) } for i := 0; i < 10; i++ { - reply := &vpe.ControlPingReply{} + reply := &memclnt.ControlPingReply{} err := reqCtx[i].ReceiveReply(reply) Expect(err).ShouldNot(HaveOccurred()) } @@ -166,7 +189,7 @@ func TestMultiRequestsWithSequenceNumbers(t *testing.T) { msgs = append(msgs, &interfaces.SwInterfaceDetails{SwIfIndex: interface_types.InterfaceIndex(i)}) } ctx.mockVpp.MockReply(msgs...) - ctx.mockVpp.MockReply(&vpe.ControlPingReply{}) + ctx.mockVpp.MockReply(&memclnt.ControlPingReply{}) // send multipart request reqCtx := ctx.ch.SendMultiRequest(&interfaces.SwInterfaceDump{}) @@ -198,15 +221,15 @@ func TestSimpleRequestWithTimeout(t *testing.T) { // reply for a previous timeouted requests to be ignored ctx.mockVpp.MockReplyWithContext(mock.MsgWithContext{ - Msg: &vpe.ControlPingReply{}, + Msg: &memclnt.ControlPingReply{}, SeqNum: 0, }) // send reply later - req1 := &vpe.ControlPing{} + req1 := &memclnt.ControlPing{} reqCtx1 := ctx.ch.SendRequest(req1) - reply := &vpe.ControlPingReply{} + reply := &memclnt.ControlPingReply{} err := reqCtx1.ReceiveReply(reply) Expect(err).ToNot(BeNil()) Expect(err.Error()).To(HavePrefix("no reply received within the timeout period")) @@ -214,21 +237,21 @@ func TestSimpleRequestWithTimeout(t *testing.T) { ctx.mockVpp.MockReplyWithContext( // reply for the previous request mock.MsgWithContext{ - Msg: &vpe.ControlPingReply{}, + Msg: &memclnt.ControlPingReply{}, SeqNum: 1, }, // reply for the next request mock.MsgWithContext{ - Msg: &vpe.ControlPingReply{}, + Msg: &memclnt.ControlPingReply{}, SeqNum: 2, }) // next request - req2 := &vpe.ControlPing{} + req2 := &memclnt.ControlPing{} reqCtx2 := ctx.ch.SendRequest(req2) // second request should ignore the first reply and return the second one - reply = &vpe.ControlPingReply{} + reply = &memclnt.ControlPingReply{} err = reqCtx2.ReceiveReply(reply) Expect(err).To(BeNil()) } @@ -238,34 +261,34 @@ func TestSimpleRequestsWithMissingReply(t *testing.T) { defer ctx.teardownTest() // request without reply - req1 := &vpe.ControlPing{} + req1 := &memclnt.ControlPing{} reqCtx1 := ctx.ch.SendRequest(req1) // another request without reply - req2 := &vpe.ControlPing{} + req2 := &memclnt.ControlPing{} reqCtx2 := ctx.ch.SendRequest(req2) // third request with reply ctx.mockVpp.MockReplyWithContext(mock.MsgWithContext{ - Msg: &vpe.ControlPingReply{}, + Msg: &memclnt.ControlPingReply{}, SeqNum: 3, }) - req3 := &vpe.ControlPing{} + req3 := &memclnt.ControlPing{} reqCtx3 := ctx.ch.SendRequest(req3) // the first two should fail, but not consume reply for the 3rd - reply := &vpe.ControlPingReply{} + reply := &memclnt.ControlPingReply{} err := reqCtx1.ReceiveReply(reply) Expect(err).ToNot(BeNil()) Expect(err.Error()).To(Equal("missing binary API reply with sequence number: 1")) - reply = &vpe.ControlPingReply{} + reply = &memclnt.ControlPingReply{} err = reqCtx2.ReceiveReply(reply) Expect(err).ToNot(BeNil()) Expect(err.Error()).To(Equal("missing binary API reply with sequence number: 2")) // the second request should succeed - reply = &vpe.ControlPingReply{} + reply = &memclnt.ControlPingReply{} err = reqCtx3.ReceiveReply(reply) Expect(err).To(BeNil()) } @@ -276,9 +299,9 @@ func TestMultiRequestsWithErrors(t *testing.T) { // replies for a previous timeouted requests to be ignored msgs := []mock.MsgWithContext{ - {Msg: &vpe.ControlPingReply{}, SeqNum: 0xffff - 1}, - {Msg: &vpe.ControlPingReply{}, SeqNum: 0xffff}, - {Msg: &vpe.ControlPingReply{}, SeqNum: 0}, + {Msg: &memclnt.ControlPingReply{}, SeqNum: 0xffff - 1}, + {Msg: &memclnt.ControlPingReply{}, SeqNum: 0xffff}, + {Msg: &memclnt.ControlPingReply{}, SeqNum: 0}, } for i := 0; i < 10; i++ { msgs = append(msgs, mock.MsgWithContext{ @@ -291,7 +314,7 @@ func TestMultiRequestsWithErrors(t *testing.T) { // reply for a next request msgs = append(msgs, mock.MsgWithContext{ - Msg: &vpe.ControlPingReply{}, + Msg: &memclnt.ControlPingReply{}, SeqNum: 2, }) @@ -323,8 +346,8 @@ func TestMultiRequestsWithErrors(t *testing.T) { Expect(err.Error()).To(Equal("missing binary API reply with sequence number: 1")) // reply for the second request has not been consumed - reqCtx2 := ctx.ch.SendRequest(&vpe.ControlPing{}) - reply2 := &vpe.ControlPingReply{} + reqCtx2 := ctx.ch.SendRequest(&memclnt.ControlPing{}) + reply2 := &memclnt.ControlPingReply{} err = reqCtx2.ReceiveReply(reply2) Expect(err).To(BeNil()) } @@ -337,23 +360,23 @@ func TestRequestsOrdering(t *testing.T) { // some replies will get thrown away // first request - ctx.mockVpp.MockReply(&vpe.ControlPingReply{}) - req1 := &vpe.ControlPing{} + ctx.mockVpp.MockReply(&memclnt.ControlPingReply{}) + req1 := &memclnt.ControlPing{} reqCtx1 := ctx.ch.SendRequest(req1) // second request - ctx.mockVpp.MockReply(&vpe.ControlPingReply{}) - req2 := &vpe.ControlPing{} + ctx.mockVpp.MockReply(&memclnt.ControlPingReply{}) + req2 := &memclnt.ControlPing{} reqCtx2 := ctx.ch.SendRequest(req2) // if reply for the second request is read first, the reply for the first // request gets thrown away. - reply2 := &vpe.ControlPingReply{} + reply2 := &memclnt.ControlPingReply{} err := reqCtx2.ReceiveReply(reply2) Expect(err).To(BeNil()) // first request has already been considered closed - reply1 := &vpe.ControlPingReply{} + reply1 := &memclnt.ControlPingReply{} err = reqCtx1.ReceiveReply(reply1) Expect(err).ToNot(BeNil()) Expect(err.Error()).To(HavePrefix("no reply received within the timeout period")) @@ -368,12 +391,12 @@ func TestCycleOverSetOfSequenceNumbers(t *testing.T) { for i := 0; i < numIters+30; i++ { if i < numIters { - ctx.mockVpp.MockReply(&vpe.ControlPingReply{}) - req := &vpe.ControlPing{} + ctx.mockVpp.MockReply(&memclnt.ControlPingReply{}) + req := &memclnt.ControlPing{} reqCtx[i] = ctx.ch.SendRequest(req) } if i > 30 { - reply := &vpe.ControlPingReply{} + reply := &memclnt.ControlPingReply{} err := reqCtx[i-30].ReceiveReply(reply) Expect(err).ShouldNot(HaveOccurred()) } diff --git a/core/control_ping.go b/core/control_ping.go index ed8d274..31ed327 100644 --- a/core/control_ping.go +++ b/core/control_ping.go @@ -1,7 +1,7 @@ package core import ( - "git.fd.io/govpp.git/api" + "go.fd.io/govpp/api" ) var ( diff --git a/core/request_handler.go b/core/request_handler.go index fc704cb..851ac64 100644 --- a/core/request_handler.go +++ b/core/request_handler.go @@ -23,7 +23,7 @@ import ( logger "github.com/sirupsen/logrus" - "git.fd.io/govpp.git/api" + "go.fd.io/govpp/api" ) var ReplyChannelTimeout = time.Millisecond * 100 @@ -55,51 +55,6 @@ func (c *Connection) watchRequests(ch *Channel) { } // processRequest processes a single request received on the request channel. -func (c *Connection) sendMessage(context uint32, msg api.Message) error { - // check whether we are connected to VPP - if atomic.LoadUint32(&c.vppConnected) == 0 { - return ErrNotConnected - } - - /*log := log.WithFields(logger.Fields{ - "context": context, - "msg_name": msg.GetMessageName(), - "msg_crc": msg.GetCrcString(), - })*/ - - // retrieve message ID - msgID, err := c.GetMessageID(msg) - if err != nil { - //log.WithError(err).Debugf("unable to retrieve message ID: %#v", msg) - return err - } - - //log = log.WithField("msg_id", msgID) - - // encode the message - data, err := c.codec.EncodeMsg(msg, msgID) - if err != nil { - log.WithError(err).Debugf("unable to encode message: %#v", msg) - return err - } - - //log = log.WithField("msg_length", len(data)) - - if log.Level >= logger.DebugLevel { - log.Debugf("--> SEND: MSG %T %+v", msg, msg) - } - - // send message to VPP - err = c.vppClient.SendMsg(context, data) - if err != nil { - log.WithError(err).Debugf("unable to send message: %#v", msg) - return err - } - - return nil -} - -// processRequest processes a single request received on the request channel. func (c *Connection) processRequest(ch *Channel, req *vppRequest) error { // check whether we are connected to VPP if atomic.LoadUint32(&c.vppConnected) == 0 { @@ -157,6 +112,7 @@ func (c *Connection) processRequest(ch *Channel, req *vppRequest) error { } // send the request to VPP + t := time.Now() err = c.vppClient.SendMsg(context, data) if err != nil { log.WithFields(logger.Fields{ @@ -172,6 +128,7 @@ func (c *Connection) processRequest(ch *Channel, req *vppRequest) error { }).Warnf("Unable to send message") return err } + c.trace(req.msg, ch.id, t, false) if req.multi { // send a control ping to determine end of the multipart response @@ -189,6 +146,7 @@ func (c *Connection) processRequest(ch *Channel, req *vppRequest) error { }).Debugf(" -> SEND MSG: %T", c.msgControlPing) } + t = time.Now() if err := c.vppClient.SendMsg(context, pingData); err != nil { log.WithFields(logger.Fields{ "context": context, @@ -196,6 +154,7 @@ func (c *Connection) processRequest(ch *Channel, req *vppRequest) error { "error": err, }).Warnf("unable to send control ping") } + c.trace(c.msgControlPing, ch.id, t, false) } return nil @@ -210,9 +169,9 @@ func (c *Connection) msgCallback(msgID uint16, data []byte) { return } - msg, ok := c.msgMap[msgID] - if !ok { - log.Warnf("Unknown message received, ID: %d", msgID) + msg, err := c.getMessageByID(msgID) + if err != nil { + log.Warnln(err) return } @@ -221,7 +180,7 @@ func (c *Connection) msgCallback(msgID uint16, data []byte) { // - replies that don't have context as first field (comes as zero) // - events that don't have context at all (comes as non zero) // - context, err := c.codec.DecodeMsgContext(data, msg) + context, err := c.codec.DecodeMsgContext(data, msg.GetMessageType()) if err != nil { log.WithField("msg_id", msgID).Warnf("Unable to decode message context: %v", err) return @@ -229,15 +188,15 @@ func (c *Connection) msgCallback(msgID uint16, data []byte) { chanID, isMulti, seqNum := unpackRequestContext(context) - if log.Level == logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled - msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) - - // decode the message - if err = c.codec.DecodeMsg(data, msg); err != nil { - err = fmt.Errorf("decoding message failed: %w", err) - return - } + // decode and trace the message + msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) + if err = c.codec.DecodeMsg(data, msg); err != nil { + log.WithField("msg", msg).Warnf("Unable to decode message: %v", err) + return + } + c.trace(msg, chanID, time.Now(), true) + if log.Level == logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled log.WithFields(logger.Fields{ "context": context, "msg_id": msgID, @@ -246,7 +205,7 @@ func (c *Connection) msgCallback(msgID uint16, data []byte) { "is_multi": isMulti, "seq_num": seqNum, "msg_crc": msg.GetCrcString(), - }).Debugf("<-- govpp RECEIVE: %s %+v", msg.GetMessageName(), msg) + }).Debugf("<-- govpp RECEIVE: %s", msg.GetMessageName()) } if context == 0 || c.isNotificationMessage(msgID) { @@ -419,3 +378,14 @@ func compareSeqNumbers(seqNum1, seqNum2 uint16) int { } return 1 } + +// Returns message based on the message ID not depending on message path +func (c *Connection) getMessageByID(msgID uint16) (msg api.Message, err error) { + var ok bool + for _, messages := range c.msgMapByPath { + if msg, ok = messages[msgID]; ok { + return msg, nil + } + } + return nil, fmt.Errorf("unknown message received, ID: %d", msgID) +} diff --git a/core/stats.go b/core/stats.go index 0717be6..897a475 100644 --- a/core/stats.go +++ b/core/stats.go @@ -3,16 +3,16 @@ package core import ( "path" "strings" - "sync/atomic" "time" - "git.fd.io/govpp.git/adapter" - "git.fd.io/govpp.git/api" + "go.fd.io/govpp/adapter" + "go.fd.io/govpp/api" ) var ( - RetryUpdateCount = 10 - RetryUpdateDelay = time.Millisecond * 10 + RetryUpdateCount = 10 + RetryUpdateDelay = time.Millisecond * 10 + HealthCheckInterval = time.Second // default health check probe interval ) const ( @@ -39,6 +39,12 @@ const ( CounterStatsPrefix = "/err/" + MemoryStatSegPrefix = "/mem/statseg" + MemoryStatSegment = "/mem/stat segment" + MemoryMainHeap = "/mem/main heap" + MemoryStats_Total = "total" + MemoryStats_Used = "used" + InterfaceStatsPrefix = "/if/" InterfaceStats_Names = InterfaceStatsPrefix + "names" InterfaceStats_Drops = InterfaceStatsPrefix + "drops" @@ -72,19 +78,34 @@ const ( type StatsConnection struct { statsClient adapter.StatsAPI - // connected is true if the adapter is connected to VPP - connected uint32 + maxAttempts int // interval for reconnect attempts + recInterval time.Duration // maximum number of reconnect attempts + + connChan chan ConnectionEvent // connection event channel + done chan struct{} // to terminate stats connection watcher errorStatsData *adapter.StatDir nodeStatsData *adapter.StatDir ifaceStatsData *adapter.StatDir sysStatsData *adapter.StatDir bufStatsData *adapter.StatDir + memStatsData *adapter.StatDir } -func newStatsConnection(stats adapter.StatsAPI) *StatsConnection { +func newStatsConnection(stats adapter.StatsAPI, attempts int, interval time.Duration) *StatsConnection { + if attempts == 0 { + attempts = DefaultMaxReconnectAttempts + } + if interval == 0 { + interval = DefaultReconnectInterval + } + return &StatsConnection{ statsClient: stats, + maxAttempts: attempts, + recInterval: interval, + connChan: make(chan ConnectionEvent, NotificationChanBufSize), + done: make(chan struct{}), } } @@ -92,28 +113,50 @@ func newStatsConnection(stats adapter.StatsAPI) *StatsConnection { // This call blocks until it is either connected, or an error occurs. // Only one connection attempt will be performed. func ConnectStats(stats adapter.StatsAPI) (*StatsConnection, error) { - c := newStatsConnection(stats) + log.Debug("Connecting to stats..") + c := newStatsConnection(stats, DefaultMaxReconnectAttempts, DefaultReconnectInterval) - if err := c.connectClient(); err != nil { + if err := c.statsClient.Connect(); err != nil { return nil, err } + log.Debugf("Connected to stats.") return c, nil } -func (c *StatsConnection) connectClient() error { - log.Debug("Connecting to stats..") +// AsyncConnectStats connects to the VPP stats socket asynchronously and returns the connection +// handle with state channel. The call is non-blocking and the caller is expected to watch ConnectionEvent +// values from the channel and wait for connect/disconnect events. Connection loop tries to reconnect the +// socket in case the session was disconnected. +func AsyncConnectStats(stats adapter.StatsAPI, attempts int, interval time.Duration) (*StatsConnection, chan ConnectionEvent, error) { + log.Debug("Connecting to stats asynchronously..") + c := newStatsConnection(stats, attempts, interval) - if err := c.statsClient.Connect(); err != nil { - return err - } + go c.connectLoop() - log.Debugf("Connected to stats.") + return c, c.connChan, nil +} - // store connected state - atomic.StoreUint32(&c.connected, 1) +func (c *StatsConnection) connectLoop() { + log.Debug("Asynchronously connecting to stats..") + var reconnectAttempts int - return nil + // loop until connected + for { + if err := c.statsClient.Connect(); err == nil { + c.sendStatsConnEvent(ConnectionEvent{Timestamp: time.Now(), State: Connected}) + break + } else if reconnectAttempts < c.maxAttempts { + reconnectAttempts++ + log.Warnf("connecting stats failed (attempt %d/%d): %v", reconnectAttempts, c.maxAttempts, err) + time.Sleep(c.recInterval) + } else { + c.sendStatsConnEvent(ConnectionEvent{Timestamp: time.Now(), State: Failed, Error: err}) + return + } + } + // start monitoring stats connection state + go c.monitorSocket() } // Disconnect disconnects from Stats API and releases all connection-related resources. @@ -122,14 +165,41 @@ func (c *StatsConnection) Disconnect() { return } if c.statsClient != nil { - c.disconnectClient() + if err := c.statsClient.Disconnect(); err != nil { + log.Debugf("disconnecting stats client failed: %v", err) + } } + close(c.connChan) + close(c.done) } -func (c *StatsConnection) disconnectClient() { - if atomic.CompareAndSwapUint32(&c.connected, 1, 0) { - if err := c.statsClient.Disconnect(); err != nil { - log.Debugf("disconnecting stats client failed: %v", err) +func (c *StatsConnection) monitorSocket() { + var state, lastState ConnectionState + ticker := time.NewTicker(HealthCheckInterval) + + for { + select { + case <-ticker.C: + _, err := c.statsClient.ListStats(SystemStats_Heartbeat) + state = Connected + if err == adapter.ErrStatsDataBusy { + state = NotResponding + } + if err == adapter.ErrStatsDisconnected { + state = Disconnected + } + if err == adapter.ErrStatsAccessFailed { + state = Failed + } + if state == lastState { + continue + } + lastState = state + c.sendStatsConnEvent(ConnectionEvent{Timestamp: time.Now(), State: state, Error: err}) + case <-c.done: + log.Debugf("health check watcher closed") + c.sendStatsConnEvent(ConnectionEvent{Timestamp: time.Now(), State: Disconnected, Error: nil}) + break } } } @@ -198,6 +268,9 @@ func (c *StatsConnection) GetSystemStats(sysStats *api.SystemStats) (err error) if ss, ok := stat.Data.(adapter.SimpleCounterStat); ok { vals = make([]uint64, len(ss)) for w := range ss { + if ss[w] == nil { + continue + } vals[w] = uint64(ss[w][0]) } } @@ -230,14 +303,23 @@ func (c *StatsConnection) GetErrorStats(errorStats *api.ErrorStats) (err error) } for i, stat := range c.errorStatsData.Entries { - if stat.Type != adapter.ErrorIndex { - continue - } if errStat, ok := stat.Data.(adapter.ErrorStat); ok { - errorStats.Errors[i].Value = uint64(errStat) + values := make([]uint64, len(errStat)) + for j, errStatW := range errStat { + values[j] = uint64(errStatW) + } + errorStats.Errors[i].Values = values + } + if errStat, ok := stat.Data.(adapter.SimpleCounterStat); ok { + values := make([]uint64, len(errStat)) + for j, errStatW := range errStat { + for _, val := range errStatW { + values[j] += uint64(val) + } + } + errorStats.Errors[i].Values = values } } - return nil } @@ -468,3 +550,61 @@ func (c *StatsConnection) GetBufferStats(bufStats *api.BufferStats) (err error) return nil } + +func (c *StatsConnection) GetMemoryStats(memStats *api.MemoryStats) (err error) { + if err := c.updateStats(&c.memStatsData, MemoryStatSegPrefix, MemoryStatSegment, MemoryMainHeap); err != nil { + return err + } + convertStats := func(stats []adapter.Counter) api.MemoryCounters { + memUsg := make([]adapter.Counter, 7) + copy(memUsg, stats) + return api.MemoryCounters{ + Total: uint64(memUsg[0]), Used: uint64(memUsg[1]), Free: uint64(memUsg[2]), UsedMMap: uint64(memUsg[3]), + TotalAlloc: uint64(memUsg[4]), FreeChunks: uint64(memUsg[5]), Releasable: uint64(memUsg[6]), + } + } + + for _, stat := range c.memStatsData.Entries { + if strings.Contains(string(stat.Name), MemoryStatSegPrefix) { + _, f := path.Split(string(stat.Name)) + var val float64 + m, ok := stat.Data.(adapter.ScalarStat) + if ok { + val = float64(m) + } + switch f { + case MemoryStats_Total: + memStats.Total = val + case MemoryStats_Used: + memStats.Used = val + } + } else if string(stat.Name) == MemoryStatSegment { + if perHeapStats, ok := stat.Data.(adapter.SimpleCounterStat); ok { + if memStats.Stat == nil { + memStats.Stat = make(map[int]api.MemoryCounters) + } + for heap, stats := range perHeapStats { + memStats.Stat[heap] = convertStats(stats) + } + } + } else if string(stat.Name) == MemoryMainHeap { + if perHeapStats, ok := stat.Data.(adapter.SimpleCounterStat); ok { + if memStats.Main == nil { + memStats.Main = make(map[int]api.MemoryCounters) + } + for heap, stats := range perHeapStats { + memStats.Main[heap] = convertStats(stats) + } + } + } + } + return nil +} + +func (c *StatsConnection) sendStatsConnEvent(event ConnectionEvent) { + select { + case c.connChan <- event: + default: + log.Warn("Stats connection state channel is full, discarding value.") + } +} diff --git a/core/stream.go b/core/stream.go index 171b201..86bb99e 100644 --- a/core/stream.go +++ b/core/stream.go @@ -19,46 +19,54 @@ import ( "errors" "fmt" "reflect" - "sync/atomic" + "sync" + "time" - "git.fd.io/govpp.git/api" + "go.fd.io/govpp/api" ) type Stream struct { - id uint32 conn *Connection ctx context.Context channel *Channel + // available options + requestSize int + replySize int + replyTimeout time.Duration + // per-request context + pkgPath string + sync.Mutex } -func (c *Connection) NewStream(ctx context.Context) (api.Stream, error) { +func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) { if c == nil { return nil, errors.New("nil connection passed in") } - // TODO: add stream options as variadic parameters for customizing: - // - request/reply channel size - // - reply timeout - // - retries - // - ??? + s := &Stream{ + conn: c, + ctx: ctx, + // default options + requestSize: RequestChanBufSize, + replySize: ReplyChanBufSize, + replyTimeout: DefaultReplyTimeout, + } - // create new channel - chID := uint16(atomic.AddUint32(&c.maxChannelID, 1) & 0x7fff) - channel := newChannel(chID, c, c.codec, c, 10, 10) + // parse custom options + for _, option := range options { + option(s) + } - // store API channel within the client - c.channelsLock.Lock() - c.channels[chID] = channel - c.channelsLock.Unlock() + ch, err := c.newChannel(s.requestSize, s.replySize) + if err != nil { + return nil, err + } + s.channel = ch + s.channel.SetReplyTimeout(s.replyTimeout) // Channel.watchRequests are not started here intentionally, because // requests are sent directly by SendMsg. - return &Stream{ - id: uint32(chID), - conn: c, - ctx: ctx, - channel: channel, - }, nil + return s, nil } func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Message) error { @@ -66,18 +74,18 @@ func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Mess if err != nil { return err } + defer func() { _ = stream.Close() }() if err := stream.SendMsg(req); err != nil { return err } - msg, err := stream.RecvMsg() + s := stream.(*Stream) + rep, err := s.recvReply() if err != nil { return err } - if msg.GetMessageName() != reply.GetMessageName() || - msg.GetCrcString() != reply.GetCrcString() { - return fmt.Errorf("unexpected reply: %T %+v", msg, msg) + if err := s.channel.msgCodec.DecodeMsg(rep.data, reply); err != nil { + return err } - reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(msg).Elem()) return nil } @@ -102,10 +110,53 @@ func (s *Stream) SendMsg(msg api.Message) error { if err := s.conn.processRequest(s.channel, req); err != nil { return err } + s.Lock() + s.pkgPath = s.conn.GetMessagePath(msg) + s.Unlock() return nil } func (s *Stream) RecvMsg() (api.Message, error) { + reply, err := s.recvReply() + if err != nil { + return nil, err + } + // resolve message type + s.Lock() + path := s.pkgPath + s.Unlock() + msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID) + if err != nil { + return nil, err + } + // allocate message instance + msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) + // decode message data + if err := s.channel.msgCodec.DecodeMsg(reply.data, msg); err != nil { + return nil, err + } + return msg, nil +} + +func WithRequestSize(size int) api.StreamOption { + return func(stream api.Stream) { + stream.(*Stream).requestSize = size + } +} + +func WithReplySize(size int) api.StreamOption { + return func(stream api.Stream) { + stream.(*Stream).replySize = size + } +} + +func WithReplyTimeout(timeout time.Duration) api.StreamOption { + return func(stream api.Stream) { + stream.(*Stream).replyTimeout = timeout + } +} + +func (s *Stream) recvReply() (*vppReply, error) { if s.conn == nil { return nil, errors.New("stream closed") } @@ -120,18 +171,7 @@ func (s *Stream) RecvMsg() (api.Message, error) { // and stream does not use it return nil, reply.err } - // resolve message type - msg, err := s.channel.msgIdentifier.LookupByID(reply.msgID) - if err != nil { - return nil, err - } - // allocate message instance - msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) - // decode message data - if err := s.channel.msgCodec.DecodeMsg(reply.data, msg); err != nil { - return nil, err - } - return msg, nil + return reply, nil case <-s.ctx.Done(): return nil, s.ctx.Err() diff --git a/core/trace.go b/core/trace.go new file mode 100644 index 0000000..b818657 --- /dev/null +++ b/core/trace.go @@ -0,0 +1,70 @@ +// Copyright (c) 2021 Cisco and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "go.fd.io/govpp/api" + "sort" + "sync" + "sync/atomic" +) + +// trace is the API tracer object synchronizing and keeping recoded messages. +type trace struct { + list []*api.Record + mux *sync.Mutex + + isEnabled int32 +} + +func (c *trace) Enable(enable bool) { + if enable && atomic.CompareAndSwapInt32(&c.isEnabled, 0, 1) { + log.Debugf("API trace enabled") + } else if atomic.CompareAndSwapInt32(&c.isEnabled, 1, 0) { + log.Debugf("API trace disabled") + } +} + +func (c *trace) GetRecords() (list []*api.Record) { + c.mux.Lock() + for _, entry := range c.list { + list = append(list, entry) + } + c.mux.Unlock() + sort.Slice(list, func(i, j int) bool { + return list[i].Timestamp.Before(list[j].Timestamp) + }) + return list +} + +func (c *trace) GetRecordsForChannel(chId uint16) (list []*api.Record) { + c.mux.Lock() + for _, entry := range c.list { + if entry.ChannelID == chId { + list = append(list, entry) + } + } + c.mux.Unlock() + sort.Slice(list, func(i, j int) bool { + return list[i].Timestamp.Before(list[j].Timestamp) + }) + return list +} + +func (c *trace) Clear() { + c.mux.Lock() + c.list = make([]*api.Record, 0) + c.mux.Unlock() +} diff --git a/core/trace_test.go b/core/trace_test.go new file mode 100644 index 0000000..6d1d5ba --- /dev/null +++ b/core/trace_test.go @@ -0,0 +1,265 @@ +package core_test + +import ( + "go.fd.io/govpp/api" + interfaces "go.fd.io/govpp/binapi/interface" + "go.fd.io/govpp/binapi/ip" + "go.fd.io/govpp/binapi/l2" + memclnt "go.fd.io/govpp/binapi/memclnt" + "go.fd.io/govpp/binapi/memif" + "go.fd.io/govpp/core" + . "github.com/onsi/gomega" + "strings" + "testing" +) + +func TestTraceEnabled(t *testing.T) { + ctx := setupTest(t, false) + defer ctx.teardownTest() + + Expect(ctx.conn.Trace()).ToNot(BeNil()) + ctx.conn.Trace().Enable(true) + + request := []api.Message{ + &interfaces.CreateLoopback{}, + &memif.MemifCreate{}, + &l2.BridgeDomainAddDel{}, + &ip.IPTableAddDel{}, + } + reply := []api.Message{ + &interfaces.CreateLoopbackReply{}, + &memif.MemifCreateReply{}, + &l2.BridgeDomainAddDelReply{}, + &ip.IPTableAddDelReply{}, + } + + for i := 0; i < len(request); i++ { + ctx.mockVpp.MockReply(reply[i]) + err := ctx.ch.SendRequest(request[i]).ReceiveReply(reply[i]) + Expect(err).To(BeNil()) + } + + traced := ctx.conn.Trace().GetRecords() + Expect(traced).ToNot(BeNil()) + Expect(traced).To(HaveLen(8)) + for i, entry := range traced { + Expect(entry.Timestamp).ToNot(BeNil()) + Expect(entry.Message.GetMessageName()).ToNot(Equal("")) + if strings.HasSuffix(entry.Message.GetMessageName(), "_reply") || + strings.HasSuffix(entry.Message.GetMessageName(), "_details") { + Expect(entry.IsReceived).To(BeTrue()) + } else { + Expect(entry.IsReceived).To(BeFalse()) + } + if i%2 == 0 { + Expect(request[i/2].GetMessageName()).To(Equal(entry.Message.GetMessageName())) + } else { + Expect(reply[i/2].GetMessageName()).To(Equal(entry.Message.GetMessageName())) + } + } +} + +func TestMultiRequestTraceEnabled(t *testing.T) { + ctx := setupTest(t, false) + defer ctx.teardownTest() + + ctx.conn.Trace().Enable(true) + + request := []api.Message{ + &interfaces.SwInterfaceDump{}, + } + reply := []api.Message{ + &interfaces.SwInterfaceDetails{ + SwIfIndex: 1, + }, + &interfaces.SwInterfaceDetails{ + SwIfIndex: 2, + }, + &interfaces.SwInterfaceDetails{ + SwIfIndex: 3, + }, + &memclnt.ControlPingReply{}, + } + + ctx.mockVpp.MockReply(reply...) + multiCtx := ctx.ch.SendMultiRequest(request[0]) + + i := 0 + for { + last, err := multiCtx.ReceiveReply(reply[i]) + Expect(err).ToNot(HaveOccurred()) + if last { + break + } + i++ + } + + traced := ctx.conn.Trace().GetRecords() + Expect(traced).ToNot(BeNil()) + Expect(traced).To(HaveLen(6)) + for i, entry := range traced { + Expect(entry.Timestamp).ToNot(BeNil()) + Expect(entry.Message.GetMessageName()).ToNot(Equal("")) + if strings.HasSuffix(entry.Message.GetMessageName(), "_reply") || + strings.HasSuffix(entry.Message.GetMessageName(), "_details") { + Expect(entry.IsReceived).To(BeTrue()) + } else { + Expect(entry.IsReceived).To(BeFalse()) + } + if i == 0 { + Expect(request[0].GetMessageName()).To(Equal(entry.Message.GetMessageName())) + } else if i == len(traced)-1 { + msg := memclnt.ControlPing{} + Expect(msg.GetMessageName()).To(Equal(entry.Message.GetMessageName())) + } else { + Expect(reply[i-1].GetMessageName()).To(Equal(entry.Message.GetMessageName())) + } + } +} + +func TestTraceDisabled(t *testing.T) { + ctx := setupTest(t, false) + defer ctx.teardownTest() + + ctx.conn.Trace().Enable(false) + + request := []api.Message{ + &interfaces.CreateLoopback{}, + &memif.MemifCreate{}, + &l2.BridgeDomainAddDel{}, + &ip.IPTableAddDel{}, + } + reply := []api.Message{ + &interfaces.CreateLoopbackReply{}, + &memif.MemifCreateReply{}, + &l2.BridgeDomainAddDelReply{}, + &ip.IPTableAddDelReply{}, + } + + for i := 0; i < len(request); i++ { + ctx.mockVpp.MockReply(reply[i]) + err := ctx.ch.SendRequest(request[i]).ReceiveReply(reply[i]) + Expect(err).To(BeNil()) + } + + traced := ctx.conn.Trace().GetRecords() + Expect(traced).To(BeNil()) +} + +func TestTracePerChannel(t *testing.T) { + ctx := setupTest(t, false) + defer ctx.teardownTest() + + ctx.conn.Trace().Enable(true) + + ch1 := ctx.ch + ch2, err := ctx.conn.NewAPIChannel() + Expect(err).ToNot(HaveOccurred()) + + requestCh1 := []api.Message{ + &interfaces.CreateLoopback{}, + &memif.MemifCreate{}, + &l2.BridgeDomainAddDel{}, + } + replyCh1 := []api.Message{ + &interfaces.CreateLoopbackReply{}, + &memif.MemifCreateReply{}, + &l2.BridgeDomainAddDelReply{}, + } + requestCh2 := []api.Message{ + &ip.IPTableAddDel{}, + } + replyCh2 := []api.Message{ + &ip.IPTableAddDelReply{}, + } + + for i := 0; i < len(requestCh1); i++ { + ctx.mockVpp.MockReply(replyCh1[i]) + err := ch1.SendRequest(requestCh1[i]).ReceiveReply(replyCh1[i]) + Expect(err).To(BeNil()) + } + for i := 0; i < len(requestCh2); i++ { + ctx.mockVpp.MockReply(replyCh2[i]) + err := ch2.SendRequest(requestCh2[i]).ReceiveReply(replyCh2[i]) + Expect(err).To(BeNil()) + } + + trace := ctx.conn.Trace().GetRecords() + Expect(trace).ToNot(BeNil()) + Expect(trace).To(HaveLen(8)) + + // per channel + channel1, ok := ch1.(*core.Channel) + Expect(ok).To(BeTrue()) + channel2, ok := ch2.(*core.Channel) + Expect(ok).To(BeTrue()) + + tracedCh1 := ctx.conn.Trace().GetRecordsForChannel(channel1.GetID()) + Expect(tracedCh1).ToNot(BeNil()) + Expect(tracedCh1).To(HaveLen(6)) + for i, entry := range tracedCh1 { + Expect(entry.Timestamp).ToNot(BeNil()) + Expect(entry.Message.GetMessageName()).ToNot(Equal("")) + if strings.HasSuffix(entry.Message.GetMessageName(), "_reply") || + strings.HasSuffix(entry.Message.GetMessageName(), "_details") { + Expect(entry.IsReceived).To(BeTrue()) + } else { + Expect(entry.IsReceived).To(BeFalse()) + } + if i%2 == 0 { + Expect(requestCh1[i/2].GetMessageName()).To(Equal(entry.Message.GetMessageName())) + } else { + Expect(replyCh1[i/2].GetMessageName()).To(Equal(entry.Message.GetMessageName())) + } + } + + tracedCh2 := ctx.conn.Trace().GetRecordsForChannel(channel2.GetID()) + Expect(tracedCh2).ToNot(BeNil()) + Expect(tracedCh2).To(HaveLen(2)) + for i, entry := range tracedCh2 { + Expect(entry.Timestamp).ToNot(BeNil()) + Expect(entry.Message.GetMessageName()).ToNot(Equal("")) + if strings.HasSuffix(entry.Message.GetMessageName(), "_reply") || + strings.HasSuffix(entry.Message.GetMessageName(), "_details") { + Expect(entry.IsReceived).To(BeTrue()) + } else { + Expect(entry.IsReceived).To(BeFalse()) + } + if i%2 == 0 { + Expect(requestCh2[i/2].GetMessageName()).To(Equal(entry.Message.GetMessageName())) + } else { + Expect(replyCh2[i/2].GetMessageName()).To(Equal(entry.Message.GetMessageName())) + } + } +} + +func TestTraceClear(t *testing.T) { + ctx := setupTest(t, false) + defer ctx.teardownTest() + + ctx.conn.Trace().Enable(true) + + request := []api.Message{ + &interfaces.CreateLoopback{}, + &memif.MemifCreate{}, + } + reply := []api.Message{ + &interfaces.CreateLoopbackReply{}, + &memif.MemifCreateReply{}, + } + + for i := 0; i < len(request); i++ { + ctx.mockVpp.MockReply(reply[i]) + err := ctx.ch.SendRequest(request[i]).ReceiveReply(reply[i]) + Expect(err).To(BeNil()) + } + + traced := ctx.conn.Trace().GetRecords() + Expect(traced).ToNot(BeNil()) + Expect(traced).To(HaveLen(4)) + + ctx.conn.Trace().Clear() + traced = ctx.conn.Trace().GetRecords() + Expect(traced).To(BeNil()) + Expect(traced).To(BeEmpty()) +} |