import xarray as xr
import numpy as np
import tempfile
import os
from vtkmodules.vtkIONetCDF import (
    vtkNetCDFCFReader,
)
from vtkmodules.vtkCommonExecutionModel import (
    vtkStreamingDemandDrivenPipeline
)

@xr.register_dataset_accessor("vtk")
class VtkAccessor:
    def __init__(self, xarray_obj):
        self._obj = xarray_obj

    def dataset(self, spherical_coordinates=True,
                vertical_scale=1.0, vertical_bias=0.0, time_index=None,
                active_scalars=None, encoding=None):
        """Creates a new VTK dataset from the current xarray and returns it"""
        with tempfile.TemporaryDirectory() as tmp:
            file_name = os.path.join(tmp, 'xarray_file.nc')
            self._obj.to_netcdf(file_name, encoding=encoding)
            reader = vtkNetCDFCFReader(
                file_name=file_name,
                spherical_coordinates=spherical_coordinates,
                vertical_scale=vertical_scale,
                vertical_bias=vertical_bias)
            # check the time information
            if time_index:
                reader.UpdateInformation()
                information = reader.GetOutputInformation(0)
                if information.Has(vtkStreamingDemandDrivenPipeline.TIME_STEPS()):
                    times = information.Get(
                        vtkStreamingDemandDrivenPipeline.TIME_STEPS())
                    if time_index >=0 and time_index < len(times):
                        information.Set(
                            vtkStreamingDemandDrivenPipeline.UPDATE_TIME_STEP(),
                            times[time_index])
                        reader.PropagateUpdateExtent()
                    else:
                        raise ValueError("Invalid time index: ", time_index)
            data = reader.update().output
            if active_scalars:
                if spherical_coordinates:
                    data.GetCellData().SetActiveScalars(active_scalars)
                else:
                    data.GetPointData().SetActiveScalars(active_scalars)
            return data
