from __main__ import qt, ctk, slicer
import workflowFunctions as workflowFunctions

class GuidedInterventionPETCTRegistrationStep( ctk.ctkWorkflowWidgetStep ) :

  def __init__( self, stepid ):
    self.initialize( stepid )
    self.setName( '2. Registration' )
    self.setDescription('Register Motion compensated CT-CBCT')
    self.CLINode = None
    self.Observations = []
    qt.QTimer.singleShot(0, self.killButton)


  def killButton(self):
    # hide useless button
    bl = slicer.util.findChildren(text='Tracker')
    if len(bl):
      bl[0].hide()


  def createUserInterface( self ):
    layout =  qt.QVBoxLayout(self)

    #Get widget from ui file and set to layout
    self.widget = workflowFunctions.loadUI('GuidedInterventionPETCTRegistration.ui');
    self.CLIProgressBar = slicer.qSlicerCLIProgressBar()
    layout.addWidget(self.widget)
    layout.addWidget(self.CLIProgressBar)
    workflowFunctions.setScene(self.widget)
    workflowFunctions.get(self.widget,"RegistrationButton").connect('clicked()',self.rigidRegistration)
    workflowFunctions.get(self.widget,"SpatialSamplesSliderWidget").setDecimals(4)


  def validate( self, desiredBranchId ):

    if self.CLINode is None:
      messageBox = qt.QMessageBox.warning( self, 'Error', 'Please start registration process' )
      validation = False
    else:
      CLIstatus = self.CLINode.GetStatusString()
      validation = True
      if CLIstatus != "Completed":
        messageBox = qt.QMessageBox.warning( self, 'Error', 'Please wait end of registration' )
        validation = False
    super( GuidedInterventionPETCTRegistrationStep, self ).validate(validation,desiredBranchId)
    self.setCBCTVolumeAsBackground()


  def onEntry(self, comingFrom, transitionType):
    comingFromId = "None"

    self.PETVolume = slicer.mrmlScene.GetFirstNodeByName("PETVolume")
    self.CTVolume = slicer.mrmlScene.GetFirstNodeByName("CTVolume")
    self.CBCTVolume = slicer.mrmlScene.GetFirstNodeByName("CBCTVolume")

    if comingFrom: comingFromId = comingFrom.id()
    super(GuidedInterventionPETCTRegistrationStep, self).onEntry(comingFrom, transitionType)
    qt.QTimer.singleShot(0, self.killButton)


  def onExit(self, goingTo, transitionType):
    goingToId = "None"
    if goingTo: goingToId = goingTo.id()
    # execute the transition
    super(GuidedInterventionPETCTRegistrationStep, self).onExit(goingTo, transitionType)


  def rigidRegistration(self):
    tempCT=slicer.mrmlScene.CreateNodeByClass("vtkMRMLScalarVolumeNode")
    tempCT.SetName("tempCT")
    slicer.mrmlScene.AddNode(tempCT)
    tempCBCT=slicer.mrmlScene.CreateNodeByClass("vtkMRMLScalarVolumeNode")
    tempCBCT.SetName("tempCBCT")
    slicer.mrmlScene.AddNode(tempCBCT)
  
    rsv=slicer.modules.resamplescalarvolume
    rsvPar={}
    rsvPar["interpolation"]="lanczos"
    rsvPar["outputPixelSpacing"]="3,3,4"
    rsvPar["InputVolume"]=self.CTVolume
    rsvPar["OutputVolume"]=tempCT
    slicer.cli.run(rsv, None, rsvPar, True)
  
    rsvPar["InputVolume"]=self.CBCTVolume
    rsvPar["OutputVolume"]=tempCBCT
    slicer.cli.run(rsv, None, rsvPar, True)

    fgm=slicer.modules.brainsroiauto
    fgmPar={}
    fgmPar["cropOutput"]=True
    fgmPar["ROIAutoDilateSize"]=10
    fgmPar["inputVolume"]=tempCT
    fgmPar["outputVolume"]=tempCT
    slicer.cli.run(fgm, None, fgmPar, True)
  
    fgmPar["inputVolume"]=tempCBCT
    fgmPar["outputVolume"]=tempCBCT
    slicer.cli.run(fgm, None, fgmPar, True)
    self.setCBCTVolumeAsBackground()

    reg1params={}
    reg1params["initializeTransformMode"] = "useMomentsAlign"
    reg1params["numberOfIterations"] = workflowFunctions.get(self.widget,"IterationSpinBox").value
    reg1params["samplingPercentage"] = workflowFunctions.get(self.widget,"SpatialSamplesSliderWidget").value / 100.0
    reg1params["samplingPercentage"] = reg1params["samplingPercentage"] / 2
    reg1params["backgroundFillValue"] = workflowFunctions.get(self.widget,"DefaultPixelValueSpinBox").value
    reg1params["numberOfHistogramBins"] = workflowFunctions.get(self.widget,"HistogramBinsSliderWidget").value
    reg1params["translationScale"] = 5000
    reg1params["fixedVolume"] = tempCBCT
    reg1params["movingVolume"] = tempCT
    reg1params["outputVolume"] = slicer.mrmlScene.CreateNodeByClass("vtkMRMLScalarVolumeNode")
    reg1params["outputVolume"].SetName("RegistrationResult")
    slicer.mrmlScene.AddNode(reg1params["outputVolume"])    
    
    reg1params["transformType"] = "Rigid"
    reg = slicer.mrmlScene.CreateNodeByClass("vtkMRMLLinearTransformNode")
    reg.SetName("imageRegistrationTransform")
    reg1params["linearTransform"] = reg
    slicer.mrmlScene.AddNode(reg)
    self.CLINode = slicer.cli.createNode(slicer.modules.brainsfit)
    self.CLINode.SetName("RigidRegistrationNode")
    self.CLIProgressBar.setCommandLineModuleNode(self.CLINode)
    # on success, start the second phase
    workflowFunctions.addObserver(self, self.CLINode, self.CLINode.StatusModifiedEvent, self.onRigidRegistrationCompleted)
    slicer.cli.run(slicer.modules.brainsfit, self.CLINode, reg1params, wait_for_completion = False)
  
  
  def onRigidRegistrationCompleted(self, cliNode, event):
    # print "step1Status: ", cliNode.GetStatusString()
    if not cliNode.IsBusy():
      workflowFunctions.removeObservers(self, self.onRigidRegistrationCompleted)
      if cliNode.GetStatusString() == 'Completed':
        if workflowFunctions.get(self.widget,"BSplineGroupBox").checked:
          self.bSplineRegistration()
        else:
          transformNodeID=slicer.mrmlScene.GetFirstNodeByName("imageRegistrationTransform").GetID()
          self.CTVolume.SetAndObserveTransformNodeID(transformNodeID)
          self.PETVolume.SetAndObserveTransformNodeID(transformNodeID)
          # temporary downsampled images are no longer needed
          slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("tempCBCT"))
          slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("tempCT"))
          self.setCBCTVolumeAsBackground()

  def bSplineRegistration(self):
    BSplineParameters = {}
    BSplineParameters["initialTransform"]=slicer.mrmlScene.GetFirstNodeByName("imageRegistrationTransform")
    BSplineParameters["initialTransform"].SetName("initialRegistrationTransform")
    BSplineParameters["numberOfIterations"] = workflowFunctions.get(self.widget,"IterationSpinBox").value
    BSplineParameters["numberOfIterations"] = BSplineParameters["numberOfIterations"] / 10 + 1
    BSplineParameters["samplingPercentage"] = workflowFunctions.get(self.widget,"SpatialSamplesSliderWidget").value / 100.0
    BSplineParameters["samplingPercentage"] = BSplineParameters["samplingPercentage"] * 5
    if BSplineParameters["samplingPercentage"] > 1.0:
        BSplineParameters["samplingPercentage"] = 1.0
    BSplineParameters["backgroundFillValue"] = workflowFunctions.get(self.widget,"DefaultPixelValueSpinBox").value
    BSplineParameters["fixedVolume"] = slicer.mrmlScene.GetFirstNodeByName("tempCBCT") # self.CBCTVolume
    BSplineParameters["movingVolume"] = slicer.mrmlScene.GetFirstNodeByName("tempCT") # self.CTVolume
    BSplineParameters["outputVolume"] = slicer.mrmlScene.GetFirstNodeByName("RegistrationResult")

    BSplineParameters["transformType"] = "BSpline"
    BSplineParameters["bsplineTransform"] = slicer.mrmlScene.CreateNodeByClass("vtkMRMLBSplineTransformNode")
    BSplineParameters["bsplineTransform"].SetName("imageRegistrationTransform")
    slicer.mrmlScene.AddNode(BSplineParameters["bsplineTransform"])
  
    scalarGS = int(workflowFunctions.get(self.widget,"GridSizeSliderWidget").value)
    BSplineParameters["splineGridSize"] = str(scalarGS) + "," + str(scalarGS) + "," + str(scalarGS)
    BSplineParameters["numberOfHistogramBins"] = workflowFunctions.get(self.widget,"HistogramBinsSliderWidget").value
    BSplineParameters["translationScale"] = 2000
    if workflowFunctions.get(self.widget,"ConstrainDeformationCheckBox").checked:
      BSplineParameters["maxBSplineDisplacement"] = workflowFunctions.get(self.widget,"MaximumDeformationSpinBox").value
   
    bsplineCLINode = slicer.cli.createNode(slicer.modules.brainsfit)
    bsplineCLINode.SetName("DeformationRegistrationNode")
    self.CLIProgressBar.setCommandLineModuleNode(bsplineCLINode)
    # on success, apply the same transform to the PET volume
    workflowFunctions.addObserver(self, bsplineCLINode, bsplineCLINode.StatusModifiedEvent, self.deformableRegistrationCompleted)
    slicer.cli.run(slicer.modules.brainsfit, bsplineCLINode, BSplineParameters, wait_for_completion = False)
  
  def deformableRegistrationCompleted(self, cliNode, event):
    # print "step2Status: ", cliNode.GetStatusString()
    if cliNode.GetStatusString() == 'Completed' or cliNode.GetStatusString() == 'Completing':
      # 'Completed' event is sometimes not fired, only 'Completing'
      workflowFunctions.removeObservers(self, self.deformableRegistrationCompleted)
      #self.PETVolume.SetAndObserveTransformNodeID(self.CTVolume.GetTransformNodeID())
      transformNodeID=slicer.mrmlScene.GetFirstNodeByName("imageRegistrationTransform").GetID()
      self.CTVolume.SetAndObserveTransformNodeID(transformNodeID)
      self.PETVolume.SetAndObserveTransformNodeID(transformNodeID)
      # temporary downsampled images are no longer needed
      slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("tempCBCT"))
      slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("tempCT"))
      slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("initialRegistrationTransform"))
      self.setCBCTVolumeAsBackground()
      self.CLINode.SetStatus(0x20) #Completed
      self.CLIProgressBar.setCommandLineModuleNode(self.CLINode)
    else:
      self.CLINode.SetStatus(cliNode.GetStatus())
  
  def setCBCTVolumeAsBackground(self):
    if self.CBCTVolume != None:
      workflowFunctions.setActiveVolume(self.CBCTVolume)
