aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/wireguard/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/wireguard/test')
-rwxr-xr-xsrc/plugins/wireguard/test/test_wireguard.py292
1 files changed, 292 insertions, 0 deletions
diff --git a/src/plugins/wireguard/test/test_wireguard.py b/src/plugins/wireguard/test/test_wireguard.py
new file mode 100755
index 00000000000..cd124f3e246
--- /dev/null
+++ b/src/plugins/wireguard/test/test_wireguard.py
@@ -0,0 +1,292 @@
+#!/usr/bin/env python3
+""" Wg tests """
+
+from scapy.packet import Packet
+from scapy.packet import Raw
+from scapy.layers.l2 import Ether
+from scapy.layers.inet import IP, UDP
+from scapy.contrib.wireguard import Wireguard, WireguardResponse, \
+ WireguardInitiation
+from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
+from cryptography.hazmat.primitives.serialization import Encoding, \
+ PrivateFormat, PublicFormat, NoEncryption
+
+from vpp_ipip_tun_interface import VppIpIpTunInterface
+from vpp_interface import VppInterface
+from vpp_object import VppObject
+from framework import VppTestCase
+from re import compile
+import unittest
+
+""" TestWg is a subclass of VPPTestCase classes.
+
+Wg test.
+
+"""
+
+
+class VppWgInterface(VppInterface):
+ """
+ VPP WireGuard interface
+ """
+
+ def __init__(self, test, src, port, key=None):
+ super(VppWgInterface, self).__init__(test)
+
+ self.key = key
+ if not self.key:
+ self.generate = True
+ else:
+ self.generate = False
+ self.port = port
+ self.src = src
+
+ def add_vpp_config(self):
+ r = self.test.vapi.wireguard_interface_create(interface={
+ 'user_instance': 0xffffffff,
+ 'port': self.port,
+ 'src_ip': self.src,
+ 'private_key': self.key_bytes()
+ })
+ self.set_sw_if_index(r.sw_if_index)
+ self.test.registry.register(self, self.test.logger)
+ return self
+
+ def key_bytes(self):
+ if self.key:
+ return self.key.private_bytes(Encoding.Raw,
+ PrivateFormat.Raw,
+ NoEncryption())
+ else:
+ return bytearray(32)
+
+ def remove_vpp_config(self):
+ self.test.vapi.wireguard_interface_delete(
+ sw_if_index=self._sw_if_index)
+
+ def query_vpp_config(self):
+ ts = self.test.vapi.wireguard_interface_dump(sw_if_index=0xffffffff)
+ for t in ts:
+ if t.interface.sw_if_index == self._sw_if_index and \
+ str(t.interface.src_ip) == self.src and \
+ t.interface.port == self.port and \
+ t.interface.private_key == self.key_bytes():
+ return True
+ return False
+
+ def __str__(self):
+ return self.object_id()
+
+ def object_id(self):
+ return "wireguard-%d" % self._sw_if_index
+
+
+def find_route(test, prefix, table_id=0):
+ routes = test.vapi.ip_route_dump(table_id, False)
+
+ for e in routes:
+ if table_id == e.route.table_id \
+ and str(e.route.prefix) == str(prefix):
+ return True
+ return False
+
+
+class VppWgPeer(VppObject):
+
+ def __init__(self,
+ test,
+ itf,
+ endpoint,
+ port,
+ allowed_ips,
+ persistent_keepalive=15):
+ self._test = test
+ self.itf = itf
+ self.endpoint = endpoint
+ self.port = port
+ self.allowed_ips = allowed_ips
+ self.persistent_keepalive = persistent_keepalive
+ self.private_key = X25519PrivateKey.generate()
+ self.public_key = self.private_key.public_key()
+ self.hash = bytearray(16)
+
+ def validate_routing(self):
+ for a in self.allowed_ips:
+ self._test.assertTrue(find_route(self._test, a))
+
+ def validate_no_routing(self):
+ for a in self.allowed_ips:
+ self._test.assertFalse(find_route(self._test, a))
+
+ def add_vpp_config(self):
+ rv = self._test.vapi.wireguard_peer_add(
+ peer={
+ 'public_key': self.public_key_bytes(),
+ 'port': self.port,
+ 'endpoint': self.endpoint,
+ 'n_allowed_ips': len(self.allowed_ips),
+ 'allowed_ips': self.allowed_ips,
+ 'sw_if_index': self.itf.sw_if_index,
+ 'persistent_keepalive': self.persistent_keepalive})
+ self.index = rv.peer_index
+ self._test.registry.register(self, self._test.logger)
+ self.validate_routing()
+ return self
+
+ def remove_vpp_config(self):
+ self._test.vapi.wireguard_peer_remove(peer_index=self.index)
+ self.validate_no_routing()
+
+ def object_id(self):
+ return ("wireguard-peer-%s" % self.index)
+
+ def public_key_bytes(self):
+ return self.public_key.public_bytes(Encoding.Raw,
+ PublicFormat.Raw)
+
+ def private_key_bytes(self):
+ return self.private_key.private_bytes(Encoding.Raw,
+ PrivateFormat.Raw,
+ NoEncryption())
+
+ def query_vpp_config(self):
+ peers = self._test.vapi.wireguard_peers_dump()
+
+ for p in peers:
+ if p.peer.public_key == self.public_key_bytes() and \
+ p.peer.port == self.port and \
+ str(p.peer.endpoint) == self.endpoint and \
+ p.peer.sw_if_index == self.itf.sw_if_index and \
+ len(self.allowed_ips) == p.peer.n_allowed_ips:
+ self.allowed_ips.sort()
+ p.peer.allowed_ips.sort()
+
+ for (a1, a2) in zip(self.allowed_ips, p.peer.allowed_ips):
+ if str(a1) != str(a2):
+ return False
+ return True
+ return False
+
+
+class TestWg(VppTestCase):
+ """ Wireguard Test Case """
+
+ error_str = compile(r"Error")
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestWg, cls).setUpClass()
+ try:
+ cls.create_pg_interfaces(range(3))
+ for i in cls.pg_interfaces:
+ i.admin_up()
+ i.config_ip4()
+ i.resolve_arp()
+
+ except Exception:
+ super(TestWg, cls).tearDownClass()
+ raise
+
+ @classmethod
+ def tearDownClass(cls):
+ super(TestWg, cls).tearDownClass()
+
+ def test_wg_interface(self):
+ port = 12312
+
+ # Create interface
+ wg0 = VppWgInterface(self,
+ self.pg1.local_ip4,
+ port).add_vpp_config()
+
+ self.logger.info(self.vapi.cli("sh int"))
+
+ # delete interface
+ wg0.remove_vpp_config()
+
+ def test_wg_peer(self):
+ wg_output_node_name = '/err/wg-output-tun/'
+ wg_input_node_name = '/err/wg-input/'
+
+ port = 12323
+
+ # Create interfaces
+ wg0 = VppWgInterface(self,
+ self.pg1.local_ip4,
+ port,
+ key=X25519PrivateKey.generate()).add_vpp_config()
+ wg1 = VppWgInterface(self,
+ self.pg2.local_ip4,
+ port+1).add_vpp_config()
+ wg0.admin_up()
+ wg1.admin_up()
+
+ # Check peer counter
+ self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
+
+ self.pg_enable_capture(self.pg_interfaces)
+ self.pg_start()
+
+ peer_1 = VppWgPeer(self,
+ wg0,
+ self.pg1.remote_ip4,
+ port+1,
+ ["10.11.2.0/24",
+ "10.11.3.0/24"]).add_vpp_config()
+ self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
+
+ # wait for the peer to send a handshake
+ capture = self.pg1.get_capture(1, timeout=2)
+ handshake = capture[0]
+
+ self.assertEqual(handshake[IP].src, wg0.src)
+ self.assertEqual(handshake[IP].dst, peer_1.endpoint)
+ self.assertEqual(handshake[UDP].sport, wg0.port)
+ self.assertEqual(handshake[UDP].dport, peer_1.port)
+ handshake = Wireguard(handshake[Raw])
+ self.assertEqual(handshake.message_type, 1) # "initiate")
+ init = handshake[WireguardInitiation]
+
+ # route a packet into the wg interface
+ # use the allowed-ip prefix
+ p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+ IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
+ UDP(sport=555, dport=556) /
+ Raw())
+ # rx = self.send_and_expect(self.pg0, [p], self.pg1)
+ rx = self.send_and_assert_no_replies(self.pg0, [p])
+
+ self.logger.info(self.vapi.cli("sh error"))
+ init_sent = wg_output_node_name + "Keypair error"
+ self.assertEqual(1, self.statistics.get_err_counter(init_sent))
+
+ # Create many peers on sencond interface
+ NUM_PEERS = 16
+ self.pg2.generate_remote_hosts(NUM_PEERS)
+ self.pg2.configure_ipv4_neighbors()
+
+ peers = []
+ for i in range(NUM_PEERS):
+ peers.append(VppWgPeer(self,
+ wg1,
+ self.pg2.remote_hosts[i].ip4,
+ port+1+i,
+ ["10.10.%d.4/32" % i]).add_vpp_config())
+ self.assertEqual(len(self.vapi.wireguard_peers_dump()), i+2)
+
+ self.logger.info(self.vapi.cli("show wireguard peer"))
+ self.logger.info(self.vapi.cli("show wireguard interface"))
+ self.logger.info(self.vapi.cli("show adj 37"))
+ self.logger.info(self.vapi.cli("sh ip fib 172.16.3.17"))
+ self.logger.info(self.vapi.cli("sh ip fib 10.11.3.0"))
+
+ # remove peers
+ for p in peers:
+ self.assertTrue(p.query_vpp_config())
+ p.remove_vpp_config()
+ self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
+ peer_1.remove_vpp_config()
+ self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
+
+ wg0.remove_vpp_config()
+ # wg1.remove_vpp_config()