aboutsummaryrefslogtreecommitdiffstats
path: root/extras/gomemif/memif/control_channel.go
diff options
context:
space:
mode:
Diffstat (limited to 'extras/gomemif/memif/control_channel.go')
-rw-r--r--extras/gomemif/memif/control_channel.go965
1 files changed, 965 insertions, 0 deletions
diff --git a/extras/gomemif/memif/control_channel.go b/extras/gomemif/memif/control_channel.go
new file mode 100644
index 00000000000..32e34933ab4
--- /dev/null
+++ b/extras/gomemif/memif/control_channel.go
@@ -0,0 +1,965 @@
+/*
+ *------------------------------------------------------------------
+ * Copyright (c) 2020 Cisco and/or its affiliates.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at:
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *------------------------------------------------------------------
+ */
+
+package memif
+
+import (
+ "bytes"
+ "container/list"
+ "encoding/binary"
+ "fmt"
+ "os"
+ "sync"
+ "syscall"
+)
+
+const maxEpollEvents = 1
+const maxControlLen = 256
+
+const errorFdNotFound = "fd not found"
+
+// controlMsg represents a message used in communication between memif peers
+type controlMsg struct {
+ Buffer *bytes.Buffer
+ Fd int
+}
+
+// listener represents a listener functionality of UNIX domain socket
+type listener struct {
+ socket *Socket
+ event syscall.EpollEvent
+}
+
+// controlChannel represents a communication channel between memif peers
+// backed by UNIX domain socket
+type controlChannel struct {
+ listRef *list.Element
+ socket *Socket
+ i *Interface
+ event syscall.EpollEvent
+ data [msgSize]byte
+ control [maxControlLen]byte
+ controlLen int
+ msgQueue []controlMsg
+ isConnected bool
+}
+
+// Socket represents a UNIX domain socket used for communication
+// between memif peers
+type Socket struct {
+ appName string
+ filename string
+ listener *listener
+ interfaceList *list.List
+ ccList *list.List
+ epfd int
+ wakeEvent syscall.EpollEvent
+ stopPollChan chan struct{}
+ wg sync.WaitGroup
+}
+
+// StopPolling stops polling events on the socket
+func (socket *Socket) StopPolling() error {
+ if socket.stopPollChan != nil {
+ // stop polling msg
+ close(socket.stopPollChan)
+ // wake epoll
+ buf := make([]byte, 8)
+ binary.PutUvarint(buf, 1)
+ n, err := syscall.Write(int(socket.wakeEvent.Fd), buf[:])
+ if err != nil {
+ return err
+ }
+ if n != 8 {
+ return fmt.Errorf("Faild to write to eventfd")
+ }
+ // wait until polling is stopped
+ socket.wg.Wait()
+ }
+
+ return nil
+}
+
+// StartPolling starts polling and handling events on the socket,
+// enabling communication between memif peers
+func (socket *Socket) StartPolling(errChan chan<- error) {
+ socket.stopPollChan = make(chan struct{})
+ socket.wg.Add(1)
+ go func() {
+ var events [maxEpollEvents]syscall.EpollEvent
+ defer socket.wg.Done()
+
+ for {
+ select {
+ case <-socket.stopPollChan:
+ return
+ default:
+ num, err := syscall.EpollWait(socket.epfd, events[:], -1)
+ if err != nil {
+ errChan <- fmt.Errorf("EpollWait: ", err)
+ return
+ }
+
+ for ev := 0; ev < num; ev++ {
+ if events[0].Fd == socket.wakeEvent.Fd {
+ continue
+ }
+ err = socket.handleEvent(&events[0])
+ if err != nil {
+ errChan <- fmt.Errorf("handleEvent: ", err)
+ }
+ }
+ }
+ }
+ }()
+}
+
+// addEvent adds event to epoll instance associated with the socket
+func (socket *Socket) addEvent(event *syscall.EpollEvent) error {
+ err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_ADD, int(event.Fd), event)
+ if err != nil {
+ return fmt.Errorf("EpollCtl: %s", err)
+ }
+ return nil
+}
+
+// addEvent deletes event to epoll instance associated with the socket
+func (socket *Socket) delEvent(event *syscall.EpollEvent) error {
+ err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_DEL, int(event.Fd), event)
+ if err != nil {
+ return fmt.Errorf("EpollCtl: %s", err)
+ }
+ return nil
+}
+
+// Delete deletes the socket
+func (socket *Socket) Delete() (err error) {
+ for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
+ cc, ok := elt.Value.(*controlChannel)
+ if ok {
+ err = cc.close(true, "Socket deleted")
+ if err != nil {
+ return nil
+ }
+ }
+ }
+ for elt := socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
+ i, ok := elt.Value.(*Interface)
+ if ok {
+ err = i.Delete()
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ if socket.listener != nil {
+ err = socket.listener.close()
+ if err != nil {
+ return err
+ }
+ err = os.Remove(socket.filename)
+ if err != nil {
+ return nil
+ }
+ }
+
+ err = socket.delEvent(&socket.wakeEvent)
+ if err != nil {
+ return fmt.Errorf("Failed to delete event: ", err)
+ }
+
+ syscall.Close(socket.epfd)
+
+ return nil
+}
+
+// NewSocket returns a new Socket
+func NewSocket(appName string, filename string) (socket *Socket, err error) {
+ socket = &Socket{
+ appName: appName,
+ filename: filename,
+ interfaceList: list.New(),
+ ccList: list.New(),
+ }
+ if socket.filename == "" {
+ socket.filename = DefaultSocketFilename
+ }
+
+ socket.epfd, _ = syscall.EpollCreate1(0)
+
+ efd, err := eventFd()
+ socket.wakeEvent = syscall.EpollEvent{
+ Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
+ Fd: int32(efd),
+ }
+ err = socket.addEvent(&socket.wakeEvent)
+ if err != nil {
+ return nil, fmt.Errorf("Failed to add event: ", err)
+ }
+
+ return socket, nil
+}
+
+// handleEvent handles epoll event
+func (socket *Socket) handleEvent(event *syscall.EpollEvent) error {
+ if socket.listener != nil && socket.listener.event.Fd == event.Fd {
+ return socket.listener.handleEvent(event)
+ }
+
+ for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
+ cc, ok := elt.Value.(*controlChannel)
+ if ok {
+ if cc.event.Fd == event.Fd {
+ return cc.handleEvent(event)
+ }
+ }
+ }
+
+ return fmt.Errorf(errorFdNotFound)
+}
+
+// handleEvent handles epoll event for listener
+func (l *listener) handleEvent(event *syscall.EpollEvent) error {
+ // hang up
+ if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
+ err := l.close()
+ if err != nil {
+ return fmt.Errorf("Failed to close listener after hang up event: ", err)
+ }
+ return fmt.Errorf("Hang up: ", l.socket.filename)
+ }
+
+ // error
+ if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
+ err := l.close()
+ if err != nil {
+ return fmt.Errorf("Failed to close listener after receiving an error event: ", err)
+ }
+ return fmt.Errorf("Received error event on listener ", l.socket.filename)
+ }
+
+ // read message
+ if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
+ newFd, _, err := syscall.Accept(int(l.event.Fd))
+ if err != nil {
+ return fmt.Errorf("Accept: %s", err)
+ }
+
+ cc, err := l.socket.addControlChannel(newFd, nil)
+ if err != nil {
+ return fmt.Errorf("Failed to add control channel: %s", err)
+ }
+
+ err = cc.msgEnqHello()
+ if err != nil {
+ return fmt.Errorf("msgEnqHello: %s", err)
+ }
+
+ err = cc.sendMsg()
+ if err != nil {
+ return err
+ }
+
+ return nil
+ }
+
+ return fmt.Errorf("Unexpected event: ", event.Events)
+}
+
+// handleEvent handles epoll event for control channel
+func (cc *controlChannel) handleEvent(event *syscall.EpollEvent) error {
+ var size int
+ var err error
+
+ // hang up
+ if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
+ // close cc, don't send msg
+ err := cc.close(false, "")
+ if err != nil {
+ return fmt.Errorf("Failed to close control channel after hang up event: ", err)
+ }
+ return fmt.Errorf("Hang up: ", cc.i.GetName())
+ }
+
+ if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
+ // close cc, don't send msg
+ err := cc.close(false, "")
+ if err != nil {
+ return fmt.Errorf("Failed to close control channel after receiving an error event: ", err)
+ }
+ return fmt.Errorf("Received error event on control channel ", cc.i.GetName())
+ }
+
+ if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
+ size, cc.controlLen, _, _, err = syscall.Recvmsg(int(cc.event.Fd), cc.data[:], cc.control[:], 0)
+ if err != nil {
+ return fmt.Errorf("recvmsg: %s", err)
+ }
+ if size != msgSize {
+ return fmt.Errorf("invalid message size %d", size)
+ }
+
+ err = cc.parseMsg()
+ if err != nil {
+ return err
+ }
+
+ err = cc.sendMsg()
+ if err != nil {
+ return err
+ }
+
+ return nil
+ }
+
+ return fmt.Errorf("Unexpected event: ", event.Events)
+}
+
+// close closes the listener
+func (l *listener) close() error {
+ err := l.socket.delEvent(&l.event)
+ if err != nil {
+ return fmt.Errorf("Failed to del event: ", err)
+ }
+ err = syscall.Close(int(l.event.Fd))
+ if err != nil {
+ return fmt.Errorf("Failed to close socket: ", err)
+ }
+ return nil
+}
+
+// AddListener adds a lisntener to the socket. The fd must describe a
+// UNIX domain socket already bound to a UNIX domain filename and
+// marked as listener
+func (socket *Socket) AddListener(fd int) (err error) {
+ l := &listener{
+ // we will need this to look up master interface by id
+ socket: socket,
+ }
+
+ l.event = syscall.EpollEvent{
+ Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
+ Fd: int32(fd),
+ }
+ err = socket.addEvent(&l.event)
+ if err != nil {
+ return fmt.Errorf("Failed to add event: ", err)
+ }
+
+ socket.listener = l
+
+ return nil
+}
+
+// addListener creates new UNIX domain socket, binds it to the address
+// and marks it as listener
+func (socket *Socket) addListener() (err error) {
+ // create socket
+ fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
+ if err != nil {
+ return fmt.Errorf("Failed to create UNIX domain socket")
+ }
+ usa := &syscall.SockaddrUnix{Name: socket.filename}
+
+ // Bind to address and start listening
+ err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1)
+ if err != nil {
+ return fmt.Errorf("Failed to set socket option %s : %v", socket.filename, err)
+ }
+ err = syscall.Bind(fd, usa)
+ if err != nil {
+ return fmt.Errorf("Failed to bind socket %s : %v", socket.filename, err)
+ }
+ err = syscall.Listen(fd, syscall.SOMAXCONN)
+ if err != nil {
+ return fmt.Errorf("Failed to listen on socket %s : %v", socket.filename, err)
+ }
+
+ return socket.AddListener(fd)
+}
+
+// close closes a control channel, if the control channel is assigned an
+// interface, the interface is disconnected
+func (cc *controlChannel) close(sendMsg bool, str string) (err error) {
+ if sendMsg == true {
+ // first clear message queue so that the disconnect
+ // message is the only message in queue
+ cc.msgQueue = []controlMsg{}
+ cc.msgEnqDisconnect(str)
+
+ err = cc.sendMsg()
+ if err != nil {
+ return err
+ }
+ }
+
+ err = cc.socket.delEvent(&cc.event)
+ if err != nil {
+ return fmt.Errorf("Failed to del event: ", err)
+ }
+
+ // remove referance form socket
+ cc.socket.ccList.Remove(cc.listRef)
+
+ if cc.i != nil {
+ err = cc.i.disconnect()
+ if err != nil {
+ return fmt.Errorf("Interface Disconnect: ", err)
+ }
+ }
+
+ return nil
+}
+
+//addControlChannel returns a new controlChannel and adds it to the socket
+func (socket *Socket) addControlChannel(fd int, i *Interface) (*controlChannel, error) {
+ cc := &controlChannel{
+ socket: socket,
+ i: i,
+ isConnected: false,
+ }
+
+ var err error
+
+ cc.event = syscall.EpollEvent{
+ Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
+ Fd: int32(fd),
+ }
+ err = socket.addEvent(&cc.event)
+ if err != nil {
+ return nil, fmt.Errorf("Failed to add event: ", err)
+ }
+
+ cc.listRef = socket.ccList.PushBack(cc)
+
+ return cc, nil
+}
+
+func (cc *controlChannel) msgEnqAck() (err error) {
+ buf := new(bytes.Buffer)
+ err = binary.Write(buf, binary.LittleEndian, msgTypeAck)
+
+ msg := controlMsg{
+ Buffer: buf,
+ Fd: -1,
+ }
+
+ cc.msgQueue = append(cc.msgQueue, msg)
+
+ return nil
+}
+
+func (cc *controlChannel) msgEnqHello() (err error) {
+ hello := MsgHello{
+ VersionMin: Version,
+ VersionMax: Version,
+ MaxRegion: 255,
+ MaxRingM2S: 255,
+ MaxRingS2M: 255,
+ MaxLog2RingSize: 14,
+ }
+
+ copy(hello.Name[:], []byte(cc.socket.appName))
+
+ buf := new(bytes.Buffer)
+ err = binary.Write(buf, binary.LittleEndian, msgTypeHello)
+ err = binary.Write(buf, binary.LittleEndian, hello)
+
+ msg := controlMsg{
+ Buffer: buf,
+ Fd: -1,
+ }
+
+ cc.msgQueue = append(cc.msgQueue, msg)
+
+ return nil
+}
+
+func (cc *controlChannel) parseHello() (err error) {
+ var hello MsgHello
+
+ buf := bytes.NewReader(cc.data[msgTypeSize:])
+ err = binary.Read(buf, binary.LittleEndian, &hello)
+ if err != nil {
+ return
+ }
+
+ if hello.VersionMin > Version || hello.VersionMax < Version {
+ return fmt.Errorf("Incompatible memif version")
+ }
+
+ cc.i.run = cc.i.args.MemoryConfig
+
+ cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingS2M)
+ cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingM2S)
+ cc.i.run.Log2RingSize = min8(cc.i.args.MemoryConfig.Log2RingSize, hello.MaxLog2RingSize)
+
+ cc.i.remoteName = string(hello.Name[:])
+
+ return nil
+}
+
+func (cc *controlChannel) msgEnqInit() (err error) {
+ init := MsgInit{
+ Version: Version,
+ Id: cc.i.args.Id,
+ Mode: interfaceModeEthernet,
+ }
+
+ copy(init.Name[:], []byte(cc.socket.appName))
+
+ buf := new(bytes.Buffer)
+ err = binary.Write(buf, binary.LittleEndian, msgTypeInit)
+ err = binary.Write(buf, binary.LittleEndian, init)
+
+ msg := controlMsg{
+ Buffer: buf,
+ Fd: -1,
+ }
+
+ cc.msgQueue = append(cc.msgQueue, msg)
+
+ return nil
+}
+
+func (cc *controlChannel) parseInit() (err error) {
+ var init MsgInit
+
+ buf := bytes.NewReader(cc.data[msgTypeSize:])
+ err = binary.Read(buf, binary.LittleEndian, &init)
+ if err != nil {
+ return
+ }
+
+ if init.Version != Version {
+ return fmt.Errorf("Incompatible memif driver version")
+ }
+
+ // find peer interface
+ for elt := cc.socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
+ i, ok := elt.Value.(*Interface)
+ if ok {
+ if i.args.Id == init.Id && i.args.IsMaster && i.cc == nil {
+ // verify secret
+ if i.args.Secret != init.Secret {
+ return fmt.Errorf("Invalid secret")
+ }
+ // interface is assigned to control channel
+ i.cc = cc
+ cc.i = i
+ cc.i.run = cc.i.args.MemoryConfig
+ cc.i.remoteName = string(init.Name[:])
+
+ return nil
+ }
+ }
+ }
+
+ return fmt.Errorf("Invalid interface id")
+}
+
+func (cc *controlChannel) msgEnqAddRegion(regionIndex uint16) (err error) {
+ if len(cc.i.regions) <= int(regionIndex) {
+ return fmt.Errorf("Invalid region index")
+ }
+
+ addRegion := MsgAddRegion{
+ Index: regionIndex,
+ Size: cc.i.regions[regionIndex].size,
+ }
+
+ buf := new(bytes.Buffer)
+ err = binary.Write(buf, binary.LittleEndian, msgTypeAddRegion)
+ err = binary.Write(buf, binary.LittleEndian, addRegion)
+
+ msg := controlMsg{
+ Buffer: buf,
+ Fd: cc.i.regions[regionIndex].fd,
+ }
+
+ cc.msgQueue = append(cc.msgQueue, msg)
+
+ return nil
+}
+
+func (cc *controlChannel) parseAddRegion() (err error) {
+ var addRegion MsgAddRegion
+
+ buf := bytes.NewReader(cc.data[msgTypeSize:])
+ err = binary.Read(buf, binary.LittleEndian, &addRegion)
+ if err != nil {
+ return
+ }
+
+ fd, err := cc.parseControlMsg()
+ if err != nil {
+ return fmt.Errorf("parseControlMsg: %s", err)
+ }
+
+ if addRegion.Index > 255 {
+ return fmt.Errorf("Invalid memory region index")
+ }
+
+ region := memoryRegion{
+ size: addRegion.Size,
+ fd: fd,
+ }
+
+ cc.i.regions = append(cc.i.regions, region)
+
+ return nil
+}
+
+func (cc *controlChannel) msgEnqAddRing(ringType ringType, ringIndex uint16) (err error) {
+ var q Queue
+ var flags uint16 = 0
+
+ if ringType == ringTypeS2M {
+ q = cc.i.txQueues[ringIndex]
+ flags = msgAddRingFlagS2M
+ } else {
+ q = cc.i.rxQueues[ringIndex]
+ }
+
+ addRing := MsgAddRing{
+ Index: ringIndex,
+ Offset: uint32(q.ring.offset),
+ Region: uint16(q.ring.region),
+ RingSizeLog2: uint8(q.ring.log2Size),
+ Flags: flags,
+ PrivateHdrSize: 0,
+ }
+
+ buf := new(bytes.Buffer)
+ err = binary.Write(buf, binary.LittleEndian, msgTypeAddRing)
+ err = binary.Write(buf, binary.LittleEndian, addRing)
+
+ msg := controlMsg{
+ Buffer: buf,
+ Fd: q.interruptFd,
+ }
+
+ cc.msgQueue = append(cc.msgQueue, msg)
+
+ return nil
+}
+
+func (cc *controlChannel) parseAddRing() (err error) {
+ var addRing MsgAddRing
+
+ buf := bytes.NewReader(cc.data[msgTypeSize:])
+ err = binary.Read(buf, binary.LittleEndian, &addRing)
+ if err != nil {
+ return
+ }
+
+ fd, err := cc.parseControlMsg()
+ if err != nil {
+ return err
+ }
+
+ if addRing.Index >= cc.i.run.NumQueuePairs {
+ return fmt.Errorf("invalid ring index")
+ }
+
+ q := Queue{
+ i: cc.i,
+ interruptFd: fd,
+ }
+
+ if (addRing.Flags & msgAddRingFlagS2M) == msgAddRingFlagS2M {
+ q.ring = newRing(int(addRing.Region), ringTypeS2M, int(addRing.Offset), int(addRing.RingSizeLog2))
+ cc.i.rxQueues = append(cc.i.rxQueues, q)
+ } else {
+ q.ring = newRing(int(addRing.Region), ringTypeM2S, int(addRing.Offset), int(addRing.RingSizeLog2))
+ cc.i.txQueues = append(cc.i.txQueues, q)
+ }
+
+ return nil
+}
+
+func (cc *controlChannel) msgEnqConnect() (err error) {
+ var connect MsgConnect
+ copy(connect.Name[:], []byte(cc.i.args.Name))
+
+ buf := new(bytes.Buffer)
+ err = binary.Write(buf, binary.LittleEndian, msgTypeConnect)
+ err = binary.Write(buf, binary.LittleEndian, connect)
+
+ msg := controlMsg{
+ Buffer: buf,
+ Fd: -1,
+ }
+
+ cc.msgQueue = append(cc.msgQueue, msg)
+
+ return nil
+}
+
+func (cc *controlChannel) parseConnect() (err error) {
+ var connect MsgConnect
+
+ buf := bytes.NewReader(cc.data[msgTypeSize:])
+ err = binary.Read(buf, binary.LittleEndian, &connect)
+ if err != nil {
+ return
+ }
+
+ cc.i.peerName = string(connect.Name[:])
+
+ err = cc.i.connect()
+ if err != nil {
+ return err
+ }
+
+ cc.isConnected = true
+
+ return nil
+}
+
+func (cc *controlChannel) msgEnqConnected() (err error) {
+ var connected MsgConnected
+ copy(connected.Name[:], []byte(cc.i.args.Name))
+
+ buf := new(bytes.Buffer)
+ err = binary.Write(buf, binary.LittleEndian, msgTypeConnected)
+ err = binary.Write(buf, binary.LittleEndian, connected)
+
+ msg := controlMsg{
+ Buffer: buf,
+ Fd: -1,
+ }
+
+ cc.msgQueue = append(cc.msgQueue, msg)
+
+ return nil
+}
+
+func (cc *controlChannel) parseConnected() (err error) {
+ var conn MsgConnected
+
+ buf := bytes.NewReader(cc.data[msgTypeSize:])
+ err = binary.Read(buf, binary.LittleEndian, &conn)
+ if err != nil {
+ return
+ }
+
+ cc.i.peerName = string(conn.Name[:])
+
+ err = cc.i.connect()
+ if err != nil {
+ return err
+ }
+
+ cc.isConnected = true
+
+ return nil
+}
+
+func (cc *controlChannel) msgEnqDisconnect(str string) (err error) {
+ dc := MsgDisconnect{
+ // not implemented
+ Code: 0,
+ }
+ copy(dc.String[:], str)
+
+ buf := new(bytes.Buffer)
+ err = binary.Write(buf, binary.LittleEndian, msgTypeDisconnect)
+ err = binary.Write(buf, binary.LittleEndian, dc)
+
+ msg := controlMsg{
+ Buffer: buf,
+ Fd: -1,
+ }
+
+ cc.msgQueue = append(cc.msgQueue, msg)
+
+ return nil
+}
+
+func (cc *controlChannel) parseDisconnect() (err error) {
+ var dc MsgDisconnect
+
+ buf := bytes.NewReader(cc.data[msgTypeSize:])
+ err = binary.Read(buf, binary.LittleEndian, &dc)
+ if err != nil {
+ return
+ }
+
+ err = cc.close(false, string(dc.String[:]))
+ if err != nil {
+ return fmt.Errorf("Failed to disconnect control channel: ", err)
+ }
+
+ return nil
+}
+
+func (cc *controlChannel) parseMsg() error {
+ var msgType msgType
+ var err error
+
+ buf := bytes.NewReader(cc.data[:])
+ err = binary.Read(buf, binary.LittleEndian, &msgType)
+
+ if msgType == msgTypeAck {
+ return nil
+ } else if msgType == msgTypeHello {
+ // Configure
+ err = cc.parseHello()
+ if err != nil {
+ goto error
+ }
+ // Initialize slave memif
+ err = cc.i.initializeRegions()
+ if err != nil {
+ goto error
+ }
+ err = cc.i.initializeQueues()
+ if err != nil {
+ goto error
+ }
+ // Enqueue messages
+ err = cc.msgEnqInit()
+ if err != nil {
+ goto error
+ }
+ for i := 0; i < len(cc.i.regions); i++ {
+ err = cc.msgEnqAddRegion(uint16(i))
+ if err != nil {
+ goto error
+ }
+ }
+ for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
+ err = cc.msgEnqAddRing(ringTypeS2M, uint16(i))
+ if err != nil {
+ goto error
+ }
+ }
+ for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
+ err = cc.msgEnqAddRing(ringTypeM2S, uint16(i))
+ if err != nil {
+ goto error
+ }
+ }
+ err = cc.msgEnqConnect()
+ if err != nil {
+ goto error
+ }
+ } else if msgType == msgTypeInit {
+ err = cc.parseInit()
+ if err != nil {
+ goto error
+ }
+
+ err = cc.msgEnqAck()
+ if err != nil {
+ goto error
+ }
+ } else if msgType == msgTypeAddRegion {
+ err = cc.parseAddRegion()
+ if err != nil {
+ goto error
+ }
+
+ err = cc.msgEnqAck()
+ if err != nil {
+ goto error
+ }
+ } else if msgType == msgTypeAddRing {
+ err = cc.parseAddRing()
+ if err != nil {
+ goto error
+ }
+
+ err = cc.msgEnqAck()
+ if err != nil {
+ goto error
+ }
+ } else if msgType == msgTypeConnect {
+ err = cc.parseConnect()
+ if err != nil {
+ goto error
+ }
+
+ err = cc.msgEnqConnected()
+ if err != nil {
+ goto error
+ }
+ } else if msgType == msgTypeConnected {
+ err = cc.parseConnected()
+ if err != nil {
+ goto error
+ }
+ } else if msgType == msgTypeDisconnect {
+ err = cc.parseDisconnect()
+ if err != nil {
+ goto error
+ }
+ } else {
+ err = fmt.Errorf("unknown message %d", msgType)
+ goto error
+ }
+
+ return nil
+
+error:
+ err1 := cc.close(true, err.Error())
+ if err1 != nil {
+ return fmt.Errorf(err.Error(), ": Failed to close control channel: ", err1)
+ }
+
+ return err
+}
+
+// parseControlMsg parses control message and returns file descriptor
+// if any
+func (cc *controlChannel) parseControlMsg() (fd int, err error) {
+ // Assert only called when we require FD
+ fd = -1
+
+ controlMsgs, err := syscall.ParseSocketControlMessage(cc.control[:cc.controlLen])
+ if err != nil {
+ return -1, fmt.Errorf("syscall.ParseSocketControlMessage: %s", err)
+ }
+
+ if len(controlMsgs) == 0 {
+ return -1, fmt.Errorf("Missing control message")
+ }
+
+ for _, cmsg := range controlMsgs {
+ if cmsg.Header.Level == syscall.SOL_SOCKET {
+ if cmsg.Header.Type == syscall.SCM_RIGHTS {
+ FDs, err := syscall.ParseUnixRights(&cmsg)
+ if err != nil {
+ return -1, fmt.Errorf("syscall.ParseUnixRights: %s", err)
+ }
+ if len(FDs) == 0 {
+ continue
+ }
+ // Only expect single FD
+ fd = FDs[0]
+ }
+ }
+ }
+
+ if fd == -1 {
+ return -1, fmt.Errorf("Missing file descriptor")
+ }
+
+ return fd, nil
+}