from __main__ import qt, ctk, vtk, slicer
import workflowFunctions as workflowFunctions

class GuidedInterventionPETCTRegistrationStep( ctk.ctkWorkflowWidgetStep ) :

  def __init__( self, stepid ):
    self.initialize( stepid )
    self.setName( '2. Registration' )
    self.setDescription('Register Motion compensated CT-CBCT')
    self.CLINode = None
    self.Observations = []
    qt.QTimer.singleShot(0, self.killButton)


  def killButton(self):
    # hide useless button
    bl = slicer.util.findChildren(text='Tracker')
    if len(bl):
      bl[0].hide()


  def createUserInterface( self ):
    layout =  qt.QVBoxLayout(self)

    #Get widget from ui file and set to layout
    self.widget = workflowFunctions.loadUI('GuidedInterventionPETCTRegistration.ui');
    self.CLIProgressBar = slicer.qSlicerCLIProgressBar()
    layout.addWidget(self.widget)
    layout.addWidget(self.CLIProgressBar)
    workflowFunctions.setScene(self.widget)
    workflowFunctions.get(self.widget,"RegistrationButton").connect('clicked()',self.rigidRegistration)
    workflowFunctions.get(self.widget,"xTranslation").connect('valueChanged(int)',self
    .buildInitialTransform)
    workflowFunctions.get(self.widget,"yTranslation").connect('valueChanged(int)',self.buildInitialTransform)
    workflowFunctions.get(self.widget,"zTranslation").connect('valueChanged(int)',self.buildInitialTransform)
    workflowFunctions.get(self.widget,"SpatialSamplesSliderWidget").setDecimals(4)
    self.setSliderMaximums()

  def validate( self, desiredBranchId ):

    if self.CLINode is None:
      messageBox = qt.QMessageBox.warning( self, 'Error', 'Please start registration process' )
      validation = False
    else:
      CLIstatus = self.CLINode.GetStatusString()
      validation = True
      if CLIstatus != "Completed":
        messageBox = qt.QMessageBox.warning( self, 'Error', 'Please wait end of registration' )
        validation = False
    super( GuidedInterventionPETCTRegistrationStep, self ).validate(validation,desiredBranchId)
    self.setCBCTVolumeAsBackground()


  def onEntry(self, comingFrom, transitionType):
    comingFromId = "None"

    self.PETVolume = slicer.mrmlScene.GetFirstNodeByName("PETVolume")
    self.CTVolume = slicer.mrmlScene.GetFirstNodeByName("CTVolume")
    self.CBCTVolume = slicer.mrmlScene.GetFirstNodeByName("CBCTVolume")

    self.reg = slicer.mrmlScene.CreateNodeByClass("vtkMRMLLinearTransformNode")
    self.reg.SetName("imageRegistrationTransform")
    slicer.mrmlScene.AddNode(self.reg)
    self.CTVolume.SetAndObserveTransformNodeID(self.reg.GetID())
    self.PETVolume.SetAndObserveTransformNodeID(self.reg.GetID())
    workflowFunctions.setActiveVolume(self.CTVolume, 1, 0.5)

    # calculate extents so we can set proper translation slider sizes for manual initialization
    bounds=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

    self.CBCTVolume.GetRASBounds(bounds)
    cbct0=[bounds[0], bounds[2], bounds[4], 0.0] # origin corner
    cbct1=[bounds[1], bounds[3], bounds[5], 0.0] # opposite corner
    self.cbctCenter=[(cbct0[0]+cbct1[0])/2, (cbct0[1]+cbct1[1])/2, (cbct0[2]+cbct1[2])/2, 0.0]

    self.CTVolume.GetRASBounds(bounds)
    ct0=[bounds[0], bounds[2], bounds[4], 0.0] # origin corner
    ct1=[bounds[1], bounds[3], bounds[5], 0.0] # opposite corner
    self.ctCenter=[(ct0[0]+ct1[0])/2, (ct0[1]+ct1[1])/2, (ct0[2]+ct1[2])/2, 0.0]

    initialTranslation = vtk.vtkMatrix4x4()
    initialTranslation.Identity()
    for i in range(0,3):
        initialTranslation.SetElement(i,3, self.cbctCenter[i]-self.ctCenter[i])
    self.reg.SetMatrixTransformToParent(initialTranslation)

    xHalf=max(ct1[0]-ct0[0],cbct1[0]-cbct0[0])/2
    yHalf=max(ct1[1]-ct0[1],cbct1[1]-cbct0[1])/2
    zHalf=max(ct1[2]-ct0[2],cbct1[2]-cbct0[2])/2
    self.extentHalves=[xHalf, yHalf, zHalf]
    self.setSliderMaximums()

    if comingFrom: comingFromId = comingFrom.id()
    super(GuidedInterventionPETCTRegistrationStep, self).onEntry(comingFrom, transitionType)
    qt.QTimer.singleShot(0, self.killButton)

  def buildInitialTransform(self, sliderValue):
    initialTranslation = vtk.vtkMatrix4x4()
    initialTranslation.Identity()
    sliders=[workflowFunctions.get(self.widget,"xTranslation").value,
             workflowFunctions.get(self.widget,"yTranslation").value,
             workflowFunctions.get(self.widget,"zTranslation").value]

    for i in range(0,3):
        initialTranslation.SetElement(i,3, self.cbctCenter[i]-self.ctCenter[i]+sliders[i])
    self.reg.SetMatrixTransformToParent(initialTranslation)

  def setSliderMaximums(self):
    if hasattr(self, 'widget'):
        workflowFunctions.get(self.widget,"xTranslation").minimum=-self.extentHalves[0]
        workflowFunctions.get(self.widget,"xTranslation").maximum=self.extentHalves[0]
        workflowFunctions.get(self.widget,"xTranslation").value=0
        workflowFunctions.get(self.widget,"yTranslation").minimum=-self.extentHalves[1]
        workflowFunctions.get(self.widget,"yTranslation").maximum=self.extentHalves[1]
        workflowFunctions.get(self.widget,"yTranslation").value=0
        workflowFunctions.get(self.widget,"zTranslation").minimum=-self.extentHalves[2]
        workflowFunctions.get(self.widget,"zTranslation").maximum=self.extentHalves[2]
        workflowFunctions.get(self.widget,"zTranslation").value=0

  def onExit(self, goingTo, transitionType):
    goingToId = "None"
    if goingTo: goingToId = goingTo.id()
    # execute the transition
    super(GuidedInterventionPETCTRegistrationStep, self).onExit(goingTo, transitionType)


  def rigidRegistration(self):
    self.tempCT=slicer.mrmlScene.CreateNodeByClass("vtkMRMLScalarVolumeNode")
    self.tempCT.SetName("tempCT")
    slicer.mrmlScene.AddNode(self.tempCT)
    self.tempCBCT=slicer.mrmlScene.CreateNodeByClass("vtkMRMLScalarVolumeNode")
    self.tempCBCT.SetName("tempCBCT")
    slicer.mrmlScene.AddNode(self.tempCBCT)
    
    self.maskCT=slicer.mrmlScene.CreateNodeByClass("vtkMRMLScalarVolumeNode")
    self.maskCT.SetName("maskCT")
    slicer.mrmlScene.AddNode(self.maskCT)
    self.maskCBCT=slicer.mrmlScene.CreateNodeByClass("vtkMRMLScalarVolumeNode")
    self.maskCBCT.SetName("maskCBCT")
    slicer.mrmlScene.AddNode(self.maskCBCT)
    # downsample volumes
    rsv=slicer.modules.resamplescalarvolume
    rsvPar={}
    rsvPar["interpolation"]="lanczos"
    rsvPar["outputPixelSpacing"]="3,3,4"
    rsvPar["InputVolume"]=self.CTVolume
    rsvPar["OutputVolume"]=self.tempCT
    slicer.cli.run(rsv, None, rsvPar, True)
    rsvPar["InputVolume"]=self.CBCTVolume
    rsvPar["OutputVolume"]=self.tempCBCT
    slicer.cli.run(rsv, None, rsvPar, True)

    self.fgm=slicer.modules.brainsroiauto
    self.fgmPar={}
    self.fgmPar["cropOutput"]=True
    self.fgmPar["ROIAutoDilateSize"]=25
    self.fgmPar["thresholdCorrectionFactor"]=1.5
    self.fgmPar["inputVolume"]=self.tempCT
    self.fgmPar["outputVolume"]=self.tempCT
    self.fgmPar["outputROIMaskVolume"]=self.maskCT
    self.CLINode = slicer.cli.createNode(self.fgm)
    self.CLIProgressBar.setCommandLineModuleNode(self.CLINode)
    # on success, start the next phase
    workflowFunctions.addObserver(self, self.CLINode, self.CLINode.StatusModifiedEvent, self.onRigid1Modified)
    slicer.cli.run(self.fgm, self.CLINode, self.fgmPar, wait_for_completion = False)

  def onRigid1Modified(self, cliNode, event):
    # if cliNode.GetStatusString() == 'Completed' or cliNode.GetStatusString() == 'Completing':
    if not cliNode.IsBusy():
      workflowFunctions.removeObservers(self, self.onRigid1Modified)
      self.rigidStep2()

  def rigidStep2(self):
    self.fgmPar["inputVolume"]=self.tempCBCT
    self.fgmPar["outputVolume"]=self.tempCBCT
    self.fgmPar["outputROIMaskVolume"]=self.maskCBCT
    # on success, start the next phase
    workflowFunctions.addObserver(self, self.CLINode, self.CLINode.StatusModifiedEvent, self.onRigid2Modified)
    slicer.cli.run(self.fgm, self.CLINode, self.fgmPar, wait_for_completion = False)

  def onRigid2Modified(self, cliNode, event):
    # print "rigid4Status: ", cliNode.GetStatusString()
    if cliNode.GetStatusString() == 'Completed' or cliNode.GetStatusString() == 'Completing':
    # if not cliNode.IsBusy():
      workflowFunctions.removeObservers(self, self.onRigid2Modified)
      self.rigidStep3()

  def rigidStep3(self):
    self.smoothCT=slicer.mrmlScene.CreateNodeByClass("vtkMRMLScalarVolumeNode")
    self.smoothCT.SetName("smoothCT")
    slicer.mrmlScene.AddNode(self.smoothCT)
    self.smoothCBCT=slicer.mrmlScene.CreateNodeByClass("vtkMRMLScalarVolumeNode")
    self.smoothCBCT.SetName("smoothCBCT")
    slicer.mrmlScene.AddNode(self.smoothCBCT)
    # smooth the volumes to enable better convergence
    gbif=slicer.modules.gaussianblurimagefilter
    gbifPar={}
    gbifPar["sigma"]=6
    gbifPar["inputVolume"]=self.tempCT
    gbifPar["outputVolume"]=self.smoothCT
    slicer.cli.run(gbif, None, gbifPar, True)
    gbifPar["inputVolume"]=self.tempCBCT
    gbifPar["outputVolume"]=self.smoothCBCT
    slicer.cli.run(gbif, None, gbifPar, True)

    self.rigidPar={}
    self.rigidPar["initializeTransformMode"] = "Off"
    self.rigidPar["initialTransform"]=slicer.mrmlScene.GetFirstNodeByName("imageRegistrationTransform")
    self.rigidPar["numberOfIterations"] = 10000
    self.rigidPar["samplingPercentage"] = workflowFunctions.get(self.widget,"SpatialSamplesSliderWidget").value / 100.0
    self.rigidPar["samplingPercentage"] = self.rigidPar["samplingPercentage"] * 5
    if self.rigidPar["samplingPercentage"]>1.0:
        self.rigidPar["samplingPercentage"]=1.0
    self.rigidPar["backgroundFillValue"] = -1000 # air
    self.rigidPar["numberOfHistogramBins"] = 20
    self.rigidPar["translationScale"] = 15000
    self.rigidPar["relaxationFactor"] = 0.7
    # self.rigidPar["maximumStepLength"] = 2 # millimeters?
    # self.rigidPar["minimumStepLength"] = 0.01 # millimeters?
    self.rigidPar["maskProcessingMode"] = "ROI"
    self.rigidPar["fixedBinaryVolume"] = self.maskCBCT
    self.rigidPar["movingBinaryVolume"] = self.maskCT
    self.rigidPar["histogramMatch"] = True
    self.rigidPar["fixedVolume"] = self.smoothCBCT
    self.rigidPar["movingVolume"] = self.smoothCT
    self.rigidPar["outputVolume"] = slicer.mrmlScene.CreateNodeByClass("vtkMRMLScalarVolumeNode")
    self.rigidPar["outputVolume"].SetName("RegistrationResult")
    slicer.mrmlScene.AddNode(self.rigidPar["outputVolume"])

    self.rigidPar["transformType"] = "Rigid"
    self.rigidPar["linearTransform"] = self.reg
    self.CLINode = slicer.cli.createNode(slicer.modules.brainsfit)
    self.CLINode.SetName("RigidRegistrationNode")
    self.CLIProgressBar.setCommandLineModuleNode(self.CLINode)
    # on success start the final phase
    workflowFunctions.addObserver(self, self.CLINode, self.CLINode.StatusModifiedEvent, self.onRigidRegistrationCompleted)
    slicer.cli.run(slicer.modules.brainsfit, self.CLINode, self.rigidPar, wait_for_completion = False)


  def onRigidRegistrationCompleted(self, cliNode, event):
    # print "rigidStatus: ", cliNode.GetStatusString()
    # if cliNode.GetStatusString() == 'Completed' or cliNode.GetStatusString() == 'Completing':
    if not cliNode.IsBusy():
      workflowFunctions.removeObservers(self, self.onRigidRegistrationCompleted)
      self.setCBCTVolumeAsBackground()
      self.bSplineRegistration()


  def bSplineRegistration(self):
    
    initReg = slicer.mrmlScene.CreateNodeByClass("vtkMRMLLinearTransformNode")
    initReg.Copy(self.reg);
    initReg.SetName("initialRegistrationTransform")
    slicer.mrmlScene.AddNode(initReg)
    
    if workflowFunctions.get(self.widget,"BSplineGroupBox").checked:
      # parameters are not very suitable for the data, so BSpline terminates after 1 iteration
      BSplineParameters = {}
      BSplineParameters["initialTransform"]=initReg
      BSplineParameters["initialTransform"].SetName("initialRegistrationTransform")
      BSplineParameters["numberOfIterations"] = 1000
      BSplineParameters["samplingPercentage"] = workflowFunctions.get(self.widget,"SpatialSamplesSliderWidget").value / 100.0
      # BSplineParameters["samplingPercentage"] = BSplineParameters["samplingPercentage"] * 5
      # if BSplineParameters["samplingPercentage"] > 1.0:
          # BSplineParameters["samplingPercentage"] = 1.0
      BSplineParameters["backgroundFillValue"] = -1000 # air
      BSplineParameters["fixedVolume"] = self.tempCBCT
      BSplineParameters["movingVolume"] = self.tempCT
      BSplineParameters["outputVolume"] = slicer.mrmlScene.GetFirstNodeByName("RegistrationResult")

      BSplineParameters["transformType"] = "BSpline"
      BSplineParameters["bsplineTransform"] = slicer.mrmlScene.CreateNodeByClass("vtkMRMLBSplineTransformNode")
      BSplineParameters["bsplineTransform"].SetName("imageRegistrationTransform")
      slicer.mrmlScene.AddNode(BSplineParameters["bsplineTransform"])

      scalarGS = int(workflowFunctions.get(self.widget,"GridSizeSliderWidget").value)
      BSplineParameters["splineGridSize"] = str(scalarGS) + "," + str(scalarGS) + "," + str(scalarGS)
      BSplineParameters["numberOfHistogramBins"] = workflowFunctions.get(self.widget,"HistogramBinsSliderWidget").value
      BSplineParameters["translationScale"] = 5000
      # BSplineParameters["maximumStepLength"] = 0.5 # millimeters
      # BSplineParameters["minimumStepLength"] = 0.01 # millimeters
      BSplineParameters["maskProcessingMode"] = "ROI"
      BSplineParameters["fixedBinaryVolume"] = self.maskCBCT
      BSplineParameters["movingBinaryVolume"] = self.maskCT
      BSplineParameters["histogramMatch"] = True
      BSplineParameters["maxBSplineDisplacement"] = 25
    else: #repeat rigid registration with sharp images
      self.rigidPar["initializeTransformMode"] = "Off"
      self.rigidPar["initialTransform"]=self.reg
      self.rigidPar["fixedVolume"] = self.tempCBCT
      self.rigidPar["movingVolume"] = self.tempCT
      self.rigidPar["translationScale"] = 5000
      self.rigidPar["relaxationFactor"] = 0.3
      # self.rigidPar["maximumStepLength"] = 0.2 # millimeters?
      # self.rigidPar["minimumStepLength"] = 0.005 # millimeters?
      # self.rigidPar["samplingPercentage"] = self.rigidPar["samplingPercentage"] / 5

    self.CLINode = slicer.cli.createNode(slicer.modules.brainsfit)
    self.CLIProgressBar.setCommandLineModuleNode(self.CLINode)
    # on success, apply the same transform to the PET volume
    workflowFunctions.addObserver(self, self.CLINode, self.CLINode.StatusModifiedEvent, self.deformableRegistrationCompleted)
    if workflowFunctions.get(self.widget,"BSplineGroupBox").checked:
      # print ("BSplineParameters: ", BSplineParameters)
      slicer.cli.run(slicer.modules.brainsfit, self.CLINode, BSplineParameters, wait_for_completion = False)
    else:
      slicer.cli.run(slicer.modules.brainsfit, self.CLINode, self.rigidPar, wait_for_completion = False)


  def deformableRegistrationCompleted(self, cliNode, event):
    # print "step2Status: ", cliNode.GetStatusString()
    if cliNode.GetStatusString() == 'Completed' or cliNode.GetStatusString() == 'Completing':
    # if not cliNode.IsBusy():
      # 'Completed' event is sometimes not fired, only 'Completing'
      workflowFunctions.removeObservers(self, self.deformableRegistrationCompleted)
      # temporary downsampled images are no longer needed
      slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("tempCBCT"))
      slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("tempCT"))
      slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("maskCBCT"))
      slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("maskCT"))
      slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("initialRegistrationTransform"))
      slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("smoothCBCT"))
      slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName("smoothCT"))
      self.setCBCTVolumeAsBackground()
      self.CLINode.SetStatus(0x20) #Completed


  def setCBCTVolumeAsBackground(self):
    if self.CBCTVolume != None:
      workflowFunctions.setActiveVolume(self.CBCTVolume)
