from __main__ import qt, ctk, slicer
import workflowFunctions as workflowFunctions

class GuidedInterventionPETCTRegistrationAltStep( ctk.ctkWorkflowWidgetStep ) :

  def __init__( self, stepid ):
    self.initialize( stepid )
    self.setName( '2. Registration' )
    self.setDescription('Register Motion compensated CT-CBCT (Manual)')
    self.CLINode = None
    self.Observations = []
    qt.QTimer.singleShot(0, self.killButton)


  def killButton(self):
    # hide useless button
    bl = slicer.util.findChildren(text='Tracker')
    if len(bl):
      bl[0].hide()


  def createUserInterface( self ):
    layout =  qt.QVBoxLayout(self)

    #Get widget from ui file and set to layout
    self.widget = workflowFunctions.loadUI('GuidedInterventionPETCTRegistrationAlt.ui');
    self.CLIProgressBar = slicer.qSlicerCLIProgressBar()
    layout.addWidget(self.widget)
    layout.addWidget(self.CLIProgressBar)
    workflowFunctions.setScene(self.widget)
    workflowFunctions.get(self.widget,"RegistrationButton").connect('clicked()',self.startRegistration)
    workflowFunctions.get(self.widget,"SpatialSamplesSliderWidget").setDecimals(4)


  def validate( self, desiredBranchId ):

    if self.CLINode is None:
      messageBox = qt.QMessageBox.warning( self, 'Error', 'Please start registration process' )
      validation = False
    else:
      CLIstatus = self.CLINode.GetStatusString()
      validation = True
      if CLIstatus != "Completed":
        messageBox = qt.QMessageBox.warning( self, 'Error', 'Please wait end of registration' )
        validation = False
    super( GuidedInterventionPETCTRegistrationAltStep, self ).validate(validation,desiredBranchId)
    self.setCBCTVolumeAsBackground()


  def onEntry(self, comingFrom, transitionType):
    comingFromId = "None"

    self.PETVolume = slicer.mrmlScene.GetFirstNodeByName("PETVolume")
    self.CTVolume = slicer.mrmlScene.GetFirstNodeByName("CTVolume")
    self.CBCTVolume = slicer.mrmlScene.GetFirstNodeByName("CBCTVolume")

    if comingFrom: comingFromId = comingFrom.id()
    super(GuidedInterventionPETCTRegistrationAltStep, self).onEntry(comingFrom, transitionType)
    qt.QTimer.singleShot(0, self.killButton)


  def onExit(self, goingTo, transitionType):
    goingToId = "None"
    if goingTo: goingToId = goingTo.id()
    # execute the transition
    super(GuidedInterventionPETCTRegistrationAltStep, self).onExit(goingTo, transitionType)


  def startRegistration(self):
    '''
      Here we are going to put the call the CLI module doing the registration
      For the time being we just copy one of the input volumes as the result output
    '''

    # Run the BSpline deformation registration
    BSplineParameters = {}
    BSplineParameters["initializeTransformMode"] = "useMomentsAlign"
    BSplineParameters["numberOfIterations"] = workflowFunctions.get(self.widget,"IterationSpinBox").value
    BSplineParameters["samplingPercentage"] = workflowFunctions.get(self.widget,"SpatialSamplesSliderWidget").value / 100.0
    BSplineParameters["backgroundFillValue"] = workflowFunctions.get(self.widget,"DefaultPixelValueSpinBox").value
    BSplineParameters["fixedVolume"] = self.CBCTVolume
    BSplineParameters["movingVolume"] = self.CTVolume
    # BSplineParameters["OutputWarp"] =  slicer.mrmlScene.CreateNodeByClass("vtkMRMLGridTransformNode")
    BSplineParameters["outputVolume"] = slicer.mrmlScene.CreateNodeByClass("vtkMRMLScalarVolumeNode")

    if workflowFunctions.get(self.widget,"BSplineGroupBox").checked:
      BSplineParameters["transformType"] = "Rigid,BSpline"
      BSplineParameters["bsplineTransform"] = slicer.mrmlScene.CreateNodeByClass("vtkMRMLBSplineTransformNode")
      BSplineParameters["bsplineTransform"].SetName("imageRegistrationTransform")
      slicer.mrmlScene.AddNode(BSplineParameters["bsplineTransform"])

      scalarGS = int(workflowFunctions.get(self.widget,"GridSizeSliderWidget").value)
      BSplineParameters["splineGridSize"] = str(scalarGS) + "," + str(scalarGS) + "," + str(scalarGS)
      BSplineParameters["numberOfHistogramBins"] = workflowFunctions.get(self.widget,"HistogramBinsSliderWidget").value
      if workflowFunctions.get(self.widget,"ConstrainDeformationCheckBox").checked:
        BSplineParameters["maxBSplineDisplacement"] = workflowFunctions.get(self.widget,"MaximumDeformationSpinBox").value

    else:
      BSplineParameters["transformType"] = "Rigid"
      BSplineParameters["linearTransform"] = slicer.mrmlScene.CreateNodeByClass("vtkMRMLLinearTransformNode")
      BSplineParameters["linearTransform"].SetName("imageRegistrationTransform")
      slicer.mrmlScene.AddNode(BSplineParameters["linearTransform"])
    
    # BSplineParameters["OutputWarp"].SetName("outputBSplineRegistrationWarp")
    BSplineParameters["outputVolume"].SetName("RegistrationResult")
    # slicer.mrmlScene.AddNode(BSplineParameters["OutputWarp"])
    slicer.mrmlScene.AddNode(BSplineParameters["outputVolume"])

    self.CLINode = slicer.cli.createNode(slicer.modules.brainsfit)
    self.CLINode.SetName("DeformationRegistrationNode")
    self.CLIProgressBar.setCommandLineModuleNode(self.CLINode)
    # on success, apply the same transform to the PET volume
    workflowFunctions.addObserver(self, self.CLINode, self.CLINode.StatusModifiedEvent, self.onCLINodeModified)
    slicer.cli.run(slicer.modules.brainsfit, self.CLINode, BSplineParameters, wait_for_completion = False)

  def onCLINodeModified(self, cliNode, event):
    if cliNode.GetStatusString() == 'Completed':
      self.PETVolume.SetAndObserveTransformNodeID(self.CTVolume.GetTransformNodeID());

  def setCBCTVolumeAsBackground(self):
    if self.CBCTVolume != None:
      workflowFunctions.setActiveVolume(self.CBCTVolume)
