# -----------------------------------------------------------------------------
# export PYTHONPATH=/home/seb/Documents/code/Async/build/lib/python3.10/site-packages
# export LD_LIBRARY_PATH=/home/seb/Documents/code/Async/build/lib
#
# python3.10 -m venv pv-venv
# source ./pv-venv/bin/activate
# pip install -U pip
# pip install trame
#
# python ./paraview/Remoting/Microservices/Testing/Trame/TestSCDemo.py
# or using pvpython:
# pvpython ./paraview/Remoting/Microservices/Testing/Trame/TestSCDemo.py --venv ./pv-venv
# -----------------------------------------------------------------------------
from parat import venv
import asyncio
import base64
import json
from parat.services import (
    ParaT,
    PipelineBuilder,
    DefinitionManager,
    ActiveObjects,
    PropertyManager,
    PipelineViewer,
    ProgressObserver,
)
from parat.trame.simput import ParaViewSimput
from paraview.modules.vtkRemotingServerManager import vtkSMProxySelectionModel
from paraview.modules.vtkRemotingPythonAsyncCore import (
    vtkPythonObservableWrapperUtilities,
)

# -----------------------------------------------------------------------------

from trame.app import get_server, asynchronous
from trame.ui.vuetify import SinglePageWithDrawerLayout
from trame.widgets import vuetify, simput, trame, html

# -----------------------------------------------------------------------------


def apply_exec(**kwargs):
    print("Apply")


class App:
    def __init__(self, server=None):
        if server is None:
            server = get_server()

        self.server = server
        self.state = server.state
        self.ctrl = server.controller

        # ParaView Async
        self._app = ParaT()
        self._running = True
        self._pv_simput = ParaViewSimput(server, name="async")

        # internal state
        self.active_proxy = None
        self.active_representation = None
        self.active_view = None

        # state
        self.state.simput_active_source = 0
        self.state.spin_wait = 1 / 30  # target 30 fps

        # controller
        self.ctrl.on_server_ready.add_task(self.initialize)
        self.ctrl.on_server_exited.add_task(self.finalize)
        self.ctrl.apply = self._pv_simput.apply
        self.ctrl.reset = self._pv_simput.reset
        self.ctrl.on_apply = apply_exec  # FIXME bug on trame-server

    async def initialize(self, **kwargs):
        self._session = await self._app.initialize()
        self._pipeline = PipelineViewer(self._session)
        self._builder = PipelineBuilder(self._session)
        self._def_mgt = DefinitionManager(self._session)
        self._active = ActiveObjects(self._session, "ActiveSources")
        self._prop_mgr = PropertyManager()
        self._progress = ProgressObserver(self._session)

        # Bind definition service to simput helper
        self._pv_simput.set_definition_manager(self._def_mgt)

        asynchronous.create_task(self.monitor_progress_ds())

        # default demo pipeline
        await self.setup_demo()

        # Tasks to monitor state change
        asynchronous.create_task(self.on_active_change())
        asynchronous.create_task(self.on_pipeline_change())
        asynchronous.create_task(self.monitor_server_status())
        asynchronous.create_task(self.monitor_view_stream())

        # State listener
        self.state.change("view_size")(self.on_view_size_change)

    async def finalize(self, **kwargs):
        await self._app.finalize()

    async def monitor_server_status(self):
        while self._running:
            with self.state as state:
                await asyncio.sleep(state.spin_wait)

                # Spinning
                if state.spinning and self.active_view:
                    self.active_view.GetCamera().Azimuth(1)
                    self.active_view.StillRender()

                # Update client
                state.status_server += 5
                if state.status_server > 360:
                    state.status_server = 0

    async def on_active_change(self):
        async for proxy in self._active.GetCurrentObservable():
            self.active_proxy = proxy
            active_ids = []
            active_id = "0"

            if proxy:
                active_ids = [str(proxy.GetGlobalID())]
                simput_proxy = self._pv_simput.to_sinput(proxy)
                if simput_proxy is None:
                    simput_proxy = await self._pv_simput.create(proxy)
                active_id = simput_proxy.id
                await self._prop_mgr.UpdatePipeline(proxy)
                dataInformation = proxy.GetDataInformation(0)

            with self.state as state:
                state.git_tree_actives = active_ids
                self.state.simput_active_source = active_id

    async def on_pipeline_change(self):
        async for pipelineState in self._pipeline.GetObservable():
            list_to_fill = []
            for item in pipelineState:
                node = {
                    "name": item.GetName(),
                    "parent": str(
                        item.GetParentIDs()[0] if len(item.GetParentIDs()) > 0 else 0
                    ),
                    "id": str(item.GetID()),
                    "visibile": 1,
                }
                list_to_fill.append(node)

            with self.state as state:
                state.git_tree_sources = list_to_fill

    def ui_active_change(self, active):
        proxy = self._session.GetProxyManager().FindProxy(int(active[0]))
        self._active.SetCurrentProxy(proxy, vtkSMProxySelectionModel.CLEAR)

    async def create_proxy(self):
        with self.state as state:
            xml_group = state.xml_group
            xml_name = state.xml_name
            proxy = None
            if xml_group == "sources":
                proxy = await self._builder.CreateProxy(xml_group, xml_name)
            elif xml_group == "filters":
                input = self.active_proxy
                proxy = await self._builder.CreateProxy(
                    xml_group, xml_name, Input=input
                )
            elif xml_group == "representations":
                input = self.active_proxy
                view = self.active_view
                proxy = await self._builder.CreateRepresentation(
                    input, 0, view, xml_name
                )
                self.active_representation = proxy
            elif xml_group == "views":
                input = self.active_proxy
                proxy = await self._builder.CreateProxy(xml_group, xml_name)
                self.active_view = proxy
            else:
                print(f"Not sure what to create with {xml_group}::{xml_name}")

            # No proxy just skip work...
            if proxy is None:
                print("!!! No proxy created !!!")
                return

            # Load proxy definition
            simput_proxy = await self._pv_simput.create(proxy)

            if proxy == self.active_representation:
                state.simput_active_representation = simput_proxy.id

    async def setup_demo(self):
        view = await self._builder.CreateProxy("views", "RenderView")
        wavelet = await self._builder.CreateProxy("sources", "RTAnalyticSource")
        scalarRange = [37.35310363769531, 276.8288269042969]
        delta = (scalarRange[1] - scalarRange[0]) / 9
        values = [scalarRange[0] + (delta * float(i)) for i in range(10)]
        contour = await self._builder.CreateProxy(
            "filters",
            "Contour",
            Input=wavelet,
            ContourValues=values,
            SelectInputScalars=["", "", "", "", "RTData"],
        )

        self._prop_mgr.Push(contour)
        await self._prop_mgr.UpdatePipeline(wavelet)
        representation = await self._builder.CreateRepresentation(
            contour, 0, view, "GeometryRepresentation"
        )
        await self._prop_mgr.Update(view)
        view.ResetCameraUsingVisiblePropBounds()

        # View encoding setup
        self._prop_mgr.SetValues(
            view,
            force_push=True,
            CodecType=4,  #  Lossless=-1, VP9=0, AV1=1, H264=2, H265=3, JPEG=4
            UseHardwareAcceleration=False,
            Display=True,
            Quality=80,
            StreamOutput=True,
        )

        # Keep track of view + rep
        self.active_view = view
        self.active_representation = representation

    def on_view_size_change(self, view_size, **kwargs):
        size = view_size.get("size")
        self._prop_mgr.SetValues(
            self.active_view,
            ViewSize=(size.get("width"), size.get("height")),
            force_push=True,
        )
        #print(self._prop_mgr.GetValues(self.active_view))
        self.active_view.StillRender()

    async def monitor_view_stream(self):
        async for package in vtkPythonObservableWrapperUtilities.GetIterator(
            self.active_view.GetViewOutputObservable()
        ):
            if package:
                chunk = base64.b64encode(package.GetData()).decode("ascii")
                mime = package.GetMimeType()
                # The package seems to have invalid height
                with self.state as state:
                    state.img_url = f"data:{mime};base64,{chunk}"

    async def monitor_progress_ds(self):
        async for message in self._progress.GetServiceProgressObservable("ds"):
            data = json.loads(message)
            print(data)
            with self.state as state:
                progress = data["ds"]["Progress"] / 100.0
                # state.status_ds = int(progress * 360)


# -----------------------------------------------------------------------------
# Trame App
# -----------------------------------------------------------------------------

server = get_server()
app = App(server)

# -----------------------------------------------------------------------------
# GUI
# -----------------------------------------------------------------------------

with SinglePageWithDrawerLayout(server) as layout:
    layout.root = app._pv_simput.root_widget

    layout.title.set_text("Parat")
    with layout.toolbar as toolbar:
        toolbar.dense = True
        vuetify.VSpacer()
        vuetify.VProgressCircular(
            "D",
            color="amber",
            size=35,
            width=5,
            rotate=("status_ds", 0),
            value=("20",),
            classes="mx-2",
        )
        vuetify.VProgressCircular(
            "R",
            color="purple",
            size=35,
            width=5,
            rotate=("status_rs", 0),
            value=("20",),
            classes="mx-2",
        )
        vuetify.VProgressCircular(
            "S",
            color="red",
            size=35,
            width=5,
            rotate=("status_server", 0),
            value=("20",),
            classes="mx-2",
        )
        vuetify.VProgressCircular(
            "C",
            color="teal",
            size=35,
            width=5,
            indeterminate=True,
            classes="mx-2",
        )
        vuetify.VDivider(vertical=True, classes="mx-2")
        vuetify.VCheckbox(
            small=True,
            v_model=("spinning", False),
            dense=True,
            classes="mx-2",
            hide_details=True,
            on_icon="mdi-axis-z-rotate-counterclockwise",
            off_icon="mdi-axis-z-rotate-counterclockwise",
        )
        with vuetify.VBtn(
            icon=True, small=True, click=app.ctrl.view_reset_camera, classes="mx-2"
        ):
            vuetify.VIcon("mdi-crop-free")

    with layout.drawer as drawer:
        drawer.width = 300
        trame.GitTree(
            sources=("git_tree_sources", []),
            actives=("git_tree_actives", []),
            actives_change=(app.ui_active_change, "[$event]"),
        )
        simput.SimputItem(item_id=("simput_active_source", None))

    with layout.content:
        with vuetify.VContainer(fluid=True, classes="pa-0 fill-height"):
            with trame.SizeObserver("view_size"):
                html.Img(src=("img_url", ""))

# -----------------------------------------------------------------------------
# CLI
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    server.start()
