#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function, unicode_literals
import adios2
import argparse
from mpi4py import MPI
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import decomp

def SetupArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument("--instream", "-i", help="Name of the input stream", required=True)
    parser.add_argument("--varname1", "-v1", help="Name of variable read", default="velocity_x")
    parser.add_argument("--varname2", "-v2", help="Name of variable read", default="velocity_z")
    args = parser.parse_args()

    args.displaysec = float(0.1)
    args.nx = int(1)
    args.ny = int(1)
    args.nz = int(1)

    return args


def Plot2D(dataName1, dataName2,data1, data2, args, fullshape, step, fontsize):
    # Plotting part
    displaysec = args.displaysec
    gs = gridspec.GridSpec(1, 1)
    fig = plt.figure(1, figsize=(8,8))
    ax = fig.add_subplot(gs[0, 0])
    data = np.matmul(data1, data2.transpose())
    colorax = ax.imshow(data, origin='lower',extent=[0, fullshape[1], 0, fullshape[0]], cmap=plt.get_cmap('gist_ncar'))
    cbar = fig.colorbar(colorax, orientation='horizontal')
    cbar.ax.tick_params(labelsize=fontsize-4)

    for i in range(args.ny):
        y = fullshape[0] / args.ny * i
        ax.plot([0, fullshape[1]], [y, y], color='black')

    for i in range(args.nx):
        x = fullshape[1] / args.nx * i
        ax.plot([x, x], [0, fullshape[0]], color='black')

    ax.set_title("dummy simulation - step {0}".format(step), fontsize=fontsize)
    ax.set_xlabel(dataName1, fontsize=fontsize)
    ax.set_ylabel(dataName2, fontsize=fontsize)
    plt.tick_params(labelsize = fontsize-4)
    plt.ion()
    plt.show()
    plt.pause(displaysec)

    plt.clf()


def read_data(varname, fr, start_coord, size_dims):

    data= fr.read(varname, start_coord, size_dims )
    data = np.squeeze(data)
    return data


if __name__ == "__main__":
    # fontsize on plot
    fontsize = 14

    args = SetupArgs()
#    print(args)

    # Setup up 2D communicators if MPI is installed
    mpi = decomp.MPISetup(args, 3)
    myrank = mpi.rank['app']

    # Read the data from this object
    fr = adios2.open(args.instream, "r", mpi.comm_app,"../adios2.xml", "Writer")

    # Read through the steps, one at a time
    plot_step = 0
    for fr_step in fr:
        start, size, fullshape = mpi.Partition_3D_3D(fr, args)
        cur_step= fr_step.current_step()
        vars_info = fr.available_variables()
#        print (vars_info)
        shape3_str = vars_info[args.varname1]["Shape"].split(',')
        shape3 = list(map(int,shape3_str))
        sim_step = fr_step.read("step")

        if myrank == 0:
            print("GS Plot step {0} processing simulation output step {1} or computation step {2}".format(plot_step,cur_step, sim_step), flush=True)

        data = read_data (args.varname1, fr_step, [0,0,int(shape3[2]/2)], [shape3[0],shape3[1],1])
        data2 = read_data (args.varname2, fr_step, [0,0,int(shape3[2]/2)], [shape3[0],shape3[1],1])
        Plot2D (args.varname1, args.varname2, data, data2, args, fullshape, sim_step, fontsize)
        plot_step = plot_step + 1

    fr.close()
