diff options
-rw-r--r-- | plugins/snat-plugin/snat/in2out.c | 114 | ||||
-rw-r--r-- | test/test_snat.py | 67 |
2 files changed, 181 insertions, 0 deletions
diff --git a/plugins/snat-plugin/snat/in2out.c b/plugins/snat-plugin/snat/in2out.c index 43ee2d0e..d7b647ef 100644 --- a/plugins/snat-plugin/snat/in2out.c +++ b/plugins/snat-plugin/snat/in2out.c @@ -405,6 +405,108 @@ static inline u32 icmp_in2out_slow_path (snat_main_t *sm, return next0; } +/** + * @brief Hairpinning + * + * Hairpinning allows two endpoints on the internal side of the NAT to + * communicate even if they only use each other's external IP addresses + * and ports. + * + * @param sm SNAT main. + * @param b0 Vlib buffer. + * @param ip0 IP header. + * @param udp0 UDP header. + * @param tcp0 TCP header. + * @param proto0 SNAT protocol. + */ +static inline void +snat_hairpinning (snat_main_t *sm, + vlib_buffer_t * b0, + ip4_header_t * ip0, + udp_header_t * udp0, + tcp_header_t * tcp0, + u32 proto0) +{ + snat_session_key_t key0, sm0; + snat_static_mapping_key_t k0; + snat_session_t * s0; + clib_bihash_kv_8_8_t kv0, value0; + ip_csum_t sum0; + u32 new_dst_addr0 = 0, old_dst_addr0, ti = 0, si; + u16 new_dst_port0, old_dst_port0; + + key0.addr = ip0->dst_address; + key0.port = udp0->dst_port; + key0.protocol = proto0; + key0.fib_index = sm->outside_fib_index; + kv0.key = key0.as_u64; + + /* Check if destination is in active sessions */ + if (clib_bihash_search_8_8 (&sm->out2in, &kv0, &value0)) + { + /* or static mappings */ + if (!snat_static_mapping_match(sm, key0, &sm0, 1)) + { + new_dst_addr0 = sm0.addr.as_u32; + new_dst_port0 = sm0.port; + vnet_buffer(b0)->sw_if_index[VLIB_TX] = sm0.fib_index; + } + } + else + { + si = value0.value; + if (sm->num_workers > 1) + { + k0.addr = ip0->dst_address; + k0.port = udp0->dst_port; + k0.fib_index = sm->outside_fib_index; + kv0.key = k0.as_u64; + if (clib_bihash_search_8_8 (&sm->worker_by_out, &kv0, &value0)) + ASSERT(0); + else + ti = value0.value; + } + else + ti = sm->num_workers; + + s0 = pool_elt_at_index (sm->per_thread_data[ti].sessions, si); + new_dst_addr0 = s0->in2out.addr.as_u32; + new_dst_port0 = s0->in2out.port; + vnet_buffer(b0)->sw_if_index[VLIB_TX] = s0->in2out.fib_index; + } + + /* Destination is behind the same NAT, use internal address and port */ + if (new_dst_addr0) + { + old_dst_addr0 = ip0->dst_address.as_u32; + ip0->dst_address.as_u32 = new_dst_addr0; + sum0 = ip0->checksum; + sum0 = ip_csum_update (sum0, old_dst_addr0, new_dst_addr0, + ip4_header_t, dst_address); + ip0->checksum = ip_csum_fold (sum0); + + old_dst_port0 = tcp0->ports.dst; + if (PREDICT_TRUE(new_dst_port0 != old_dst_port0)) + { + if (PREDICT_TRUE(proto0 == SNAT_PROTOCOL_TCP)) + { + tcp0->ports.dst = new_dst_port0; + sum0 = tcp0->checksum; + sum0 = ip_csum_update (sum0, old_dst_addr0, new_dst_addr0, + ip4_header_t, dst_address); + sum0 = ip_csum_update (sum0, old_dst_port0, new_dst_port0, + ip4_header_t /* cheat */, length); + tcp0->checksum = ip_csum_fold(sum0); + } + else + { + udp0->dst_port = new_dst_port0; + udp0->checksum = 0; + } + } + } +} + static inline uword snat_in2out_node_fn_inline (vlib_main_t * vm, vlib_node_runtime_t * node, @@ -594,6 +696,9 @@ snat_in2out_node_fn_inline (vlib_main_t * vm, udp0->checksum = 0; } + /* Hairpinning */ + snat_hairpinning (sm, b0, ip0, udp0, tcp0, proto0); + /* Accounting */ s0->last_heard = now; s0->total_pkts++; @@ -739,6 +844,9 @@ snat_in2out_node_fn_inline (vlib_main_t * vm, udp1->checksum = 0; } + /* Hairpinning */ + snat_hairpinning (sm, b1, ip1, udp1, tcp1, proto1); + /* Accounting */ s1->last_heard = now; s1->total_pkts++; @@ -919,6 +1027,9 @@ snat_in2out_node_fn_inline (vlib_main_t * vm, udp0->checksum = 0; } + /* Hairpinning */ + snat_hairpinning (sm, b0, ip0, udp0, tcp0, proto0); + /* Accounting */ s0->last_heard = now; s0->total_pkts++; @@ -1430,6 +1541,9 @@ snat_in2out_fast_static_map_fn (vlib_main_t * vm, } } + /* Hairpinning */ + snat_hairpinning (sm, b0, ip0, udp0, tcp0, proto0); + trace0: if (PREDICT_FALSE((node->flags & VLIB_NODE_FLAG_TRACE) && (b0->flags & VLIB_BUFFER_IS_TRACED))) diff --git a/test/test_snat.py b/test/test_snat.py index e90d9c0b..5cc76f6c 100644 --- a/test/test_snat.py +++ b/test/test_snat.py @@ -34,6 +34,9 @@ class TestSNAT(VppTestCase): i.config_ip4() i.resolve_arp() + cls.pg0.generate_remote_hosts(2) + cls.pg0.configure_ipv4_neighbors() + cls.overlapping_interfaces = list(list(cls.pg_interfaces[4:7])) for i in cls.overlapping_interfaces: @@ -526,6 +529,70 @@ class TestSNAT(VppTestCase): capture = self.pg6.get_capture() self.verify_capture_in(capture, self.pg6) + def test_hairpinning(self): + """ SNAT hairpinning """ + + host = self.pg0.remote_hosts[0] + server = self.pg0.remote_hosts[1] + host_in_port = 1234 + host_out_port = 0 + server_in_port = 5678 + server_out_port = 8765 + + self.snat_add_address(self.snat_addr) + self.vapi.snat_interface_add_del_feature(self.pg0.sw_if_index) + self.vapi.snat_interface_add_del_feature(self.pg1.sw_if_index, + is_inside=0) + # add static mapping for server + self.snat_add_static_mapping(server.ip4, self.snat_addr, + server_in_port, server_out_port) + + # send packet from host to server + p = (Ether(src=host.mac, dst=self.pg0.local_mac) / + IP(src=host.ip4, dst=self.snat_addr) / + TCP(sport=host_in_port, dport=server_out_port)) + self.pg0.add_stream(p) + self.pg_enable_capture(self.pg_interfaces) + self.pg_start() + capture = self.pg0.get_capture() + self.assertEqual(1, len(capture)) + p = capture[0] + try: + ip = p[IP] + tcp = p[TCP] + self.assertEqual(ip.src, self.snat_addr) + self.assertEqual(ip.dst, server.ip4) + self.assertNotEqual(tcp.sport, host_in_port) + self.assertEqual(tcp.dport, server_in_port) + host_out_port = tcp.sport + except: + error("Unexpected or invalid packet:") + error(p.show()) + raise + + # send reply from server to host + p = (Ether(src=server.mac, dst=self.pg0.local_mac) / + IP(src=server.ip4, dst=self.snat_addr) / + TCP(sport=server_in_port, dport=host_out_port)) + self.pg0.add_stream(p) + self.pg_enable_capture(self.pg_interfaces) + self.pg_start() + capture = self.pg0.get_capture() + self.assertEqual(1, len(capture)) + p = capture[0] + try: + ip = p[IP] + tcp = p[TCP] + self.assertEqual(ip.src, self.snat_addr) + self.assertEqual(ip.dst, host.ip4) + self.assertEqual(tcp.sport, server_out_port) + self.assertEqual(tcp.dport, host_in_port) + except: + error("Unexpected or invalid packet:") + error(p.show()) + raise + + def tearDown(self): super(TestSNAT, self).tearDown() if not self.vpp_dead: |