#!/usr/bin/env python
#
# Usage:
#  python plot_scaling_results.py input-file1-ext input-file2-ext ...
#
# Description:
# Plots speed up, parallel efficiency and time to solution given a "timesteps" output file generated by SWIFT.
# 
# Example:
# python plot_scaling_results.py _hreads_cosma_stdout.txt _threads_knl_stdout.txt
# 
# The working directory should contain files 1_threads_cosma_stdout.txt - 64_threads_cosma_stdout.txt and 1_threads_knl_stdout.txt - 64_threads_knl_stdout.txt, i.e wall clock time for each run using a given number of threads

import sys
import glob
import re
import numpy as np
import matplotlib.pyplot as plt

version = []
branch = []
revision = []
hydro_scheme = []
hydro_kernel = []
hydro_neighbours = []
hydro_eta = []
threadList = []
linestyle = ('ro-','bo-','go-','yo-','mo-')
#cmdLine = './swift_fixdt -s -t 16 cosmoVolume.yml'
#platform = 'KNL'

# Work out how many data series there are
if len(sys.argv) == 2:
  inputFileNames = (sys.argv[1],"")
  numOfSeries = 1
elif len(sys.argv) == 3:
  inputFileNames = (sys.argv[1],sys.argv[2])
  numOfSeries = 2
elif len(sys.argv) == 4:
  inputFileNames = (sys.argv[1],sys.argv[2],sys.argv[3])
  numOfSeries = 3
elif len(sys.argv) == 5:
  inputFileNames = (sys.argv[1],sys.argv[2],sys.argv[3],sys.argv[4])
  numOfSeries = 4
elif len(sys.argv) == 6:
  inputFileNames = (sys.argv[1],sys.argv[2],sys.argv[3],sys.argv[4],sys.argv[5])
  numOfSeries = 5

# Get the names of the branch, Git revision, hydro scheme and hydro kernel
def parse_header(inputFile):
  with open(inputFile, 'r') as f:
    found_end = False
    for line in f:
      if 'Branch:' in line:
        s = line.split()
        branch.append(s[2])
      elif 'Revision:' in line:
        s = line.split() 
        revision.append(s[2])
      elif 'Hydrodynamic scheme:' in line:
        line = line[2:-1]
        s = line.split()
        line = s[2:]
        hydro_scheme.append(" ".join(line))
      elif 'Hydrodynamic kernel:' in line:
        line = line[2:-1]
        s = line.split()
        line = s[2:5]
        hydro_kernel.append(" ".join(line))
      elif 'neighbours:' in line:
        s = line.split() 
        hydro_neighbours.append(s[4])
      elif 'Eta:' in line:
        s = line.split() 
        hydro_eta.append(s[2])
  return

# Parse file and return total time taken, speed up and parallel efficiency
def parse_files():
  
  times = []
  totalTime = []
  serialTime = []
  speedUp = []
  parallelEff = []

  for i in range(0,numOfSeries): # Loop over each data series
 
    # Get each file that starts with the cmd line arg
    file_list = glob.glob(inputFileNames[i] + "*")
    
    threadList.append([])

    # Create a list of threads using the list of files
    for fileName in file_list:
      s = re.split(r'[_.]+',fileName)
      threadList[i].append(int(s[1]))

    # Sort the thread list in ascending order and save the indices
    sorted_indices = np.argsort(threadList[i])
    threadList[i].sort()

    # Sort the file list in ascending order acording to the thread number
    file_list = [ file_list[j] for j in sorted_indices]

    parse_header(file_list[0])
    
    version.append(branch[i] + " " + revision[i] + "\n" + hydro_scheme[i] + 
                   "\n" + hydro_kernel[i] + r", $N_{neigh}$=" + hydro_neighbours[i] + 
                   r", $\eta$=" + hydro_eta[i] + "\n")                  
    times.append([])
    totalTime.append([])
    speedUp.append([])
    parallelEff.append([])

    # Loop over all files for a given series and load the times
    for j in range(0,len(file_list)):
      times[i].append([])
      times[i][j].append(np.loadtxt(file_list[j],usecols=(5,)))
      totalTime[i].append(np.sum(times[i][j]))

    serialTime.append(totalTime[i][0])
    
    # Loop over all files for a given series and calculate speed up and parallel efficiency
    for j in range(0,len(file_list)):
      speedUp[i].append(serialTime[i] / totalTime[i][j])
      parallelEff[i].append(speedUp[i][j] / threadList[i][j])

  return (times,totalTime,speedUp,parallelEff)

def print_results(times,totalTime,parallelEff,version):
 
  for i in range(0,numOfSeries):
    print " "
    print "------------------------------------"
    print version[i]
    print "------------------------------------"
    print "Wall clock time for: {} time steps".format(len(times[0][0][0]))
    print "------------------------------------"
    
    for j in range(0,len(threadList[i])):
      print str(threadList[i][j]) + " threads: {}".format(totalTime[i][j])
    
    print " "
    print "------------------------------------"
    print "Parallel Efficiency for: {} time steps".format(len(times[0][0][0]))
    print "------------------------------------"
    
    for j in range(0,len(threadList[i])):
      print str(threadList[i][j]) + " threads: {}".format(parallelEff[i][j])

  return

def plot_results(times,totalTime,speedUp,parallelEff):
  
  fig, axarr = plt.subplots(2, 2,figsize=(15,15))
  speedUpPlot = axarr[0, 0]
  parallelEffPlot = axarr[0, 1]
  totalTimePlot = axarr[1, 0]
  emptyPlot = axarr[1, 1]
  
  # Plot speed up
  for i in range(0,numOfSeries):
    speedUpPlot.plot(threadList[i],speedUp[i],linestyle[i],label=version[i])
  
  speedUpPlot.plot(threadList[i],threadList[i],color='k',linestyle='--')
  speedUpPlot.set_ylabel("Speed Up")
  speedUpPlot.set_xlabel("No. of Threads")

  # Plot parallel efficiency
  for i in range(0,numOfSeries):
    parallelEffPlot.plot(threadList[i],parallelEff[i],linestyle[i])
  
  parallelEffPlot.set_xscale('log')
  parallelEffPlot.set_ylabel("Parallel Efficiency")
  parallelEffPlot.set_xlabel("No. of Threads")
  parallelEffPlot.set_ylim([0,1.1])

  # Plot time to solution     
  for i in range(0,numOfSeries):
    totalTimePlot.loglog(threadList[i],totalTime[i],linestyle[i],label=version[i])
  
  totalTimePlot.set_xscale('log')
  totalTimePlot.set_xlabel("No. of Threads")
  totalTimePlot.set_ylabel("Time to Solution (ms)")
  
  totalTimePlot.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,prop={'size':14})
  emptyPlot.axis('off')
  
  for i, txt in enumerate(threadList[0]):
    speedUpPlot.annotate(txt, (threadList[0][i],speedUp[0][i]))
    parallelEffPlot.annotate(txt, (threadList[0][i],parallelEff[0][i]))
    totalTimePlot.annotate(txt, (threadList[0][i],totalTime[0][i]))

  #fig.suptitle("Thread Speed Up, Parallel Efficiency and Time To Solution for {} Time Steps of Cosmo Volume\n Cmd Line: {}, Platform: {}".format(len(times[0][0][0]),cmdLine,platform))
  fig.suptitle("Thread Speed Up, Parallel Efficiency and Time To Solution for {} Time Steps".format(len(times[0][0][0])))

  return

# Calculate results
(times,totalTime,speedUp,parallelEff) = parse_files()

plot_results(times,totalTime,speedUp,parallelEff)

print_results(times,totalTime,parallelEff,version)

plt.show()
