diff options
-rw-r--r-- | src/plugins/snat/in2out.c | 70 | ||||
-rw-r--r-- | src/plugins/snat/out2in.c | 40 | ||||
-rw-r--r-- | test/test_snat.py | 50 |
3 files changed, 153 insertions, 7 deletions
diff --git a/src/plugins/snat/in2out.c b/src/plugins/snat/in2out.c index 685cdca8045..d396c79c1b2 100644 --- a/src/plugins/snat/in2out.c +++ b/src/plugins/snat/in2out.c @@ -967,6 +967,55 @@ static inline u32 icmp_in2out_slow_path (snat_main_t *sm, return next0; } +static void +snat_in2out_unknown_proto (snat_main_t *sm, + vlib_buffer_t * b, + ip4_header_t * ip, + u32 rx_fib_index) +{ + clib_bihash_kv_8_8_t kv, value; + snat_static_mapping_t *m; + snat_session_key_t m_key; + u32 old_addr, new_addr; + ip_csum_t sum; + + m_key.addr = ip->src_address; + m_key.port = 0; + m_key.protocol = 0; + m_key.fib_index = rx_fib_index; + kv.key = m_key.as_u64; + if (clib_bihash_search_8_8 (&sm->static_mapping_by_local, &kv, &value)) + return; + + m = pool_elt_at_index (sm->static_mappings, value.value); + + old_addr = ip->src_address.as_u32; + new_addr = ip->src_address.as_u32 = m->external_addr.as_u32; + sum = ip->checksum; + sum = ip_csum_update (sum, old_addr, new_addr, ip4_header_t, src_address); + ip->checksum = ip_csum_fold (sum); + + /* Hairpinning */ + m_key.addr = ip->dst_address; + m_key.fib_index = sm->outside_fib_index; + kv.key = m_key.as_u64; + if (clib_bihash_search_8_8 (&sm->static_mapping_by_external, &kv, &value)) + { + vnet_buffer(b)->sw_if_index[VLIB_TX] = sm->outside_fib_index; + return; + } + + m = pool_elt_at_index (sm->static_mappings, value.value); + + old_addr = ip->dst_address.as_u32; + new_addr = ip->dst_address.as_u32 = m->local_addr.as_u32; + sum = ip->checksum; + sum = ip_csum_update (sum, old_addr, new_addr, ip4_header_t, dst_address); + ip->checksum = ip_csum_fold (sum); + + vnet_buffer(b)->sw_if_index[VLIB_TX] = vnet_buffer(b)->sw_if_index[VLIB_RX]; +} + static inline uword snat_in2out_node_fn_inline (vlib_main_t * vm, vlib_node_runtime_t * node, @@ -1065,8 +1114,11 @@ snat_in2out_node_fn_inline (vlib_main_t * vm, if (is_slow_path) { if (PREDICT_FALSE (proto0 == ~0)) - goto trace00; - + { + snat_in2out_unknown_proto (sm, b0, ip0, rx_fib_index0); + goto trace00; + } + if (PREDICT_FALSE (proto0 == SNAT_PROTOCOL_ICMP)) { next0 = icmp_in2out_slow_path @@ -1205,8 +1257,11 @@ snat_in2out_node_fn_inline (vlib_main_t * vm, if (is_slow_path) { if (PREDICT_FALSE (proto1 == ~0)) - goto trace01; - + { + snat_in2out_unknown_proto (sm, b1, ip1, rx_fib_index1); + goto trace01; + } + if (PREDICT_FALSE (proto1 == SNAT_PROTOCOL_ICMP)) { next1 = icmp_in2out_slow_path @@ -1380,8 +1435,11 @@ snat_in2out_node_fn_inline (vlib_main_t * vm, if (is_slow_path) { if (PREDICT_FALSE (proto0 == ~0)) - goto trace0; - + { + snat_in2out_unknown_proto (sm, b0, ip0, rx_fib_index0); + goto trace0; + } + if (PREDICT_FALSE (proto0 == SNAT_PROTOCOL_ICMP)) { next0 = icmp_in2out_slow_path diff --git a/src/plugins/snat/out2in.c b/src/plugins/snat/out2in.c index e8ddcf1cfd7..5c12b473c63 100644 --- a/src/plugins/snat/out2in.c +++ b/src/plugins/snat/out2in.c @@ -611,6 +611,37 @@ static inline u32 icmp_out2in_slow_path (snat_main_t *sm, return next0; } +static void +snat_out2in_unknown_proto (snat_main_t *sm, + vlib_buffer_t * b, + ip4_header_t * ip, + u32 rx_fib_index) +{ + clib_bihash_kv_8_8_t kv, value; + snat_static_mapping_t *m; + snat_session_key_t m_key; + u32 old_addr, new_addr; + ip_csum_t sum; + + m_key.addr = ip->dst_address; + m_key.port = 0; + m_key.protocol = 0; + m_key.fib_index = rx_fib_index; + kv.key = m_key.as_u64; + if (clib_bihash_search_8_8 (&sm->static_mapping_by_external, &kv, &value)) + return; + + m = pool_elt_at_index (sm->static_mappings, value.value); + + old_addr = ip->dst_address.as_u32; + new_addr = ip->dst_address.as_u32 = m->local_addr.as_u32; + sum = ip->checksum; + sum = ip_csum_update (sum, old_addr, new_addr, ip4_header_t, dst_address); + ip->checksum = ip_csum_fold (sum); + + vnet_buffer(b)->sw_if_index[VLIB_TX] = m->fib_index; +} + static uword snat_out2in_node_fn (vlib_main_t * vm, vlib_node_runtime_t * node, @@ -703,7 +734,10 @@ snat_out2in_node_fn (vlib_main_t * vm, proto0 = ip_proto_to_snat_proto (ip0->protocol); if (PREDICT_FALSE (proto0 == ~0)) + { + snat_out2in_unknown_proto(sm, b0, ip0, rx_fib_index0); goto trace0; + } if (PREDICT_FALSE (proto0 == SNAT_PROTOCOL_ICMP)) { @@ -838,7 +872,10 @@ snat_out2in_node_fn (vlib_main_t * vm, proto1 = ip_proto_to_snat_proto (ip1->protocol); if (PREDICT_FALSE (proto1 == ~0)) + { + snat_out2in_unknown_proto(sm, b1, ip1, rx_fib_index1); goto trace1; + } if (PREDICT_FALSE (proto1 == SNAT_PROTOCOL_ICMP)) { @@ -997,7 +1034,10 @@ snat_out2in_node_fn (vlib_main_t * vm, proto0 = ip_proto_to_snat_proto (ip0->protocol); if (PREDICT_FALSE (proto0 == ~0)) + { + snat_out2in_unknown_proto(sm, b0, ip0, rx_fib_index0); goto trace00; + } if (PREDICT_FALSE(ip0->ttl == 1)) { diff --git a/test/test_snat.py b/test/test_snat.py index ee689e6ac3e..e148fbae56c 100644 --- a/test/test_snat.py +++ b/test/test_snat.py @@ -9,7 +9,7 @@ from scapy.layers.inet import IP, TCP, UDP, ICMP from scapy.layers.inet import IPerror, TCPerror, UDPerror, ICMPerror from scapy.layers.inet6 import IPv6, ICMPv6EchoRequest, ICMPv6EchoReply from scapy.layers.inet6 import ICMPv6DestUnreach, IPerror6 -from scapy.layers.l2 import Ether, ARP +from scapy.layers.l2 import Ether, ARP, GRE from scapy.data import IP_PROTOS from scapy.packet import bind_layers from util import ppp @@ -1835,6 +1835,54 @@ class TestSNAT(MethodHolder): capture = self.pg8.get_capture(len(pkts)) self.verify_capture_out(capture) + def test_static_unknown_proto(self): + """ 1:1 NAT translate packet with unknown protocol """ + nat_ip = "10.0.0.10" + self.snat_add_static_mapping(self.pg0.remote_ip4, nat_ip) + self.vapi.snat_interface_add_del_feature(self.pg0.sw_if_index) + self.vapi.snat_interface_add_del_feature(self.pg1.sw_if_index, + is_inside=0) + + # in2out + p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) / + IP(src=self.pg0.remote_ip4, dst=self.pg1.remote_ip4) / + GRE() / + IP(src=self.pg2.remote_ip4, dst=self.pg2.remote_ip4) / + TCP(sport=1234, dport=1234)) + self.pg0.add_stream(p) + self.pg_enable_capture(self.pg_interfaces) + self.pg_start() + p = self.pg1.get_capture(1) + packet = p[0] + try: + self.assertEqual(packet[IP].src, nat_ip) + self.assertEqual(packet[IP].dst, self.pg1.remote_ip4) + self.assertTrue(packet.haslayer(GRE)) + self.check_ip_checksum(packet) + except: + self.logger.error(ppp("Unexpected or invalid packet:", packet)) + raise + + # out2in + p = (Ether(dst=self.pg1.local_mac, src=self.pg1.remote_mac) / + IP(src=self.pg1.remote_ip4, dst=nat_ip) / + GRE() / + IP(src=self.pg2.remote_ip4, dst=self.pg2.remote_ip4) / + TCP(sport=1234, dport=1234)) + self.pg1.add_stream(p) + self.pg_enable_capture(self.pg_interfaces) + self.pg_start() + p = self.pg0.get_capture(1) + packet = p[0] + try: + self.assertEqual(packet[IP].src, self.pg1.remote_ip4) + self.assertEqual(packet[IP].dst, self.pg0.remote_ip4) + self.assertTrue(packet.haslayer(GRE)) + self.check_ip_checksum(packet) + except: + self.logger.error(ppp("Unexpected or invalid packet:", packet)) + raise + def tearDown(self): super(TestSNAT, self).tearDown() if not self.vpp_dead: |