diff options
author | Jakub Grajciar <jgrajcia@cisco.com> | 2020-04-02 10:02:17 +0200 |
---|---|---|
committer | Damjan Marion <dmarion@me.com> | 2020-04-28 21:18:37 +0000 |
commit | 07363a45fe4a7fe693acf438f0b56f927bdd3fbd (patch) | |
tree | 6d53728ac594de1b86e85c7d4ea1d9f8d145a993 /extras/gomemif/memif/control_channel.go | |
parent | c458c493667bde30c22760e3a1839f2cac6e6447 (diff) |
gomemif: introduce gomemif
golang native memif driver
Type: feature
Signed-off-by: Jakub Grajciar <jgrajcia@cisco.com>
Change-Id: I693156a44011c80025245d25134f5bf5db6eba82
Signed-off-by: Jakub Grajciar <jgrajcia@cisco.com>
Diffstat (limited to 'extras/gomemif/memif/control_channel.go')
-rw-r--r-- | extras/gomemif/memif/control_channel.go | 965 |
1 files changed, 965 insertions, 0 deletions
diff --git a/extras/gomemif/memif/control_channel.go b/extras/gomemif/memif/control_channel.go new file mode 100644 index 00000000000..32e34933ab4 --- /dev/null +++ b/extras/gomemif/memif/control_channel.go @@ -0,0 +1,965 @@ +/* + *------------------------------------------------------------------ + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *------------------------------------------------------------------ + */ + +package memif + +import ( + "bytes" + "container/list" + "encoding/binary" + "fmt" + "os" + "sync" + "syscall" +) + +const maxEpollEvents = 1 +const maxControlLen = 256 + +const errorFdNotFound = "fd not found" + +// controlMsg represents a message used in communication between memif peers +type controlMsg struct { + Buffer *bytes.Buffer + Fd int +} + +// listener represents a listener functionality of UNIX domain socket +type listener struct { + socket *Socket + event syscall.EpollEvent +} + +// controlChannel represents a communication channel between memif peers +// backed by UNIX domain socket +type controlChannel struct { + listRef *list.Element + socket *Socket + i *Interface + event syscall.EpollEvent + data [msgSize]byte + control [maxControlLen]byte + controlLen int + msgQueue []controlMsg + isConnected bool +} + +// Socket represents a UNIX domain socket used for communication +// between memif peers +type Socket struct { + appName string + filename string + listener *listener + interfaceList *list.List + ccList *list.List + epfd int + wakeEvent syscall.EpollEvent + stopPollChan chan struct{} + wg sync.WaitGroup +} + +// StopPolling stops polling events on the socket +func (socket *Socket) StopPolling() error { + if socket.stopPollChan != nil { + // stop polling msg + close(socket.stopPollChan) + // wake epoll + buf := make([]byte, 8) + binary.PutUvarint(buf, 1) + n, err := syscall.Write(int(socket.wakeEvent.Fd), buf[:]) + if err != nil { + return err + } + if n != 8 { + return fmt.Errorf("Faild to write to eventfd") + } + // wait until polling is stopped + socket.wg.Wait() + } + + return nil +} + +// StartPolling starts polling and handling events on the socket, +// enabling communication between memif peers +func (socket *Socket) StartPolling(errChan chan<- error) { + socket.stopPollChan = make(chan struct{}) + socket.wg.Add(1) + go func() { + var events [maxEpollEvents]syscall.EpollEvent + defer socket.wg.Done() + + for { + select { + case <-socket.stopPollChan: + return + default: + num, err := syscall.EpollWait(socket.epfd, events[:], -1) + if err != nil { + errChan <- fmt.Errorf("EpollWait: ", err) + return + } + + for ev := 0; ev < num; ev++ { + if events[0].Fd == socket.wakeEvent.Fd { + continue + } + err = socket.handleEvent(&events[0]) + if err != nil { + errChan <- fmt.Errorf("handleEvent: ", err) + } + } + } + } + }() +} + +// addEvent adds event to epoll instance associated with the socket +func (socket *Socket) addEvent(event *syscall.EpollEvent) error { + err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_ADD, int(event.Fd), event) + if err != nil { + return fmt.Errorf("EpollCtl: %s", err) + } + return nil +} + +// addEvent deletes event to epoll instance associated with the socket +func (socket *Socket) delEvent(event *syscall.EpollEvent) error { + err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_DEL, int(event.Fd), event) + if err != nil { + return fmt.Errorf("EpollCtl: %s", err) + } + return nil +} + +// Delete deletes the socket +func (socket *Socket) Delete() (err error) { + for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() { + cc, ok := elt.Value.(*controlChannel) + if ok { + err = cc.close(true, "Socket deleted") + if err != nil { + return nil + } + } + } + for elt := socket.interfaceList.Front(); elt != nil; elt = elt.Next() { + i, ok := elt.Value.(*Interface) + if ok { + err = i.Delete() + if err != nil { + return err + } + } + } + + if socket.listener != nil { + err = socket.listener.close() + if err != nil { + return err + } + err = os.Remove(socket.filename) + if err != nil { + return nil + } + } + + err = socket.delEvent(&socket.wakeEvent) + if err != nil { + return fmt.Errorf("Failed to delete event: ", err) + } + + syscall.Close(socket.epfd) + + return nil +} + +// NewSocket returns a new Socket +func NewSocket(appName string, filename string) (socket *Socket, err error) { + socket = &Socket{ + appName: appName, + filename: filename, + interfaceList: list.New(), + ccList: list.New(), + } + if socket.filename == "" { + socket.filename = DefaultSocketFilename + } + + socket.epfd, _ = syscall.EpollCreate1(0) + + efd, err := eventFd() + socket.wakeEvent = syscall.EpollEvent{ + Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP, + Fd: int32(efd), + } + err = socket.addEvent(&socket.wakeEvent) + if err != nil { + return nil, fmt.Errorf("Failed to add event: ", err) + } + + return socket, nil +} + +// handleEvent handles epoll event +func (socket *Socket) handleEvent(event *syscall.EpollEvent) error { + if socket.listener != nil && socket.listener.event.Fd == event.Fd { + return socket.listener.handleEvent(event) + } + + for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() { + cc, ok := elt.Value.(*controlChannel) + if ok { + if cc.event.Fd == event.Fd { + return cc.handleEvent(event) + } + } + } + + return fmt.Errorf(errorFdNotFound) +} + +// handleEvent handles epoll event for listener +func (l *listener) handleEvent(event *syscall.EpollEvent) error { + // hang up + if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP { + err := l.close() + if err != nil { + return fmt.Errorf("Failed to close listener after hang up event: ", err) + } + return fmt.Errorf("Hang up: ", l.socket.filename) + } + + // error + if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR { + err := l.close() + if err != nil { + return fmt.Errorf("Failed to close listener after receiving an error event: ", err) + } + return fmt.Errorf("Received error event on listener ", l.socket.filename) + } + + // read message + if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN { + newFd, _, err := syscall.Accept(int(l.event.Fd)) + if err != nil { + return fmt.Errorf("Accept: %s", err) + } + + cc, err := l.socket.addControlChannel(newFd, nil) + if err != nil { + return fmt.Errorf("Failed to add control channel: %s", err) + } + + err = cc.msgEnqHello() + if err != nil { + return fmt.Errorf("msgEnqHello: %s", err) + } + + err = cc.sendMsg() + if err != nil { + return err + } + + return nil + } + + return fmt.Errorf("Unexpected event: ", event.Events) +} + +// handleEvent handles epoll event for control channel +func (cc *controlChannel) handleEvent(event *syscall.EpollEvent) error { + var size int + var err error + + // hang up + if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP { + // close cc, don't send msg + err := cc.close(false, "") + if err != nil { + return fmt.Errorf("Failed to close control channel after hang up event: ", err) + } + return fmt.Errorf("Hang up: ", cc.i.GetName()) + } + + if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR { + // close cc, don't send msg + err := cc.close(false, "") + if err != nil { + return fmt.Errorf("Failed to close control channel after receiving an error event: ", err) + } + return fmt.Errorf("Received error event on control channel ", cc.i.GetName()) + } + + if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN { + size, cc.controlLen, _, _, err = syscall.Recvmsg(int(cc.event.Fd), cc.data[:], cc.control[:], 0) + if err != nil { + return fmt.Errorf("recvmsg: %s", err) + } + if size != msgSize { + return fmt.Errorf("invalid message size %d", size) + } + + err = cc.parseMsg() + if err != nil { + return err + } + + err = cc.sendMsg() + if err != nil { + return err + } + + return nil + } + + return fmt.Errorf("Unexpected event: ", event.Events) +} + +// close closes the listener +func (l *listener) close() error { + err := l.socket.delEvent(&l.event) + if err != nil { + return fmt.Errorf("Failed to del event: ", err) + } + err = syscall.Close(int(l.event.Fd)) + if err != nil { + return fmt.Errorf("Failed to close socket: ", err) + } + return nil +} + +// AddListener adds a lisntener to the socket. The fd must describe a +// UNIX domain socket already bound to a UNIX domain filename and +// marked as listener +func (socket *Socket) AddListener(fd int) (err error) { + l := &listener{ + // we will need this to look up master interface by id + socket: socket, + } + + l.event = syscall.EpollEvent{ + Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP, + Fd: int32(fd), + } + err = socket.addEvent(&l.event) + if err != nil { + return fmt.Errorf("Failed to add event: ", err) + } + + socket.listener = l + + return nil +} + +// addListener creates new UNIX domain socket, binds it to the address +// and marks it as listener +func (socket *Socket) addListener() (err error) { + // create socket + fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + if err != nil { + return fmt.Errorf("Failed to create UNIX domain socket") + } + usa := &syscall.SockaddrUnix{Name: socket.filename} + + // Bind to address and start listening + err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1) + if err != nil { + return fmt.Errorf("Failed to set socket option %s : %v", socket.filename, err) + } + err = syscall.Bind(fd, usa) + if err != nil { + return fmt.Errorf("Failed to bind socket %s : %v", socket.filename, err) + } + err = syscall.Listen(fd, syscall.SOMAXCONN) + if err != nil { + return fmt.Errorf("Failed to listen on socket %s : %v", socket.filename, err) + } + + return socket.AddListener(fd) +} + +// close closes a control channel, if the control channel is assigned an +// interface, the interface is disconnected +func (cc *controlChannel) close(sendMsg bool, str string) (err error) { + if sendMsg == true { + // first clear message queue so that the disconnect + // message is the only message in queue + cc.msgQueue = []controlMsg{} + cc.msgEnqDisconnect(str) + + err = cc.sendMsg() + if err != nil { + return err + } + } + + err = cc.socket.delEvent(&cc.event) + if err != nil { + return fmt.Errorf("Failed to del event: ", err) + } + + // remove referance form socket + cc.socket.ccList.Remove(cc.listRef) + + if cc.i != nil { + err = cc.i.disconnect() + if err != nil { + return fmt.Errorf("Interface Disconnect: ", err) + } + } + + return nil +} + +//addControlChannel returns a new controlChannel and adds it to the socket +func (socket *Socket) addControlChannel(fd int, i *Interface) (*controlChannel, error) { + cc := &controlChannel{ + socket: socket, + i: i, + isConnected: false, + } + + var err error + + cc.event = syscall.EpollEvent{ + Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP, + Fd: int32(fd), + } + err = socket.addEvent(&cc.event) + if err != nil { + return nil, fmt.Errorf("Failed to add event: ", err) + } + + cc.listRef = socket.ccList.PushBack(cc) + + return cc, nil +} + +func (cc *controlChannel) msgEnqAck() (err error) { + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeAck) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) msgEnqHello() (err error) { + hello := MsgHello{ + VersionMin: Version, + VersionMax: Version, + MaxRegion: 255, + MaxRingM2S: 255, + MaxRingS2M: 255, + MaxLog2RingSize: 14, + } + + copy(hello.Name[:], []byte(cc.socket.appName)) + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeHello) + err = binary.Write(buf, binary.LittleEndian, hello) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseHello() (err error) { + var hello MsgHello + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &hello) + if err != nil { + return + } + + if hello.VersionMin > Version || hello.VersionMax < Version { + return fmt.Errorf("Incompatible memif version") + } + + cc.i.run = cc.i.args.MemoryConfig + + cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingS2M) + cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingM2S) + cc.i.run.Log2RingSize = min8(cc.i.args.MemoryConfig.Log2RingSize, hello.MaxLog2RingSize) + + cc.i.remoteName = string(hello.Name[:]) + + return nil +} + +func (cc *controlChannel) msgEnqInit() (err error) { + init := MsgInit{ + Version: Version, + Id: cc.i.args.Id, + Mode: interfaceModeEthernet, + } + + copy(init.Name[:], []byte(cc.socket.appName)) + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeInit) + err = binary.Write(buf, binary.LittleEndian, init) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseInit() (err error) { + var init MsgInit + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &init) + if err != nil { + return + } + + if init.Version != Version { + return fmt.Errorf("Incompatible memif driver version") + } + + // find peer interface + for elt := cc.socket.interfaceList.Front(); elt != nil; elt = elt.Next() { + i, ok := elt.Value.(*Interface) + if ok { + if i.args.Id == init.Id && i.args.IsMaster && i.cc == nil { + // verify secret + if i.args.Secret != init.Secret { + return fmt.Errorf("Invalid secret") + } + // interface is assigned to control channel + i.cc = cc + cc.i = i + cc.i.run = cc.i.args.MemoryConfig + cc.i.remoteName = string(init.Name[:]) + + return nil + } + } + } + + return fmt.Errorf("Invalid interface id") +} + +func (cc *controlChannel) msgEnqAddRegion(regionIndex uint16) (err error) { + if len(cc.i.regions) <= int(regionIndex) { + return fmt.Errorf("Invalid region index") + } + + addRegion := MsgAddRegion{ + Index: regionIndex, + Size: cc.i.regions[regionIndex].size, + } + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeAddRegion) + err = binary.Write(buf, binary.LittleEndian, addRegion) + + msg := controlMsg{ + Buffer: buf, + Fd: cc.i.regions[regionIndex].fd, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseAddRegion() (err error) { + var addRegion MsgAddRegion + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &addRegion) + if err != nil { + return + } + + fd, err := cc.parseControlMsg() + if err != nil { + return fmt.Errorf("parseControlMsg: %s", err) + } + + if addRegion.Index > 255 { + return fmt.Errorf("Invalid memory region index") + } + + region := memoryRegion{ + size: addRegion.Size, + fd: fd, + } + + cc.i.regions = append(cc.i.regions, region) + + return nil +} + +func (cc *controlChannel) msgEnqAddRing(ringType ringType, ringIndex uint16) (err error) { + var q Queue + var flags uint16 = 0 + + if ringType == ringTypeS2M { + q = cc.i.txQueues[ringIndex] + flags = msgAddRingFlagS2M + } else { + q = cc.i.rxQueues[ringIndex] + } + + addRing := MsgAddRing{ + Index: ringIndex, + Offset: uint32(q.ring.offset), + Region: uint16(q.ring.region), + RingSizeLog2: uint8(q.ring.log2Size), + Flags: flags, + PrivateHdrSize: 0, + } + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeAddRing) + err = binary.Write(buf, binary.LittleEndian, addRing) + + msg := controlMsg{ + Buffer: buf, + Fd: q.interruptFd, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseAddRing() (err error) { + var addRing MsgAddRing + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &addRing) + if err != nil { + return + } + + fd, err := cc.parseControlMsg() + if err != nil { + return err + } + + if addRing.Index >= cc.i.run.NumQueuePairs { + return fmt.Errorf("invalid ring index") + } + + q := Queue{ + i: cc.i, + interruptFd: fd, + } + + if (addRing.Flags & msgAddRingFlagS2M) == msgAddRingFlagS2M { + q.ring = newRing(int(addRing.Region), ringTypeS2M, int(addRing.Offset), int(addRing.RingSizeLog2)) + cc.i.rxQueues = append(cc.i.rxQueues, q) + } else { + q.ring = newRing(int(addRing.Region), ringTypeM2S, int(addRing.Offset), int(addRing.RingSizeLog2)) + cc.i.txQueues = append(cc.i.txQueues, q) + } + + return nil +} + +func (cc *controlChannel) msgEnqConnect() (err error) { + var connect MsgConnect + copy(connect.Name[:], []byte(cc.i.args.Name)) + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeConnect) + err = binary.Write(buf, binary.LittleEndian, connect) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseConnect() (err error) { + var connect MsgConnect + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &connect) + if err != nil { + return + } + + cc.i.peerName = string(connect.Name[:]) + + err = cc.i.connect() + if err != nil { + return err + } + + cc.isConnected = true + + return nil +} + +func (cc *controlChannel) msgEnqConnected() (err error) { + var connected MsgConnected + copy(connected.Name[:], []byte(cc.i.args.Name)) + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeConnected) + err = binary.Write(buf, binary.LittleEndian, connected) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseConnected() (err error) { + var conn MsgConnected + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &conn) + if err != nil { + return + } + + cc.i.peerName = string(conn.Name[:]) + + err = cc.i.connect() + if err != nil { + return err + } + + cc.isConnected = true + + return nil +} + +func (cc *controlChannel) msgEnqDisconnect(str string) (err error) { + dc := MsgDisconnect{ + // not implemented + Code: 0, + } + copy(dc.String[:], str) + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeDisconnect) + err = binary.Write(buf, binary.LittleEndian, dc) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseDisconnect() (err error) { + var dc MsgDisconnect + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &dc) + if err != nil { + return + } + + err = cc.close(false, string(dc.String[:])) + if err != nil { + return fmt.Errorf("Failed to disconnect control channel: ", err) + } + + return nil +} + +func (cc *controlChannel) parseMsg() error { + var msgType msgType + var err error + + buf := bytes.NewReader(cc.data[:]) + err = binary.Read(buf, binary.LittleEndian, &msgType) + + if msgType == msgTypeAck { + return nil + } else if msgType == msgTypeHello { + // Configure + err = cc.parseHello() + if err != nil { + goto error + } + // Initialize slave memif + err = cc.i.initializeRegions() + if err != nil { + goto error + } + err = cc.i.initializeQueues() + if err != nil { + goto error + } + // Enqueue messages + err = cc.msgEnqInit() + if err != nil { + goto error + } + for i := 0; i < len(cc.i.regions); i++ { + err = cc.msgEnqAddRegion(uint16(i)) + if err != nil { + goto error + } + } + for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ { + err = cc.msgEnqAddRing(ringTypeS2M, uint16(i)) + if err != nil { + goto error + } + } + for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ { + err = cc.msgEnqAddRing(ringTypeM2S, uint16(i)) + if err != nil { + goto error + } + } + err = cc.msgEnqConnect() + if err != nil { + goto error + } + } else if msgType == msgTypeInit { + err = cc.parseInit() + if err != nil { + goto error + } + + err = cc.msgEnqAck() + if err != nil { + goto error + } + } else if msgType == msgTypeAddRegion { + err = cc.parseAddRegion() + if err != nil { + goto error + } + + err = cc.msgEnqAck() + if err != nil { + goto error + } + } else if msgType == msgTypeAddRing { + err = cc.parseAddRing() + if err != nil { + goto error + } + + err = cc.msgEnqAck() + if err != nil { + goto error + } + } else if msgType == msgTypeConnect { + err = cc.parseConnect() + if err != nil { + goto error + } + + err = cc.msgEnqConnected() + if err != nil { + goto error + } + } else if msgType == msgTypeConnected { + err = cc.parseConnected() + if err != nil { + goto error + } + } else if msgType == msgTypeDisconnect { + err = cc.parseDisconnect() + if err != nil { + goto error + } + } else { + err = fmt.Errorf("unknown message %d", msgType) + goto error + } + + return nil + +error: + err1 := cc.close(true, err.Error()) + if err1 != nil { + return fmt.Errorf(err.Error(), ": Failed to close control channel: ", err1) + } + + return err +} + +// parseControlMsg parses control message and returns file descriptor +// if any +func (cc *controlChannel) parseControlMsg() (fd int, err error) { + // Assert only called when we require FD + fd = -1 + + controlMsgs, err := syscall.ParseSocketControlMessage(cc.control[:cc.controlLen]) + if err != nil { + return -1, fmt.Errorf("syscall.ParseSocketControlMessage: %s", err) + } + + if len(controlMsgs) == 0 { + return -1, fmt.Errorf("Missing control message") + } + + for _, cmsg := range controlMsgs { + if cmsg.Header.Level == syscall.SOL_SOCKET { + if cmsg.Header.Type == syscall.SCM_RIGHTS { + FDs, err := syscall.ParseUnixRights(&cmsg) + if err != nil { + return -1, fmt.Errorf("syscall.ParseUnixRights: %s", err) + } + if len(FDs) == 0 { + continue + } + // Only expect single FD + fd = FDs[0] + } + } + } + + if fd == -1 { + return -1, fmt.Errorf("Missing file descriptor") + } + + return fd, nil +} |