import numpy as np


def Locate(rank, nproc, datasize):
    extra = 0
    if (rank == nproc - 1):
        extra = datasize % nproc
    num = datasize // nproc
    start = num * rank
    size = num + extra
    return start, size


class MPISetup(object):

    readargs = []
    size = 1
    rank = {'app': 0,
            'x': 0,
            'y': 0}

    def __init__(self, args, appID):

        self.nx = args.nx
        self.ny = args.ny
        self.nz = args.nz

        from mpi4py import MPI

        self.comm_app = MPI.COMM_WORLD.Split(appID, MPI.COMM_WORLD.Get_rank()) 
        self.size = self.comm_app.Get_size()
        self.rank['app'] = self.comm_app.Get_rank()
        if (self.nx * self.ny * self.nz == 1):
            self.nx = self.size
        if self.size != (self.nx * self.ny * self.nz):
            raise ValueError("nx * ny * nz != num processes")

        if (self.ny > 1) and (self.nx > 1) and (self.nz > 1):
            comm_x = self.comm_app.Split(self.rank['app'] % self.nx, self.rank['app'])
        else:
            comm_x = self.comm_app.Split(self.rank['app'] / self.nx, self.rank['app'])
        comm_y = self.comm_app.Split(self.rank['app']/self.ny, self.rank['app'])
        comm_z = self.comm_app.Split(self.rank['app']/self.nz, self.rank['app'])
        

        self.rank['x'] = comm_x.Get_rank()
        self.rank['y'] = comm_y.Get_rank()
        self.rank['z'] = comm_z.Get_rank()

        self.readargs.append(self.comm_app)


    def Partition_3D_3D(self, fp, args):
        datashape = np.zeros(3, dtype=np.int64)
        start = np.zeros(3, dtype=np.int64)
        size = np.zeros(3, dtype=np.int64)

        var = fp.available_variables()
        data = var[str(args.varname1)]
        dshape = var[args.varname1]['Shape'].split(',')
        for i in range(len(dshape)):
            datashape[i] = int(dshape[i])

        start[0], size[0] = Locate(self.rank['y'], self.ny, datashape[0])
        start[1], size[1] = Locate(self.rank['x'], self.nx, datashape[1])
        start[0], size[0] = Locate(self.rank['y'], self.ny, datashape[0])

        return start, size, datashape
