/* Distributed under the Apache License, Version 2.0.
   See accompanying NOTICE file for details.*/

#include "LBMITK.h"

bool RunLBMITK(std::string const& filename, LBM &lbm)
{
  try
  {
    itk::ImageIOBase::Pointer image_io =
      itk::ImageIOFactory::CreateImageIO(filename.c_str(), itk::ImageIOFactory::FileModeEnum::ReadMode);
    if (image_io == nullptr)
    {
      return false;
    }
    image_io->SetFileName(filename);
    image_io->ReadImageInformation();
    // Get the pixel type
    using IOPixelType = itk::ImageIOBase::IOPixelType;
    const IOPixelType pixel_type = image_io->GetPixelType();
    std::cout << "Pixel Type is " << itk::ImageIOBase::GetPixelTypeAsString(pixel_type) << std::endl;
    // Get the component type
    using IOComponentType = itk::ImageIOBase::IOComponentType;
    const IOComponentType component_type = image_io->GetComponentType();
    std::cout << "Component Type is " << image_io->GetComponentTypeAsString(component_type) << std::endl;
    // Print out some more information
    const unsigned int image_dimension = image_io->GetNumberOfDimensions();
    std::cout << "Image Dimension is " << image_dimension << std::endl;
    if (image_dimension != 3)
    {
      std::cerr << "Currently, Only 3D images are supported\n";
      return false;
    }

    // This filter handles all types on input, but only produces
    // signed types
    switch (component_type)
    {
    case itk::ImageIOBase::UCHAR:
    {
      using PixelType = unsigned char;
      using ImageType = itk::Image<PixelType, 3>;
      LBMITK<PixelType, ImageType> lbm_itk(lbm);
      return lbm_itk.Run(filename);
    }
    case itk::ImageIOBase::CHAR:
    {
      using PixelType = char;
      using ImageType = itk::Image<PixelType, 3>;
      LBMITK<PixelType, ImageType> lbm_itk(lbm);
      return lbm_itk.Run(filename);
    }
    case itk::ImageIOBase::USHORT:
    {
      using PixelType = unsigned short;
      using ImageType = itk::Image<PixelType, 3>;
      LBMITK<PixelType, ImageType> lbm_itk(lbm);
      return lbm_itk.Run(filename);
    }
    case itk::ImageIOBase::SHORT:
    {
      using PixelType = short;
      using ImageType = itk::Image<PixelType, 3>;
      LBMITK<PixelType, ImageType> lbm_itk(lbm);
      return lbm_itk.Run(filename);
    }
    case itk::ImageIOBase::UINT:
    {
      using PixelType = unsigned int;
      using ImageType = itk::Image<PixelType, 3>;
      LBMITK<PixelType, ImageType> lbm_itk(lbm);
      return lbm_itk.Run(filename);
    }
    case itk::ImageIOBase::INT:
    {
      using PixelType = int;
      using ImageType = itk::Image<PixelType, 3>;
      LBMITK<PixelType, ImageType> lbm_itk(lbm);
      return lbm_itk.Run(filename);
    }
    case itk::ImageIOBase::ULONG:
    {
      using PixelType = unsigned long;
      using ImageType = itk::Image<PixelType, 3>;
      LBMITK<PixelType, ImageType> lbm_itk(lbm);
      return lbm_itk.Run(filename);
    }
    case itk::ImageIOBase::LONG:
    {
      using PixelType = long;
      using ImageType = itk::Image<PixelType, 3>;
      LBMITK<PixelType, ImageType> lbm_itk(lbm);
      return lbm_itk.Run(filename);
    }
    case itk::ImageIOBase::FLOAT:
    {
      using PixelType = float;
      using ImageType = itk::Image<PixelType, 3>;
      LBMITK<PixelType, ImageType> lbm_itk(lbm);
      return lbm_itk.Run(filename);
    }
    case itk::ImageIOBase::DOUBLE:
    {
      using PixelType = double;
      using ImageType = itk::Image<PixelType, 3>;
      LBMITK<PixelType, ImageType> lbm_itk(lbm);
      return lbm_itk.Run(filename);
    }
    case itk::ImageIOBase::UNKNOWNCOMPONENTTYPE:
    default:
      std::cerr << "Unknown input image pixel component type: ";
      std::cerr << itk::ImageIOBase::GetComponentTypeAsString(component_type);
      std::cerr << std::endl;
      return false;
      break;
    }
  }
  catch (itk::ExceptionObject& excep)
  {
    std::cerr << "Unable to read file : " << filename << " : with exception :" << std::endl;
    std::cerr << excep << std::endl;
  }
  return false;
}

template<LBMITK_TEMPLATE>
LBMITK<LBMITK_TYPES>::LBMITK(LBM& lbm) : m_lbm(lbm)
{

}

template<LBMITK_TEMPLATE>
LBMITK<LBMITK_TYPES>::~LBMITK()
{

}

template<LBMITK_TEMPLATE>
bool LBMITK<LBMITK_TYPES>::Run(std::string const& filename)
{
  if (!ReadImage(filename))
    return false;
  if (!NormalizeImage())
    return false;
  if (!CropImage())
    return false;
  if (!DownsampleImage())
    return false;
  if (!OrientImageToX())
    return false;

  std::cout << "Processing boundary labels...\n";
  auto origin = m_image->GetOrigin();
  auto spacing = m_image->GetSpacing();
  auto size = m_image->GetLargestPossibleRegion().GetSize();
  std::cout << "  Origin  : " << origin << "\n";
  std::cout << "  Size    : " << size << "\n";
  std::cout << "  Spacing : " << spacing << "\n";

  // Convert image to an LBM grid
  m_lbm.in.dimensions[0] = size[0];
  m_lbm.in.dimensions[1] = size[1];
  m_lbm.in.dimensions[2] = size[2];
  m_lbm.in.grid_spacing = spacing[0]/1000;
  itk::ImageRegionIteratorWithIndex<ImageType>
    m2iIt(m_image, m_image->GetLargestPossibleRegion());
  
  for (;!m2iIt.IsAtEnd();++m2iIt)
    m_lbm.in.source_labels.push_back(m2iIt.Get());

  if (!m_lbm.Run())
    return false;
  return true;
}

template<LBMITK_TEMPLATE>
bool LBMITK<LBMITK_TYPES>::ReadImage(std::string const& filename)
{
  std::cout << "Reading image...\n";
  m_image = ImageType::New();
  using ImageReaderType = itk::ImageFileReader<ImageType>;
  auto reader = ImageReaderType::New();
  reader->SetFileName(filename);
  try
  {
    reader->Update();
  }
  catch (itk::ExceptionObject& e)
  {
    std::cerr << e.what() << std::endl;
    return false;
  }
  m_image->Graft(reader->GetOutput());

  // Print out what we have
  std::cout << "  Origin  : " << m_image->GetOrigin() << "\n";
  std::cout << "  Size    : " << m_image->GetLargestPossibleRegion().GetSize() << "\n";
  std::cout << "  Spacing : " << m_image->GetSpacing() << "\n";
  return true;
}

template<LBMITK_TEMPLATE>
bool LBMITK<LBMITK_TYPES>::WriteImage(std::string const& filename)
{
  try
  {
    using ImageWriterType = itk::ImageFileWriter<ImageType>;
    auto writer = ImageWriterType::New();
    writer->SetInput(m_image);
    writer->SetFileName(filename);
    writer->Update();
  }
  catch (itk::ExceptionObject& e)
  {
    std::cerr << e.what() << std::endl;
    return false;
  }
  return true;
}

template<LBMITK_TEMPLATE>
bool LBMITK<LBMITK_TYPES>::NormalizeImage()
{
  try
  {
    std::cout << "Normalizing image spacing...\n";
    auto interp = InterpType::New();
    auto resample_filter = ResampleFilterType::New();

    auto origin = m_image->GetOrigin();
    auto spacing = m_image->GetSpacing();
    auto size = m_image->GetLargestPossibleRegion().GetSize();

    resample_filter->SetInput(m_image);
    resample_filter->SetInterpolator(interp);
    resample_filter->SetDefaultPixelValue(0);
    resample_filter->SetOutputOrigin(origin);
    resample_filter->SetOutputDirection(m_image->GetDirection());

    // Resize everything to the axis with the smallest spacing
    if (spacing[0] <= spacing[1] && spacing[0] <= spacing[2])
    {
      size[1] = size[1] * spacing[1] / spacing[0];
      spacing[1] = spacing[0];
      size[2] = size[2] * spacing[2] / spacing[0];
      spacing[2] = spacing[0];
    }
    else if (spacing[1] <= spacing[0] && spacing[1] <= spacing[2])
    {
      size[0] = size[0] * spacing[0] / spacing[1];
      spacing[0] = spacing[1];
      size[2] = size[2] * spacing[2] / spacing[1];
      spacing[2] = spacing[1];
    }
    else
    {
      size[0] = size[0] * spacing[0] / spacing[2];
      spacing[0] = spacing[2];
      size[1] = size[1] * spacing[1] / spacing[2];
      spacing[1] = spacing[2];
    }

    resample_filter->SetSize(size);
    resample_filter->SetOutputSpacing(spacing);
    resample_filter->Update();
    m_image = resample_filter->GetOutput();

    // Print out what we did
    std::cout << "  Origin  : " << m_image->GetOrigin() << "\n";
    std::cout << "  Size    : " << m_image->GetLargestPossibleRegion().GetSize() << "\n";
    std::cout << "  Spacing : " << m_image->GetSpacing() << "\n";
  }
  catch (itk::ExceptionObject & excep)
  {
    std::cerr << "Unable to rotate imate : with exception :" << std::endl;
    std::cerr << excep << std::endl;
    return false;
  }
  return true;
}

template<LBMITK_TEMPLATE>
bool LBMITK<LBMITK_TYPES>::CropImage()
{
  try
  {
    std::cout << "Cropping boundary...\n";
    auto i2m_filter = LabelImageToMapFilterType::New();

    i2m_filter->SetInput(m_image);
    i2m_filter->SetBackgroundValue(0);
    i2m_filter->Update();

    auto crop = CropType::New();
    crop->SetInput(i2m_filter->GetOutput());
    CropSize border;
    border[0] = 5;
    border[1] = 5;
    border[2] = 5;
    crop->SetCropBorder(border);
    crop->Update();

    auto m2i_filter = LabelMapToImageFilterType::New();
    m2i_filter->SetInput(crop->GetOutput());
    m2i_filter->Update();
    m_image = m2i_filter->GetOutput();

    // Get our new origin
    auto origin = m_image->GetOrigin();
    auto region = m_image->GetRequestedRegion();
    m_image->TransformIndexToPhysicalPoint(region.GetIndex(), origin);

    // Resample to this new origin/grid
    auto interp = InterpType::New();
    auto resample_filter = ResampleFilterType::New();
    auto spacing = m_image->GetSpacing();
    auto size = m_image->GetLargestPossibleRegion().GetSize();

    resample_filter->SetInput(m_image);
    resample_filter->SetInterpolator(interp);
    resample_filter->SetOutputOrigin(origin);
    resample_filter->SetOutputDirection(m_image->GetDirection());
    resample_filter->SetSize(size);
    resample_filter->SetOutputSpacing(spacing);
    resample_filter->Update();
    m_image = resample_filter->GetOutput();

    // Print out what we did
    std::cout << "  Origin  : " << m_image->GetOrigin() << "\n";
    std::cout << "  Size    : " << m_image->GetLargestPossibleRegion().GetSize() << "\n";
    std::cout << "  Spacing : " << m_image->GetSpacing() << "\n";
  }
  catch (itk::ExceptionObject & excep)
  {
    std::cerr << "Unable to rotate imate : with exception :" << std::endl;
    std::cerr << excep << std::endl;
    return false;
  }
  return true;
}

template<LBMITK_TEMPLATE>
bool LBMITK<LBMITK_TYPES>::DownsampleImage()
{
  try
  {
    auto origin = m_image->GetOrigin();
    auto spacing = m_image->GetSpacing();
    auto size = m_image->GetLargestPossibleRegion().GetSize();

    double downsample_fraction = 0;
    if (m_lbm.cfg.downsample_type == LBM::DownsampleType::Length)
    {
      if (m_lbm.cfg.downsample_to_length <= spacing[0])
      {
        // Nothing to do
        std::cout << "Requested downsample length smaller that image spacing\n";
        std::cout << "Running at resolution of provided image (" << spacing[0] << ")\n";
        return true;
      }
      downsample_fraction = spacing[0] / m_lbm.cfg.downsample_to_length;
    }
    else
      downsample_fraction = m_lbm.cfg.downsample_fraction;
    if (downsample_fraction == 0)
    {
      std::cout << "Downsampling not requested\n";
      return true;
    }

    std::cout << "Downsampling image...\n";
    auto interp = InterpType::New();
    auto resample_filter = ResampleFilterType::New();

    resample_filter->SetInput(m_image);
    resample_filter->SetInterpolator(interp);
    resample_filter->SetOutputOrigin(origin);
    resample_filter->SetOutputDirection(m_image->GetDirection());

    Size temp_size;
    temp_size[0] = size[0] * downsample_fraction;
    temp_size[1] = size[1] * downsample_fraction;
    temp_size[2] = size[2] * downsample_fraction;

    Size new_size;
    new_size[0] = temp_size[0] - (temp_size[0] % 16);
    new_size[1] = temp_size[1] - (temp_size[1] % 16);
    new_size[2] = temp_size[2] - (temp_size[2] % 16);

    Spacing new_spacing;
    new_spacing[0] = spacing[0] * (static_cast<float>(size[0]) / static_cast<float>(new_size[0]));
    new_spacing[1] = spacing[1] * (static_cast<float>(size[1]) / static_cast<float>(new_size[1]));
    new_spacing[2] = spacing[2] * (static_cast<float>(size[2]) / static_cast<float>(new_size[2]));

    resample_filter->SetSize(new_size);
    resample_filter->SetOutputSpacing(new_spacing);
    resample_filter->Update();
    m_image = resample_filter->GetOutput();

    // Print out what we did
    std::cout << "  Origin  : " << m_image->GetOrigin() << "\n";
    std::cout << "  Size    : " << m_image->GetLargestPossibleRegion().GetSize() << "\n";
    std::cout << "  Spacing : " << m_image->GetSpacing() << "\n";
  }
  catch (itk::ExceptionObject & excep)
  {
    std::cerr << "Unable to rotate image : with exception :" << std::endl;
    std::cerr << excep << std::endl;
    return false;
  }
  return true;
}

template<LBMITK_TEMPLATE>
bool LBMITK<LBMITK_TYPES>::OrientImageToX()
{
  try
  {
    std::cout << "Orienting image...\n";
    // Rotate image for long axis to be along the X-axis
    auto interp = InterpType::New();
    auto resample_filter = ResampleFilterType::New();
    auto transform = TransformType::New();

    auto origin = m_image->GetOrigin();
    auto spacing = m_image->GetSpacing();
    auto size = m_image->GetLargestPossibleRegion().GetSize();

    resample_filter->SetInput(m_image);
    resample_filter->SetInterpolator(interp);
    resample_filter->SetDefaultPixelValue(0);

    // Translate image to origin
    TransformType::OutputVectorType trans1;
    trans1[0] = -1 * (origin[0] + (spacing[0] * size[0] * 0.5));
    trans1[1] = -1 * (origin[1] + (spacing[1] * size[1] * 0.5));
    trans1[2] = -1 * (origin[2] + (spacing[2] * size[2] * 0.5));
    transform->Translate(trans1);
    // Rotate 90 degrees around Y, then 90 degrees around Z
    // Which would be the x axis after we do the Y rotate...
    TransformType::OutputVectorType yAxis, zAxis;
    yAxis[0] = 0; yAxis[1] = 1; yAxis[2] = 0;
    zAxis[0] = 1; zAxis[1] = 0; zAxis[2] = 0;
    transform->Rotate3D(yAxis, 1.5708, false);
    transform->Rotate3D(zAxis, 1.5708, false);
    // Adjust our size for this new orientation
    Size new_size;
    new_size[0] = size[1];
    new_size[1] = size[2];
    new_size[2] = size[0];
    Spacing new_spacing;
    new_spacing[0] = spacing[1];
    new_spacing[1] = spacing[2];
    new_spacing[2] = spacing[0];
    //// Now translate back
    TransformType::OutputVectorType trans2;
    trans2[0] = -1 * trans1[0];
    trans2[1] = -1 * trans1[1];
    trans2[2] = -1 * trans1[2];
    transform->Translate(trans2);
    resample_filter->SetTransform(transform);
    // Adjust our origin based on our rotation
    Origin new_origin;
    new_origin[0] = trans2[0] - (new_spacing[0] * new_size[0] * 0.5);
    new_origin[1] = trans2[1] - (new_spacing[1] * new_size[1] * 0.5);
    new_origin[2] = trans2[2] - (new_spacing[2] * new_size[2] * 0.5);

    resample_filter->SetSize(new_size);
    resample_filter->SetOutputOrigin(new_origin);
    resample_filter->SetOutputSpacing(new_spacing);
    resample_filter->SetOutputDirection(m_image->GetDirection());
    // Execute our operations
    resample_filter->Update();
    m_image = resample_filter->GetOutput();

    // Print out what we did
    std::cout << "  Origin  : " << m_image->GetOrigin() << "\n";
    std::cout << "  Size    : " << m_image->GetLargestPossibleRegion().GetSize() << "\n";
    std::cout << "  Spacing : " << m_image->GetSpacing() << "\n";
  }
  catch (itk::ExceptionObject& excep)
  {
    std::cerr << "Unable to rotate imate : with exception :" << std::endl;
    std::cerr << excep << std::endl;
    return false;
  }
  return true;
}

