#!/usr/bin/python

# Copyright (c) 2016 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__author__ = 'fangyinx.hu@intel.com'

import sys
import requests
import re
import os
import argparse
import tempfile
import shutil
import time
import paramiko
import netifaces

#
# Helper function to indent a text string
#
def indent(lines, amount, fillchar=' '):
    padding = amount * fillchar
    return padding + ('\n'+padding).join(lines.split('\n'))

#
# Main function.
# FIXME: Right now, this is really coded like a shell script, as one big
# function executed in sequence. This should be broken down into multiple
# functions.
#
def main():
    #
    # Get our default interface IP address. This will become the default
    # value for the "NFS Server IP" option.
    #
    gws = netifaces.gateways()
    addrs = netifaces.ifaddresses(gws['default'][netifaces.AF_INET][1])
    default_addr = addrs[netifaces.AF_INET][0]['addr']

    #
    # Verify CLI parameters and try to download our VPP image into a temporary
    # file first
    #
    parser = argparse.ArgumentParser()
    parser.add_argument("topology", help="the base topology to be started")
    parser.add_argument("packages", help="Path to the tldk package file that " +
                        "is/are to be installed", nargs='+')
    parser.add_argument("-c", "--copy", help="Copy the tldk packages, " +
                        "leaving the originals in place. Default is to " +
                        "move them.", action='store_true')
    parser.add_argument("-k", "--keep", help="Keep (do not delete) the " +
                        "simulation in case of error", action='store_true')
    parser.add_argument("-v", "--verbosity", action="count", default=0)
    parser.add_argument("-nip", "--nfs-server-ip", help="NFS server (our) IP " +
                        "default is derived from routing table: " +
                        "{}".format(default_addr), default=default_addr)
    parser.add_argument("-ns", "--nfs-scratch-directory",
                        help="Server location for NFS scratch diretory",
                        default="/nfs/scratch")
    parser.add_argument("-nc", "--nfs-common-directory",
                        help="Server location for NFS common (read-only) " +
                        "directory", default="/nfs/common")
    parser.add_argument("-wc", "--wait-count",
                        help="number of intervals to wait for simulation to " +
                        "be ready", type=int, default=24)
    parser.add_argument("-wt", "--wait-time",
                        help="length of a single interval to wait for " +
                        "simulation to be ready", type=int, default=5)
    parser.add_argument("-vip", "--virl-ip",
                        help="VIRL IP and Port (e.g. 127.0.0.1:19399)",
                        default="127.0.0.1:19399")
    parser.add_argument("-u", "--username", help="VIRL username",
                        default="tb4-virl")
    parser.add_argument("-p", "--password", help="VIRL password",
                        default="Cisco1234")
    parser.add_argument("-su", "--ssh-user", help="SSH username",
                        default="cisco")
    parser.add_argument("-spr", "--ssh-privkey", help="SSH private keyfile",
                        default="/home/jenkins-in/.ssh/id_rsa_virl")
    parser.add_argument("-spu", "--ssh-pubkey", help="SSH public keyfile",
                        default="/home/jenkins-in/.ssh/id_rsa_virl.pub")
    parser.add_argument("-r", "--release", help="VM disk image/release " +
                        "(ex. \"csit-ubuntu-14.04.4_2016-05-25_1.0\")",
                        default="csit-ubuntu-14.04.4_2016-05-25_1.0")
    parser.add_argument("--topology-directory", help="Topology directory",
                        default="/home/jenkins-in/testcase-infra/topologies")

    args = parser.parse_args()

    #
    # Check if topology and template exist
    #
    if args.verbosity >= 2:
        print "DEBUG: Running with topology {}".format(args.topology)

    topology_virl_filename = os.path.join(args.topology_directory,
                                          args.topology + ".virl")
    topology_yaml_filename = os.path.join(args.topology_directory,
                                          args.topology + ".yaml")

    if not os.path.isfile(topology_virl_filename):
        print "ERROR: Topology VIRL file {} does not exist".\
            format(topology_virl_filename)
        sys.exit(1)
    if not os.path.isfile(topology_yaml_filename):
        print "ERROR: Topology YAML file {} does not exist".\
            format(topology_yaml_filename)
        sys.exit(1)

    #
    # Check if TLDK package exists
    #
    for package in args.packages:
        if args.verbosity >= 2:
            print "DEBUG: Checking if file {} exists".format(package)
        if not os.path.isfile(package):
            print "ERROR: TLDK package {} does not exist.".format(package)
            sys.exit(1)

    #
    # Start VIRL topology
    #
    if args.verbosity >= 1:
        print "DEBUG: Starting VIRL topology"
    temp_handle, temp_topology = tempfile.mkstemp()
    with open(args.ssh_pubkey, 'r') as pubkey_file:
        pub_key = pubkey_file.read().replace('\n', '')
    with open(temp_topology, 'w') as new_file, \
        open(topology_virl_filename, 'r') as old_file:
        for line in old_file:
            line = line.replace("  - VIRL-USER-SSH-PUBLIC-KEY", "  - "+pub_key)
            line = line.replace("$$NFS_SERVER_SCRATCH$$", \
                args.nfs_server_ip+":"+args.nfs_scratch_directory)
            line = line.replace("$$NFS_SERVER_COMMON$$", \
                args.nfs_server_ip+":"+args.nfs_common_directory)
            line = line.replace("$$VM_IMAGE$$", "server-"+args.release)
            new_file.write(line)
    os.close(temp_handle)

    try:
        new_file = open(temp_topology, 'rb')
        headers = {'Content-Type': 'text/xml'}
        req = requests.post('http://' + args.virl_ip + '/simengine/rest/launch',
                            headers=headers,
                            auth=(args.username, args.password), data=new_file)
        if args.verbosity >= 2:
            print "DEBUG: - Response Code {}".format(req.status_code)
        new_file.close()

    except:
        print "ERROR: Launching VIRL simulation - received invalid response"
        print req
        os.remove(temp_topology)
        sys.exit(1)

    if req.status_code != 200:
        print "ERROR: Launching VIRL simulation - received status other " + \
            "than 200 HTTP OK"
        print "Status was: {} \n".format(req.status_code)
        print "Response content was: "
        print req.content
        os.remove(temp_topology)
        sys.exit(1)

    # If we got here, we had a good response. The response content is the
    # session ID.
    session_id = req.content

    #
    # Create simulation scratch directory. Move topology file into that
    # directory. Copy or move TLDK packages into that directory.
    #
    scratch_directory = os.path.join(args.nfs_scratch_directory, session_id)
    os.mkdir(scratch_directory)
    shutil.move(temp_topology, os.path.join(scratch_directory,
                                            "virl_topology.virl"))
    os.mkdir(os.path.join(scratch_directory, "tldktest"))
    for package in args.packages:
        if args.copy:
            shutil.copy(package, os.path.join(scratch_directory, "tldktest",
                                              os.path.basename(package)))
        else:
            shutil.move(package, os.path.join(scratch_directory, "tldktest",
                                              os.path.basename(package)))

    #
    # Wait for simulation to become active
    #
    if args.verbosity >= 1:
        print "DEBUG: Waiting for simulation to become active"

    sim_is_started = False
    nodelist = []

    count = args.wait_count
    while (count > 0) and not sim_is_started:
        time.sleep(args.wait_time)
        count -= 1

        req = requests.get('http://' + args.virl_ip + '/simengine/rest/nodes/' +
                           session_id, auth=(args.username, args.password))
        data = req.json()

        active = 0
        total = 0

        # Flush the node list every time, keep the last one
        nodelist = []

        # Hosts are the keys of the inner dictionary
        for key in data[session_id].keys():
            if data[session_id][key]['management-proxy'] == "self":
                continue
            nodelist.append(key)
            total += 1
            if data[session_id][key]['state'] == "ACTIVE":
                active += 1
        if args.verbosity >= 2:
            print "DEBUG: - Attempt {} out of {}, total {} hosts, {} active".\
                format(args.wait_count-count, args.wait_count, total, active)
        if active == total:
            sim_is_started = True

    if not sim_is_started:
        print "ERROR: Simulation started OK but devices never changed to " + \
            "ACTIVE state"
        print "Last VIRL response:"
        print data
        if not args.keep:
            shutil.rmtree(scratch_directory)
            req = requests.get('http://' + args.virl_ip +
                               '/simengine/rest/stop/' + session_id,
                               auth=(args.username, args.password))

    if args.verbosity >= 2:
        print "DEBUG: Nodes: " + ", ".join(nodelist)

    #
    # Fetch simulation's IPs and create files
    # (ansible hosts file, topology YAML file)
    #
    req = requests.get('http://' + args.virl_ip +
                       '/simengine/rest/interfaces/' + session_id +
                       '?fetch-state=1', auth=(args.username, args.password))
    data = req.json()

    # Populate node addresses
    nodeaddrs = {}
    topology = {}
    for key in nodelist:
        nodetype = re.split('[0-9]', key)[0]
        if not nodetype in nodeaddrs:
            nodeaddrs[nodetype] = {}
        nodeaddrs[nodetype][key] = re.split('\\/', \
            data[session_id][key]['management']['ip-address'])[0]
        if args.verbosity >= 2:
            print "DEBUG: Node {} is of type {} and has management IP {}".\
                format(key, nodetype, nodeaddrs[nodetype][key])

        topology[key] = {}
        for key2 in data[session_id][key]:
            topology[key]["nic-"+key2] = data[session_id][key][key2]
            if 'ip-address' in topology[key]["nic-"+key2]:
                if topology[key]["nic-"+key2]['ip-address'] is not None:
                    topology[key]["nic-"+key2]['ip-addr'] = re.split('\\/', \
                        topology[key]["nic-"+key2]['ip-address'])[0]

    # Write ansible file
    ansiblehosts = open(os.path.join(scratch_directory, 'ansible-hosts'), 'w')
    for key1 in nodeaddrs:
        ansiblehosts.write("[{}]\n".format(key1))
        for key2 in nodeaddrs[key1]:
            ansiblehosts.write("{} hostname={}\n".format(nodeaddrs[key1][key2],
                                                         key2))
    ansiblehosts.close()

    # Process topology YAML template
    with open(args.ssh_privkey, 'r') as privkey_file:
        priv_key = indent(privkey_file.read(), 6)

    with open(os.path.join(scratch_directory, "topology.yaml"), 'w') as \
        new_file, open(topology_yaml_filename, 'r') as old_file:
        for line in old_file:
            new_file.write(line.format(priv_key=priv_key, topology=topology))

    #
    # Wait for hosts to become reachable over SSH
    #
    if args.verbosity >= 1:
        print "DEBUG: Waiting for hosts to become reachable using SSH"

    missing = -1
    count = args.wait_count
    while (count > 0) and missing != 0:
        time.sleep(args.wait_time)
        count -= 1

        missing = 0
        for key in nodelist:
            if not os.path.exists(os.path.join(scratch_directory, key)):
                missing += 1
        if args.verbosity >= 2:
            print "DEBUG: - Attempt {} out of {}, waiting for {} hosts".\
                format(args.wait_count-count, args.wait_count, missing)

    if missing != 0:
        print "ERROR: Simulation started OK but {} hosts ".format(missing) + \
            "never mounted their NFS directory"
        if not args.keep:
            shutil.rmtree(scratch_directory)
            req = requests.get('http://' + args.virl_ip +
                               '/simengine/rest/stop/' + session_id,
                               auth=(args.username, args.password))

    #
    # just decompress the TLDK tar packages
    #
    if args.verbosity >= 1:
        print "DEBUG: Uprading TLDK"

    for key1 in nodeaddrs:
        if not key1 == 'tg':
            for key2 in nodeaddrs[key1]:
                ipaddr = nodeaddrs[key1][key2]
                if args.verbosity >= 2:
                    print "DEBUG: Upgrading TLDK on node {}".format(ipaddr)
                paramiko.util.log_to_file(os.path.join(scratch_directory,
                                                       "ssh.log"))
                client = paramiko.SSHClient()
                client.load_system_host_keys()
                client.load_host_keys("/dev/null")
                client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
                client.connect(ipaddr, username=args.ssh_user,
                               key_filename=args.ssh_privkey)
                stdin, stdout, stderr = \
                    client.exec_command('cd /scratch/tldktest/ && sudo tar zxf tldk_depends.tar.gz')
                c_stdout = stdout.read()
                c_stderr = stderr.read()
                if args.verbosity >= 2:
                    print "DEBUG: Command output was:"
                    print c_stdout
                    print "DEBUG: Command stderr was:"
                    print c_stderr

    #
    # Write a file with timestamp to scratch directory. We can use this to track
    # how long a simulation has been running.
    #
    with open(os.path.join(scratch_directory, 'start_time'), 'a') as \
        timestampfile:
        timestampfile.write('{}\n'.format(int(time.time())))

    #
    # Declare victory
    #
    if args.verbosity >= 1:
        print "SESSION ID: {}".format(session_id)

    print "{}".format(session_id)

if __name__ == "__main__":
    sys.exit(main())