From 6dd658f579ae4e1d9704e14380b69fe906c57ad2 Mon Sep 17 00:00:00 2001 From: Ondrej Fabry Date: Fri, 8 Mar 2019 11:18:22 +0100 Subject: Add socketclient implementation Change-Id: Ibf9edc0e5911d08229ac590b37c5afbc27f424a0 Signed-off-by: Ondrej Fabry --- adapter/socketclient/socketclient.go | 430 +++++++++++++++++++++++++++++++++++ 1 file changed, 430 insertions(+) create mode 100644 adapter/socketclient/socketclient.go (limited to 'adapter') diff --git a/adapter/socketclient/socketclient.go b/adapter/socketclient/socketclient.go new file mode 100644 index 0000000..eec4fd0 --- /dev/null +++ b/adapter/socketclient/socketclient.go @@ -0,0 +1,430 @@ +package socketclient + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "os" + "strings" + "sync" + "time" + + "github.com/lunixbochs/struc" + logger "github.com/sirupsen/logrus" + + "git.fd.io/govpp.git/adapter" + "git.fd.io/govpp.git/codec" + "git.fd.io/govpp.git/examples/bin_api/memclnt" +) + +const ( + // DefaultSocketName is default VPP API socket file name + DefaultSocketName = "/run/vpp-api.sock" + + sockCreateMsgId = 15 // hard-coded id for sockclnt_create message + govppClientName = "govppsock" // client name used for socket registration +) + +var ( + ConnectTimeout = time.Second * 3 + DisconnectTimeout = time.Second + + Debug = os.Getenv("DEBUG_GOVPP_SOCK") != "" + DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != "" + + Log = logger.New() // global logger +) + +// init initializes global logger, which logs debug level messages to stdout. +func init() { + Log.Out = os.Stdout + if Debug { + Log.Level = logger.DebugLevel + } +} + +type vppClient struct { + sockAddr string + conn *net.UnixConn + reader *bufio.Reader + cb adapter.MsgCallback + clientIndex uint32 + msgTable map[string]uint16 + sockDelMsgId uint16 + writeMu sync.Mutex + quit chan struct{} + wg sync.WaitGroup +} + +func NewVppClient(sockAddr string) *vppClient { + if sockAddr == "" { + sockAddr = DefaultSocketName + } + return &vppClient{ + sockAddr: sockAddr, + cb: nilCallback, + } +} + +func nilCallback(msgID uint16, data []byte) { + Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data)) +} + +func (*vppClient) WaitReady() error { + // TODO: add watcher for socket file? + return nil +} + +func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) { + Log.Debug("SetMsgCallback") + c.cb = cb +} + +func (c *vppClient) Connect() error { + Log.Debugf("Connecting to: %v", c.sockAddr) + + if err := c.connect(c.sockAddr); err != nil { + return err + } + + if err := c.open(); err != nil { + return err + } + + c.quit = make(chan struct{}) + c.wg.Add(1) + go c.readerLoop() + + return nil +} + +func (c *vppClient) connect(sockAddr string) error { + addr, err := net.ResolveUnixAddr("unixpacket", sockAddr) + if err != nil { + Log.Debugln("ResolveUnixAddr error:", err) + return err + } + + conn, err := net.DialUnix("unixpacket", nil, addr) + if err != nil { + Log.Debugln("Dial error:", err) + return err + } + + c.conn = conn + c.reader = bufio.NewReader(c.conn) + + Log.Debugf("Connected to socket: %v", addr) + + return nil +} + +func (c *vppClient) open() error { + msgCodec := new(codec.MsgCodec) + + req := &memclnt.SockclntCreate{ + Name: []byte(govppClientName), + } + msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId) + if err != nil { + Log.Debugln("Encode error:", err) + return err + } + // set non-0 context + msg[5] = 123 + + if err := c.write(msg); err != nil { + Log.Debugln("Write error: ", err) + return err + } + + readDeadline := time.Now().Add(ConnectTimeout) + if err := c.conn.SetReadDeadline(readDeadline); err != nil { + return err + } + msgReply, err := c.read() + if err != nil { + Log.Println("Read error:", err) + return err + } + // reset read deadline + if err := c.conn.SetReadDeadline(time.Time{}); err != nil { + return err + } + + //log.Printf("Client got (%d): % 0X", len(msgReply), msgReply) + + reply := new(memclnt.SockclntCreateReply) + if err := msgCodec.DecodeMsg(msgReply, reply); err != nil { + Log.Println("Decode error:", err) + return err + } + + Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v", + reply.Response, reply.Index, reply.Count) + + c.clientIndex = reply.Index + c.msgTable = make(map[string]uint16, reply.Count) + for _, x := range reply.MessageTable { + name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13})) + c.msgTable[name] = x.Index + if strings.HasPrefix(name, "sockclnt_delete_") { + c.sockDelMsgId = x.Index + } + if DebugMsgIds { + Log.Debugf(" - %4d: %q", x.Index, name) + } + } + + return nil +} + +func (c *vppClient) Disconnect() error { + if c.conn == nil { + return nil + } + Log.Debugf("Disconnecting..") + + close(c.quit) + + // force readerLoop to timeout + if err := c.conn.SetReadDeadline(time.Now()); err != nil { + return err + } + + // wait for readerLoop to return + c.wg.Wait() + + if err := c.close(); err != nil { + return err + } + + if err := c.conn.Close(); err != nil { + Log.Debugln("Close socket conn failed:", err) + return err + } + + return nil +} + +func (c *vppClient) close() error { + msgCodec := new(codec.MsgCodec) + + req := &memclnt.SockclntDelete{ + Index: c.clientIndex, + } + msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId) + if err != nil { + Log.Debugln("Encode error:", err) + return err + } + // set non-0 context + msg[5] = 124 + + Log.Debugf("sending socklntDel (%d byes): % 0X\n", len(msg), msg) + if err := c.write(msg); err != nil { + Log.Debugln("Write error: ", err) + return err + } + + readDeadline := time.Now().Add(DisconnectTimeout) + if err := c.conn.SetReadDeadline(readDeadline); err != nil { + return err + } + msgReply, err := c.read() + if err != nil { + Log.Debugln("Read error:", err) + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + // we accept timeout for reply + return nil + } + return err + } + // reset read deadline + if err := c.conn.SetReadDeadline(time.Time{}); err != nil { + return err + } + + reply := new(memclnt.SockclntDeleteReply) + if err := msgCodec.DecodeMsg(msgReply, reply); err != nil { + Log.Debugln("Decode error:", err) + return err + } + + Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response) + + return nil +} + +func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) { + msg := msgName + "_" + msgCrc + msgID, ok := c.msgTable[msg] + if !ok { + return 0, fmt.Errorf("unknown message: %q", msg) + } + return msgID, nil +} + +type reqHeader struct { + //MsgID uint16 + ClientIndex uint32 + Context uint32 +} + +func (c *vppClient) SendMsg(context uint32, data []byte) error { + h := &reqHeader{ + ClientIndex: c.clientIndex, + Context: context, + } + buf := new(bytes.Buffer) + if err := struc.Pack(buf, h); err != nil { + return err + } + copy(data[2:], buf.Bytes()) + + Log.Debugf("SendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data) + + if err := c.write(data); err != nil { + Log.Debugln("write error: ", err) + return err + } + + return nil +} + +func (c *vppClient) write(msg []byte) error { + h := &msgheader{ + Data_len: uint32(len(msg)), + } + buf := new(bytes.Buffer) + if err := struc.Pack(buf, h); err != nil { + return err + } + header := buf.Bytes() + + // we lock to prevent mixing multiple message sends + c.writeMu.Lock() + defer c.writeMu.Unlock() + + var w io.Writer = c.conn + + if n, err := w.Write(header); err != nil { + return err + } else { + Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header) + } + if n, err := w.Write(msg); err != nil { + return err + } else { + Log.Debugf(" - msg sent (%d/%d): % 0X", n, len(msg), msg) + } + + return nil +} + +type msgHeader struct { + MsgID uint16 + Context uint32 +} + +func (c *vppClient) readerLoop() { + defer c.wg.Done() + for { + select { + case <-c.quit: + Log.Debugf("reader quit") + return + default: + } + + msg, err := c.read() + if err != nil { + if isClosedError(err) { + return + } + Log.Debugf("READ FAILED: %v", err) + continue + } + h := new(msgHeader) + if err := struc.Unpack(bytes.NewReader(msg), h); err != nil { + Log.Debugf("unpacking header failed: %v", err) + continue + } + + Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context) + c.cb(h.MsgID, msg) + } +} + +type msgheader struct { + Q int `struc:"uint64"` + Data_len uint32 `struc:"uint32"` + Gc_mark_timestamp uint32 `struc:"uint32"` + //data [0]uint8 +} + +func (c *vppClient) read() ([]byte, error) { + Log.Debug("reading next msg..") + + header := make([]byte, 16) + + n, err := io.ReadAtLeast(c.reader, header, 16) + if err != nil { + return nil, err + } else if n == 0 { + Log.Debugln("zero bytes header") + return nil, nil + } + if n != 16 { + Log.Debug("invalid header data (%d): % 0X", n, header[:n]) + return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n) + } + Log.Debugf(" - read header %d bytes: % 0X", n, header) + + h := &msgheader{} + if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil { + return nil, err + } + Log.Debugf(" - decoded header: %+v", h) + + msgLen := int(h.Data_len) + msg := make([]byte, msgLen) + + n, err = c.reader.Read(msg) + if err != nil { + return nil, err + } + Log.Debugf(" - read msg %d bytes (%d buffered)", n, c.reader.Buffered()) + + if msgLen > n { + remain := msgLen - n + Log.Debugf("continue read for another %d bytes", remain) + view := msg[n:] + + for remain > 0 { + + nbytes, err := c.reader.Read(view) + if err != nil { + return nil, err + } else if nbytes == 0 { + return nil, fmt.Errorf("zero nbytes") + } + + remain -= nbytes + Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain) + + view = view[nbytes:] + } + } + + return msg, nil +} + +func isClosedError(err error) bool { + if err == io.EOF { + return true + } + return strings.HasSuffix(err.Error(), "use of closed network connection") +} -- cgit 1.2.3-korg