diff options
Diffstat (limited to 'extras/gomemif/memif')
-rw-r--r-- | extras/gomemif/memif/BUILD.bazel | 17 | ||||
-rw-r--r-- | extras/gomemif/memif/control_channel.go | 965 | ||||
-rw-r--r-- | extras/gomemif/memif/control_channel_unsafe.go | 60 | ||||
-rw-r--r-- | extras/gomemif/memif/interface.go | 507 | ||||
-rw-r--r-- | extras/gomemif/memif/interface_unsafe.go | 40 | ||||
-rw-r--r-- | extras/gomemif/memif/memif.go | 345 | ||||
-rw-r--r-- | extras/gomemif/memif/memif_unsafe.go | 55 | ||||
-rw-r--r-- | extras/gomemif/memif/packet_reader.go | 91 | ||||
-rw-r--r-- | extras/gomemif/memif/packet_writer.go | 95 |
9 files changed, 2175 insertions, 0 deletions
diff --git a/extras/gomemif/memif/BUILD.bazel b/extras/gomemif/memif/BUILD.bazel new file mode 100644 index 00000000000..e6539ff59bd --- /dev/null +++ b/extras/gomemif/memif/BUILD.bazel @@ -0,0 +1,17 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "memif", + srcs = [ + "interface.go", + "interface_unsafe.go", + "control_channel.go", + "control_channel_unsafe.go", + "memif.go", + "memif_unsafe.go", + "packet_writer.go", + "packet_reader.go", + ], + importpath = "memif", + visibility = ["//visibility:public",], +) diff --git a/extras/gomemif/memif/control_channel.go b/extras/gomemif/memif/control_channel.go new file mode 100644 index 00000000000..32e34933ab4 --- /dev/null +++ b/extras/gomemif/memif/control_channel.go @@ -0,0 +1,965 @@ +/* + *------------------------------------------------------------------ + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *------------------------------------------------------------------ + */ + +package memif + +import ( + "bytes" + "container/list" + "encoding/binary" + "fmt" + "os" + "sync" + "syscall" +) + +const maxEpollEvents = 1 +const maxControlLen = 256 + +const errorFdNotFound = "fd not found" + +// controlMsg represents a message used in communication between memif peers +type controlMsg struct { + Buffer *bytes.Buffer + Fd int +} + +// listener represents a listener functionality of UNIX domain socket +type listener struct { + socket *Socket + event syscall.EpollEvent +} + +// controlChannel represents a communication channel between memif peers +// backed by UNIX domain socket +type controlChannel struct { + listRef *list.Element + socket *Socket + i *Interface + event syscall.EpollEvent + data [msgSize]byte + control [maxControlLen]byte + controlLen int + msgQueue []controlMsg + isConnected bool +} + +// Socket represents a UNIX domain socket used for communication +// between memif peers +type Socket struct { + appName string + filename string + listener *listener + interfaceList *list.List + ccList *list.List + epfd int + wakeEvent syscall.EpollEvent + stopPollChan chan struct{} + wg sync.WaitGroup +} + +// StopPolling stops polling events on the socket +func (socket *Socket) StopPolling() error { + if socket.stopPollChan != nil { + // stop polling msg + close(socket.stopPollChan) + // wake epoll + buf := make([]byte, 8) + binary.PutUvarint(buf, 1) + n, err := syscall.Write(int(socket.wakeEvent.Fd), buf[:]) + if err != nil { + return err + } + if n != 8 { + return fmt.Errorf("Faild to write to eventfd") + } + // wait until polling is stopped + socket.wg.Wait() + } + + return nil +} + +// StartPolling starts polling and handling events on the socket, +// enabling communication between memif peers +func (socket *Socket) StartPolling(errChan chan<- error) { + socket.stopPollChan = make(chan struct{}) + socket.wg.Add(1) + go func() { + var events [maxEpollEvents]syscall.EpollEvent + defer socket.wg.Done() + + for { + select { + case <-socket.stopPollChan: + return + default: + num, err := syscall.EpollWait(socket.epfd, events[:], -1) + if err != nil { + errChan <- fmt.Errorf("EpollWait: ", err) + return + } + + for ev := 0; ev < num; ev++ { + if events[0].Fd == socket.wakeEvent.Fd { + continue + } + err = socket.handleEvent(&events[0]) + if err != nil { + errChan <- fmt.Errorf("handleEvent: ", err) + } + } + } + } + }() +} + +// addEvent adds event to epoll instance associated with the socket +func (socket *Socket) addEvent(event *syscall.EpollEvent) error { + err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_ADD, int(event.Fd), event) + if err != nil { + return fmt.Errorf("EpollCtl: %s", err) + } + return nil +} + +// addEvent deletes event to epoll instance associated with the socket +func (socket *Socket) delEvent(event *syscall.EpollEvent) error { + err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_DEL, int(event.Fd), event) + if err != nil { + return fmt.Errorf("EpollCtl: %s", err) + } + return nil +} + +// Delete deletes the socket +func (socket *Socket) Delete() (err error) { + for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() { + cc, ok := elt.Value.(*controlChannel) + if ok { + err = cc.close(true, "Socket deleted") + if err != nil { + return nil + } + } + } + for elt := socket.interfaceList.Front(); elt != nil; elt = elt.Next() { + i, ok := elt.Value.(*Interface) + if ok { + err = i.Delete() + if err != nil { + return err + } + } + } + + if socket.listener != nil { + err = socket.listener.close() + if err != nil { + return err + } + err = os.Remove(socket.filename) + if err != nil { + return nil + } + } + + err = socket.delEvent(&socket.wakeEvent) + if err != nil { + return fmt.Errorf("Failed to delete event: ", err) + } + + syscall.Close(socket.epfd) + + return nil +} + +// NewSocket returns a new Socket +func NewSocket(appName string, filename string) (socket *Socket, err error) { + socket = &Socket{ + appName: appName, + filename: filename, + interfaceList: list.New(), + ccList: list.New(), + } + if socket.filename == "" { + socket.filename = DefaultSocketFilename + } + + socket.epfd, _ = syscall.EpollCreate1(0) + + efd, err := eventFd() + socket.wakeEvent = syscall.EpollEvent{ + Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP, + Fd: int32(efd), + } + err = socket.addEvent(&socket.wakeEvent) + if err != nil { + return nil, fmt.Errorf("Failed to add event: ", err) + } + + return socket, nil +} + +// handleEvent handles epoll event +func (socket *Socket) handleEvent(event *syscall.EpollEvent) error { + if socket.listener != nil && socket.listener.event.Fd == event.Fd { + return socket.listener.handleEvent(event) + } + + for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() { + cc, ok := elt.Value.(*controlChannel) + if ok { + if cc.event.Fd == event.Fd { + return cc.handleEvent(event) + } + } + } + + return fmt.Errorf(errorFdNotFound) +} + +// handleEvent handles epoll event for listener +func (l *listener) handleEvent(event *syscall.EpollEvent) error { + // hang up + if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP { + err := l.close() + if err != nil { + return fmt.Errorf("Failed to close listener after hang up event: ", err) + } + return fmt.Errorf("Hang up: ", l.socket.filename) + } + + // error + if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR { + err := l.close() + if err != nil { + return fmt.Errorf("Failed to close listener after receiving an error event: ", err) + } + return fmt.Errorf("Received error event on listener ", l.socket.filename) + } + + // read message + if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN { + newFd, _, err := syscall.Accept(int(l.event.Fd)) + if err != nil { + return fmt.Errorf("Accept: %s", err) + } + + cc, err := l.socket.addControlChannel(newFd, nil) + if err != nil { + return fmt.Errorf("Failed to add control channel: %s", err) + } + + err = cc.msgEnqHello() + if err != nil { + return fmt.Errorf("msgEnqHello: %s", err) + } + + err = cc.sendMsg() + if err != nil { + return err + } + + return nil + } + + return fmt.Errorf("Unexpected event: ", event.Events) +} + +// handleEvent handles epoll event for control channel +func (cc *controlChannel) handleEvent(event *syscall.EpollEvent) error { + var size int + var err error + + // hang up + if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP { + // close cc, don't send msg + err := cc.close(false, "") + if err != nil { + return fmt.Errorf("Failed to close control channel after hang up event: ", err) + } + return fmt.Errorf("Hang up: ", cc.i.GetName()) + } + + if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR { + // close cc, don't send msg + err := cc.close(false, "") + if err != nil { + return fmt.Errorf("Failed to close control channel after receiving an error event: ", err) + } + return fmt.Errorf("Received error event on control channel ", cc.i.GetName()) + } + + if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN { + size, cc.controlLen, _, _, err = syscall.Recvmsg(int(cc.event.Fd), cc.data[:], cc.control[:], 0) + if err != nil { + return fmt.Errorf("recvmsg: %s", err) + } + if size != msgSize { + return fmt.Errorf("invalid message size %d", size) + } + + err = cc.parseMsg() + if err != nil { + return err + } + + err = cc.sendMsg() + if err != nil { + return err + } + + return nil + } + + return fmt.Errorf("Unexpected event: ", event.Events) +} + +// close closes the listener +func (l *listener) close() error { + err := l.socket.delEvent(&l.event) + if err != nil { + return fmt.Errorf("Failed to del event: ", err) + } + err = syscall.Close(int(l.event.Fd)) + if err != nil { + return fmt.Errorf("Failed to close socket: ", err) + } + return nil +} + +// AddListener adds a lisntener to the socket. The fd must describe a +// UNIX domain socket already bound to a UNIX domain filename and +// marked as listener +func (socket *Socket) AddListener(fd int) (err error) { + l := &listener{ + // we will need this to look up master interface by id + socket: socket, + } + + l.event = syscall.EpollEvent{ + Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP, + Fd: int32(fd), + } + err = socket.addEvent(&l.event) + if err != nil { + return fmt.Errorf("Failed to add event: ", err) + } + + socket.listener = l + + return nil +} + +// addListener creates new UNIX domain socket, binds it to the address +// and marks it as listener +func (socket *Socket) addListener() (err error) { + // create socket + fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + if err != nil { + return fmt.Errorf("Failed to create UNIX domain socket") + } + usa := &syscall.SockaddrUnix{Name: socket.filename} + + // Bind to address and start listening + err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1) + if err != nil { + return fmt.Errorf("Failed to set socket option %s : %v", socket.filename, err) + } + err = syscall.Bind(fd, usa) + if err != nil { + return fmt.Errorf("Failed to bind socket %s : %v", socket.filename, err) + } + err = syscall.Listen(fd, syscall.SOMAXCONN) + if err != nil { + return fmt.Errorf("Failed to listen on socket %s : %v", socket.filename, err) + } + + return socket.AddListener(fd) +} + +// close closes a control channel, if the control channel is assigned an +// interface, the interface is disconnected +func (cc *controlChannel) close(sendMsg bool, str string) (err error) { + if sendMsg == true { + // first clear message queue so that the disconnect + // message is the only message in queue + cc.msgQueue = []controlMsg{} + cc.msgEnqDisconnect(str) + + err = cc.sendMsg() + if err != nil { + return err + } + } + + err = cc.socket.delEvent(&cc.event) + if err != nil { + return fmt.Errorf("Failed to del event: ", err) + } + + // remove referance form socket + cc.socket.ccList.Remove(cc.listRef) + + if cc.i != nil { + err = cc.i.disconnect() + if err != nil { + return fmt.Errorf("Interface Disconnect: ", err) + } + } + + return nil +} + +//addControlChannel returns a new controlChannel and adds it to the socket +func (socket *Socket) addControlChannel(fd int, i *Interface) (*controlChannel, error) { + cc := &controlChannel{ + socket: socket, + i: i, + isConnected: false, + } + + var err error + + cc.event = syscall.EpollEvent{ + Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP, + Fd: int32(fd), + } + err = socket.addEvent(&cc.event) + if err != nil { + return nil, fmt.Errorf("Failed to add event: ", err) + } + + cc.listRef = socket.ccList.PushBack(cc) + + return cc, nil +} + +func (cc *controlChannel) msgEnqAck() (err error) { + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeAck) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) msgEnqHello() (err error) { + hello := MsgHello{ + VersionMin: Version, + VersionMax: Version, + MaxRegion: 255, + MaxRingM2S: 255, + MaxRingS2M: 255, + MaxLog2RingSize: 14, + } + + copy(hello.Name[:], []byte(cc.socket.appName)) + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeHello) + err = binary.Write(buf, binary.LittleEndian, hello) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseHello() (err error) { + var hello MsgHello + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &hello) + if err != nil { + return + } + + if hello.VersionMin > Version || hello.VersionMax < Version { + return fmt.Errorf("Incompatible memif version") + } + + cc.i.run = cc.i.args.MemoryConfig + + cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingS2M) + cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingM2S) + cc.i.run.Log2RingSize = min8(cc.i.args.MemoryConfig.Log2RingSize, hello.MaxLog2RingSize) + + cc.i.remoteName = string(hello.Name[:]) + + return nil +} + +func (cc *controlChannel) msgEnqInit() (err error) { + init := MsgInit{ + Version: Version, + Id: cc.i.args.Id, + Mode: interfaceModeEthernet, + } + + copy(init.Name[:], []byte(cc.socket.appName)) + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeInit) + err = binary.Write(buf, binary.LittleEndian, init) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseInit() (err error) { + var init MsgInit + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &init) + if err != nil { + return + } + + if init.Version != Version { + return fmt.Errorf("Incompatible memif driver version") + } + + // find peer interface + for elt := cc.socket.interfaceList.Front(); elt != nil; elt = elt.Next() { + i, ok := elt.Value.(*Interface) + if ok { + if i.args.Id == init.Id && i.args.IsMaster && i.cc == nil { + // verify secret + if i.args.Secret != init.Secret { + return fmt.Errorf("Invalid secret") + } + // interface is assigned to control channel + i.cc = cc + cc.i = i + cc.i.run = cc.i.args.MemoryConfig + cc.i.remoteName = string(init.Name[:]) + + return nil + } + } + } + + return fmt.Errorf("Invalid interface id") +} + +func (cc *controlChannel) msgEnqAddRegion(regionIndex uint16) (err error) { + if len(cc.i.regions) <= int(regionIndex) { + return fmt.Errorf("Invalid region index") + } + + addRegion := MsgAddRegion{ + Index: regionIndex, + Size: cc.i.regions[regionIndex].size, + } + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeAddRegion) + err = binary.Write(buf, binary.LittleEndian, addRegion) + + msg := controlMsg{ + Buffer: buf, + Fd: cc.i.regions[regionIndex].fd, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseAddRegion() (err error) { + var addRegion MsgAddRegion + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &addRegion) + if err != nil { + return + } + + fd, err := cc.parseControlMsg() + if err != nil { + return fmt.Errorf("parseControlMsg: %s", err) + } + + if addRegion.Index > 255 { + return fmt.Errorf("Invalid memory region index") + } + + region := memoryRegion{ + size: addRegion.Size, + fd: fd, + } + + cc.i.regions = append(cc.i.regions, region) + + return nil +} + +func (cc *controlChannel) msgEnqAddRing(ringType ringType, ringIndex uint16) (err error) { + var q Queue + var flags uint16 = 0 + + if ringType == ringTypeS2M { + q = cc.i.txQueues[ringIndex] + flags = msgAddRingFlagS2M + } else { + q = cc.i.rxQueues[ringIndex] + } + + addRing := MsgAddRing{ + Index: ringIndex, + Offset: uint32(q.ring.offset), + Region: uint16(q.ring.region), + RingSizeLog2: uint8(q.ring.log2Size), + Flags: flags, + PrivateHdrSize: 0, + } + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeAddRing) + err = binary.Write(buf, binary.LittleEndian, addRing) + + msg := controlMsg{ + Buffer: buf, + Fd: q.interruptFd, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseAddRing() (err error) { + var addRing MsgAddRing + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &addRing) + if err != nil { + return + } + + fd, err := cc.parseControlMsg() + if err != nil { + return err + } + + if addRing.Index >= cc.i.run.NumQueuePairs { + return fmt.Errorf("invalid ring index") + } + + q := Queue{ + i: cc.i, + interruptFd: fd, + } + + if (addRing.Flags & msgAddRingFlagS2M) == msgAddRingFlagS2M { + q.ring = newRing(int(addRing.Region), ringTypeS2M, int(addRing.Offset), int(addRing.RingSizeLog2)) + cc.i.rxQueues = append(cc.i.rxQueues, q) + } else { + q.ring = newRing(int(addRing.Region), ringTypeM2S, int(addRing.Offset), int(addRing.RingSizeLog2)) + cc.i.txQueues = append(cc.i.txQueues, q) + } + + return nil +} + +func (cc *controlChannel) msgEnqConnect() (err error) { + var connect MsgConnect + copy(connect.Name[:], []byte(cc.i.args.Name)) + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeConnect) + err = binary.Write(buf, binary.LittleEndian, connect) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseConnect() (err error) { + var connect MsgConnect + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &connect) + if err != nil { + return + } + + cc.i.peerName = string(connect.Name[:]) + + err = cc.i.connect() + if err != nil { + return err + } + + cc.isConnected = true + + return nil +} + +func (cc *controlChannel) msgEnqConnected() (err error) { + var connected MsgConnected + copy(connected.Name[:], []byte(cc.i.args.Name)) + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeConnected) + err = binary.Write(buf, binary.LittleEndian, connected) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseConnected() (err error) { + var conn MsgConnected + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &conn) + if err != nil { + return + } + + cc.i.peerName = string(conn.Name[:]) + + err = cc.i.connect() + if err != nil { + return err + } + + cc.isConnected = true + + return nil +} + +func (cc *controlChannel) msgEnqDisconnect(str string) (err error) { + dc := MsgDisconnect{ + // not implemented + Code: 0, + } + copy(dc.String[:], str) + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, msgTypeDisconnect) + err = binary.Write(buf, binary.LittleEndian, dc) + + msg := controlMsg{ + Buffer: buf, + Fd: -1, + } + + cc.msgQueue = append(cc.msgQueue, msg) + + return nil +} + +func (cc *controlChannel) parseDisconnect() (err error) { + var dc MsgDisconnect + + buf := bytes.NewReader(cc.data[msgTypeSize:]) + err = binary.Read(buf, binary.LittleEndian, &dc) + if err != nil { + return + } + + err = cc.close(false, string(dc.String[:])) + if err != nil { + return fmt.Errorf("Failed to disconnect control channel: ", err) + } + + return nil +} + +func (cc *controlChannel) parseMsg() error { + var msgType msgType + var err error + + buf := bytes.NewReader(cc.data[:]) + err = binary.Read(buf, binary.LittleEndian, &msgType) + + if msgType == msgTypeAck { + return nil + } else if msgType == msgTypeHello { + // Configure + err = cc.parseHello() + if err != nil { + goto error + } + // Initialize slave memif + err = cc.i.initializeRegions() + if err != nil { + goto error + } + err = cc.i.initializeQueues() + if err != nil { + goto error + } + // Enqueue messages + err = cc.msgEnqInit() + if err != nil { + goto error + } + for i := 0; i < len(cc.i.regions); i++ { + err = cc.msgEnqAddRegion(uint16(i)) + if err != nil { + goto error + } + } + for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ { + err = cc.msgEnqAddRing(ringTypeS2M, uint16(i)) + if err != nil { + goto error + } + } + for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ { + err = cc.msgEnqAddRing(ringTypeM2S, uint16(i)) + if err != nil { + goto error + } + } + err = cc.msgEnqConnect() + if err != nil { + goto error + } + } else if msgType == msgTypeInit { + err = cc.parseInit() + if err != nil { + goto error + } + + err = cc.msgEnqAck() + if err != nil { + goto error + } + } else if msgType == msgTypeAddRegion { + err = cc.parseAddRegion() + if err != nil { + goto error + } + + err = cc.msgEnqAck() + if err != nil { + goto error + } + } else if msgType == msgTypeAddRing { + err = cc.parseAddRing() + if err != nil { + goto error + } + + err = cc.msgEnqAck() + if err != nil { + goto error + } + } else if msgType == msgTypeConnect { + err = cc.parseConnect() + if err != nil { + goto error + } + + err = cc.msgEnqConnected() + if err != nil { + goto error + } + } else if msgType == msgTypeConnected { + err = cc.parseConnected() + if err != nil { + goto error + } + } else if msgType == msgTypeDisconnect { + err = cc.parseDisconnect() + if err != nil { + goto error + } + } else { + err = fmt.Errorf("unknown message %d", msgType) + goto error + } + + return nil + +error: + err1 := cc.close(true, err.Error()) + if err1 != nil { + return fmt.Errorf(err.Error(), ": Failed to close control channel: ", err1) + } + + return err +} + +// parseControlMsg parses control message and returns file descriptor +// if any +func (cc *controlChannel) parseControlMsg() (fd int, err error) { + // Assert only called when we require FD + fd = -1 + + controlMsgs, err := syscall.ParseSocketControlMessage(cc.control[:cc.controlLen]) + if err != nil { + return -1, fmt.Errorf("syscall.ParseSocketControlMessage: %s", err) + } + + if len(controlMsgs) == 0 { + return -1, fmt.Errorf("Missing control message") + } + + for _, cmsg := range controlMsgs { + if cmsg.Header.Level == syscall.SOL_SOCKET { + if cmsg.Header.Type == syscall.SCM_RIGHTS { + FDs, err := syscall.ParseUnixRights(&cmsg) + if err != nil { + return -1, fmt.Errorf("syscall.ParseUnixRights: %s", err) + } + if len(FDs) == 0 { + continue + } + // Only expect single FD + fd = FDs[0] + } + } + } + + if fd == -1 { + return -1, fmt.Errorf("Missing file descriptor") + } + + return fd, nil +} diff --git a/extras/gomemif/memif/control_channel_unsafe.go b/extras/gomemif/memif/control_channel_unsafe.go new file mode 100644 index 00000000000..9e91297b160 --- /dev/null +++ b/extras/gomemif/memif/control_channel_unsafe.go @@ -0,0 +1,60 @@ +/* + *------------------------------------------------------------------ + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *------------------------------------------------------------------ + */ + +package memif + +import ( + "fmt" + "os" + "syscall" + "unsafe" +) + +// sendMsg sends a control message from contorl channels message queue +func (cc *controlChannel) sendMsg() (err error) { + if len(cc.msgQueue) < 1 { + return nil + } + // Get message buffer + msg := cc.msgQueue[0] + // Dequeue + cc.msgQueue = cc.msgQueue[1:] + + iov := &syscall.Iovec{ + Base: &msg.Buffer.Bytes()[0], + Len: msgSize, + } + + msgh := syscall.Msghdr{ + Iov: iov, + Iovlen: 1, + } + + if msg.Fd > 0 { + oob := syscall.UnixRights(msg.Fd) + msgh.Control = &oob[0] + msgh.Controllen = uint64(syscall.CmsgSpace(4)) + } + + _, _, errno := syscall.Syscall(syscall.SYS_SENDMSG, uintptr(cc.event.Fd), uintptr(unsafe.Pointer(&msgh)), uintptr(0)) + if errno != 0 { + err = os.NewSyscallError("sendmsg", errno) + return fmt.Errorf("SYS_SENDMSG: %s", errno) + } + + return nil +} diff --git a/extras/gomemif/memif/interface.go b/extras/gomemif/memif/interface.go new file mode 100644 index 00000000000..a571deb43c9 --- /dev/null +++ b/extras/gomemif/memif/interface.go @@ -0,0 +1,507 @@ +/* + *------------------------------------------------------------------ + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *------------------------------------------------------------------ + */ + +// Package memif provides the implementation of shared memory interface (memif). +// +// Memif network interfaces communicate using UNIX domain socket. This socket +// must be first created using NewSocket(). Then interfaces can be added +// to this socket using NewInterface(). To start communication on each socket +// socket.StartPolling() must be called. socket.StopPolling() will stop +// the communication. When the interface changes link status Connected and +// Disconencted callbacks set in Arguments for each interface are called +// respectively. Once the interface is connected rx and tx queues can be +// aquired using interface.GetRxQueue() and interface.GetTxQueue(). +// Packets can be transmitted by calling queue.ReadPacket() on rx queues and +// queue.WritePacket() on tx queues. If the interface is disconnected +// queue.ReadPacket() and queue.WritePacket() MUST not be called. +// +// Data transmission is backed by shared memory. The driver works in +// promiscuous mode only. +package memif + +import ( + "container/list" + "fmt" + "os" + "syscall" +) + +const ( + DefaultSocketFilename = "/run/vpp/memif.sock" + DefaultNumQueuePairs = 1 + DefaultLog2RingSize = 10 + DefaultPacketBufferSize = 2048 +) + +const mfd_allow_sealing = 2 +const sys_memfd_create = 319 +const f_add_seals = 1033 +const f_seal_shrink = 0x0002 + +const efd_nonblock = 04000 + +// ConnectedFunc is a callback called when an interface is connected +type ConnectedFunc func(i *Interface) error + +// DisconnectedFunc is a callback called when an interface is disconnected +type DisconnectedFunc func(i *Interface) error + +// MemoryConfig represents shared memory configuration +type MemoryConfig struct { + NumQueuePairs uint16 // number of queue pairs + Log2RingSize uint8 // ring size as log2 + PacketBufferSize uint32 // size of single packet buffer +} + +// Arguments represent interface configuration +type Arguments struct { + Id uint32 // Interface identifier unique across socket. Used to identify peer interface when connecting + IsMaster bool // Interface role master/slave + Name string + Secret [24]byte // optional parameter, secrets of the interfaces must match if they are to connect + MemoryConfig MemoryConfig + ConnectedFunc ConnectedFunc // callback called when interface changes status to connected + DisconnectedFunc DisconnectedFunc // callback called when interface changes status to disconnected + PrivateData interface{} // private data used by client program +} + +// memoryRegion represents a shared memory mapped file +type memoryRegion struct { + data []byte + size uint64 + fd int + packetBufferOffset uint32 +} + +// Queue represents rx or tx queue +type Queue struct { + ring *ring + i *Interface + lastHead uint16 + lastTail uint16 + interruptFd int +} + +// Interface represents memif network interface +type Interface struct { + args Arguments + run MemoryConfig + privateData interface{} + listRef *list.Element + socket *Socket + cc *controlChannel + remoteName string + peerName string + regions []memoryRegion + txQueues []Queue + rxQueues []Queue +} + +// IsMaster returns true if the interfaces role is master, else returns false +func (i *Interface) IsMaster() bool { + return i.args.IsMaster +} + +// GetRemoteName returns the name of the application on which the peer +// interface exists +func (i *Interface) GetRemoteName() string { + return i.remoteName +} + +// GetPeerName returns peer interfaces name +func (i *Interface) GetPeerName() string { + return i.peerName +} + +// GetName returens interfaces name +func (i *Interface) GetName() string { + return i.args.Name +} + +// GetMemoryConfig returns interfaces active memory config. +// If interface is not connected the config is invalid. +func (i *Interface) GetMemoryConfig() MemoryConfig { + return i.run +} + +// GetRxQueue returns an rx queue specified by queue index +func (i *Interface) GetRxQueue(qid int) (*Queue, error) { + if qid >= len(i.rxQueues) { + return nil, fmt.Errorf("Invalid Queue index") + } + return &i.rxQueues[qid], nil +} + +// GetRxQueue returns a tx queue specified by queue index +func (i *Interface) GetTxQueue(qid int) (*Queue, error) { + if qid >= len(i.txQueues) { + return nil, fmt.Errorf("Invalid Queue index") + } + return &i.txQueues[qid], nil +} + +// GetEventFd returns queues interrupt event fd +func (q *Queue) GetEventFd() (int, error) { + return q.interruptFd, nil +} + +// GetFilename returns sockets filename +func (socket *Socket) GetFilename() string { + return socket.filename +} + +// close closes the queue +func (q *Queue) close() { + syscall.Close(q.interruptFd) +} + +// IsConnecting returns true if the interface is connecting +func (i *Interface) IsConnecting() bool { + if i.cc != nil { + return true + } + return false +} + +// IsConnected returns true if the interface is connected +func (i *Interface) IsConnected() bool { + if i.cc != nil && i.cc.isConnected { + return true + } + return false +} + +// Disconnect disconnects the interface +func (i *Interface) Disconnect() (err error) { + if i.cc != nil { + // close control and disconenct interface + return i.cc.close(true, "Interface disconnected") + } + return nil +} + +// disconnect finalizes interface disconnection +func (i *Interface) disconnect() (err error) { + if i.cc == nil { // disconnected + return nil + } + + err = i.args.DisconnectedFunc(i) + if err != nil { + return fmt.Errorf("DisconnectedFunc: ", err) + } + + for _, q := range i.txQueues { + q.close() + } + i.txQueues = []Queue{} + + for _, q := range i.rxQueues { + q.close() + } + i.rxQueues = []Queue{} + + // unmap regions + for _, r := range i.regions { + err = syscall.Munmap(r.data) + if err != nil { + return err + } + err = syscall.Close(r.fd) + if err != nil { + return err + } + } + i.regions = nil + i.cc = nil + + i.peerName = "" + i.remoteName = "" + + return nil +} + +// Delete deletes the interface +func (i *Interface) Delete() (err error) { + i.Disconnect() + + // remove referance on socket + i.socket.interfaceList.Remove(i.listRef) + i = nil + + return nil +} + +// GetSocket returns the socket the interface belongs to +func (i *Interface) GetSocket() *Socket { + return i.socket +} + +// GetPrivateDate returns interfaces private data +func (i *Interface) GetPrivateData() interface{} { + return i.args.PrivateData +} + +// GetId returns interfaces id +func (i *Interface) GetId() uint32 { + return i.args.Id +} + +// RoleToString returns 'Master' if isMaster os true, else returns 'Slave' +func RoleToString(isMaster bool) string { + if isMaster { + return "Master" + } + return "Slave" +} + +// RequestConnection is used by slave interface to connect to a socket and +// create a control channel +func (i *Interface) RequestConnection() error { + if i.IsMaster() { + return fmt.Errorf("Only slave can request connection") + } + // create socket + fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + if err != nil { + return fmt.Errorf("Failed to create UNIX domain socket: %v", err) + } + usa := &syscall.SockaddrUnix{Name: i.socket.filename} + + // Connect to listener socket + err = syscall.Connect(fd, usa) + if err != nil { + return fmt.Errorf("Failed to connect socket %s : %v", i.socket.filename, err) + } + + // Create control channel + i.cc, err = i.socket.addControlChannel(fd, i) + if err != nil { + return fmt.Errorf("Failed to create control channel: %v", err) + } + + return nil +} + +// NewInterface returns a new memif network interface. When creating an interface +// it's id must be unique across socket with the exception of loopback interface +// in which case the id is the same but role differs +func (socket *Socket) NewInterface(args *Arguments) (*Interface, error) { + var err error + // make sure the ID is unique on this socket + for elt := socket.interfaceList.Front(); elt != nil; elt = elt.Next() { + i, ok := elt.Value.(*Interface) + if ok { + if i.args.Id == args.Id && i.args.IsMaster == args.IsMaster { + return nil, fmt.Errorf("Interface with id %u role %s already exists on this socket", args.Id, RoleToString(args.IsMaster)) + } + } + } + + // copy interface configuration + i := Interface{ + args: *args, + } + // set default values + if i.args.MemoryConfig.NumQueuePairs == 0 { + i.args.MemoryConfig.NumQueuePairs = DefaultNumQueuePairs + } + if i.args.MemoryConfig.Log2RingSize == 0 { + i.args.MemoryConfig.Log2RingSize = DefaultLog2RingSize + } + if i.args.MemoryConfig.PacketBufferSize == 0 { + i.args.MemoryConfig.PacketBufferSize = DefaultPacketBufferSize + } + + i.socket = socket + + // append interface to the list + i.listRef = socket.interfaceList.PushBack(&i) + + if i.args.IsMaster { + if socket.listener == nil { + err = socket.addListener() + if err != nil { + return nil, fmt.Errorf("Failed to create listener channel: %s", err) + } + } + } + + return &i, nil +} + +// eventFd returns an eventfd (SYS_EVENTFD2) +func eventFd() (efd int, err error) { + u_efd, _, errno := syscall.Syscall(syscall.SYS_EVENTFD2, uintptr(0), uintptr(efd_nonblock), 0) + if errno != 0 { + return -1, os.NewSyscallError("eventfd", errno) + } + return int(u_efd), nil +} + +// addRegions creates and adds a new memory region to the interface (slave only) +func (i *Interface) addRegion(hasPacketBuffers bool, hasRings bool) (err error) { + var r memoryRegion + + if hasRings { + r.packetBufferOffset = uint32((i.run.NumQueuePairs + i.run.NumQueuePairs) * (ringSize + descSize*(1<<i.run.Log2RingSize))) + } else { + r.packetBufferOffset = 0 + } + + if hasPacketBuffers { + r.size = uint64(r.packetBufferOffset + i.run.PacketBufferSize*uint32(1<<i.run.Log2RingSize)*uint32(i.run.NumQueuePairs+i.run.NumQueuePairs)) + } else { + r.size = uint64(r.packetBufferOffset) + } + + r.fd, err = memfdCreate() + if err != nil { + return err + } + + _, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(r.fd), uintptr(f_add_seals), uintptr(f_seal_shrink)) + if errno != 0 { + syscall.Close(r.fd) + return fmt.Errorf("memfdCreate: %s", os.NewSyscallError("fcntl", errno)) + } + + err = syscall.Ftruncate(r.fd, int64(r.size)) + if err != nil { + syscall.Close(r.fd) + r.fd = -1 + return fmt.Errorf("memfdCreate: %s", err) + } + + r.data, err = syscall.Mmap(r.fd, 0, int(r.size), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + return fmt.Errorf("addRegion: %s", err) + } + + i.regions = append(i.regions, r) + + return nil +} + +// initializeRegions initializes interfaces regions (slave only) +func (i *Interface) initializeRegions() (err error) { + + err = i.addRegion(true, true) + if err != nil { + return fmt.Errorf("initializeRegions: %s", err) + } + + return nil +} + +// initializeQueues initializes interfaces queues (slave only) +func (i *Interface) initializeQueues() (err error) { + var q *Queue + var desc descBuf + var slot int + + desc = newDescBuf() + desc.setFlags(0) + desc.setRegion(0) + desc.setLength(int(i.run.PacketBufferSize)) + + for qid := 0; qid < int(i.run.NumQueuePairs); qid++ { + /* TX */ + q = &Queue{ + ring: i.newRing(0, ringTypeS2M, qid), + lastHead: 0, + lastTail: 0, + i: i, + } + q.ring.setCookie(cookie) + q.ring.setFlags(1) + q.interruptFd, err = eventFd() + if err != nil { + return err + } + q.putRing() + i.txQueues = append(i.txQueues, *q) + + for j := 0; j < q.ring.size; j++ { + slot = qid*q.ring.size + j + desc.setOffset(int(i.regions[0].packetBufferOffset + uint32(slot)*i.run.PacketBufferSize)) + q.putDescBuf(slot, desc) + } + } + for qid := 0; qid < int(i.run.NumQueuePairs); qid++ { + /* RX */ + q = &Queue{ + ring: i.newRing(0, ringTypeM2S, qid), + lastHead: 0, + lastTail: 0, + i: i, + } + q.ring.setCookie(cookie) + q.ring.setFlags(1) + q.interruptFd, err = eventFd() + if err != nil { + return err + } + q.putRing() + i.rxQueues = append(i.rxQueues, *q) + + for j := 0; j < q.ring.size; j++ { + slot = qid*q.ring.size + j + desc.setOffset(int(i.regions[0].packetBufferOffset + uint32(slot)*i.run.PacketBufferSize)) + q.putDescBuf(slot, desc) + } + } + + return nil +} + +// connect finalizes interface connection +func (i *Interface) connect() (err error) { + for rid, _ := range i.regions { + r := &i.regions[rid] + if r.data == nil { + r.data, err = syscall.Mmap(r.fd, 0, int(r.size), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED) + if err != nil { + return fmt.Errorf("Mmap: %s", err) + } + } + } + + for _, q := range i.txQueues { + q.updateRing() + + if q.ring.getCookie() != cookie { + return fmt.Errorf("Wrong cookie") + } + + q.lastHead = 0 + q.lastTail = 0 + } + + for _, q := range i.rxQueues { + q.updateRing() + + if q.ring.getCookie() != cookie { + return fmt.Errorf("Wrong cookie") + } + + q.lastHead = 0 + q.lastTail = 0 + } + + return i.args.ConnectedFunc(i) +} diff --git a/extras/gomemif/memif/interface_unsafe.go b/extras/gomemif/memif/interface_unsafe.go new file mode 100644 index 00000000000..f5cbc2ed207 --- /dev/null +++ b/extras/gomemif/memif/interface_unsafe.go @@ -0,0 +1,40 @@ +/* + *------------------------------------------------------------------ + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *------------------------------------------------------------------ + */ + +package memif + +import ( + "fmt" + "os" + "syscall" + "unsafe" +) + +// memfdCreate returns memory file file descriptor (memif.sys_memfd_create) +func memfdCreate() (mfd int, err error) { + p0, err := syscall.BytePtrFromString("memif_region_0") + if err != nil { + return -1, fmt.Errorf("memfdCreate: %s", err) + } + + u_mfd, _, errno := syscall.Syscall(sys_memfd_create, uintptr(unsafe.Pointer(p0)), uintptr(mfd_allow_sealing), uintptr(0)) + if errno != 0 { + return -1, fmt.Errorf("memfdCreate: %s", os.NewSyscallError("memfd_create", errno)) + } + + return int(u_mfd), nil +} diff --git a/extras/gomemif/memif/memif.go b/extras/gomemif/memif/memif.go new file mode 100644 index 00000000000..1a7857de03e --- /dev/null +++ b/extras/gomemif/memif/memif.go @@ -0,0 +1,345 @@ +/* + *------------------------------------------------------------------ + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *------------------------------------------------------------------ + */ + +package memif + +import ( + "encoding/binary" + "fmt" + "syscall" +) + +const cookie = 0x3E31F20 + +// VersionMajor is memif protocols major version +const VersionMajor = 2 + +// VersionMinor is memif protocols minor version +const VersionMinor = 0 + +// Version is memif protocols version as uint16 +// (M-Major m-minor: MMMMMMMMmmmmmmmm) +const Version = ((VersionMajor << 8) | VersionMinor) + +type msgType uint16 + +const ( + msgTypeNone msgType = iota + msgTypeAck + msgTypeHello + msgTypeInit + msgTypeAddRegion + msgTypeAddRing + msgTypeConnect + msgTypeConnected + msgTypeDisconnect +) + +type interfaceMode uint8 + +const ( + interfaceModeEthernet interfaceMode = iota + interfaceModeIp + interfaceModePuntInject +) + +const msgSize = 128 +const msgTypeSize = 2 + +const msgAddRingFlagS2M = (1 << 0) + +// Descriptor flags +// +// next buffer present +const descFlagNext = (1 << 0) + +// Ring flags +// +// Interrupt +const ringFlagInterrupt = 1 + +func min16(a uint16, b uint16) uint16 { + if a < b { + return a + } + return b +} + +func min8(a uint8, b uint8) uint8 { + if a < b { + return a + } + return b +} + +type MsgHello struct { + // app name + Name [32]byte + VersionMin uint16 + VersionMax uint16 + MaxRegion uint16 + MaxRingM2S uint16 + MaxRingS2M uint16 + MaxLog2RingSize uint8 +} + +type MsgInit struct { + Version uint16 + Id uint32 + Mode interfaceMode + Secret [24]byte + // app name + Name [32]byte +} + +type MsgAddRegion struct { + Index uint16 + Size uint64 +} + +type MsgAddRing struct { + Flags uint16 + Index uint16 + Region uint16 + Offset uint32 + RingSizeLog2 uint8 + PrivateHdrSize uint16 +} + +type MsgConnect struct { + // interface name + Name [32]byte +} + +type MsgConnected struct { + // interface name + Name [32]byte +} + +type MsgDisconnect struct { + Code uint32 + String [96]byte +} + +/* DESCRIPTOR BEGIN */ + +const descSize = 16 + +// desc field offsets +const descFlagsOffset = 0 +const descRegionOffset = 2 +const descLengthOffset = 4 +const descOffsetOffset = 8 +const descMetadataOffset = 12 + +// descBuf represents a memif descriptor as array of bytes +type descBuf []byte + +// newDescBuf returns new descriptor buffer +func newDescBuf() descBuf { + return make(descBuf, descSize) +} + +// getDescBuff copies descriptor from shared memory to descBuf +func (q *Queue) getDescBuf(slot int, db descBuf) { + copy(db, q.i.regions[q.ring.region].data[q.ring.offset+ringSize+slot*descSize:]) +} + +// putDescBuf copies contents of descriptor buffer into shared memory +func (q *Queue) putDescBuf(slot int, db descBuf) { + copy(q.i.regions[q.ring.region].data[q.ring.offset+ringSize+slot*descSize:], db) +} + +func (db descBuf) getFlags() int { + return (int)(binary.LittleEndian.Uint16((db)[descFlagsOffset:])) +} + +func (db descBuf) getRegion() int { + return (int)(binary.LittleEndian.Uint16((db)[descRegionOffset:])) +} + +func (db descBuf) getLength() int { + return (int)(binary.LittleEndian.Uint32((db)[descLengthOffset:])) +} + +func (db descBuf) getOffset() int { + return (int)(binary.LittleEndian.Uint32((db)[descOffsetOffset:])) +} + +func (db descBuf) getMetadata() int { + return (int)(binary.LittleEndian.Uint32((db)[descMetadataOffset:])) +} + +func (db descBuf) setFlags(val int) { + binary.LittleEndian.PutUint16((db)[descFlagsOffset:], uint16(val)) +} + +func (db descBuf) setRegion(val int) { + binary.LittleEndian.PutUint16((db)[descRegionOffset:], uint16(val)) +} + +func (db descBuf) setLength(val int) { + binary.LittleEndian.PutUint32((db)[descLengthOffset:], uint32(val)) +} + +func (db descBuf) setOffset(val int) { + binary.LittleEndian.PutUint32((db)[descOffsetOffset:], uint32(val)) +} + +func (db descBuf) setMetadata(val int) { + binary.LittleEndian.PutUint32((db)[descMetadataOffset:], uint32(val)) +} + +/* DESCRIPTOR END */ + +/* RING BEGIN */ + +type ringType uint8 + +const ( + ringTypeS2M ringType = iota + ringTypeM2S +) + +const ringSize = 128 + +// ring field offsets +const ringCookieOffset = 0 +const ringFlagsOffset = 4 +const ringHeadOffset = 6 +const ringTailOffset = 64 + +// ringBuf represents a memif ring as array of bytes +type ringBuf []byte + +type ring struct { + ringType ringType + size int + log2Size int + region int + rb ringBuf + offset int +} + +// newRing returns new memif ring based on data received in msgAddRing (master only) +func newRing(regionIndex int, ringType ringType, ringOffset int, log2RingSize int) *ring { + r := &ring{ + ringType: ringType, + size: (1 << log2RingSize), + log2Size: log2RingSize, + rb: make(ringBuf, ringSize), + offset: ringOffset, + } + + return r +} + +// newRing returns a new memif ring +func (i *Interface) newRing(regionIndex int, ringType ringType, ringIndex int) *ring { + r := &ring{ + ringType: ringType, + size: (1 << i.run.Log2RingSize), + log2Size: int(i.run.Log2RingSize), + rb: make(ringBuf, ringSize), + } + + rSize := ringSize + descSize*r.size + if r.ringType == ringTypeS2M { + r.offset = 0 + } else { + r.offset = int(i.run.NumQueuePairs) * rSize + } + r.offset += ringIndex * rSize + + return r +} + +// putRing put the ring to the shared memory +func (q *Queue) putRing() { + copy(q.i.regions[q.ring.region].data[q.ring.offset:], q.ring.rb) +} + +// updateRing updates ring with data from shared memory +func (q *Queue) updateRing() { + copy(q.ring.rb, q.i.regions[q.ring.region].data[q.ring.offset:]) +} + +func (r *ring) getCookie() int { + return (int)(binary.LittleEndian.Uint32((r.rb)[ringCookieOffset:])) +} + +// getFlags returns the flags value from ring buffer +// Use Queue.getFlags in fast-path to avoid updating the whole ring. +func (r *ring) getFlags() int { + return (int)(binary.LittleEndian.Uint16((r.rb)[ringFlagsOffset:])) +} + +// getHead returns the head pointer value from ring buffer. +// Use readHead in fast-path to avoid updating the whole ring. +func (r *ring) getHead() int { + return (int)(binary.LittleEndian.Uint16((r.rb)[ringHeadOffset:])) +} + +// getTail returns the tail pointer value from ring buffer. +// Use readTail in fast-path to avoid updating the whole ring. +func (r *ring) getTail() int { + return (int)(binary.LittleEndian.Uint16((r.rb)[ringTailOffset:])) +} + +func (r *ring) setCookie(val int) { + binary.LittleEndian.PutUint32((r.rb)[ringCookieOffset:], uint32(val)) +} + +func (r *ring) setFlags(val int) { + binary.LittleEndian.PutUint16((r.rb)[ringFlagsOffset:], uint16(val)) +} + +// setHead set the head pointer value int the ring buffer. +// Use writeHead in fast-path to avoid putting the whole ring into shared memory. +func (r *ring) setHead(val int) { + binary.LittleEndian.PutUint16((r.rb)[ringHeadOffset:], uint16(val)) +} + +// setTail set the tail pointer value int the ring buffer. +// Use writeTail in fast-path to avoid putting the whole ring into shared memory. +func (r *ring) setTail(val int) { + binary.LittleEndian.PutUint16((r.rb)[ringTailOffset:], uint16(val)) +} + +/* RING END */ + +// isInterrupt returns true if the queue is in interrupt mode +func (q *Queue) isInterrupt() bool { + return (q.getFlags() & ringFlagInterrupt) == 0 +} + +// interrupt performs an interrupt if the queue is in interrupt mode +func (q *Queue) interrupt() error { + if q.isInterrupt() { + buf := make([]byte, 8) + binary.PutUvarint(buf, 1) + n, err := syscall.Write(q.interruptFd, buf[:]) + if err != nil { + return err + } + if n != 8 { + return fmt.Errorf("Faild to write to eventfd") + } + } + + return nil +} diff --git a/extras/gomemif/memif/memif_unsafe.go b/extras/gomemif/memif/memif_unsafe.go new file mode 100644 index 00000000000..4469d26e982 --- /dev/null +++ b/extras/gomemif/memif/memif_unsafe.go @@ -0,0 +1,55 @@ +/* + *------------------------------------------------------------------ + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *------------------------------------------------------------------ + */ + +package memif + +import ( + "unsafe" +) + +// readHead reads ring head directly form the shared memory +func (q *Queue) readHead() (head int) { + return (int)(*(*uint16)(unsafe.Pointer(&q.i.regions[q.ring.region].data[q.ring.offset+ringHeadOffset]))) + // return atomicload16(&q.i.regions[q.region].data[q.offset + descHeadOffset]) +} + +// readTail reads ring tail directly form the shared memory +func (q *Queue) readTail() (tail int) { + return (int)(*(*uint16)(unsafe.Pointer(&q.i.regions[q.ring.region].data[q.ring.offset+ringTailOffset]))) + // return atomicload16(&q.i.regions[q.region].data[q.offset + descTailOffset]) +} + +// writeHead writes ring head directly to the shared memory +func (q *Queue) writeHead(value int) { + *(*uint16)(unsafe.Pointer(&q.i.regions[q.ring.region].data[q.ring.offset+ringHeadOffset])) = *(*uint16)(unsafe.Pointer(&value)) + //atomicstore16(&q.i.regions[q.region].data[q.offset + descHeadOffset], value) +} + +// writeTail writes ring tail directly to the shared memory +func (q *Queue) writeTail(value int) { + *(*uint16)(unsafe.Pointer(&q.i.regions[q.ring.region].data[q.ring.offset+ringTailOffset])) = *(*uint16)(unsafe.Pointer(&value)) + //atomicstore16(&q.i.regions[q.region].data[q.offset + descTailOffset], value) +} + +func (q *Queue) setDescLength(slot int, length int) { + *(*uint16)(unsafe.Pointer(&q.i.regions[q.ring.region].data[q.ring.offset+ringSize+slot*descSize+descLengthOffset])) = *(*uint16)(unsafe.Pointer(&length)) +} + +// getFlags reads ring flags directly from the shared memory +func (q *Queue) getFlags() int { + return (int)(*(*uint16)(unsafe.Pointer(&q.i.regions[q.ring.region].data[q.ring.offset+ringFlagsOffset]))) +} diff --git a/extras/gomemif/memif/packet_reader.go b/extras/gomemif/memif/packet_reader.go new file mode 100644 index 00000000000..58338f6f2ab --- /dev/null +++ b/extras/gomemif/memif/packet_reader.go @@ -0,0 +1,91 @@ +/* + *------------------------------------------------------------------ + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *------------------------------------------------------------------ + */ + +package memif + +import "fmt" + +// ReadPacket reads one packet form the shared memory and +// returns the number of bytes read +func (q *Queue) ReadPacket(pkt []byte) (int, error) { + var mask int = q.ring.size - 1 + var slot int + var lastSlot int + var length int + var offset int + var pktOffset int = 0 + var nSlots uint16 + var desc descBuf = newDescBuf() + + if q.i.args.IsMaster { + slot = int(q.lastHead) + lastSlot = q.readHead() + } else { + slot = int(q.lastTail) + lastSlot = q.readTail() + } + + nSlots = uint16(lastSlot - slot) + if nSlots == 0 { + goto refill + } + + // copy descriptor from shm + q.getDescBuf(slot&mask, desc) + length = desc.getLength() + offset = desc.getOffset() + + copy(pkt[:], q.i.regions[desc.getRegion()].data[offset:offset+length]) + pktOffset += length + + slot++ + nSlots-- + + for (desc.getFlags() & descFlagNext) == descFlagNext { + if nSlots == 0 { + return 0, fmt.Errorf("Incomplete chained buffer, may suggest peer error.") + } + + q.getDescBuf(slot&mask, desc) + length = desc.getLength() + offset = desc.getOffset() + + copy(pkt[pktOffset:], q.i.regions[desc.getRegion()].data[offset:offset+length]) + pktOffset += length + + slot++ + nSlots-- + } + +refill: + if q.i.args.IsMaster { + q.lastHead = uint16(slot) + q.writeTail(slot) + } else { + q.lastTail = uint16(slot) + + head := q.readHead() + + for nSlots := uint16(q.ring.size - head + int(q.lastTail)); nSlots > 0; nSlots-- { + q.setDescLength(head&mask, int(q.i.run.PacketBufferSize)) + head++ + } + q.writeHead(head) + } + + return pktOffset, nil +} diff --git a/extras/gomemif/memif/packet_writer.go b/extras/gomemif/memif/packet_writer.go new file mode 100644 index 00000000000..702044f4b49 --- /dev/null +++ b/extras/gomemif/memif/packet_writer.go @@ -0,0 +1,95 @@ +/* + *------------------------------------------------------------------ + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *------------------------------------------------------------------ + */ + +package memif + +// WritePacket writes one packet to the shared memory and +// returns the number of bytes written +func (q *Queue) WritePacket(pkt []byte) int { + var mask int = q.ring.size - 1 + var slot int + var nFree uint16 + var packetBufferSize int = int(q.i.run.PacketBufferSize) + + if q.i.args.IsMaster { + slot = q.readTail() + nFree = uint16(q.readHead() - slot) + } else { + slot = q.readHead() + nFree = uint16(q.ring.size - slot + q.readTail()) + } + + if nFree == 0 { + q.interrupt() + return 0 + } + + // copy descriptor from shm + desc := newDescBuf() + q.getDescBuf(slot&mask, desc) + // reset flags + desc.setFlags(0) + // reset length + if q.i.args.IsMaster { + packetBufferSize = desc.getLength() + } + desc.setLength(0) + offset := desc.getOffset() + + // write packet into memif buffer + n := copy(q.i.regions[desc.getRegion()].data[offset:offset+packetBufferSize], pkt[:]) + desc.setLength(n) + for n < len(pkt) { + nFree-- + if nFree == 0 { + q.interrupt() + return 0 + } + desc.setFlags(descFlagNext) + q.putDescBuf(slot&mask, desc) + slot++ + + // copy descriptor from shm + q.getDescBuf(slot&mask, desc) + // reset flags + desc.setFlags(0) + // reset length + if q.i.args.IsMaster { + packetBufferSize = desc.getLength() + } + desc.setLength(0) + offset := desc.getOffset() + + tmp := copy(q.i.regions[desc.getRegion()].data[offset:offset+packetBufferSize], pkt[:]) + desc.setLength(tmp) + n += tmp + } + + // copy descriptor to shm + q.putDescBuf(slot&mask, desc) + slot++ + + if q.i.args.IsMaster { + q.writeTail(slot) + } else { + q.writeHead(slot) + } + + q.interrupt() + + return n +} |