diff options
-rw-r--r-- | src/plugins/nat/in2out_ed.c | 60 | ||||
-rw-r--r-- | src/plugins/nat/nat.c | 8 | ||||
-rw-r--r-- | src/plugins/nat/nat_inlines.h | 15 | ||||
-rw-r--r-- | src/plugins/nat/test/test_nat.py | 114 |
4 files changed, 110 insertions, 87 deletions
diff --git a/src/plugins/nat/in2out_ed.c b/src/plugins/nat/in2out_ed.c index 49e3812441e..19b128864f0 100644 --- a/src/plugins/nat/in2out_ed.c +++ b/src/plugins/nat/in2out_ed.c @@ -191,27 +191,18 @@ icmp_in2out_ed_slow_path (snat_main_t * sm, vlib_buffer_t * b0, return next0; } -static_always_inline u16 -snat_random_port (u16 min, u16 max) -{ - snat_main_t *sm = &snat_main; - return min + random_u32 (&sm->random_seed) / - (random_u32_max () / (max - min + 1) + 1); -} - static int nat_ed_alloc_addr_and_port (snat_main_t * sm, u32 rx_fib_index, u32 nat_proto, u32 thread_index, ip4_address_t r_addr, u16 r_port, u8 proto, u16 port_per_thread, u32 snat_thread_index, snat_session_t * s, - ip4_address_t * allocated_addr, - u16 * allocated_port, + ip4_address_t * outside_addr, + u16 * outside_port, clib_bihash_kv_16_8_t * out2in_ed_kv) { int i; snat_address_t *a, *ga = 0; - u32 portnum; snat_main_per_thread_data_t *tsm = &sm->per_thread_data[thread_index]; const u16 port_thread_offset = (port_per_thread * snat_thread_index) + 1024; @@ -225,29 +216,39 @@ nat_ed_alloc_addr_and_port (snat_main_t * sm, u32 rx_fib_index, case NAT_PROTOCOL_##N: \ if (a->fib_index == rx_fib_index) \ { \ - u16 port = snat_random_port (1, port_per_thread); \ + /* first try port suggested by caller */ \ + u16 port = clib_net_to_host_u16 (*outside_port); \ + u16 port_offset = port - port_thread_offset; \ + if (port <= port_thread_offset || \ + port > port_thread_offset + port_per_thread) \ + { \ + /* need to pick a different port, suggested port doesn't fit in \ + * this thread's port range */ \ + port_offset = snat_random_port (1, port_per_thread); \ + port = port_thread_offset + port_offset; \ + } \ u16 attempts = port_per_thread; \ - while (attempts > 0) \ + do \ { \ - --attempts; \ - portnum = port_thread_offset + port; \ - init_ed_kv (out2in_ed_kv, a->addr, \ - clib_host_to_net_u16 (portnum), r_addr, r_port, \ - s->out2in.fib_index, proto, thread_index, \ - s - tsm->sessions); \ + init_ed_kv (out2in_ed_kv, a->addr, clib_host_to_net_u16 (port), \ + r_addr, r_port, s->out2in.fib_index, proto, \ + thread_index, s - tsm->sessions); \ int rv = clib_bihash_add_del_16_8 (&sm->out2in_ed, out2in_ed_kv, \ 2 /* is_add */); \ if (0 == rv) \ { \ - ++a->busy_##n##_port_refcounts[portnum]; \ + ++a->busy_##n##_port_refcounts[port]; \ a->busy_##n##_ports_per_thread[thread_index]++; \ a->busy_##n##_ports++; \ - *allocated_addr = a->addr; \ - *allocated_port = clib_host_to_net_u16 (portnum); \ + *outside_addr = a->addr; \ + *outside_port = clib_host_to_net_u16 (port); \ return 0; \ } \ - port = (port + 1) % port_per_thread; \ + port_offset = (port_offset + 1) % port_per_thread; \ + port = port_thread_offset + port_offset; \ + --attempts; \ } \ + while (attempts > 0); \ } \ else if (a->fib_index == ~0) \ { \ @@ -326,8 +327,8 @@ slow_path_ed (snat_main_t * sm, snat_main_per_thread_data_t *tsm = &sm->per_thread_data[thread_index]; clib_bihash_kv_16_8_t out2in_ed_kv; nat44_is_idle_session_ctx_t ctx; - ip4_address_t allocated_addr; - u16 allocated_port; + ip4_address_t outside_addr; + u16 outside_port; u8 identity_nat; u32 nat_proto = ip_proto_to_nat_proto (proto); @@ -393,20 +394,21 @@ slow_path_ed (snat_main_t * sm, } /* Try to create dynamic translation */ + outside_port = l_port; // suggest using local port to allocation function if (nat_ed_alloc_addr_and_port (sm, rx_fib_index, nat_proto, thread_index, r_addr, r_port, proto, sm->port_per_thread, tsm->snat_thread_index, s, - &allocated_addr, - &allocated_port, &out2in_ed_kv)) + &outside_addr, + &outside_port, &out2in_ed_kv)) { nat_elog_notice ("addresses exhausted"); b->error = node->errors[NAT_IN2OUT_ED_ERROR_OUT_OF_PORTS]; nat_ed_session_delete (sm, s, thread_index, 1); return NAT_NEXT_DROP; } - s->out2in.addr = allocated_addr; - s->out2in.port = allocated_port; + s->out2in.addr = outside_addr; + s->out2in.port = outside_port; } else { diff --git a/src/plugins/nat/nat.c b/src/plugins/nat/nat.c index e4fed18371e..60ef22f05cc 100644 --- a/src/plugins/nat/nat.c +++ b/src/plugins/nat/nat.c @@ -2851,14 +2851,6 @@ end: return 0; } -static_always_inline u16 -snat_random_port (u16 min, u16 max) -{ - snat_main_t *sm = &snat_main; - return min + random_u32 (&sm->random_seed) / - (random_u32_max () / (max - min + 1) + 1); -} - int snat_alloc_outside_address_and_port (snat_address_t * addresses, u32 fib_index, diff --git a/src/plugins/nat/nat_inlines.h b/src/plugins/nat/nat_inlines.h index 40ac3d307a8..67411750c95 100644 --- a/src/plugins/nat/nat_inlines.h +++ b/src/plugins/nat/nat_inlines.h @@ -813,6 +813,21 @@ increment_v4_address (ip4_address_t * a) a->as_u32 = clib_host_to_net_u32 (v); } +static_always_inline u16 +snat_random_port (u16 min, u16 max) +{ + snat_main_t *sm = &snat_main; + u32 rwide; + u16 r; + + rwide = random_u32 (&sm->random_seed); + r = rwide & 0xFFFF; + if (r >= min && r <= max) + return r; + + return min + (rwide % (max - min + 1)); +} + #endif /* __included_nat_inlines_h__ */ /* diff --git a/src/plugins/nat/test/test_nat.py b/src/plugins/nat/test/test_nat.py index 6dee818d4bb..e996373342a 100644 --- a/src/plugins/nat/test/test_nat.py +++ b/src/plugins/nat/test/test_nat.py @@ -448,7 +448,7 @@ class MethodHolder(VppTestCase): return pkts def verify_capture_out(self, capture, nat_ip=None, same_port=False, - dst_ip=None, is_ip6=False): + dst_ip=None, is_ip6=False, ignore_port=False): """ Verify captured packets on outside network @@ -474,25 +474,32 @@ class MethodHolder(VppTestCase): if dst_ip is not None: self.assertEqual(packet[IP46].dst, dst_ip) if packet.haslayer(TCP): - if same_port: - self.assertEqual(packet[TCP].sport, self.tcp_port_in) - else: - self.assertNotEqual( - packet[TCP].sport, self.tcp_port_in) + if not ignore_port: + if same_port: + self.assertEqual( + packet[TCP].sport, self.tcp_port_in) + else: + self.assertNotEqual( + packet[TCP].sport, self.tcp_port_in) self.tcp_port_out = packet[TCP].sport self.assert_packet_checksums_valid(packet) elif packet.haslayer(UDP): - if same_port: - self.assertEqual(packet[UDP].sport, self.udp_port_in) - else: - self.assertNotEqual( - packet[UDP].sport, self.udp_port_in) + if not ignore_port: + if same_port: + self.assertEqual( + packet[UDP].sport, self.udp_port_in) + else: + self.assertNotEqual( + packet[UDP].sport, self.udp_port_in) self.udp_port_out = packet[UDP].sport else: - if same_port: - self.assertEqual(packet[ICMP46].id, self.icmp_id_in) - else: - self.assertNotEqual(packet[ICMP46].id, self.icmp_id_in) + if not ignore_port: + if same_port: + self.assertEqual( + packet[ICMP46].id, self.icmp_id_in) + else: + self.assertNotEqual( + packet[ICMP46].id, self.icmp_id_in) self.icmp_id_out = packet[ICMP46].id self.assert_packet_checksums_valid(packet) except: @@ -1105,7 +1112,8 @@ class MethodHolder(VppTestCase): else: raise Exception("Unsupported protocol") - def frag_in_order(self, proto=IP_PROTOS.tcp, dont_translate=False): + def frag_in_order(self, proto=IP_PROTOS.tcp, dont_translate=False, + ignore_port=False): layer = self.proto2layer(proto) if proto == IP_PROTOS.tcp: @@ -1132,14 +1140,16 @@ class MethodHolder(VppTestCase): if proto != IP_PROTOS.icmp: if not dont_translate: self.assertEqual(p[layer].dport, 20) - self.assertNotEqual(p[layer].sport, self.port_in) + if not ignore_port: + self.assertNotEqual(p[layer].sport, self.port_in) else: self.assertEqual(p[layer].sport, self.port_in) else: - if not dont_translate: - self.assertNotEqual(p[layer].id, self.port_in) - else: - self.assertEqual(p[layer].id, self.port_in) + if not ignore_port: + if not dont_translate: + self.assertNotEqual(p[layer].id, self.port_in) + else: + self.assertEqual(p[layer].id, self.port_in) self.assertEqual(data, p[Raw].load) # out2in @@ -1220,7 +1230,7 @@ class MethodHolder(VppTestCase): self.assertEqual(p[layer].id, self.port_in) self.assertEqual(data, p[Raw].load) - def reass_hairpinning(self, proto=IP_PROTOS.tcp): + def reass_hairpinning(self, proto=IP_PROTOS.tcp, ignore_port=False): layer = self.proto2layer(proto) if proto == IP_PROTOS.tcp: @@ -1243,13 +1253,16 @@ class MethodHolder(VppTestCase): self.nat_addr, self.server.ip4) if proto != IP_PROTOS.icmp: - self.assertNotEqual(p[layer].sport, self.host_in_port) + if not ignore_port: + self.assertNotEqual(p[layer].sport, self.host_in_port) self.assertEqual(p[layer].dport, self.server_in_port) else: - self.assertNotEqual(p[layer].id, self.host_in_port) + if not ignore_port: + self.assertNotEqual(p[layer].id, self.host_in_port) self.assertEqual(data, p[Raw].load) - def frag_out_of_order(self, proto=IP_PROTOS.tcp, dont_translate=False): + def frag_out_of_order(self, proto=IP_PROTOS.tcp, dont_translate=False, + ignore_port=False): layer = self.proto2layer(proto) if proto == IP_PROTOS.tcp: @@ -1278,14 +1291,16 @@ class MethodHolder(VppTestCase): if proto != IP_PROTOS.icmp: if not dont_translate: self.assertEqual(p[layer].dport, 20) - self.assertNotEqual(p[layer].sport, self.port_in) + if not ignore_port: + self.assertNotEqual(p[layer].sport, self.port_in) else: self.assertEqual(p[layer].sport, self.port_in) else: - if not dont_translate: - self.assertNotEqual(p[layer].id, self.port_in) - else: - self.assertEqual(p[layer].id, self.port_in) + if not ignore_port: + if not dont_translate: + self.assertNotEqual(p[layer].id, self.port_in) + else: + self.assertEqual(p[layer].id, self.port_in) self.assertEqual(data, p[Raw].load) # out2in @@ -4437,9 +4452,9 @@ class TestNAT44EndpointDependent(MethodHolder): self.vapi.nat44_interface_add_del_feature( sw_if_index=self.pg1.sw_if_index, is_add=1) - self.frag_in_order(proto=IP_PROTOS.tcp) - self.frag_in_order(proto=IP_PROTOS.udp) - self.frag_in_order(proto=IP_PROTOS.icmp) + self.frag_in_order(proto=IP_PROTOS.tcp, ignore_port=True) + self.frag_in_order(proto=IP_PROTOS.udp, ignore_port=True) + self.frag_in_order(proto=IP_PROTOS.icmp, ignore_port=True) def test_frag_in_order_dont_translate(self): """ NAT44 don't translate fragments arriving in order """ @@ -4463,9 +4478,9 @@ class TestNAT44EndpointDependent(MethodHolder): self.vapi.nat44_interface_add_del_feature( sw_if_index=self.pg1.sw_if_index, is_add=1) - self.frag_out_of_order(proto=IP_PROTOS.tcp) - self.frag_out_of_order(proto=IP_PROTOS.udp) - self.frag_out_of_order(proto=IP_PROTOS.icmp) + self.frag_out_of_order(proto=IP_PROTOS.tcp, ignore_port=True) + self.frag_out_of_order(proto=IP_PROTOS.udp, ignore_port=True) + self.frag_out_of_order(proto=IP_PROTOS.icmp, ignore_port=True) def test_frag_out_of_order_dont_translate(self): """ NAT44 don't translate fragments arriving out of order """ @@ -4593,9 +4608,9 @@ class TestNAT44EndpointDependent(MethodHolder): proto=IP_PROTOS.udp) self.nat44_add_static_mapping(self.server.ip4, self.nat_addr) - self.reass_hairpinning(proto=IP_PROTOS.tcp) - self.reass_hairpinning(proto=IP_PROTOS.udp) - self.reass_hairpinning(proto=IP_PROTOS.icmp) + self.reass_hairpinning(proto=IP_PROTOS.tcp, ignore_port=True) + self.reass_hairpinning(proto=IP_PROTOS.udp, ignore_port=True) + self.reass_hairpinning(proto=IP_PROTOS.icmp, ignore_port=True) def test_clear_sessions(self): """ NAT44 ED session clearing test """ @@ -4617,7 +4632,7 @@ class TestNAT44EndpointDependent(MethodHolder): self.pg_enable_capture(self.pg_interfaces) self.pg_start() capture = self.pg1.get_capture(len(pkts)) - self.verify_capture_out(capture) + self.verify_capture_out(capture, ignore_port=True) sessions = self.statistics.get_counter('/nat44/total-sessions') self.assertTrue(sessions[0][0] > 0) @@ -4664,7 +4679,7 @@ class TestNAT44EndpointDependent(MethodHolder): self.pg_enable_capture(self.pg_interfaces) self.pg_start() capture = self.pg1.get_capture(len(pkts)) - self.verify_capture_out(capture) + self.verify_capture_out(capture, ignore_port=True) err = self.statistics.get_err_counter( '/err/nat44-ed-in2out-slowpath/TCP packets') @@ -4752,7 +4767,7 @@ class TestNAT44EndpointDependent(MethodHolder): self.pg_enable_capture(self.pg_interfaces) self.pg_start() capture = self.pg1.get_capture(len(pkts)) - self.verify_capture_out(capture) + self.verify_capture_out(capture, ignore_port=True) err_new = self.statistics.get_err_counter( '/err/nat44-ed-in2out-slowpath/out of ports') @@ -4806,7 +4821,7 @@ class TestNAT44EndpointDependent(MethodHolder): self.pg_enable_capture(self.pg_interfaces) self.pg_start() capture = self.pg8.get_capture(len(pkts)) - self.verify_capture_out(capture) + self.verify_capture_out(capture, ignore_port=True) err = self.statistics.get_err_counter( '/err/nat44-ed-in2out-slowpath/TCP packets') @@ -5555,13 +5570,13 @@ class TestNAT44EndpointDependent(MethodHolder): self.pg_enable_capture(self.pg_interfaces) self.pg_start() capture = self.pg1.get_capture(len(pkts)) - self.verify_capture_out(capture) + self.verify_capture_out(capture, ignore_port=True) pkts = self.create_stream_in(self.pg0, self.pg1) self.pg0.add_stream(pkts) self.pg_enable_capture(self.pg_interfaces) self.pg_start() capture = self.pg1.get_capture(len(pkts)) - self.verify_capture_out(capture) + self.verify_capture_out(capture, ignore_port=True) # from external network back to local network host pkts = self.create_stream_out(self.pg1) @@ -5585,7 +5600,7 @@ class TestNAT44EndpointDependent(MethodHolder): self.pg_enable_capture(self.pg_interfaces) self.pg_start() capture = self.pg1.get_capture(len(pkts)) - self.verify_capture_out(capture) + self.verify_capture_out(capture, ignore_port=True) pkts = self.create_stream_out(self.pg1) self.pg1.add_stream(pkts) @@ -6595,7 +6610,7 @@ class TestNAT44EndpointDependent(MethodHolder): self.pg_enable_capture(self.pg_interfaces) self.pg_start() capture = self.pg1.get_capture(len(pkts)) - self.verify_capture_out(capture) + self.verify_capture_out(capture, ignore_port=True) # out2in pkts = self.create_stream_out(self.pg1) @@ -6627,7 +6642,7 @@ class TestNAT44EndpointDependent(MethodHolder): pkts_in2out = self.create_stream_in(self.pg0, self.pg1) capture = self.send_and_expect(self.pg0, pkts_in2out, self.pg1, len(pkts_in2out)) - self.verify_capture_out(capture) + self.verify_capture_out(capture, ignore_port=True) # send out2in again, with sessions created it should work now pkts_out2in = self.create_stream_out(self.pg1) @@ -6657,7 +6672,7 @@ class TestNAT44EndpointDependent(MethodHolder): # send in2out to generate ACL state (NAT state was created earlier) capture = self.send_and_expect(self.pg0, pkts_in2out, self.pg1, len(pkts_in2out)) - self.verify_capture_out(capture) + self.verify_capture_out(capture, ignore_port=True) # send out2in again. ACL state exists so it should work now. # TCP packets with the syn flag set also need the ack flag @@ -6762,7 +6777,6 @@ class TestNAT44EndpointDependent(MethodHolder): ip = p[IP] tcp = p[TCP] self.assertEqual(ip.src, self.nat_addr) - self.assertNotEqual(tcp.sport, 2345) self.assert_packet_checksums_valid(p) port = tcp.sport except: |