summaryrefslogtreecommitdiffstats
path: root/src/vpp-api/python/vpp_papi/vpp_transport_socket.py
blob: 174ab74d0b83dbd497252008449187d03fb505fe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
#
# VPP Unix Domain Socket Transport.
#
import socket
import struct
import threading
import select
import multiprocessing
import queue
import logging

logger = logging.getLogger("vpp_papi.transport")
logger.addHandler(logging.NullHandler())


class VppTransportSocketIOError(IOError):
    # TODO: Document different values of error number (first numeric argument).
    pass


class VppTransport:
    VppTransportSocketIOError = VppTransportSocketIOError

    def __init__(self, parent, read_timeout, server_address):
        self.connected = False
        self.read_timeout = read_timeout if read_timeout > 0 else None
        self.parent = parent
        self.server_address = server_address
        self.header = struct.Struct(">QII")
        self.message_table = {}
        # These queues can be accessed async.
        # They are always up, but replaced on connect.
        # TODO: Use multiprocessing.Pipe instead of multiprocessing.Queue
        # if possible.
        self.sque = multiprocessing.Queue()
        self.q = multiprocessing.Queue()
        # The following fields are set in connect().
        self.message_thread = None
        self.socket = None

    def msg_thread_func(self):
        while True:
            try:
                rlist, _, _ = select.select([self.socket, self.sque._reader], [], [])
            except (socket.error, ValueError):
                # Terminate thread
                logging.error("select failed")
                self.q.put(None)
                return

            for r in rlist:
                if r == self.sque._reader:
                    # Terminate
                    self.q.put(None)
                    return

                elif r == self.socket:
                    try:
                        msg = self._read()
                        if not msg:
                            self.q.put(None)
                            return
                    except socket.error:
                        self.q.put(None)
                        return
                    # Put either to local queue or if context == 0
                    # callback queue
                    if not self.do_async and self.parent.has_context(msg):
                        self.q.put(msg)
                    else:
                        self.parent.msg_handler_async(msg)
                else:
                    raise VppTransportSocketIOError(2, "Unknown response from select")

    def connect(self, name, pfx, msg_handler, rx_qlen, do_async=False):
        # TODO: Reorder the actions and add "roll-backs",
        # to restore clean disconnect state when failure happens durng connect.

        if self.message_thread is not None:
            raise VppTransportSocketIOError(
                1, "PAPI socket transport connect: Need to disconnect first."
            )

        # Create a UDS socket
        self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self.socket.settimeout(self.read_timeout)

        # Connect the socket to the port where the server is listening
        try:
            self.socket.connect(self.server_address)
        except socket.error as msg:
            # logging.error("{} on socket {}".format(msg, self.server_address))
            raise msg

        self.connected = True

        # Queues' feeder threads from previous connect may still be sending.
        # Close and join to avoid any errors.
        self.sque.close()
        self.q.close()
        self.sque.join_thread()
        self.q.join_thread()
        # Finally safe to replace.
        self.sque = multiprocessing.Queue()
        self.q = multiprocessing.Queue()
        self.message_thread = threading.Thread(target=self.msg_thread_func)

        # Initialise sockclnt_create
        sockclnt_create = self.parent.messages["sockclnt_create"]
        sockclnt_create_reply = self.parent.messages["sockclnt_create_reply"]

        args = {"_vl_msg_id": 15, "name": name, "context": 124}
        b = sockclnt_create.pack(args)
        self.write(b)
        msg = self._read()
        hdr, length = self.parent.header.unpack(msg, 0)
        if hdr.msgid != 16:
            # TODO: Add first numeric argument.
            raise VppTransportSocketIOError("Invalid reply message")

        r, length = sockclnt_create_reply.unpack(msg)
        self.socket_index = r.index
        for m in r.message_table:
            n = m.name
            self.message_table[n] = m.index

        self.message_thread.daemon = True
        self.do_async = do_async
        self.message_thread.start()

        return 0

    def disconnect(self):
        # TODO: Support repeated disconnect calls, recommend users to call
        # disconnect when they are not sure what the state is after failures.
        # TODO: Any volunteer for comprehensive docstrings?
        rv = 0
        try:
            # Might fail, if VPP closes socket before packet makes it out,
            # or if there was a failure during connect().
            # TODO: manually build message so that .disconnect releases server-side resources
            rv = self.parent.api.sockclnt_delete(index=self.socket_index)
        except (IOError, self.parent.VPPApiError):
            pass
        self.connected = False
        if self.socket is not None:
            self.socket.close()
        if self.sque is not None:
            self.sque.put(True)  # Terminate listening thread
        if self.message_thread is not None and self.message_thread.is_alive():
            # Allow additional connect() calls.
            self.message_thread.join()
        # Wipe message table, VPP can be restarted with different plugins.
        self.message_table = {}
        # Collect garbage.
        self.message_thread = None
        self.socket = None
        # Queues will be collected after connect replaces them.
        return rv

    def suspend(self):
        pass

    def resume(self):
        pass

    def callback(self):
        raise NotImplementedError

    def get_callback(self, do_async):
        return self.callback

    def get_msg_index(self, name):
        try:
            return self.message_table[name]
        except KeyError:
            return 0

    def msg_table_max_index(self):
        return len(self.message_table)

    def write(self, buf):
        """Send a binary-packed message to VPP."""
        if not self.connected:
            raise VppTransportSocketIOError(1, "Not connected")

        # Send header
        header = self.header.pack(0, len(buf), 0)
        try:
            self.socket.sendall(header)
            self.socket.sendall(buf)
        except socket.error as err:
            raise VppTransportSocketIOError(1, "Sendall error: {err!r}".format(err=err))

    def _read_fixed(self, size):
        """Repeat receive until fixed size is read. Return empty on error."""
        buf = bytearray(size)
        view = memoryview(buf)
        left = size
        while 1:
            got = self.socket.recv_into(view, left)
            if got <= 0:
                # Read error.
                return ""
            if got >= left:
                # TODO: Raise if got > left?
                break
            left -= got
            view = view[got:]
        return buf

    def _read(self):
        """Read single complete message, return it or empty on error."""
        hdr = self._read_fixed(16)
        if not hdr:
            return
        (_, hdrlen, _) = self.header.unpack(hdr)  # If at head of message

        # Read rest of message
        msg = self._read_fixed(hdrlen)
        if hdrlen == len(msg):
            return msg
        raise VppTransportSocketIOError(1, "Unknown socket read error")

    def read(self, timeout=None):
        if not self.connected:
            raise VppTransportSocketIOError(1, "Not connected")
        if timeout is None:
            timeout = self.read_timeout
        try:
            return self.q.get(True, timeout)
        except queue.Empty:
            return None