#!/usr/bin/env python3

import sys
import re
import argparse
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

class Point():
    "CC event"
    def __init__(self, x, y):
        self.x = x
        self.y = y

def listx(points):
  return list(map(lambda pt: pt.x, points))

def listy(points):
  return list(map(lambda pt: pt.y, points))

def plot_data(d):

  plt.figure(1)

  cwndx = listx(d["cwnd"])
  cwndy = listy(d["cwnd"])
  congx = listx(d["congestion"])
  congy = listy(d["congestion"])
  rcvrdx = listx(d["recovered"])
  rcvrdy = listy(d["recovered"])
  rxttx = listx(d["rxtTimeout"])
  rxtty = listy(d["rxtTimeout"])

  # cwnd/ssthresh/cc events
  plt.subplot(311)
  plt.title("cwnd/ssthresh")
  pcwnd = plt.plot(cwndx, cwndy, 'r')
  psst = plt.plot(cwndx, d["ssthresh"], 'y-')
  pcong = plt.plot(congx, congy,'yo')
  precov = plt.plot(rcvrdx, rcvrdy,'co')
  prxtt = plt.plot(rxttx, rxtty,'mo')

  marker1 = Line2D(range(1), range(1), color="r")
  marker2 = Line2D(range(1), range(1), color="y")
  marker3 = Line2D(range(1), range(1), color="w", marker="o", markerfacecolor="y")
  marker4 = Line2D(range(1), range(1), color="w", marker="o", markerfacecolor="c")
  marker5 = Line2D(range(1), range(1), color="w", marker="o", markerfacecolor="m")
  plt.legend((marker1, marker2, marker3, marker4, marker5),
             ('cwnd', 'ssthresh', 'congestion', 'recovered', 'rxt-timeout'),
             loc=4)
  axes = plt.gca()
  axes.set_ylim([-20e4, max(cwndy) + 20e4])

  # snd variables
  plt.subplot(312)
  plt.title("cc variables")
  plt.plot(cwndx, d["space"], 'g-', markersize=1)
  plt.plot(cwndx, d["flight"], 'b-', markersize=1)
  plt.plot(cwndx, d["sacked"], 'm:', markersize=1)
  plt.plot(cwndx, d["lost"], 'y:', markersize=1)
  plt.plot(cwndx, d["cc-space"], 'k:', markersize=1)
  plt.plot(cwndx, cwndy, 'ro', markersize=2)

  plt.plot(congx, congy, 'y^', markersize=10, markerfacecolor="y")
  plt.plot(rcvrdx, rcvrdy, 'c^', markersize=10, markerfacecolor="c")
  plt.plot(rxttx, rxtty, 'm^', markersize=10, markerfacecolor="m")

  #plt.plot(cwndx, d["snd_wnd"], 'ko', markersize=1)
  plt.legend(("snd-space", "flight", "sacked", "lost", "cc-space", "cwnd",
              "congestion", "recovered", "rxt-timeout"),
             loc=1)

  # rto/srrt/rttvar
  plt.subplot(313)
  plt.title("rtt")
  plt.plot(cwndx, d["srtt"], 'g-')
  plt.plot(cwndx, [x/1000 for x in d["mrtt-us"]], 'r-')
  plt.plot(cwndx, d["rttvar"], 'b-')
  plt.legend(["srtt", "mrtt-us", "rttvar"])
  axes = plt.gca()
  #plt.plot(cwndx, rto, 'r-')
  #axes.set_ylim([0, int(max(rto[2:len(rto)])) + 50])

  # show
  plt.show()

def find_pattern(file_path,session_idx):
    is_active_open = 1
    listener_pattern = "l\[\d\]"
    if (is_active_open):
      initial_pattern = "\[\d\](\.\d+:\d+\->\.\d+:\d+)\s+open:\s"
    else:
      initial_pattern = "\[\d\](\.\d+:\d+\->\.\d+:\d+)\s"
    idx = 0
    f = open(file_path, 'r')
    for line in f:
      # skip listener lines (server)
      if (re.search(listener_pattern, line) != None):
        continue
      match = re.search(initial_pattern, line)
      if (match == None):
        continue
      if (idx < session_idx):
        idx += 1
        continue
      filter_pattern = str(match.group(1)) + "\s+(.+)"
      print ("pattern is %s" % filter_pattern)
      f.close()
      return filter_pattern
    raise Exception ("Could not find initial pattern")

def compute_time(min, sec, msec):
  return int(min)*60 + int(sec) + int(msec)/1000.0

def run(file_path, session_idx):
    filter_sessions = 1
    filter_pattern = ""

    patterns = {
      "time"      : "^\d+:(\d+):(\d+):(\d+):\d+",
      "listener"  : "l\[\d\]",
      "cc"        : "cwnd (\d+) flight (\d+) space (\d+) ssthresh (\d+) snd_wnd (\d+)",
      "cc-snd"    : "cc_space (\d+) sacked (\d+) lost (\d+)",
      "rtt"       : "rto (\d+) srtt (\d+) mrtt-us (\d+) rttvar (\d+)",
      "rxtt"      : "rxt-timeout",
      "congestion": "congestion",
      "recovered" : "recovered",
    }
    d = {
      "cwnd"        : [],
      "space"       : [],
      "flight"      : [],
      "ssthresh"    : [],
      "snd_wnd"     : [],
      "cc-space"    : [],
      "lost"        : [],
      "sacked"      : [],
      "rto"         : [],
      "srtt"        : [],
      "mrtt-us"     : [],
      "rttvar"      : [],
      "rxtTimeout"  : [],
      "congestion"  : [],
      "recovered"   : [],
    }

    if (filter_sessions):
        filter_pattern = find_pattern(file_path, session_idx)
    f = open(file_path, 'r')

    stats_index = 0
    start_time = 0

    for line in f:
        # skip listener lines (server)
        if (re.search(patterns["listener"], line) != None):
            continue
        # filter sessions
        if (filter_sessions):
            match = re.search(filter_pattern, line)
            if (match == None):
                continue

        original_line = line
        line = match.group(1)
        match = re.search (patterns["time"], original_line)
        if (match == None):
          print "something went wrong! no time!"
          continue
        time = compute_time (match.group(1), match.group(2), match.group(3))
        if (start_time == 0):
          start_time = time

        time = time - start_time
        match = re.search(patterns["cc"], line)
        if (match != None):
          d["cwnd"].append(Point(time, int(match.group(1))))
          d["flight"].append(int(match.group(2)))
          d["space"].append(int(match.group(3)))
          d["ssthresh"].append(int(match.group(4)))
          d["snd_wnd"].append(int(match.group(5)))
          stats_index += 1
          continue
        match = re.search(patterns["cc-snd"], line)
        if (match != None):
          d["cc-space"].append(int(match.group(1)))
          d["sacked"].append(int(match.group(2)))
          d["lost"].append(int(match.group(3)))
        match = re.search(patterns["rtt"], line)
        if (match != None):
           d["rto"].append(int(match.group(1)))
           d["srtt"].append(int(match.group(2)))
           d["mrtt-us"].append(int(match.group(3)))
           d["rttvar"].append(int(match.group(4)))
        if (stats_index == 0):
           continue
        match = re.search(patterns["rxtt"], line)
        if (match != None):
          d["rxtTimeout"].append(Point(time, d["cwnd"][stats_index - 1].y + 1e4))
          continue
        match = re.search(patterns["congestion"], line)
        if (match != None):
          d["congestion"].append(Point(time, d["cwnd"][stats_index - 1].y - 1e4))
          continue
        match = re.search(patterns["recovered"], line)
        if (match != None):
          d["recovered"].append(Point(time, d["cwnd"][stats_index - 1].y))
          continue

    plot_data(d)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot tcp cc logs")
    parser.add_argument('-f', action='store', dest='file', required=True,
                        help="elog file in txt format")
    parser.add_argument('-s', action='store', dest='session_index', default=0,
                        help="session index for which to plot cc logs" )
    results = parser.parse_args()
    run(results.file, int(results.session_index))