summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/plugins/snat/in2out.c68
-rw-r--r--test/test_snat.py166
2 files changed, 231 insertions, 3 deletions
diff --git a/src/plugins/snat/in2out.c b/src/plugins/snat/in2out.c
index bc86a7a4d23..ddde702a50c 100644
--- a/src/plugins/snat/in2out.c
+++ b/src/plugins/snat/in2out.c
@@ -515,7 +515,7 @@ u32 icmp_match_in2out_slow(snat_main_t *sm, vlib_node_runtime_t *node,
goto out;
}
- if (PREDICT_FALSE(icmp0->type != ICMP4_echo_request))
+ if (PREDICT_FALSE(icmp_is_error_message (icmp0)))
{
b0->error = node->errors[SNAT_IN2OUT_ERROR_BAD_ICMP_TYPE];
next0 = SNAT_IN2OUT_NEXT_DROP;
@@ -869,11 +869,77 @@ static inline u32 icmp_in2out_slow_path (snat_main_t *sm,
u32 thread_index,
snat_session_t ** p_s0)
{
+ snat_session_key_t key0, sm0;
+ clib_bihash_kv_8_8_t kv0, value0;
+ snat_worker_key_t k0;
+ u32 new_dst_addr0 = 0, old_dst_addr0, si, ti = 0;
+ ip_csum_t sum0;
+
next0 = icmp_in2out(sm, b0, ip0, icmp0, sw_if_index0, rx_fib_index0, node,
next0, thread_index, p_s0, 0);
snat_session_t * s0 = *p_s0;
if (PREDICT_TRUE(next0 != SNAT_IN2OUT_NEXT_DROP && s0))
{
+ /* Hairpinning */
+ if (!icmp_is_error_message (icmp0))
+ {
+ icmp_echo_header_t *echo0 = (icmp_echo_header_t *)(icmp0+1);
+ u16 icmp_id0 = echo0->identifier;
+ key0.addr = ip0->dst_address;
+ key0.port = icmp_id0;
+ key0.protocol = SNAT_PROTOCOL_ICMP;
+ key0.fib_index = sm->outside_fib_index;
+ kv0.key = key0.as_u64;
+
+ /* Check if destination is in active sessions */
+ if (clib_bihash_search_8_8 (&sm->out2in, &kv0, &value0))
+ {
+ /* or static mappings */
+ if (!snat_static_mapping_match(sm, key0, &sm0, 1, 0))
+ {
+ new_dst_addr0 = sm0.addr.as_u32;
+ vnet_buffer(b0)->sw_if_index[VLIB_TX] = sm0.fib_index;
+ }
+ }
+ else
+ {
+ si = value0.value;
+ if (sm->num_workers > 1)
+ {
+ k0.addr = ip0->dst_address;
+ k0.port = icmp_id0;
+ k0.fib_index = sm->outside_fib_index;
+ kv0.key = k0.as_u64;
+ if (clib_bihash_search_8_8 (&sm->worker_by_out, &kv0, &value0))
+ ASSERT(0);
+ else
+ ti = value0.value;
+ }
+ else
+ ti = sm->num_workers;
+
+ s0 = pool_elt_at_index (sm->per_thread_data[ti].sessions, si);
+ new_dst_addr0 = s0->in2out.addr.as_u32;
+ vnet_buffer(b0)->sw_if_index[VLIB_TX] = s0->in2out.fib_index;
+ echo0->identifier = s0->in2out.port;
+ sum0 = icmp0->checksum;
+ sum0 = ip_csum_update (sum0, icmp_id0, s0->in2out.port,
+ icmp_echo_header_t, identifier);
+ icmp0->checksum = ip_csum_fold (sum0);
+ }
+
+ /* Destination is behind the same NAT, use internal address and port */
+ if (new_dst_addr0)
+ {
+ old_dst_addr0 = ip0->dst_address.as_u32;
+ ip0->dst_address.as_u32 = new_dst_addr0;
+ sum0 = ip0->checksum;
+ sum0 = ip_csum_update (sum0, old_dst_addr0, new_dst_addr0,
+ ip4_header_t, dst_address);
+ ip0->checksum = ip_csum_fold (sum0);
+ }
+ }
+
/* Accounting */
s0->last_heard = now;
s0->total_pkts++;
diff --git a/test/test_snat.py b/test/test_snat.py
index 8d384384222..0eceaab256b 100644
--- a/test/test_snat.py
+++ b/test/test_snat.py
@@ -324,7 +324,7 @@ class TestSNAT(MethodHolder):
i.config_ip4()
i.resolve_arp()
- cls.pg0.generate_remote_hosts(2)
+ cls.pg0.generate_remote_hosts(3)
cls.pg0.configure_ipv4_neighbors()
cls.overlapping_interfaces = list(list(cls.pg_interfaces[4:7]))
@@ -1016,7 +1016,7 @@ class TestSNAT(MethodHolder):
self.icmp_id_in])
def test_hairpinning(self):
- """ SNAT hairpinning """
+ """ SNAT hairpinning - 1:1 NAT with port"""
host = self.pg0.remote_hosts[0]
server = self.pg0.remote_hosts[1]
@@ -1075,6 +1075,168 @@ class TestSNAT(MethodHolder):
self.logger.error(ppp("Unexpected or invalid packet:"), p)
raise
+ def test_hairpinning2(self):
+ """ SNAT hairpinning - 1:1 NAT"""
+
+ server1_nat_ip = "10.0.0.10"
+ server2_nat_ip = "10.0.0.11"
+ host = self.pg0.remote_hosts[0]
+ server1 = self.pg0.remote_hosts[1]
+ server2 = self.pg0.remote_hosts[2]
+ server_tcp_port = 22
+ server_udp_port = 20
+
+ self.snat_add_address(self.snat_addr)
+ self.vapi.snat_interface_add_del_feature(self.pg0.sw_if_index)
+ self.vapi.snat_interface_add_del_feature(self.pg1.sw_if_index,
+ is_inside=0)
+
+ # add static mapping for servers
+ self.snat_add_static_mapping(server1.ip4, server1_nat_ip)
+ self.snat_add_static_mapping(server2.ip4, server2_nat_ip)
+
+ # host to server1
+ pkts = []
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=host.ip4, dst=server1_nat_ip) /
+ TCP(sport=self.tcp_port_in, dport=server_tcp_port))
+ pkts.append(p)
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=host.ip4, dst=server1_nat_ip) /
+ UDP(sport=self.udp_port_in, dport=server_udp_port))
+ pkts.append(p)
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=host.ip4, dst=server1_nat_ip) /
+ ICMP(id=self.icmp_id_in, type='echo-request'))
+ pkts.append(p)
+ self.pg0.add_stream(pkts)
+ self.pg_enable_capture(self.pg_interfaces)
+ self.pg_start()
+ capture = self.pg0.get_capture(len(pkts))
+ for packet in capture:
+ try:
+ self.assertEqual(packet[IP].src, self.snat_addr)
+ self.assertEqual(packet[IP].dst, server1.ip4)
+ if packet.haslayer(TCP):
+ self.assertNotEqual(packet[TCP].sport, self.tcp_port_in)
+ self.assertEqual(packet[TCP].dport, server_tcp_port)
+ self.tcp_port_out = packet[TCP].sport
+ elif packet.haslayer(UDP):
+ self.assertNotEqual(packet[UDP].sport, self.udp_port_in)
+ self.assertEqual(packet[UDP].dport, server_udp_port)
+ self.udp_port_out = packet[UDP].sport
+ else:
+ self.assertNotEqual(packet[ICMP].id, self.icmp_id_in)
+ self.icmp_id_out = packet[ICMP].id
+ except:
+ self.logger.error(ppp("Unexpected or invalid packet:", packet))
+ raise
+
+ # server1 to host
+ pkts = []
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=server1.ip4, dst=self.snat_addr) /
+ TCP(sport=server_tcp_port, dport=self.tcp_port_out))
+ pkts.append(p)
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=server1.ip4, dst=self.snat_addr) /
+ UDP(sport=server_udp_port, dport=self.udp_port_out))
+ pkts.append(p)
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=server1.ip4, dst=self.snat_addr) /
+ ICMP(id=self.icmp_id_out, type='echo-reply'))
+ pkts.append(p)
+ self.pg0.add_stream(pkts)
+ self.pg_enable_capture(self.pg_interfaces)
+ self.pg_start()
+ capture = self.pg0.get_capture(len(pkts))
+ for packet in capture:
+ try:
+ self.assertEqual(packet[IP].src, server1_nat_ip)
+ self.assertEqual(packet[IP].dst, host.ip4)
+ if packet.haslayer(TCP):
+ self.assertEqual(packet[TCP].dport, self.tcp_port_in)
+ self.assertEqual(packet[TCP].sport, server_tcp_port)
+ elif packet.haslayer(UDP):
+ self.assertEqual(packet[UDP].dport, self.udp_port_in)
+ self.assertEqual(packet[UDP].sport, server_udp_port)
+ else:
+ self.assertEqual(packet[ICMP].id, self.icmp_id_in)
+ except:
+ self.logger.error(ppp("Unexpected or invalid packet:", packet))
+ raise
+
+ # server2 to server1
+ pkts = []
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=server2.ip4, dst=server1_nat_ip) /
+ TCP(sport=self.tcp_port_in, dport=server_tcp_port))
+ pkts.append(p)
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=server2.ip4, dst=server1_nat_ip) /
+ UDP(sport=self.udp_port_in, dport=server_udp_port))
+ pkts.append(p)
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=server2.ip4, dst=server1_nat_ip) /
+ ICMP(id=self.icmp_id_in, type='echo-request'))
+ pkts.append(p)
+ self.pg0.add_stream(pkts)
+ self.pg_enable_capture(self.pg_interfaces)
+ self.pg_start()
+ capture = self.pg0.get_capture(len(pkts))
+ for packet in capture:
+ try:
+ self.assertEqual(packet[IP].src, server2_nat_ip)
+ self.assertEqual(packet[IP].dst, server1.ip4)
+ if packet.haslayer(TCP):
+ self.assertEqual(packet[TCP].sport, self.tcp_port_in)
+ self.assertEqual(packet[TCP].dport, server_tcp_port)
+ self.tcp_port_out = packet[TCP].sport
+ elif packet.haslayer(UDP):
+ self.assertEqual(packet[UDP].sport, self.udp_port_in)
+ self.assertEqual(packet[UDP].dport, server_udp_port)
+ self.udp_port_out = packet[UDP].sport
+ else:
+ self.assertEqual(packet[ICMP].id, self.icmp_id_in)
+ self.icmp_id_out = packet[ICMP].id
+ except:
+ self.logger.error(ppp("Unexpected or invalid packet:", packet))
+ raise
+
+ # server1 to server2
+ pkts = []
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=server1.ip4, dst=server2_nat_ip) /
+ TCP(sport=server_tcp_port, dport=self.tcp_port_out))
+ pkts.append(p)
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=server1.ip4, dst=server2_nat_ip) /
+ UDP(sport=server_udp_port, dport=self.udp_port_out))
+ pkts.append(p)
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=server1.ip4, dst=server2_nat_ip) /
+ ICMP(id=self.icmp_id_out, type='echo-reply'))
+ pkts.append(p)
+ self.pg0.add_stream(pkts)
+ self.pg_enable_capture(self.pg_interfaces)
+ self.pg_start()
+ capture = self.pg0.get_capture(len(pkts))
+ for packet in capture:
+ try:
+ self.assertEqual(packet[IP].src, server1_nat_ip)
+ self.assertEqual(packet[IP].dst, server2.ip4)
+ if packet.haslayer(TCP):
+ self.assertEqual(packet[TCP].dport, self.tcp_port_in)
+ self.assertEqual(packet[TCP].sport, server_tcp_port)
+ elif packet.haslayer(UDP):
+ self.assertEqual(packet[UDP].dport, self.udp_port_in)
+ self.assertEqual(packet[UDP].sport, server_udp_port)
+ else:
+ self.assertEqual(packet[ICMP].id, self.icmp_id_in)
+ except:
+ self.logger.error(ppp("Unexpected or invalid packet:", packet))
+ raise
+
def test_max_translations_per_user(self):
""" MAX translations per user - recycle the least recently used """