From 6955595a577e1b7d316b5b69267bf1d1d951a4ab Mon Sep 17 00:00:00 2001 From: Paul Vinciguerra Date: Fri, 1 Mar 2019 08:46:29 -0800 Subject: Tests: Example duplicate code refactoring. Refactor of duplicate methods in test/test_reassembly.py. Change-Id: I46f880da6a0ced2acae1fa33c6892d0148b26139 Signed-off-by: Paul Vinciguerra --- test/test_reassembly.py | 516 +++++++++++++++++++++++------------------------- 1 file changed, 243 insertions(+), 273 deletions(-) diff --git a/test/test_reassembly.py b/test/test_reassembly.py index 7bca794c5c0..aee67b185cc 100644 --- a/test/test_reassembly.py +++ b/test/test_reassembly.py @@ -1,17 +1,19 @@ #!/usr/bin/env python +from random import shuffle import six import unittest -from random import shuffle - -from framework import VppTestCase, VppTestRunner +from parameterized import parameterized from scapy.packet import Raw from scapy.layers.l2 import Ether, GRE from scapy.layers.inet import IP, UDP, ICMP -from util import ppp, fragment_rfc791, fragment_rfc8200 + from scapy.layers.inet6 import IPv6, IPv6ExtHdrFragment, ICMPv6ParamProblem,\ ICMPv6TimeExceeded + +from framework import VppTestCase, VppTestRunner +from util import ppp, fragment_rfc791, fragment_rfc8200 from vpp_gre_interface import VppGreInterface, VppGre6Interface from vpp_ip import DpoProto from vpp_ip_route import VppIpRoute, VppRoutePath @@ -19,8 +21,181 @@ from vpp_ip_route import VppIpRoute, VppRoutePath # 35 is enough to have >257 400-byte fragments test_packet_count = 35 +# +# +_scapy_ip_family_types = (IP, IPv6) + + +def validate_scapy_ip_family(scapy_ip_family): + + if scapy_ip_family not in _scapy_ip_family_types: + raise ValueError("'scapy_ip_family' must be of type: %s. Got %s" % + (_scapy_ip_family_types, scapy_ip_family)) + + +class TestIPReassemblyMixin(object): + + def verify_capture(self, scapy_ip_family, capture, + dropped_packet_indexes=None): + """Verify captured packet stream. + + :param list capture: Captured packet stream. + """ + validate_scapy_ip_family(scapy_ip_family) + + if dropped_packet_indexes is None: + dropped_packet_indexes = [] + info = None + seen = set() + for packet in capture: + try: + self.logger.debug(ppp("Got packet:", packet)) + ip = packet[scapy_ip_family] + udp = packet[UDP] + payload_info = self.payload_to_info(str(packet[Raw])) + packet_index = payload_info.index + self.assertTrue( + packet_index not in dropped_packet_indexes, + ppp("Packet received, but should be dropped:", packet)) + if packet_index in seen: + raise Exception(ppp("Duplicate packet received", packet)) + seen.add(packet_index) + self.assertEqual(payload_info.dst, self.src_if.sw_if_index) + info = self._packet_infos[packet_index] + self.assertTrue(info is not None) + self.assertEqual(packet_index, info.index) + saved_packet = info.data + self.assertEqual(ip.src, saved_packet[scapy_ip_family].src) + self.assertEqual(ip.dst, saved_packet[scapy_ip_family].dst) + self.assertEqual(udp.payload, saved_packet[UDP].payload) + except Exception: + self.logger.error(ppp("Unexpected or invalid packet:", packet)) + raise + for index in self._packet_infos: + self.assertTrue(index in seen or index in dropped_packet_indexes, + "Packet with packet_index %d not received" % index) + + def test_disabled(self, scapy_ip_family, stream, + dropped_packet_indexes): + """ reassembly disabled """ + validate_scapy_ip_family(scapy_ip_family) + is_ip6 = 1 if scapy_ip_family == IPv6 else 0 + + self.vapi.ip_reassembly_set(timeout_ms=1000, max_reassemblies=0, + expire_walk_interval_ms=10000, + is_ip6=is_ip6) + + self.pg_enable_capture() + self.src_if.add_stream(stream) + self.pg_start() + + packets = self.dst_if.get_capture( + len(self.pkt_infos) - len(dropped_packet_indexes)) + self.verify_capture(scapy_ip_family, packets, dropped_packet_indexes) + self.src_if.assert_nothing_captured() -class TestIPv4Reassembly(VppTestCase): + def test_duplicates(self, scapy_ip_family, stream): + """ duplicate fragments """ + validate_scapy_ip_family(scapy_ip_family) + + self.pg_enable_capture() + self.src_if.add_stream(stream) + self.pg_start() + + packets = self.dst_if.get_capture(len(self.pkt_infos)) + self.verify_capture(scapy_ip_family, packets) + self.src_if.assert_nothing_captured() + + def test_random(self, scapy_ip_family, stream): + """ random order reassembly """ + validate_scapy_ip_family(scapy_ip_family) + + fragments = list(stream) + shuffle(fragments) + + self.pg_enable_capture() + self.src_if.add_stream(fragments) + self.pg_start() + + packets = self.dst_if.get_capture(len(self.packet_infos)) + self.verify_capture(scapy_ip_family, packets) + self.src_if.assert_nothing_captured() + + # run it all again to verify correctness + self.pg_enable_capture() + self.src_if.add_stream(fragments) + self.pg_start() + + packets = self.dst_if.get_capture(len(self.packet_infos)) + self.verify_capture(scapy_ip_family, packets) + self.src_if.assert_nothing_captured() + + def test_reassembly(self, scapy_ip_family, stream): + """ basic reassembly """ + validate_scapy_ip_family(scapy_ip_family) + + self.pg_enable_capture() + self.src_if.add_stream(stream) + self.pg_start() + + packets = self.dst_if.get_capture(len(self.pkt_infos)) + self.verify_capture(scapy_ip_family, packets) + self.src_if.assert_nothing_captured() + + # run it all again to verify correctness + self.pg_enable_capture() + self.src_if.add_stream(stream) + self.pg_start() + + packets = self.dst_if.get_capture(len(self.pkt_infos)) + self.verify_capture(scapy_ip_family, packets) + self.src_if.assert_nothing_captured() + + def test_reversed(self, scapy_ip_family, stream): + """ reverse order reassembly """ + validate_scapy_ip_family(scapy_ip_family) + + fragments = list(stream) + fragments.reverse() + + self.pg_enable_capture() + self.src_if.add_stream(fragments) + self.pg_start() + + packets = self.dst_if.get_capture(len(self.packet_infos)) + self.verify_capture(scapy_ip_family, packets) + self.src_if.assert_nothing_captured() + + # run it all again to verify correctness + self.pg_enable_capture() + self.src_if.add_stream(fragments) + self.pg_start() + + packets = self.dst_if.get_capture(len(self.packet_infos)) + self.verify_capture(scapy_ip_family, packets) + self.src_if.assert_nothing_captured() + + def test_timeout_inline(self, scapy_ip_family, stream, + dropped_packet_indexes): + """ timeout (inline) """ + validate_scapy_ip_family(scapy_ip_family) + is_ip6 = 1 if scapy_ip_family == IPv6 else 0 + + self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000, + expire_walk_interval_ms=10000, + is_ip6=is_ip6) + + self.pg_enable_capture() + self.src_if.add_stream(stream) + self.pg_start() + + packets = self.dst_if.get_capture( + len(self.pkt_infos) - len(dropped_packet_indexes)) + self.verify_capture(scapy_ip_family, packets, + dropped_packet_indexes) + + +class TestIPv4Reassembly(TestIPReassemblyMixin, VppTestCase): """ IPv4 Reassembly """ @classmethod @@ -101,83 +276,22 @@ class TestIPv4Reassembly(VppTestCase): (len(infos), len(cls.fragments_400), len(cls.fragments_300), len(cls.fragments_200))) - def verify_capture(self, capture, dropped_packet_indexes=[]): - """Verify captured packet stream. - - :param list capture: Captured packet stream. - """ - info = None - seen = set() - for packet in capture: - try: - self.logger.debug(ppp("Got packet:", packet)) - ip = packet[IP] - udp = packet[UDP] - payload_info = self.payload_to_info(str(packet[Raw])) - packet_index = payload_info.index - self.assertTrue( - packet_index not in dropped_packet_indexes, - ppp("Packet received, but should be dropped:", packet)) - if packet_index in seen: - raise Exception(ppp("Duplicate packet received", packet)) - seen.add(packet_index) - self.assertEqual(payload_info.dst, self.src_if.sw_if_index) - info = self._packet_infos[packet_index] - self.assertTrue(info is not None) - self.assertEqual(packet_index, info.index) - saved_packet = info.data - self.assertEqual(ip.src, saved_packet[IP].src) - self.assertEqual(ip.dst, saved_packet[IP].dst) - self.assertEqual(udp.payload, saved_packet[UDP].payload) - except Exception: - self.logger.error(ppp("Unexpected or invalid packet:", packet)) - raise - for index in self._packet_infos: - self.assertTrue(index in seen or index in dropped_packet_indexes, - "Packet with packet_index %d not received" % index) - - def test_reassembly(self): + @parameterized.expand([(IP, None)]) + def test_reassembly(self, family, stream): """ basic reassembly """ + stream = self.__class__.fragments_200 + super(TestIPv4Reassembly, self).test_reassembly(family, stream) - self.pg_enable_capture() - self.src_if.add_stream(self.fragments_200) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() - - # run it all again to verify correctness - self.pg_enable_capture() - self.src_if.add_stream(self.fragments_200) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() - - def test_reversed(self): + @parameterized.expand([(IP, None)]) + def test_reversed(self, family, stream): """ reverse order reassembly """ + stream = self.__class__.fragments_200 + super(TestIPv4Reassembly, self).test_reversed(family, stream) - fragments = list(self.fragments_200) - fragments.reverse() - - self.pg_enable_capture() - self.src_if.add_stream(fragments) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.packet_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() - - # run it all again to verify correctness - self.pg_enable_capture() - self.src_if.add_stream(fragments) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.packet_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() + @parameterized.expand([(IP, None)]) + def test_random(self, family, stream): + stream = self.__class__.fragments_200 + super(TestIPv4Reassembly, self).test_random(family, stream) def test_5737(self): """ fragment length + ip header size > 65535 """ @@ -275,45 +389,16 @@ class TestIPv4Reassembly(VppTestCase): # self.assert_packet_counter_equal( # "/err/ip4-reassembly-feature/malformed packets", 1) - def test_random(self): - """ random order reassembly """ - - fragments = list(self.fragments_200) - shuffle(fragments) - - self.pg_enable_capture() - self.src_if.add_stream(fragments) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.packet_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() - - # run it all again to verify correctness - self.pg_enable_capture() - self.src_if.add_stream(fragments) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.packet_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() - - def test_duplicates(self): + @parameterized.expand([(IP, None)]) + def test_duplicates(self, family, stream): """ duplicate fragments """ - fragments = [ + # IPv4 uses 4 fields in pkt_infos, IPv6 uses 3. x for (_, frags, _, _) in self.pkt_infos for x in frags for _ in range(0, min(2, len(frags))) ] - - self.pg_enable_capture() - self.src_if.add_stream(fragments) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() + super(TestIPv4Reassembly, self).test_duplicates(family, fragments) def test_overlap1(self): """ overlapping fragments case #1 """ @@ -332,7 +417,7 @@ class TestIPv4Reassembly(VppTestCase): self.pg_start() packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) + self.verify_capture(IP, packets) self.src_if.assert_nothing_captured() # run it all to verify correctness @@ -341,7 +426,7 @@ class TestIPv4Reassembly(VppTestCase): self.pg_start() packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) + self.verify_capture(IP, packets) self.src_if.assert_nothing_captured() def test_overlap2(self): @@ -367,7 +452,7 @@ class TestIPv4Reassembly(VppTestCase): self.pg_start() packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) + self.verify_capture(IP, packets) self.src_if.assert_nothing_captured() # run it all to verify correctness @@ -376,26 +461,20 @@ class TestIPv4Reassembly(VppTestCase): self.pg_start() packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) + self.verify_capture(IP, packets) self.src_if.assert_nothing_captured() - def test_timeout_inline(self): + @parameterized.expand([(IP, None, None)]) + def test_timeout_inline(self, family, stream, dropped_packet_indexes): """ timeout (inline) """ + stream = self.fragments_400 dropped_packet_indexes = set( index for (index, frags, _, _) in self.pkt_infos if len(frags) > 1 ) + super(TestIPv4Reassembly, self).test_timeout_inline( + family, stream, dropped_packet_indexes) - self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000, - expire_walk_interval_ms=10000) - - self.pg_enable_capture() - self.src_if.add_stream(self.fragments_400) - self.pg_start() - - packets = self.dst_if.get_capture( - len(self.pkt_infos) - len(dropped_packet_indexes)) - self.verify_capture(packets, dropped_packet_indexes) self.src_if.assert_nothing_captured() def test_timeout_cleanup(self): @@ -430,30 +509,22 @@ class TestIPv4Reassembly(VppTestCase): packets = self.dst_if.get_capture( len(self.pkt_infos) - len(dropped_packet_indexes)) - self.verify_capture(packets, dropped_packet_indexes) + self.verify_capture(IP, packets, dropped_packet_indexes) self.src_if.assert_nothing_captured() - def test_disabled(self): + @parameterized.expand([(IP, None, None)]) + def test_disabled(self, family, stream, dropped_packet_indexes): """ reassembly disabled """ + stream = self.__class__.fragments_400 dropped_packet_indexes = set( index for (index, frags_400, _, _) in self.pkt_infos if len(frags_400) > 1) - - self.vapi.ip_reassembly_set(timeout_ms=1000, max_reassemblies=0, - expire_walk_interval_ms=10000) - - self.pg_enable_capture() - self.src_if.add_stream(self.fragments_400) - self.pg_start() - - packets = self.dst_if.get_capture( - len(self.pkt_infos) - len(dropped_packet_indexes)) - self.verify_capture(packets, dropped_packet_indexes) - self.src_if.assert_nothing_captured() + super(TestIPv4Reassembly, self).test_disabled( + family, stream, dropped_packet_indexes) -class TestIPv6Reassembly(VppTestCase): +class TestIPv6Reassembly(TestIPReassemblyMixin, VppTestCase): """ IPv6 Reassembly """ @classmethod @@ -531,126 +602,38 @@ class TestIPv6Reassembly(VppTestCase): (len(infos), len(cls.fragments_400), len(cls.fragments_300))) - def verify_capture(self, capture, dropped_packet_indexes=[]): - """Verify captured packet strea . - - :param list capture: Captured packet stream. - """ - info = None - seen = set() - for packet in capture: - try: - self.logger.debug(ppp("Got packet:", packet)) - ip = packet[IPv6] - udp = packet[UDP] - payload_info = self.payload_to_info(str(packet[Raw])) - packet_index = payload_info.index - self.assertTrue( - packet_index not in dropped_packet_indexes, - ppp("Packet received, but should be dropped:", packet)) - if packet_index in seen: - raise Exception(ppp("Duplicate packet received", packet)) - seen.add(packet_index) - self.assertEqual(payload_info.dst, self.src_if.sw_if_index) - info = self._packet_infos[packet_index] - self.assertTrue(info is not None) - self.assertEqual(packet_index, info.index) - saved_packet = info.data - self.assertEqual(ip.src, saved_packet[IPv6].src) - self.assertEqual(ip.dst, saved_packet[IPv6].dst) - self.assertEqual(udp.payload, saved_packet[UDP].payload) - except Exception: - self.logger.error(ppp("Unexpected or invalid packet:", packet)) - raise - for index in self._packet_infos: - self.assertTrue(index in seen or index in dropped_packet_indexes, - "Packet with packet_index %d not received" % index) - - def test_reassembly(self): + @parameterized.expand([(IPv6, None)]) + def test_reassembly(self, family, stream): """ basic reassembly """ + stream = self.__class__.fragments_400 + super(TestIPv6Reassembly, self).test_reassembly(family, stream) - self.pg_enable_capture() - self.src_if.add_stream(self.fragments_400) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() - - # run it all again to verify correctness - self.pg_enable_capture() - self.src_if.add_stream(self.fragments_400) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() - - def test_reversed(self): + @parameterized.expand([(IPv6, None)]) + def test_reversed(self, family, stream): """ reverse order reassembly """ + stream = self.__class__.fragments_400 + super(TestIPv6Reassembly, self).test_reversed(family, stream) - fragments = list(self.fragments_400) - fragments.reverse() - - self.pg_enable_capture() - self.src_if.add_stream(fragments) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() - - # run it all again to verify correctness - self.pg_enable_capture() - self.src_if.add_stream(fragments) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() - - def test_random(self): + @parameterized.expand([(IPv6, None)]) + def test_random(self, family, stream): """ random order reassembly """ + stream = self.__class__.fragments_400 + super(TestIPv6Reassembly, self).test_random(family, stream) - fragments = list(self.fragments_400) - shuffle(fragments) - - self.pg_enable_capture() - self.src_if.add_stream(fragments) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() - - # run it all again to verify correctness - self.pg_enable_capture() - self.src_if.add_stream(fragments) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() - - def test_duplicates(self): + @parameterized.expand([(IPv6, None)]) + def test_duplicates(self, family, stream): """ duplicate fragments """ fragments = [ + # IPv4 uses 4 fields in pkt_infos, IPv6 uses 3. x for (_, frags, _) in self.pkt_infos for x in frags for _ in range(0, min(2, len(frags))) ] - - self.pg_enable_capture() - self.src_if.add_stream(fragments) - self.pg_start() - - packets = self.dst_if.get_capture(len(self.pkt_infos)) - self.verify_capture(packets) - self.src_if.assert_nothing_captured() + super(TestIPv6Reassembly, self).test_duplicates(family, fragments) def test_overlap1(self): - """ overlapping fragments case #1 """ + """ overlapping fragments case #1 (differs from IP test case)""" fragments = [] for _, frags_400, frags_300 in self.pkt_infos: @@ -671,11 +654,11 @@ class TestIPv6Reassembly(VppTestCase): packets = self.dst_if.get_capture( len(self.pkt_infos) - len(dropped_packet_indexes)) - self.verify_capture(packets, dropped_packet_indexes) + self.verify_capture(IPv6, packets, dropped_packet_indexes) self.src_if.assert_nothing_captured() def test_overlap2(self): - """ overlapping fragments case #2 """ + """ overlapping fragments case #2 (differs from IP test case)""" fragments = [] for _, frags_400, frags_300 in self.pkt_infos: @@ -702,26 +685,20 @@ class TestIPv6Reassembly(VppTestCase): packets = self.dst_if.get_capture( len(self.pkt_infos) - len(dropped_packet_indexes)) - self.verify_capture(packets, dropped_packet_indexes) + self.verify_capture(IPv6, packets, dropped_packet_indexes) self.src_if.assert_nothing_captured() - def test_timeout_inline(self): + @parameterized.expand([(IPv6, None, None)]) + def test_timeout_inline(self, family, stream, dropped_packets_index): """ timeout (inline) """ + stream = self.__class__.fragments_400 dropped_packet_indexes = set( index for (index, frags, _) in self.pkt_infos if len(frags) > 1 ) + super(TestIPv6Reassembly, self).test_timeout_inline( + family, stream, dropped_packet_indexes) - self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000, - expire_walk_interval_ms=10000, is_ip6=1) - - self.pg_enable_capture() - self.src_if.add_stream(self.fragments_400) - self.pg_start() - - packets = self.dst_if.get_capture( - len(self.pkt_infos) - len(dropped_packet_indexes)) - self.verify_capture(packets, dropped_packet_indexes) pkts = self.src_if.get_capture( expected_count=len(dropped_packet_indexes)) for icmp in pkts: @@ -765,7 +742,7 @@ class TestIPv6Reassembly(VppTestCase): packets = self.dst_if.get_capture( len(self.pkt_infos) - len(dropped_packet_indexes)) - self.verify_capture(packets, dropped_packet_indexes) + self.verify_capture(IPv6, packets, dropped_packet_indexes) pkts = self.src_if.get_capture( expected_count=len(dropped_packet_indexes)) for icmp in pkts: @@ -774,23 +751,16 @@ class TestIPv6Reassembly(VppTestCase): self.assertIn(icmp[IPv6ExtHdrFragment].id, dropped_packet_indexes) dropped_packet_indexes.remove(icmp[IPv6ExtHdrFragment].id) - def test_disabled(self): + @parameterized.expand([(IPv6, None, None)]) + def test_disabled(self, family, stream, dropped_packet_indexes): """ reassembly disabled """ + stream = self.__class__.fragments_400 dropped_packet_indexes = set( index for (index, frags_400, _) in self.pkt_infos if len(frags_400) > 1) - - self.vapi.ip_reassembly_set(timeout_ms=1000, max_reassemblies=0, - expire_walk_interval_ms=10000, is_ip6=1) - - self.pg_enable_capture() - self.src_if.add_stream(self.fragments_400) - self.pg_start() - - packets = self.dst_if.get_capture( - len(self.pkt_infos) - len(dropped_packet_indexes)) - self.verify_capture(packets, dropped_packet_indexes) + super(TestIPv6Reassembly, self).test_disabled( + family, stream, dropped_packet_indexes) self.src_if.assert_nothing_captured() def test_missing_upper(self): -- cgit 1.2.3-korg