#!/usr/bin/env python3

import unittest

from framework import VppTestCase
from asfframework import VppTestRunner
from scapy.layers.inet import IP, TCP, UDP
from scapy.layers.inet6 import (
    IPv6,
    ICMPv6EchoRequest,
    ICMPv6EchoReply,
)
from scapy.layers.l2 import Ether, GRE
from util import ppp
from vpp_papi import VppEnum


class TestNAT66(VppTestCase):
    """NAT66 Test Cases"""

    @classmethod
    def setUpClass(cls):
        super(TestNAT66, cls).setUpClass()

        cls.nat_addr = "fd01:ff::2"
        cls.create_pg_interfaces(range(2))
        cls.interfaces = list(cls.pg_interfaces)

        for i in cls.interfaces:
            i.admin_up()
            i.config_ip6()
            i.configure_ipv6_neighbors()

    @property
    def config_flags(self):
        return VppEnum.vl_api_nat_config_flags_t

    def plugin_enable(self):
        self.vapi.nat66_plugin_enable_disable(enable=1)

    def plugin_disable(self):
        self.vapi.nat66_plugin_enable_disable(enable=0)

    def setUp(self):
        super(TestNAT66, self).setUp()
        self.plugin_enable()

    def tearDown(self):
        super(TestNAT66, self).tearDown()
        if not self.vpp_dead:
            self.plugin_disable()

    def test_static(self):
        """1:1 NAT66 test"""
        flags = self.config_flags.NAT_IS_INSIDE
        self.vapi.nat66_add_del_interface(
            is_add=1, flags=flags, sw_if_index=self.pg0.sw_if_index
        )
        self.vapi.nat66_add_del_interface(is_add=1, sw_if_index=self.pg1.sw_if_index)
        self.vapi.nat66_add_del_static_mapping(
            local_ip_address=self.pg0.remote_ip6,
            external_ip_address=self.nat_addr,
            is_add=1,
        )

        # in2out
        pkts = []
        p = (
            Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac)
            / IPv6(src=self.pg0.remote_ip6, dst=self.pg1.remote_ip6)
            / TCP()
        )
        pkts.append(p)
        p = (
            Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac)
            / IPv6(src=self.pg0.remote_ip6, dst=self.pg1.remote_ip6)
            / UDP()
        )
        pkts.append(p)
        p = (
            Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac)
            / IPv6(src=self.pg0.remote_ip6, dst=self.pg1.remote_ip6)
            / ICMPv6EchoRequest()
        )
        pkts.append(p)
        p = (
            Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac)
            / IPv6(src=self.pg0.remote_ip6, dst=self.pg1.remote_ip6)
            / GRE()
            / IP()
            / TCP()
        )
        pkts.append(p)
        self.pg0.add_stream(pkts)
        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()
        capture = self.pg1.get_capture(len(pkts))

        for packet in capture:
            try:
                self.assertEqual(packet[IPv6].src, self.nat_addr)
                self.assertEqual(packet[IPv6].dst, self.pg1.remote_ip6)
                self.assert_packet_checksums_valid(packet)
            except:
                self.logger.error(ppp("Unexpected or invalid packet:", packet))
                raise

        # out2in
        pkts = []
        p = (
            Ether(dst=self.pg1.local_mac, src=self.pg1.remote_mac)
            / IPv6(src=self.pg1.remote_ip6, dst=self.nat_addr)
            / TCP()
        )
        pkts.append(p)
        p = (
            Ether(dst=self.pg1.local_mac, src=self.pg1.remote_mac)
            / IPv6(src=self.pg1.remote_ip6, dst=self.nat_addr)
            / UDP()
        )
        pkts.append(p)
        p = (
            Ether(dst=self.pg1.local_mac, src=self.pg1.remote_mac)
            / IPv6(src=self.pg1.remote_ip6, dst=self.nat_addr)
            / ICMPv6EchoReply()
        )
        pkts.append(p)
        p = (
            Ether(dst=self.pg1.local_mac, src=self.pg1.remote_mac)
            / IPv6(src=self.pg1.remote_ip6, dst=self.nat_addr)
            / GRE()
            / IP()
            / TCP()
        )
        pkts.append(p)
        self.pg1.add_stream(pkts)
        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()
        capture = self.pg0.get_capture(len(pkts))
        for packet in capture:
            try:
                self.assertEqual(packet[IPv6].src, self.pg1.remote_ip6)
                self.assertEqual(packet[IPv6].dst, self.pg0.remote_ip6)
                self.assert_packet_checksums_valid(packet)
            except:
                self.logger.error(ppp("Unexpected or invalid packet:", packet))
                raise

        sm = self.vapi.nat66_static_mapping_dump()
        self.assertEqual(len(sm), 1)
        self.assertEqual(sm[0].total_pkts, 8)

    def test_check_no_translate(self):
        """NAT66 translate only when egress interface is outside interface"""
        flags = self.config_flags.NAT_IS_INSIDE
        self.vapi.nat66_add_del_interface(
            is_add=1, flags=flags, sw_if_index=self.pg0.sw_if_index
        )
        self.vapi.nat66_add_del_interface(
            is_add=1, flags=flags, sw_if_index=self.pg1.sw_if_index
        )
        self.vapi.nat66_add_del_static_mapping(
            local_ip_address=self.pg0.remote_ip6,
            external_ip_address=self.nat_addr,
            is_add=1,
        )

        # in2out
        p = (
            Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac)
            / IPv6(src=self.pg0.remote_ip6, dst=self.pg1.remote_ip6)
            / UDP()
        )
        self.pg0.add_stream([p])
        self.pg_enable_capture(self.pg_interfaces)
        self.pg_start()
        capture = self.pg1.get_capture(1)
        packet = capture[0]
        try:
            self.assertEqual(packet[IPv6].src, self.pg0.remote_ip6)
            self.assertEqual(packet[IPv6].dst, self.pg1.remote_ip6)
        except:
            self.logger.error(ppp("Unexpected or invalid packet:", packet))
            raise


if __name__ == "__main__":
    unittest.main(testRunner=VppTestRunner)