summaryrefslogtreecommitdiffstats
path: root/test/hook.py
blob: 97b05d05153941d6169e3dea3282aa31fcc6d09f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os
import sys
import traceback
import ipaddress
from subprocess import check_output, CalledProcessError

import scapy.compat
import framework
from log import RED, single_line_delim, double_line_delim
from util import check_core_path, get_core_path


class Hook(object):
    """
    Generic hooks before/after API/CLI calls
    """

    def __init__(self, test):
        self.test = test
        self.logger = test.logger

    def before_api(self, api_name, api_args):
        """
        Function called before API call
        Emit a debug message describing the API name and arguments

        @param api_name: name of the API
        @param api_args: tuple containing the API arguments
        """

        def _friendly_format(val):
            if not isinstance(val, str):
                return val
            if len(val) == 6:
                return '{!s} ({!s})'.format(val, ':'.join(['{:02x}'.format(
                    scapy.compat.orb(x)) for x in val]))
            try:
                # we don't call test_type(val) because it is a packed value.
                return '{!s} ({!s})'.format(val, str(
                    ipaddress.ip_address(val)))
            except ValueError:
                return val

        _args = ', '.join("{!s}={!r}".format(key, _friendly_format(val)) for
                          (key, val) in api_args.items())
        self.logger.debug("API: %s (%s)" %
                          (api_name, _args), extra={'color': RED})

    def after_api(self, api_name, api_args):
        """
        Function called after API call

        @param api_name: name of the API
        @param api_args: tuple containing the API arguments
        """
        pass

    def before_cli(self, cli):
        """
        Function called before CLI call
        Emit a debug message describing the CLI

        @param cli: CLI string
        """
        self.logger.debug("CLI: %s" % (cli), extra={'color': RED})

    def after_cli(self, cli):
        """
        Function called after CLI call
        """
        pass


class PollHook(Hook):
    """ Hook which checks if the vpp subprocess is alive """

    def __init__(self, test):
        super(PollHook, self).__init__(test)

    def on_crash(self, core_path):
        self.logger.error("Core file present, debug with: gdb %s %s",
                          self.test.vpp_bin, core_path)
        check_core_path(self.logger, core_path)
        self.logger.error("Running `file %s':", core_path)
        try:
            info = check_output(["file", core_path])
            self.logger.error(info)
        except CalledProcessError as e:
            self.logger.error(
                "Subprocess returned with error running `file' utility on "
                "core-file, "
                "rc=%s",  e.returncode)
        except OSError as e:
            self.logger.error(
                "Subprocess returned OS error running `file' utility on "
                "core-file, "
                "oserror=(%s) %s", e.errno, e.strerror)
        except Exception as e:
            self.logger.error(
                "Subprocess returned unanticipated error running `file' "
                "utility on core-file, "
                "%s", e)

    def poll_vpp(self):
        """
        Poll the vpp status and throw an exception if it's not running
        :raises VppDiedError: exception if VPP is not running anymore
        """
        if self.test.vpp_dead:
            # already dead, nothing to do
            return

        self.test.vpp.poll()
        if self.test.vpp.returncode is not None:
            self.test.vpp_dead = True
            raise framework.VppDiedError(rv=self.test.vpp.returncode)
            core_path = get_core_path(self.test.tempdir)
            if os.path.isfile(core_path):
                self.on_crash(core_path)

    def before_api(self, api_name, api_args):
        """
        Check if VPP died before executing an API

        :param api_name: name of the API
        :param api_args: tuple containing the API arguments
        :raises VppDiedError: exception if VPP is not running anymore

        """
        super(PollHook, self).before_api(api_name, api_args)
        self.poll_vpp()

    def before_cli(self, cli):
        """
        Check if VPP died before executing a CLI

        :param cli: CLI string
        :raises Exception: exception if VPP is not running anymore

        """
        super(PollHook, self).before_cli(cli)
        self.poll_vpp()


class StepHook(PollHook):
    """ Hook which requires user to press ENTER before doing any API/CLI """

    def __init__(self, test):
        self.skip_stack = None
        self.skip_num = None
        self.skip_count = 0
        super(StepHook, self).__init__(test)

    def skip(self):
        if self.skip_stack is None:
            return False
        stack = traceback.extract_stack()
        counter = 0
        skip = True
        for e in stack:
            if counter > self.skip_num:
                break
            if e[0] != self.skip_stack[counter][0]:
                skip = False
            if e[1] != self.skip_stack[counter][1]:
                skip = False
            counter += 1
        if skip:
            self.skip_count += 1
            return True
        else:
            print("%d API/CLI calls skipped in specified stack "
                  "frame" % self.skip_count)
            self.skip_count = 0
            self.skip_stack = None
            self.skip_num = None
            return False

    def user_input(self):
        print('number\tfunction\tfile\tcode')
        counter = 0
        stack = traceback.extract_stack()
        for e in stack:
            print('%02d.\t%s\t%s:%d\t[%s]' % (counter, e[2], e[0], e[1], e[3]))
            counter += 1
        print(single_line_delim)
        print("You may enter a number of stack frame chosen from above")
        print("Calls in/below that stack frame will be not be stepped anymore")
        print(single_line_delim)
        while True:
            print("Enter your choice, if any, and press ENTER to continue "
                  "running the testcase...")
            choice = sys.stdin.readline().rstrip('\r\n')
            if choice == "":
                choice = None
            try:
                if choice is not None:
                    num = int(choice)
            except ValueError:
                print("Invalid input")
                continue
            if choice is not None and (num < 0 or num >= len(stack)):
                print("Invalid choice")
                continue
            break
        if choice is not None:
            self.skip_stack = stack
            self.skip_num = num

    def before_cli(self, cli):
        """ Wait for ENTER before executing CLI """
        if self.skip():
            print("Skip pause before executing CLI: %s" % cli)
        else:
            print(double_line_delim)
            print("Test paused before executing CLI: %s" % cli)
            print(single_line_delim)
            self.user_input()
        super(StepHook, self).before_cli(cli)

    def before_api(self, api_name, api_args):
        """ Wait for ENTER before executing API """
        if self.skip():
            print("Skip pause before executing API: %s (%s)"
                  % (api_name, api_args))
        else:
            print(double_line_delim)
            print("Test paused before executing API: %s (%s)"
                  % (api_name, api_args))
            print(single_line_delim)
            self.user_input()
        super(StepHook, self).before_api(api_name, api_args)
">net_template.format(self.count) class L4_Conn(): """ L4 'connection' tied to two VPP interfaces """ def __init__(self, testcase, if1, if2, af, l4proto, port1, port2): self.testcase = testcase self.ifs = [None, None] self.ifs[0] = if1 self.ifs[1] = if2 self.address_family = af self.l4proto = l4proto self.ports = [None, None] self.ports[0] = port1 self.ports[1] = port2 self def pkt(self, side, l4args={}, payload="x"): is_ip6 = 1 if self.address_family == AF_INET6 else 0 s0 = side s1 = 1 - side src_if = self.ifs[s0] dst_if = self.ifs[s1] layer_3 = [IP(src=src_if.remote_ip4, dst=dst_if.remote_ip4), IPv6(src=src_if.remote_ip6, dst=dst_if.remote_ip6)] merged_l4args = {'sport': self.ports[s0], 'dport': self.ports[s1]} merged_l4args.update(l4args) p = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) / layer_3[is_ip6] / self.l4proto(**merged_l4args) / Raw(payload)) return p def send(self, side, flags=None, payload=""): l4args = {} if flags is not None: l4args['flags'] = flags self.ifs[side].add_stream(self.pkt(side, l4args=l4args, payload=payload)) self.ifs[1 - side].enable_capture() self.testcase.pg_start() def recv(self, side): p = self.ifs[side].wait_for_packet(1) return p def send_through(self, side, flags=None, payload=""): self.send(side, flags, payload) p = self.recv(1 - side) return p def send_pingpong(self, side, flags1=None, flags2=None): p1 = self.send_through(side, flags1) p2 = self.send_through(1 - side, flags2) return [p1, p2] class L4_CONN_SIDE: L4_CONN_SIDE_ZERO = 0 L4_CONN_SIDE_ONE = 1 class LoggerWrapper(object): def __init__(self, logger=None): self._logger = logger def debug(self, *args, **kwargs): if self._logger: self._logger.debug(*args, **kwargs) def error(self, *args, **kwargs): if self._logger: self._logger.error(*args, **kwargs) def fragment_rfc791(packet, fragsize, _logger=None): """ Fragment an IPv4 packet per RFC 791 :param packet: packet to fragment :param fragsize: size at which to fragment :note: IP options are not supported :returns: list of fragments """ logger = LoggerWrapper(_logger) logger.debug(ppp("Fragmenting packet:", packet)) packet = packet.__class__(scapy.compat.raw(packet)) # recalc. all values if len(packet[IP].options) > 0: raise Exception("Not implemented") if len(packet) <= fragsize: return [packet] pre_ip_len = len(packet) - len(packet[IP]) ip_header_len = packet[IP].ihl * 4 hex_packet = scapy.compat.raw(packet) hex_headers = hex_packet[:(pre_ip_len + ip_header_len)] hex_payload = hex_packet[(pre_ip_len + ip_header_len):] pkts = [] ihl = packet[IP].ihl otl = len(packet[IP]) nfb = (fragsize - pre_ip_len - ihl * 4) / 8 fo = packet[IP].frag p = packet.__class__(hex_headers + hex_payload[:nfb * 8]) p[IP].flags = "MF" p[IP].frag = fo p[IP].len = ihl * 4 + nfb * 8 del p[IP].chksum pkts.append(p) p = packet.__class__(hex_headers + hex_payload[nfb * 8:]) p[IP].len = otl - nfb * 8 p[IP].frag = fo + nfb del p[IP].chksum more_fragments = fragment_rfc791(p, fragsize, _logger) pkts.extend(more_fragments) return pkts def fragment_rfc8200(packet, identification, fragsize, _logger=None): """ Fragment an IPv6 packet per RFC 8200 :param packet: packet to fragment :param fragsize: size at which to fragment :note: IP options are not supported :returns: list of fragments """ logger = LoggerWrapper(_logger) packet = packet.__class__(scapy.compat.raw(packet)) # recalc. all values if len(packet) <= fragsize: return [packet] logger.debug(ppp("Fragmenting packet:", packet)) pkts = [] counter = 0 routing_hdr = None hop_by_hop_hdr = None upper_layer = None seen_ipv6 = False ipv6_nr = -1 l = packet.getlayer(counter) while l is not None: if l.__class__ is IPv6: if seen_ipv6: # ignore 2nd IPv6 header and everything below.. break ipv6_nr = counter seen_ipv6 = True elif l.__class__ is IPv6ExtHdrFragment: raise Exception("Already fragmented") elif l.__class__ is IPv6ExtHdrRouting: routing_hdr = counter elif l.__class__ is IPv6ExtHdrHopByHop: hop_by_hop_hdr = counter elif seen_ipv6 and not upper_layer and \ not l.__class__.__name__.startswith('IPv6ExtHdr'): upper_layer = counter counter = counter + 1 l = packet.getlayer(counter) logger.debug( "Layers seen: IPv6(#%s), Routing(#%s), HopByHop(#%s), upper(#%s)" % (ipv6_nr, routing_hdr, hop_by_hop_hdr, upper_layer)) if upper_layer is None: raise Exception("Upper layer header not found in IPv6 packet") last_per_fragment_hdr = ipv6_nr if routing_hdr is None: if hop_by_hop_hdr is not None: last_per_fragment_hdr = hop_by_hop_hdr else: last_per_fragment_hdr = routing_hdr logger.debug("Last per-fragment hdr is #%s" % (last_per_fragment_hdr)) per_fragment_headers = packet.copy() per_fragment_headers[last_per_fragment_hdr].remove_payload() logger.debug(ppp("Per-fragment headers:", per_fragment_headers)) ext_and_upper_layer = packet.getlayer(last_per_fragment_hdr)[1] hex_payload = scapy.compat.raw(ext_and_upper_layer) logger.debug("Payload length is %s" % len(hex_payload)) logger.debug(ppp("Ext and upper layer:", ext_and_upper_layer)) fragment_ext_hdr = IPv6ExtHdrFragment() logger.debug(ppp("Fragment header:", fragment_ext_hdr)) if len(per_fragment_headers) + len(fragment_ext_hdr) +\ len(ext_and_upper_layer) - len(ext_and_upper_layer.payload)\ > fragsize: raise Exception("Cannot fragment this packet - MTU too small " "(%s, %s, %s, %s, %s)" % ( len(per_fragment_headers), len(fragment_ext_hdr), len(ext_and_upper_layer), len(ext_and_upper_layer.payload), fragsize)) orig_nh = packet[IPv6].nh p = per_fragment_headers del p[IPv6].plen del p[IPv6].nh p = p / fragment_ext_hdr del p[IPv6ExtHdrFragment].nh first_payload_len_nfb = (fragsize - len(p)) / 8 p = p / Raw(hex_payload[:first_payload_len_nfb * 8]) del p[IPv6].plen p[IPv6ExtHdrFragment].nh = orig_nh p[IPv6ExtHdrFragment].id = identification p[IPv6ExtHdrFragment].offset = 0 p[IPv6ExtHdrFragment].m = 1 p = p.__class__(scapy.compat.raw(p)) logger.debug(ppp("Fragment %s:" % len(pkts), p)) pkts.append(p) offset = first_payload_len_nfb * 8 logger.debug("Offset after first fragment: %s" % offset) while len(hex_payload) > offset: p = per_fragment_headers del p[IPv6].plen del p[IPv6].nh p = p / fragment_ext_hdr del p[IPv6ExtHdrFragment].nh l_nfb = (fragsize - len(p)) / 8 p = p / Raw(hex_payload[offset:offset + l_nfb * 8]) p[IPv6ExtHdrFragment].nh = orig_nh p[IPv6ExtHdrFragment].id = identification p[IPv6ExtHdrFragment].offset = offset / 8 p[IPv6ExtHdrFragment].m = 1 p = p.__class__(scapy.compat.raw(p)) logger.debug(ppp("Fragment %s:" % len(pkts), p)) pkts.append(p) offset = offset + l_nfb * 8 pkts[-1][IPv6ExtHdrFragment].m = 0 # reset more-flags in last fragment return pkts def reassemble4_core(listoffragments, return_ip): buffer = BytesIO() first = listoffragments[0] buffer.seek(20) for pkt in listoffragments: buffer.seek(pkt[IP].frag*8) buffer.write(bytes(pkt[IP].payload)) first.len = len(buffer.getvalue()) + 20 first.flags = 0 del(first.chksum) if return_ip: header = bytes(first[IP])[:20] return first[IP].__class__(header + buffer.getvalue()) else: header = bytes(first[Ether])[:34] return first[Ether].__class__(header + buffer.getvalue()) def reassemble4_ether(listoffragments): return reassemble4_core(listoffragments, False) def reassemble4(listoffragments): return reassemble4_core(listoffragments, True)