summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/plugins/snat/in2out.c10
-rw-r--r--test/test_snat.py17
2 files changed, 27 insertions, 0 deletions
diff --git a/src/plugins/snat/in2out.c b/src/plugins/snat/in2out.c
index ddde702a50c..685cdca8045 100644
--- a/src/plugins/snat/in2out.c
+++ b/src/plugins/snat/in2out.c
@@ -854,6 +854,16 @@ snat_hairpinning (snat_main_t *sm,
udp0->checksum = 0;
}
}
+ else
+ {
+ if (PREDICT_TRUE(proto0 == SNAT_PROTOCOL_TCP))
+ {
+ sum0 = tcp0->checksum;
+ sum0 = ip_csum_update (sum0, old_dst_addr0, new_dst_addr0,
+ ip4_header_t, dst_address);
+ tcp0->checksum = ip_csum_fold(sum0);
+ }
+ }
}
}
diff --git a/test/test_snat.py b/test/test_snat.py
index c6344a9ee74..c2f9280d62a 100644
--- a/test/test_snat.py
+++ b/test/test_snat.py
@@ -27,6 +27,17 @@ class MethodHolder(VppTestCase):
def tearDown(self):
super(MethodHolder, self).tearDown()
+ def check_tcp_checksum(self, pkt):
+ """
+ Check TCP checksum in IP packet
+
+ :param pkt: Packet to check TCP checksum
+ """
+ new = pkt.__class__(str(pkt))
+ del new['TCP'].chksum
+ new = new.__class__(str(new))
+ self.assertEqual(new['TCP'].chksum, pkt['TCP'].chksum)
+
def create_stream_in(self, in_if, out_if, ttl=64):
"""
Create packet stream for inside network
@@ -1111,6 +1122,7 @@ class TestSNAT(MethodHolder):
self.assertEqual(ip.dst, server.ip4)
self.assertNotEqual(tcp.sport, host_in_port)
self.assertEqual(tcp.dport, server_in_port)
+ self.check_tcp_checksum(p)
host_out_port = tcp.sport
except:
self.logger.error(ppp("Unexpected or invalid packet:", p))
@@ -1132,6 +1144,7 @@ class TestSNAT(MethodHolder):
self.assertEqual(ip.dst, host.ip4)
self.assertEqual(tcp.sport, server_out_port)
self.assertEqual(tcp.dport, host_in_port)
+ self.check_tcp_checksum(p)
except:
self.logger.error(ppp("Unexpected or invalid packet:"), p)
raise
@@ -1182,6 +1195,7 @@ class TestSNAT(MethodHolder):
self.assertNotEqual(packet[TCP].sport, self.tcp_port_in)
self.assertEqual(packet[TCP].dport, server_tcp_port)
self.tcp_port_out = packet[TCP].sport
+ self.check_tcp_checksum(packet)
elif packet.haslayer(UDP):
self.assertNotEqual(packet[UDP].sport, self.udp_port_in)
self.assertEqual(packet[UDP].dport, server_udp_port)
@@ -1218,6 +1232,7 @@ class TestSNAT(MethodHolder):
if packet.haslayer(TCP):
self.assertEqual(packet[TCP].dport, self.tcp_port_in)
self.assertEqual(packet[TCP].sport, server_tcp_port)
+ self.check_tcp_checksum(packet)
elif packet.haslayer(UDP):
self.assertEqual(packet[UDP].dport, self.udp_port_in)
self.assertEqual(packet[UDP].sport, server_udp_port)
@@ -1253,6 +1268,7 @@ class TestSNAT(MethodHolder):
self.assertEqual(packet[TCP].sport, self.tcp_port_in)
self.assertEqual(packet[TCP].dport, server_tcp_port)
self.tcp_port_out = packet[TCP].sport
+ self.check_tcp_checksum(packet)
elif packet.haslayer(UDP):
self.assertEqual(packet[UDP].sport, self.udp_port_in)
self.assertEqual(packet[UDP].dport, server_udp_port)
@@ -1289,6 +1305,7 @@ class TestSNAT(MethodHolder):
if packet.haslayer(TCP):
self.assertEqual(packet[TCP].dport, self.tcp_port_in)
self.assertEqual(packet[TCP].sport, server_tcp_port)
+ self.check_tcp_checksum(packet)
elif packet.haslayer(UDP):
self.assertEqual(packet[UDP].dport, self.udp_port_in)
self.assertEqual(packet[UDP].sport, server_udp_port)