diff options
Diffstat (limited to 'src/plugins/wireguard/test/test_wireguard.py')
-rwxr-xr-x | src/plugins/wireguard/test/test_wireguard.py | 292 |
1 files changed, 292 insertions, 0 deletions
diff --git a/src/plugins/wireguard/test/test_wireguard.py b/src/plugins/wireguard/test/test_wireguard.py new file mode 100755 index 00000000000..cd124f3e246 --- /dev/null +++ b/src/plugins/wireguard/test/test_wireguard.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +""" Wg tests """ + +from scapy.packet import Packet +from scapy.packet import Raw +from scapy.layers.l2 import Ether +from scapy.layers.inet import IP, UDP +from scapy.contrib.wireguard import Wireguard, WireguardResponse, \ + WireguardInitiation +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey +from cryptography.hazmat.primitives.serialization import Encoding, \ + PrivateFormat, PublicFormat, NoEncryption + +from vpp_ipip_tun_interface import VppIpIpTunInterface +from vpp_interface import VppInterface +from vpp_object import VppObject +from framework import VppTestCase +from re import compile +import unittest + +""" TestWg is a subclass of VPPTestCase classes. + +Wg test. + +""" + + +class VppWgInterface(VppInterface): + """ + VPP WireGuard interface + """ + + def __init__(self, test, src, port, key=None): + super(VppWgInterface, self).__init__(test) + + self.key = key + if not self.key: + self.generate = True + else: + self.generate = False + self.port = port + self.src = src + + def add_vpp_config(self): + r = self.test.vapi.wireguard_interface_create(interface={ + 'user_instance': 0xffffffff, + 'port': self.port, + 'src_ip': self.src, + 'private_key': self.key_bytes() + }) + self.set_sw_if_index(r.sw_if_index) + self.test.registry.register(self, self.test.logger) + return self + + def key_bytes(self): + if self.key: + return self.key.private_bytes(Encoding.Raw, + PrivateFormat.Raw, + NoEncryption()) + else: + return bytearray(32) + + def remove_vpp_config(self): + self.test.vapi.wireguard_interface_delete( + sw_if_index=self._sw_if_index) + + def query_vpp_config(self): + ts = self.test.vapi.wireguard_interface_dump(sw_if_index=0xffffffff) + for t in ts: + if t.interface.sw_if_index == self._sw_if_index and \ + str(t.interface.src_ip) == self.src and \ + t.interface.port == self.port and \ + t.interface.private_key == self.key_bytes(): + return True + return False + + def __str__(self): + return self.object_id() + + def object_id(self): + return "wireguard-%d" % self._sw_if_index + + +def find_route(test, prefix, table_id=0): + routes = test.vapi.ip_route_dump(table_id, False) + + for e in routes: + if table_id == e.route.table_id \ + and str(e.route.prefix) == str(prefix): + return True + return False + + +class VppWgPeer(VppObject): + + def __init__(self, + test, + itf, + endpoint, + port, + allowed_ips, + persistent_keepalive=15): + self._test = test + self.itf = itf + self.endpoint = endpoint + self.port = port + self.allowed_ips = allowed_ips + self.persistent_keepalive = persistent_keepalive + self.private_key = X25519PrivateKey.generate() + self.public_key = self.private_key.public_key() + self.hash = bytearray(16) + + def validate_routing(self): + for a in self.allowed_ips: + self._test.assertTrue(find_route(self._test, a)) + + def validate_no_routing(self): + for a in self.allowed_ips: + self._test.assertFalse(find_route(self._test, a)) + + def add_vpp_config(self): + rv = self._test.vapi.wireguard_peer_add( + peer={ + 'public_key': self.public_key_bytes(), + 'port': self.port, + 'endpoint': self.endpoint, + 'n_allowed_ips': len(self.allowed_ips), + 'allowed_ips': self.allowed_ips, + 'sw_if_index': self.itf.sw_if_index, + 'persistent_keepalive': self.persistent_keepalive}) + self.index = rv.peer_index + self._test.registry.register(self, self._test.logger) + self.validate_routing() + return self + + def remove_vpp_config(self): + self._test.vapi.wireguard_peer_remove(peer_index=self.index) + self.validate_no_routing() + + def object_id(self): + return ("wireguard-peer-%s" % self.index) + + def public_key_bytes(self): + return self.public_key.public_bytes(Encoding.Raw, + PublicFormat.Raw) + + def private_key_bytes(self): + return self.private_key.private_bytes(Encoding.Raw, + PrivateFormat.Raw, + NoEncryption()) + + def query_vpp_config(self): + peers = self._test.vapi.wireguard_peers_dump() + + for p in peers: + if p.peer.public_key == self.public_key_bytes() and \ + p.peer.port == self.port and \ + str(p.peer.endpoint) == self.endpoint and \ + p.peer.sw_if_index == self.itf.sw_if_index and \ + len(self.allowed_ips) == p.peer.n_allowed_ips: + self.allowed_ips.sort() + p.peer.allowed_ips.sort() + + for (a1, a2) in zip(self.allowed_ips, p.peer.allowed_ips): + if str(a1) != str(a2): + return False + return True + return False + + +class TestWg(VppTestCase): + """ Wireguard Test Case """ + + error_str = compile(r"Error") + + @classmethod + def setUpClass(cls): + super(TestWg, cls).setUpClass() + try: + cls.create_pg_interfaces(range(3)) + for i in cls.pg_interfaces: + i.admin_up() + i.config_ip4() + i.resolve_arp() + + except Exception: + super(TestWg, cls).tearDownClass() + raise + + @classmethod + def tearDownClass(cls): + super(TestWg, cls).tearDownClass() + + def test_wg_interface(self): + port = 12312 + + # Create interface + wg0 = VppWgInterface(self, + self.pg1.local_ip4, + port).add_vpp_config() + + self.logger.info(self.vapi.cli("sh int")) + + # delete interface + wg0.remove_vpp_config() + + def test_wg_peer(self): + wg_output_node_name = '/err/wg-output-tun/' + wg_input_node_name = '/err/wg-input/' + + port = 12323 + + # Create interfaces + wg0 = VppWgInterface(self, + self.pg1.local_ip4, + port, + key=X25519PrivateKey.generate()).add_vpp_config() + wg1 = VppWgInterface(self, + self.pg2.local_ip4, + port+1).add_vpp_config() + wg0.admin_up() + wg1.admin_up() + + # Check peer counter + self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0) + + self.pg_enable_capture(self.pg_interfaces) + self.pg_start() + + peer_1 = VppWgPeer(self, + wg0, + self.pg1.remote_ip4, + port+1, + ["10.11.2.0/24", + "10.11.3.0/24"]).add_vpp_config() + self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1) + + # wait for the peer to send a handshake + capture = self.pg1.get_capture(1, timeout=2) + handshake = capture[0] + + self.assertEqual(handshake[IP].src, wg0.src) + self.assertEqual(handshake[IP].dst, peer_1.endpoint) + self.assertEqual(handshake[UDP].sport, wg0.port) + self.assertEqual(handshake[UDP].dport, peer_1.port) + handshake = Wireguard(handshake[Raw]) + self.assertEqual(handshake.message_type, 1) # "initiate") + init = handshake[WireguardInitiation] + + # route a packet into the wg interface + # use the allowed-ip prefix + p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) / + IP(src=self.pg0.remote_ip4, dst="10.11.3.2") / + UDP(sport=555, dport=556) / + Raw()) + # rx = self.send_and_expect(self.pg0, [p], self.pg1) + rx = self.send_and_assert_no_replies(self.pg0, [p]) + + self.logger.info(self.vapi.cli("sh error")) + init_sent = wg_output_node_name + "Keypair error" + self.assertEqual(1, self.statistics.get_err_counter(init_sent)) + + # Create many peers on sencond interface + NUM_PEERS = 16 + self.pg2.generate_remote_hosts(NUM_PEERS) + self.pg2.configure_ipv4_neighbors() + + peers = [] + for i in range(NUM_PEERS): + peers.append(VppWgPeer(self, + wg1, + self.pg2.remote_hosts[i].ip4, + port+1+i, + ["10.10.%d.4/32" % i]).add_vpp_config()) + self.assertEqual(len(self.vapi.wireguard_peers_dump()), i+2) + + self.logger.info(self.vapi.cli("show wireguard peer")) + self.logger.info(self.vapi.cli("show wireguard interface")) + self.logger.info(self.vapi.cli("show adj 37")) + self.logger.info(self.vapi.cli("sh ip fib 172.16.3.17")) + self.logger.info(self.vapi.cli("sh ip fib 10.11.3.0")) + + # remove peers + for p in peers: + self.assertTrue(p.query_vpp_config()) + p.remove_vpp_config() + self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1) + peer_1.remove_vpp_config() + self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0) + + wg0.remove_vpp_config() + # wg1.remove_vpp_config() |