#!/usr/bin/env python

import unittest
import socket
from logging import *

from framework import VppTestCase, VppTestRunner
from vpp_sub_interface import VppSubInterface, VppDot1QSubint, VppDot1ADSubint
from vpp_gre_interface import VppGreInterface
from vpp_ip_route import IpRoute, RoutePath
from vpp_papi_provider import L2_VTR_OP

from scapy.packet import Raw
from scapy.layers.l2 import Ether, Dot1Q, ARP, GRE
from scapy.layers.inet import IP, UDP
from scapy.layers.inet6 import ICMPv6ND_NS, ICMPv6ND_RA, IPv6, UDP
from scapy.contrib.mpls import MPLS
from scapy.volatile import RandMAC, RandIP


class TestGRE(VppTestCase):
    """ GRE Test Case """

    @classmethod
    def setUpClass(cls):
        super(TestGRE, cls).setUpClass()

    def setUp(self):
        super(TestGRE, self).setUp()

        # create 2 pg interfaces - set one in a non-default table.
        self.create_pg_interfaces(range(2))

        self.pg1.set_table_ip4(1)
        for i in self.pg_interfaces:
            i.admin_up()
            i.config_ip4()
            i.resolve_arp()

    def tearDown(self):
        super(TestGRE, self).tearDown()

    def create_stream_ip4(self, src_if, src_ip, dst_ip):
        pkts = []
        for i in range(0, 257):
            info = self.create_packet_info(src_if.sw_if_index,
                                           src_if.sw_if_index)
            payload = self.info_to_payload(info)
            p = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
                 IP(src=src_ip, dst=dst_ip) /
                 UDP(sport=1234, dport=1234) /
                 Raw(payload))
            info.data = p.copy()
            pkts.append(p)
        return pkts

    def create_tunnel_stream_4o4(self, src_if,
                                 tunnel_src, tunnel_dst,
                                 src_ip, dst_ip):
        pkts = []
        for i in range(0, 257):
            info = self.create_packet_info(src_if.sw_if_index,
                                           src_if.sw_if_index)
            payload = self.info_to_payload(info)
            p = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
                 IP(src=tunnel_src, dst=tunnel_dst) /
                 GRE() /
                 IP(src=src_ip, dst=dst_ip) /
                 UDP(sport=1234, dport=1234) /
                 Raw(payload))
            info.data = p.copy()
            pkts.append(p)
        return pkts

    def create_tunnel_stream_6o4(self, src_if,
                                 tunnel_src, tunnel_dst,
                                 src_ip, dst_ip):
        pkts = []
        for i in range(0, 257):
            info = self.create_packet_info(src_if.sw_if_index,
                                           src_if.sw_if_index)
            payload = self.info_to_payload(info)
            p = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
                 IP(src=tunnel_src, dst=tunnel_dst) /
                 GRE() /
                 IPv6(src=src_ip, dst=dst_ip) /
                 UDP(sport=1234, dport=1234) /
                 Raw(payload))
            info.data = p.copy()
            pkts.append(p)
        return pkts

    def create_tunnel_stream_l2o4(self, src_if,
                                  tunnel_src, tunnel_dst):
        pkts = []
        for i in range(0, 257):
            info = self.create_packet_info(src_if.sw_if_index,
                                           src_if.sw_if_index)
            payload = self.info_to_payload(info)
            p = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
                 IP(src=tunnel_src, dst=tunnel_dst) /
                 GRE() /
                 Ether(dst=RandMAC('*:*:*:*:*:*'),
                       src=RandMAC('*:*:*:*:*:*')) /
                 IP(src=str(RandIP()), dst=str(RandIP())) /
                 UDP(sport=1234, dport=1234) /
                 Raw(payload))
            info.data = p.copy()
            pkts.append(p)
        return pkts

    def create_tunnel_stream_vlano4(self, src_if,
                                    tunnel_src, tunnel_dst, vlan):
        pkts = []
        for i in range(0, 257):
            info = self.create_packet_info(src_if.sw_if_index,
                                           src_if.sw_if_index)
            payload = self.info_to_payload(info)
            p = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
                 IP(src=tunnel_src, dst=tunnel_dst) /
                 GRE() /
                 Ether(dst=RandMAC('*:*:*:*:*:*'),
                       src=RandMAC('*:*:*:*:*:*')) /
                 Dot1Q(vlan=vlan) /
                 IP(src=str(RandIP()), dst=str(RandIP())) /
                 UDP(sport=1234, dport=1234) /
                 Raw(payload))
            info.data = p.copy()
            pkts.append(p)
        return pkts

    def verify_filter(self, capture, sent):
        if not len(capture) == len(sent):
            # filter out any IPv6 RAs from the captur
            for p in capture:
                if (p.haslayer(ICMPv6ND_RA)):
                    capture.remove(p)
        return capture

    def verify_tunneled_4o4(self, src_if, capture, sent,
                            tunnel_src, tunnel_dst):

        capture = self.verify_filter(capture, sent)
        self.assertEqual(len(capture), len(sent))

        for i in range(len(capture)):
            try:
                tx = sent[i]
                rx = capture[i]

                tx_ip = tx[IP]
                rx_ip = rx[IP]

                self.assertEqual(rx_ip.src, tunnel_src)
                self.assertEqual(rx_ip.dst, tunnel_dst)

                rx_gre = rx[GRE]
                rx_ip = rx_gre[IP]

                self.assertEqual(rx_ip.src, tx_ip.src)
                self.assertEqual(rx_ip.dst, tx_ip.dst)
                # IP processing post pop has decremented the TTL
                self.assertEqual(rx_ip.ttl + 1, tx_ip.ttl)

            except:
                rx.show()
                tx.show()
                raise

    def verify_tunneled_l2o4(self, src_if, capture, sent,
                             tunnel_src, tunnel_dst):
        capture = self.verify_filter(capture, sent)
        self.assertEqual(len(capture), len(sent))

        for i in range(len(capture)):
            try:
                tx = sent[i]
                rx = capture[i]

                tx_ip = tx[IP]
                rx_ip = rx[IP]

                self.assertEqual(rx_ip.src, tunnel_src)
                self.assertEqual(rx_ip.dst, tunnel_dst)

                rx_gre = rx[GRE]
                rx_l2 = rx_gre[Ether]
                rx_ip = rx_l2[IP]
                tx_gre = tx[GRE]
                tx_l2 = tx_gre[Ether]
                tx_ip = tx_l2[IP]

                self.assertEqual(rx_ip.src, tx_ip.src)
                self.assertEqual(rx_ip.dst, tx_ip.dst)
                # bridged, not L3 forwarded, so no TTL decrement
                self.assertEqual(rx_ip.ttl, tx_ip.ttl)

            except:
                rx.show()
                tx.show()
                raise

    def verify_tunneled_vlano4(self, src_if, capture, sent,
                               tunnel_src, tunnel_dst, vlan):
        try:
            capture = self.verify_filter(capture, sent)
            self.assertEqual(len(capture), len(sent))
        except:
            capture.show()
            raise

        for i in range(len(capture)):
            try:
                tx = sent[i]
                rx = capture[i]

                tx_ip = tx[IP]
                rx_ip = rx[IP]

                self.assertEqual(rx_ip.src, tunnel_src)
                self.assertEqual(rx_ip.dst, tunnel_dst)

                rx_gre = rx[GRE]
                rx_l2 = rx_gre[Ether]
                rx_vlan = rx_l2[Dot1Q]
                rx_ip = rx_l2[IP]

                self.assertEqual(rx_vlan.vlan, vlan)

                tx_gre = tx[GRE]
                tx_l2 = tx_gre[Ether]
                tx_ip = tx_l2[IP]

                self.assertEqual(rx_ip.src, tx_ip.src)
                self.assertEqual(rx_ip.dst, tx_ip.dst)
                # bridged, not L3 forwarded, so no TTL decrement
                self.assertEqual(rx_ip.ttl, tx_ip.ttl)

            except:
                rx.show()
                tx.show()
                raise

    def verify_decapped_4o4(self, src_if, capture, sent):
        capture = self.verify_filter(capture, sent)
        self.assertEqual(len(capture), len(sent))

        for i in range(len(capture)):
            try:
                tx = sent[i]
                rx = capture[i]

                tx_ip = tx[IP]
                rx_ip = rx[IP]
                tx_gre = tx[GRE]
                tx_ip = tx_gre[IP]

                self.assertEqual(rx_ip.src, tx_ip.src)
                self.assertEqual(rx_ip.dst, tx_ip.dst)
                # IP processing post pop has decremented the TTL
                self.assertEqual(rx_ip.ttl + 1, tx_ip.ttl)

            except:
                rx.show()
                tx.show()
                raise

    def verify_decapped_6o4(self, src_if, capture, sent):
        capture = self.verify_filter(capture, sent)
        self.assertEqual(len(capture), len(sent))

        for i in range(len(capture)):
            try:
                tx = sent[i]
                rx = capture[i]

                tx_ip = tx[IP]
                rx_ip = rx[IPv6]
                tx_gre = tx[GRE]
                tx_ip = tx_gre[IPv6]

                self.assertEqual(rx_ip.src, tx_ip.src)
                self.assertEqual(rx_ip.dst, tx_ip.dst)
                self.assertEqual(rx_ip.hlim + 1, tx_ip.hlim)

            except:
                rx.show()
                tx.show()
                raise

    def test_gre(self):
        """ GRE tunnel Tests """

        #
        # Create an L3 GRE tunnel.
        #  - set it admin up
        #  - assign an IP Addres
        #  - Add a route via the tunnel
        #
        gre_if = VppGreInterface(self,
                                 self.pg0.local_ip4,
                                 "1.1.1.2")
        gre_if.add_vpp_config()

        #
        # The double create (create the same tunnel twice) should fail,
        # and we should still be able to use the original
        #
        try:
            gre_if.add_vpp_config()
        except Exception:
            pass
        else:
            self.fail("Double GRE tunnel add does not fail")

        gre_if.admin_up()
        gre_if.config_ip4()

        route_via_tun = IpRoute(self, "4.4.4.4", 32,
                                [RoutePath("0.0.0.0", gre_if.sw_if_index)])

        route_via_tun.add_vpp_config()

        #
        # Send a packet stream that is routed into the tunnel
        #  - they are all dropped since the tunnel's desintation IP
        #    is unresolved - or resolves via the default route - which
        #    which is a drop.
        #
        tx = self.create_stream_ip4(self.pg0, "5.5.5.5", "4.4.4.4")
        self.pg0.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg0.get_capture()

        try:
            self.assertEqual(0, len(rx))
        except:
            error("GRE packets forwarded without DIP resolved")
            error(rx.show())
            raise

        #
        # Add a route that resolves the tunnel's destination
        #
        route_tun_dst = IpRoute(self, "1.1.1.2", 32,
                                [RoutePath(self.pg0.remote_ip4,
                                           self.pg0.sw_if_index)])
        route_tun_dst.add_vpp_config()

        #
        # Send a packet stream that is routed into the tunnel
        #  - packets are GRE encapped
        #
        self.vapi.cli("clear trace")
        tx = self.create_stream_ip4(self.pg0, "5.5.5.5", "4.4.4.4")
        self.pg0.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg0.get_capture()
        self.verify_tunneled_4o4(self.pg0, rx, tx,
                                 self.pg0.local_ip4, "1.1.1.2")

        #
        # Send tunneled packets that match the created tunnel and
        # are decapped and forwarded
        #
        self.vapi.cli("clear trace")
        tx = self.create_tunnel_stream_4o4(self.pg0,
                                           "1.1.1.2",
                                           self.pg0.local_ip4,
                                           self.pg0.local_ip4,
                                           self.pg0.remote_ip4)
        self.pg0.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg0.get_capture()
        self.verify_decapped_4o4(self.pg0, rx, tx)

        #
        # Send tunneled packets that do not match the tunnel's src
        #
        self.vapi.cli("clear trace")
        tx = self.create_tunnel_stream_4o4(self.pg0,
                                           "1.1.1.3",
                                           self.pg0.local_ip4,
                                           self.pg0.local_ip4,
                                           self.pg0.remote_ip4)
        self.pg0.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg0.get_capture()
        try:
            self.assertEqual(0, len(rx))
        except:
            error("GRE packets forwarded despite no SRC address match")
            error(rx.show())
            raise

        #
        # Configure IPv6 on the PG interface so we can route IPv6
        # packets
        #
        self.pg0.config_ip6()
        self.pg0.resolve_ndp()

        #
        # Send IPv6 tunnel encapslated packets
        #  - dropped since IPv6 is not enabled on the tunnel
        #
        self.vapi.cli("clear trace")
        tx = self.create_tunnel_stream_6o4(self.pg0,
                                           "1.1.1.2",
                                           self.pg0.local_ip4,
                                           self.pg0.local_ip6,
                                           self.pg0.remote_ip6)
        self.pg0.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg0.get_capture()
        try:
            self.assertEqual(0, len(rx))
        except:
            error("IPv6 GRE packets forwarded despite IPv6 not enabled on tunnel")
            error(rx.show())
            raise

        #
        # Enable IPv6 on the tunnel
        #
        gre_if.config_ip6()

        #
        # Send IPv6 tunnel encapslated packets
        #  - forwarded since IPv6 is enabled on the tunnel
        #
        self.vapi.cli("clear trace")
        tx = self.create_tunnel_stream_6o4(self.pg0,
                                           "1.1.1.2",
                                           self.pg0.local_ip4,
                                           self.pg0.local_ip6,
                                           self.pg0.remote_ip6)
        self.pg0.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg0.get_capture()
        self.verify_decapped_6o4(self.pg0, rx, tx)

        #
        # test case cleanup
        #
        route_tun_dst.remove_vpp_config()
        route_via_tun.remove_vpp_config()
        gre_if.remove_vpp_config()

        self.pg0.unconfig_ip6()

    def test_gre_vrf(self):
        """ GRE tunnel VRF Tests """

        #
        # Create an L3 GRE tunnel whose destination is in the non-default
        # table. The underlay is thus non-default - the overlay is still
        # the default.
        #  - set it admin up
        #  - assign an IP Addres
        #
        gre_if = VppGreInterface(self, self.pg1.local_ip4,
                                 "2.2.2.2",
                                 outer_fib_id=1)
        gre_if.add_vpp_config()
        gre_if.admin_up()
        gre_if.config_ip4()

        #
        # Add a route via the tunnel - in the overlay
        #
        route_via_tun = IpRoute(self, "9.9.9.9", 32,
                                [RoutePath("0.0.0.0", gre_if.sw_if_index)])
        route_via_tun.add_vpp_config()

        #
        # Add a route that resolves the tunnel's destination - in the
        # underlay table
        #
        route_tun_dst = IpRoute(self, "2.2.2.2", 32, table_id=1,
                                paths=[RoutePath(self.pg1.remote_ip4,
                                                 self.pg1.sw_if_index)])
        route_tun_dst.add_vpp_config()

        #
        # Send a packet stream that is routed into the tunnel
        # packets are sent in on pg0 which is in the default table
        #  - packets are GRE encapped
        #
        self.vapi.cli("clear trace")
        tx = self.create_stream_ip4(self.pg0, "5.5.5.5", "9.9.9.9")
        self.pg0.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg1.get_capture()
        self.verify_tunneled_4o4(self.pg1, rx, tx,
                                 self.pg1.local_ip4, "2.2.2.2")

        #
        # Send tunneled packets that match the created tunnel and
        # are decapped and forwarded. This tests the decap lookup
        # does not happen in the encap table
        #
        self.vapi.cli("clear trace")
        tx = self.create_tunnel_stream_4o4(self.pg1,
                                           "2.2.2.2",
                                           self.pg1.local_ip4,
                                           self.pg0.local_ip4,
                                           self.pg0.remote_ip4)
        self.pg1.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg0.get_capture()
        self.verify_decapped_4o4(self.pg0, rx, tx)

        #
        # test case cleanup
        #
        route_tun_dst.remove_vpp_config()
        route_via_tun.remove_vpp_config()
        gre_if.remove_vpp_config()

    def test_gre_l2(self):
        """ GRE tunnel L2 Tests """

        #
        # Add routes to resolve the tunnel destinations
        #
        route_tun1_dst = IpRoute(self, "2.2.2.2", 32,
                                 [RoutePath(self.pg0.remote_ip4,
                                            self.pg0.sw_if_index)])
        route_tun2_dst = IpRoute(self, "2.2.2.3", 32,
                                 [RoutePath(self.pg0.remote_ip4,
                                            self.pg0.sw_if_index)])

        route_tun1_dst.add_vpp_config()
        route_tun2_dst.add_vpp_config()

        #
        # Create 2 L2 GRE tunnels and x-connect them
        #
        gre_if1 = VppGreInterface(self, self.pg0.local_ip4,
                                  "2.2.2.2",
                                  is_teb=1)
        gre_if2 = VppGreInterface(self, self.pg0.local_ip4,
                                  "2.2.2.3",
                                  is_teb=1)
        gre_if1.add_vpp_config()
        gre_if2.add_vpp_config()

        gre_if1.admin_up()
        gre_if2.admin_up()

        self.vapi.sw_interface_set_l2_xconnect(gre_if1.sw_if_index,
                                               gre_if2.sw_if_index,
                                               enable=1)
        self.vapi.sw_interface_set_l2_xconnect(gre_if2.sw_if_index,
                                               gre_if1.sw_if_index,
                                               enable=1)

        #
        # Send in tunnel encapped L2. expect out tunnel encapped L2
        # in both directions
        #
        self.vapi.cli("clear trace")
        tx = self.create_tunnel_stream_l2o4(self.pg0,
                                            "2.2.2.2",
                                            self.pg0.local_ip4)
        self.pg0.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg0.get_capture()
        self.verify_tunneled_l2o4(self.pg0, rx, tx,
                                  self.pg0.local_ip4,
                                  "2.2.2.3")

        self.vapi.cli("clear trace")
        tx = self.create_tunnel_stream_l2o4(self.pg0,
                                            "2.2.2.3",
                                            self.pg0.local_ip4)
        self.pg0.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg0.get_capture()
        self.verify_tunneled_l2o4(self.pg0, rx, tx,
                                  self.pg0.local_ip4,
                                  "2.2.2.2")

        self.vapi.sw_interface_set_l2_xconnect(gre_if1.sw_if_index,
                                               gre_if2.sw_if_index,
                                               enable=0)
        self.vapi.sw_interface_set_l2_xconnect(gre_if2.sw_if_index,
                                               gre_if1.sw_if_index,
                                               enable=0)

        #
        # Create a VLAN sub-interfaces on the GRE TEB interfaces
        # then x-connect them
        #
        gre_if_11 = VppDot1QSubint(self, gre_if1, 11)
        gre_if_12 = VppDot1QSubint(self, gre_if2, 12)

        # gre_if_11.add_vpp_config()
        # gre_if_12.add_vpp_config()

        gre_if_11.admin_up()
        gre_if_12.admin_up()

        self.vapi.sw_interface_set_l2_xconnect(gre_if_11.sw_if_index,
                                               gre_if_12.sw_if_index,
                                               enable=1)
        self.vapi.sw_interface_set_l2_xconnect(gre_if_12.sw_if_index,
                                               gre_if_11.sw_if_index,
                                               enable=1)

        #
        # Configure both to pop thier respective VLAN tags,
        # so that during the x-coonect they will subsequently push
        #
        self.vapi.sw_interface_set_l2_tag_rewrite(gre_if_12.sw_if_index,
                                                  L2_VTR_OP.L2_POP_1,
                                                  12)
        self.vapi.sw_interface_set_l2_tag_rewrite(gre_if_11.sw_if_index,
                                                  L2_VTR_OP.L2_POP_1,
                                                  11)

        #
        # Send traffic in both directiond - expect the VLAN tags to
        # be swapped.
        #
        self.vapi.cli("clear trace")
        tx = self.create_tunnel_stream_vlano4(self.pg0,
                                              "2.2.2.2",
                                              self.pg0.local_ip4,
                                              11)
        self.pg0.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg0.get_capture()
        self.verify_tunneled_vlano4(self.pg0, rx, tx,
                                    self.pg0.local_ip4,
                                    "2.2.2.3",
                                    12)

        self.vapi.cli("clear trace")
        tx = self.create_tunnel_stream_vlano4(self.pg0,
                                              "2.2.2.3",
                                              self.pg0.local_ip4,
                                              12)
        self.pg0.add_stream(tx)

        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()

        rx = self.pg0.get_capture()
        self.verify_tunneled_vlano4(self.pg0, rx, tx,
                                    self.pg0.local_ip4,
                                    "2.2.2.2",
                                    11)

        #
        # Cleanup Test resources
        #
        gre_if_11.remove_vpp_config()
        gre_if_12.remove_vpp_config()
        gre_if1.remove_vpp_config()
        gre_if2.remove_vpp_config()
        route_tun1_dst.add_vpp_config()
        route_tun2_dst.add_vpp_config()


if __name__ == '__main__':
    unittest.main(testRunner=VppTestRunner)