#!/usr/bin/env python3
#
# Copyright (c) 2016 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import absolute_import
import ipaddress
import logging
import functools
import json
import weakref
import atexit
import importlib.resources as resources
import struct
import asyncio

from .vpp_serializer import VPPType, VPPEnumType, VPPEnumFlagType, VPPUnionType
from .vpp_serializer import VPPMessage, vpp_get_type, VPPTypeAlias

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

__all__ = (
    "FuncWrapper",
    "VppApiDynamicMethodHolder",
    "VppEnum",
    "VppEnumType",
    "VppEnumFlag",
    "VPPIOError",
    "VPPRuntimeError",
    "VPPValueError",
    "VPPApiClient",
)


def metaclass(metaclass):
    @functools.wraps(metaclass)
    def wrapper(cls):
        return metaclass(cls.__name__, cls.__bases__, cls.__dict__.copy())

    return wrapper


class VppEnumType(type):
    def __getattr__(cls, name):
        t = vpp_get_type(name)
        return t.enum


@metaclass(VppEnumType)
class VppEnum:
    pass


@metaclass(VppEnumType)
class VppEnumFlag:
    pass


class QueueObject:
    def __init__(self, b, context, future=None, eof_stream=None, details=None):
        self.b = b
        self.context = context
        self.future = future
        self.eof_stream = eof_stream
        self.details_msg = details


def vpp_atexit(vpp_weakref):
    """Clean up VPP connection on shutdown."""
    vpp_instance = vpp_weakref()
    if vpp_instance and vpp_instance.connected:
        logger.debug("Cleaning up VPP on exit")
        vpp_instance.disconnect()


def add_convenience_methods():
    # provide convenience methods to IP[46]Address.vapi_af
    def _vapi_af(self):
        if 6 == self._version:
            return VppEnum.vl_api_address_family_t.ADDRESS_IP6.value
        if 4 == self._version:
            return VppEnum.vl_api_address_family_t.ADDRESS_IP4.value
        raise ValueError("Invalid _version.")

    def _vapi_af_name(self):
        if 6 == self._version:
            return "ip6"
        if 4 == self._version:
            return "ip4"
        raise ValueError("Invalid _version.")

    ipaddress._IPAddressBase.vapi_af = property(_vapi_af)
    ipaddress._IPAddressBase.vapi_af_name = property(_vapi_af_name)


class VppApiDynamicMethodHolder:
    pass


class FuncWrapper:
    def __init__(self, func):
        self._func = func
        self.__name__ = func.__name__
        self.__doc__ = func.__doc__

    def __call__(self, **kwargs):
        return self._func(**kwargs)

    def __repr__(self):
        return "<FuncWrapper(func=<%s(%s)>)>" % (self.__name__, self.__doc__)


class VPPApiError(Exception):
    pass


class VPPNotImplementedError(NotImplementedError):
    pass


class VPPIOError(IOError):
    pass


class VPPRuntimeError(RuntimeError):
    pass


class VPPValueError(ValueError):
    pass


class VPPApiJSONFiles:
    @classmethod
    def process_json_str(self, json_str):
        api = json.loads(json_str)
        return self._process_json(api)

    @classmethod
    def process_json_array_str(self, json_str):
        services = {}
        messages = {}

        apis = json.loads(json_str)
        for a in apis:
            m, s = self._process_json(a)
            messages.update(m)
            services.update(s)
        return messages, services

    @staticmethod
    def _process_json(api):  # -> Tuple[Dict, Dict]
        types = {}
        services = {}
        messages = {}
        try:
            for t in api["enums"]:
                t[0] = "vl_api_" + t[0] + "_t"
                types[t[0]] = {"type": "enum", "data": t}
        except KeyError:
            pass
        try:
            for t in api["enumflags"]:
                t[0] = "vl_api_" + t[0] + "_t"
                types[t[0]] = {"type": "enum", "data": t}
        except KeyError:
            pass
        try:
            for t in api["unions"]:
                t[0] = "vl_api_" + t[0] + "_t"
                types[t[0]] = {"type": "union", "data": t}
        except KeyError:
            pass

        try:
            for t in api["types"]:
                t[0] = "vl_api_" + t[0] + "_t"
                types[t[0]] = {"type": "type", "data": t}
        except KeyError:
            pass

        try:
            for t, v in api["aliases"].items():
                types["vl_api_" + t + "_t"] = {"type": "alias", "data": v}
        except KeyError:
            pass

        try:
            services.update(api["services"])
        except KeyError:
            pass

        i = 0
        while True:
            unresolved = {}
            for k, v in types.items():
                t = v["data"]
                if not vpp_get_type(k):
                    if v["type"] == "enum":
                        try:
                            VPPEnumType(t[0], t[1:])
                        except ValueError:
                            unresolved[k] = v
                if not vpp_get_type(k):
                    if v["type"] == "enumflag":
                        try:
                            VPPEnumFlagType(t[0], t[1:])
                        except ValueError:
                            unresolved[k] = v
                    elif v["type"] == "union":
                        try:
                            VPPUnionType(t[0], t[1:])
                        except ValueError:
                            unresolved[k] = v
                    elif v["type"] == "type":
                        try:
                            VPPType(t[0], t[1:])
                        except ValueError:
                            unresolved[k] = v
                    elif v["type"] == "alias":
                        try:
                            VPPTypeAlias(k, t)
                        except ValueError:
                            unresolved[k] = v
            if len(unresolved) == 0:
                break
            if i > 3:
                raise VPPValueError("Unresolved type definitions {}".format(unresolved))
            types = unresolved
            i += 1
        try:
            for m in api["messages"]:
                try:
                    messages[m[0]] = VPPMessage(m[0], m[1:])
                except VPPNotImplementedError:
                    logger.error("Not implemented error for {}".format(m[0]))
        except KeyError:
            pass
        return messages, services


class VPPApiClient:
    """VPP interface.

    This class provides the APIs to VPP.  The APIs are loaded
    from provided .api.json files and makes functions accordingly.
    These functions are documented in the VPP .api files, as they
    are dynamically created.

    Additionally, VPP can send callback messages; this class
    provides a means to register a callback function to receive
    these messages in a background thread.
    """

    VPPApiError = VPPApiError
    VPPRuntimeError = VPPRuntimeError
    VPPValueError = VPPValueError
    VPPNotImplementedError = VPPNotImplementedError
    VPPIOError = VPPIOError

    def __init__(
        self,
        *,
        testmode=False,
        logger=None,
        loglevel=None,
        read_timeout=5,
        server_address="/run/vpp/api.sock",
    ):
        """Create a VPP API object.

        apifiles is a list of files containing API
        descriptions that will be loaded - methods will be
        dynamically created reflecting these APIs.  If not
        provided this will load the API files from VPP's
        default install location.

        logger, if supplied, is the logging logger object to log to.
        loglevel, if supplied, is the log level this logger is set
        to report at (from the loglevels in the logging module).
        """
        if logger is None:
            logger = logging.getLogger(
                "{}.{}".format(__name__, self.__class__.__name__)
            )
            if loglevel is not None:
                logger.setLevel(loglevel)
        self.logger = logger

        self.messages = {}
        self.services = {}
        self.id_names = []
        self.id_msgdef = []
        self.header = VPPType("header", [["u16", "msgid"], ["u32", "client_index"]])
        self.message_queue = asyncio.Queue()
        self.read_timeout = read_timeout
        self.testmode = testmode
        self.server_address = server_address
        self.stats = {}
        self.connected = False
        self.message_table = {}
        self.header_struct = struct.Struct(">QII")

        # Bootstrap the API (memclnt.api bundled with VPP PAPI)
        with resources.open_text("vpp_papi.data", "memclnt.api.json") as f:
            resource_content = f.read()
        self.messages, self.services = VPPApiJSONFiles.process_json_str(
            resource_content
        )

        # Basic sanity check
        if len(self.messages) == 0 and not testmode:
            raise VPPValueError(1, "Missing JSON message definitions")

        # Make sure we allow VPP to clean up the message rings.
        atexit.register(vpp_atexit, weakref.ref(self))

        add_convenience_methods()

    def get_function(self, name):
        return getattr(self._api, name)

    class ContextId:
        """Multiprocessing-safe provider of unique context IDs."""

        def __init__(self):
            self.context = 0

        def __call__(self):
            """Get a new unique (or, at least, not recently used) context."""
            self.context += 1
            return self.context

    get_context = ContextId()

    def get_type(self, name):
        return vpp_get_type(name)

    @property
    def api(self):
        if not hasattr(self, "_api"):
            raise VPPApiError("Not connected, api definitions not available")
        return self._api

    def make_function(self, msg, i, multipart):
        def f(**kwargs):
            return self._call_vpp_async(i, msg, multipart, **kwargs)

        f.__name__ = str(msg.name)
        f.__doc__ = ", ".join(
            ["%s %s" % (msg.fieldtypes[j], k) for j, k in enumerate(msg.fields)]
        )
        f.msg = msg

        return f

    def make_pack_function(self, msg, i, multipart):
        def f(**kwargs):
            return self._call_vpp_pack(i, msg, **kwargs)

        f.msg = msg
        return f

    def _register_functions(self):
        self.id_names = [None] * (self.vpp_dictionary_maxid + 1)
        self.id_msgdef = [None] * (self.vpp_dictionary_maxid + 1)
        self._api = VppApiDynamicMethodHolder()
        for name, msg in self.messages.items():
            n = name + "_" + msg.crc[2:]
            i = self.message_table[n]
            if i > 0:
                self.id_msgdef[i] = msg
                self.id_names[i] = name

                # Create function for client side messages.
                if name in self.services:
                    f = self.make_function(msg, i, self.services[name])
                    f_pack = self.make_pack_function(msg, i, self.services[name])
                    setattr(self._api, name, FuncWrapper(f))
                    setattr(self._api, name + "_pack", FuncWrapper(f_pack))
            else:
                self.logger.debug("No such message type or failed CRC checksum: %s", n)

    async def get_api_definitions(self):
        """get_api_definition. Bootstrap from the embedded memclnt.api.json file."""

        # Bootstrap so we can call the get_api_json function
        self._register_functions()

        # f = await self.api.get_api_json()
        f = self.api.get_api_json()
        r = await asyncio.gather(f)
        r = r[0]
        if r.retval != 0:
            raise VPPApiError("Failed to load API definitions from VPP")

        # Process JSON
        m, s = VPPApiJSONFiles.process_json_array_str(r.json)
        self.messages.update(m)
        self.services.update(s)

    def get_msg_index(self, name):
        try:
            return self.message_table[name]
        except KeyError:
            return 0

    async def connect(self, name, event_queue):
        """Attach to VPP."""
        try:
            reader, writer = await asyncio.open_unix_connection(self.server_address)
        except (PermissionError, FileNotFoundError):
            return -1
        self.reader = reader
        self.writer = writer

        # Initialise sockclnt_create
        sockclnt_create = self.messages["sockclnt_create"]
        sockclnt_create_reply = self.messages["sockclnt_create_reply"]

        args = {"_vl_msg_id": 15, "name": name, "context": 124}
        b = sockclnt_create.pack(args)
        # Send header
        hdr = self.header_struct.pack(0, len(b), 0)
        writer.write(hdr)
        writer.write(b)
        await writer.drain()
        hdr = await reader.readexactly(16)
        (_, hdrlen, _) = self.header_struct.unpack(hdr)  # If at head of message
        msg = await reader.readexactly(hdrlen)
        header2 = VPPType("header", [["u16", "msgid"], ["u32", "client_index"]])
        hdr, _ = header2.unpack(msg, 0)
        if hdr.msgid != 16:
            # TODO: Add first numeric argument.
            raise IOError("Invalid reply message")

        r, length = sockclnt_create_reply.unpack(msg)
        self.socket_index = r.index
        for m in r.message_table:
            n = m.name
            self.message_table[n] = m.index
        self.vpp_dictionary_maxid = len(self.message_table)

        # self.worker_task = asyncio.create_task(self.message_handler(event_queue))
        requests = {}
        self.queue_task = asyncio.create_task(self.queue_worker(requests))
        self.socket_task = asyncio.create_task(
            self.socket_reader(requests, event_queue)
        )

        # Register the functions we have (memclnt.json)
        await self.get_api_definitions()

        self._register_functions()

        # Initialise control ping
        crc = self.messages["control_ping"].crc
        self.control_ping_index = self.get_msg_index(("control_ping" + "_" + crc[2:]))
        self.control_ping_msgdef = self.messages["control_ping"]

        return 0

    async def disconnect(self):
        """Detach from VPP."""

        rv = 0
        try:
            # Might fail, if VPP closes socket before packet makes it out,
            # or if there was a failure during connect().
            rv = await self.api.sockclnt_delete(index=self.socket_index)
        except IOError:
            pass
        self.connected = False
        if self.writer is not None:
            self.writer.close()
            await self.writer.wait_closed()

        await self.message_queue.put(None)  # Send sentinel to stop the event processor
        await asyncio.gather(self.queue_task)  # Wait for them to finish

        self.socket_task.cancel()

        # Wipe message table, VPP can be restarted with different plugins.
        self.message_table = {}
        # Collect garbage.
        # Queues will be collected after connect replaces them.
        return rv

    def has_context(self, msg):
        if len(msg) < 10:
            return False

        header = VPPType(
            "header_with_context",
            [["u16", "msgid"], ["u32", "client_index"], ["u32", "context"]],
        )

        (i, ci, context), size = header.unpack(msg, 0)

        if self.id_names[i] == "rx_thread_exit":
            return

        #
        # Decode message and returns a tuple.
        #
        msgobj = self.id_msgdef[i]
        if "context" in msgobj.field_by_name and context >= 0:
            return True
        return False

    def decode_incoming_msg(self, msg, no_type_conversion=False):
        if not msg:
            logger.warning("vpp_api.read failed")
            return

        (i, ci), size = self.header.unpack(msg, 0)
        if self.id_names[i] == "rx_thread_exit":
            return

        #
        # Decode message and returns a tuple.
        #
        msgobj = self.id_msgdef[i]
        if not msgobj:
            raise VPPIOError(2, "Reply message undefined")

        r, size = msgobj.unpack(msg, ntc=no_type_conversion)
        return r

    def _control_ping(self, context):
        """Send a ping command."""
        args = {
            "_vl_msg_id": self.control_ping_index,
            "client_index": self.socket_index,
            "context": context,
        }
        # args['context'] = context
        # TODO: Cache packed version.
        b = self.control_ping_msgdef.pack(args)
        self.message_queue.put_nowait(QueueObject(b, context))

    def validate_args(self, msg, kwargs):
        d = set(kwargs.keys()) - set(msg.field_by_name.keys())
        if d:
            raise VPPValueError("Invalid argument {} to {}".format(list(d), msg.name))

    def _add_stat(self, name, ms):
        if name not in self.stats:
            self.stats[name] = {"max": ms, "count": 1, "avg": ms}
        else:
            if ms > self.stats[name]["max"]:
                self.stats[name]["max"] = ms
            self.stats[name]["count"] += 1
            n = self.stats[name]["count"]
            self.stats[name]["avg"] = self.stats[name]["avg"] * (n - 1) / n + ms / n

    def get_stats(self):
        s = "\n=== API PAPI STATISTICS ===\n"
        s += "{:<30} {:>4} {:>6} {:>6}\n".format("message", "cnt", "avg", "max")
        for n in sorted(self.stats.items(), key=lambda v: v[1]["avg"], reverse=True):
            s += "{:<30} {:>4} {:>6.2f} {:>6.2f}\n".format(
                n[0], n[1]["count"], n[1]["avg"], n[1]["max"]
            )
        return s

    def get_field_options(self, msg, fld_name):
        # when there is an option, the msgdef has 3 elements.
        # ['u32', 'ring_size', {'default': 1024}]
        for _def in self.messages[msg].msgdef:
            if isinstance(_def, list) and len(_def) == 3 and _def[1] == fld_name:
                return _def[2]

    async def queue_worker(self, requests):
        """Process items from an asyncio.Queue."""
        queue = self.message_queue
        while True:
            item = await queue.get()
            if item is None:  # Stop signal
                logger.debug("Stopping queue worker...")
                return
            if item.context not in requests:
                requests[item.context] = (
                    item.future,
                    item.details_msg,
                    item.eof_stream,
                    [],
                )
            await self._write(item.b)
            queue.task_done()

    async def socket_reader(self, requests, event_queue):
        """Read data from the socket asynchronously and match requests."""
        while True:
            try:
                # Await a line of data from the socket
                item = await self._read()
                if not item:
                    logger.error("Socket closed.")
                    break

                # self.message_queue.task_done()
                msgname = type(item).__name__
                logger.debug(f"socket reader: {msgname} {item.context}")
                try:
                    req = requests[item.context]
                    if req[1]:  # stream
                        logger.debug(f"Streaming message {msgname}: {req[1]} {req[2]}")
                        if msgname == req[1]:
                            req[0].set_result((item, req[3]))
                            del requests[item.context]
                            continue
                        elif msgname == req[2] or req[2] is None:
                            req[3].append(item)
                        else:
                            raise VPPIOError(1, f"Unexpected message {msgname}")
                    else:
                        req[0].set_result(item)
                        del requests[item.context]
                except Exception as e:
                    # Add to event queue
                    logger.debug("Adding {msgname} to event queue")
                    event_queue.put_nowait(item)
            except asyncio.CancelledError:
                break

    def _call_vpp_async(self, i, msgdef, service, **kwargs):
        if "context" not in kwargs:
            context = self.get_context()
            kwargs["context"] = context
        else:
            context = kwargs["context"]
        try:
            if self.socket_index:
                kwargs["client_index"] = self.socket_index
        except AttributeError:
            kwargs["client_index"] = 0
        kwargs["_vl_msg_id"] = i

        self.validate_args(msgdef, kwargs)
        b = msgdef.pack(kwargs)
        response_future = asyncio.Future()
        stream_message = service["stream_msg"] if "stream_msg" in service else None
        try:
            if service["stream"]:
                if stream_message is None:
                    eof_stream = "control_ping_reply"
                    control_ping = True
                else:
                    eof_stream = service["reply"]
                    control_ping = False
        except KeyError:
            eof_stream = stream_message = None
            control_ping = False

        self.message_queue.put_nowait(
            QueueObject(b, context, response_future, stream_message, eof_stream)
        )
        if control_ping:
            self._control_ping(context=context)

        # await self.message_queue.put_(QueueObject(b, context, response_future))
        # return await response_future
        return response_future

    def _call_vpp_pack(self, i, msg, **kwargs):
        """Given a message, return the binary representation."""
        kwargs["_vl_msg_id"] = i
        kwargs["client_index"] = 0
        kwargs["context"] = 0
        return msg.pack(kwargs)

    async def _write(self, b):
        """Send a binary-packed message to VPP."""
        hdr = self.header_struct.pack(0, len(b), 0)
        self.writer.write(hdr)
        self.writer.write(b)
        await self.writer.drain()

    async def _read(self, timeout=5, no_type_conversion=False):
        """Read single complete message, return it or empty on error."""
        hdr = await self.reader.readexactly(16)
        if not hdr:
            return
        (_, hdrlen, _) = self.header_struct.unpack(hdr)  # If at head of message

        # Read the rest of the message
        msg = await self._read_exactly(hdrlen)
        if hdrlen == len(msg):
            return self.decode_incoming_msg(msg, no_type_conversion)
        raise IOError(1, f"Unknown socket read error, read {len(msg)} bytes")

    async def _read_exactly(self, n):
        """Read exactly n bytes from the reader."""
        data = bytearray()
        while len(data) < n:
            packet = await self.reader.readexactly(n - len(data))
            if not packet:
                raise IOError(
                    1, f"Unexpected end of stream, read {len(data)} bytes out of {n}"
                )
            data.extend(packet)
        return bytes(data)

    def validate_message_table(self, namecrctable):
        """Take a dictionary of name_crc message names
        and returns an array of missing messages"""

        missing_table = []
        for name_crc in namecrctable:
            i = self.get_msg_index(name_crc)
            if i <= 0:
                missing_table.append(name_crc)
        return missing_table

    def dump_message_table(self):
        """Return VPPs API message table as name_crc dictionary"""
        return self.message_table

    def dump_message_table_filtered(self, msglist):
        """Return VPPs API message table as name_crc dictionary,
        filtered by message name list."""

        replies = [self.services[n]["reply"] for n in msglist]
        message_table_filtered = {}
        for name in msglist + replies:
            for k, v in self.message_table.items():
                if k.startswith(name):
                    message_table_filtered[k] = v
                    break
        return message_table_filtered

    def __repr__(self):
        return (
            "<VPPApiClient apifiles=%s, testmode=%s, async_thread=%s, "
            "logger=%s, read_timeout=%s, "
            "server_address='%s'>"
            % (
                self._apifiles,
                self.testmode,
                self.async_thread,
                self.logger,
                self.read_timeout,
                self.server_address,
            )
        )

    def details_iter(self, f, **kwargs):
        cursor = 0
        while True:
            kwargs["cursor"] = cursor
            rv, details = f(**kwargs)
            for d in details:
                yield d
            if rv.retval == 0 or rv.retval != -165:
                break
            cursor = rv.cursor