summaryrefslogtreecommitdiffstats
path: root/test/framework.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/framework.py')
-rw-r--r--test/framework.py27
1 files changed, 21 insertions, 6 deletions
diff --git a/test/framework.py b/test/framework.py
index 2b197326532..a7a3e2fdea7 100644
--- a/test/framework.py
+++ b/test/framework.py
@@ -1325,7 +1325,7 @@ class VppTestCase(CPUInterface, unittest.TestCase):
if 0 == len(checksums):
return
temp = temp.__class__(scapy.compat.raw(temp))
- for layer, cf in checksums:
+ for layer, cf in reversed(checksums):
calc_sum = getattr(temp[layer], cf)
self.assert_equal(
getattr(received[layer], cf),
@@ -1338,9 +1338,24 @@ class VppTestCase(CPUInterface, unittest.TestCase):
)
def assert_checksum_valid(
- self, received_packet, layer, field_name="chksum", ignore_zero_checksum=False
+ self,
+ received_packet,
+ layer,
+ checksum_field_names=["chksum", "cksum"],
+ ignore_zero_checksum=False,
):
"""Check checksum of received packet on given layer"""
+ layer_copy = received_packet[layer].copy()
+ layer_copy.remove_payload()
+ field_name = None
+ for f in checksum_field_names:
+ if hasattr(layer_copy, f):
+ field_name = f
+ break
+ if field_name is None:
+ raise Exception(
+ f"Layer `{layer}` has none of checksum fields: `{checksum_field_names}`."
+ )
received_packet_checksum = getattr(received_packet[layer], field_name)
if ignore_zero_checksum and 0 == received_packet_checksum:
return
@@ -1350,7 +1365,7 @@ class VppTestCase(CPUInterface, unittest.TestCase):
self.assert_equal(
received_packet_checksum,
getattr(recalculated[layer], field_name),
- "packet checksum on layer: %s" % layer,
+ f"packet checksum (field: {field_name}) on layer: %s" % layer,
)
def assert_ip_checksum_valid(self, received_packet, ignore_zero_checksum=False):
@@ -1386,12 +1401,12 @@ class VppTestCase(CPUInterface, unittest.TestCase):
def assert_icmpv6_checksum_valid(self, pkt):
if pkt.haslayer(ICMPv6DestUnreach):
- self.assert_checksum_valid(pkt, "ICMPv6DestUnreach", "cksum")
+ self.assert_checksum_valid(pkt, "ICMPv6DestUnreach")
self.assert_embedded_icmp_checksum_valid(pkt)
if pkt.haslayer(ICMPv6EchoRequest):
- self.assert_checksum_valid(pkt, "ICMPv6EchoRequest", "cksum")
+ self.assert_checksum_valid(pkt, "ICMPv6EchoRequest")
if pkt.haslayer(ICMPv6EchoReply):
- self.assert_checksum_valid(pkt, "ICMPv6EchoReply", "cksum")
+ self.assert_checksum_valid(pkt, "ICMPv6EchoReply")
def get_counter(self, counter):
if counter.startswith("/"):