//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//
//  Copyright 2019 Sandia Corporation.
//  Copyright 2019 UT-Battelle, LLC.
//  Copyright 2019 Los Alamos National Security.
//
//  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
//  the U.S. Government retains certain rights in this software.
//
//  Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
//  Laboratory (LANL), the U.S. Government retains certain rights in
//  this software.
//============================================================================

#include "vtkmLookupTable.h"

#include "vtkmlib/DataArrayConverters.h"

#include <vtkm/cont/ArrayHandleBasic.h>
#include <vtkm/cont/ArrayHandleTransform.h>
#include <vtkm/cont/ColorTableMap.h>
#include <vtkm/cont/ErrorBadType.h>
#include <vtkm/cont/ErrorFilterExecution.h>
#include <vtkm/cont/Invoker.h>
#include <vtkm/worklet/WorkletMapField.h>
#include <vtkm/worklet/colorconversion/ConvertToRGBA.h>
#include <vtkm/worklet/colorconversion/ShiftScaleToRGB.h>
#include <vtkm/worklet/colorconversion/ShiftScaleToRGBA.h>

#include "vtkLogger.h"

#include <sstream>

//------------------------------------------------------------------------------------------------
namespace internal
{
VTK_ABI_NAMESPACE_BEGIN

struct VtkmTables
{
  vtkm::cont::ColorTableSamplesRGB SamplesRGB;
  vtkm::cont::ColorTableSamplesRGBA SamplesRGBA;
};

VTK_ABI_NAMESPACE_END
} // internal

//------------------------------------------------------------------------------------------------
namespace
{

template <typename ColorTableSamples>
void PrintVtkmColorTableSamples(const ColorTableSamples& samples, ostream& os, vtkIndent indent)
{
  os << indent << "SampleRange: " << samples.SampleRange << "\n";
  os << indent << "NumberOfSamples: " << samples.NumberOfSamples << "\n";
  os << indent << "Samples: ";
  vtkm::cont::printSummary_ArrayHandle(samples.Samples, os);
}

} // namespace

//------------------------------------------------------------------------------------------------
VTK_ABI_NAMESPACE_BEGIN
vtkStandardNewMacro(vtkmLookupTable);

struct vtkmLookupTable::InternalMembers
{
  internal::VtkmTables VtkmColorTables;
};

vtkmLookupTable::vtkmLookupTable()
  : Detail(std::make_unique<InternalMembers>())
{
}

vtkmLookupTable::~vtkmLookupTable() = default;

const internal::VtkmTables& vtkmLookupTable::GetVtkmTables() const
{
  return this->Detail->VtkmColorTables;
}

void vtkmLookupTable::PrintSelf(ostream& os, vtkIndent indent)
{
  this->Superclass::PrintSelf(os, indent);
  os << indent << "ColorTableSamplesRGB:\n";
  PrintVtkmColorTableSamples(this->GetVtkmTables().SamplesRGB, os, indent.GetNextIndent());
  os << indent << "ColorTableSamplesRGBA:\n";
  PrintVtkmColorTableSamples(this->GetVtkmTables().SamplesRGBA, os, indent.GetNextIndent());
}
VTK_ABI_NAMESPACE_END

//------------------------------------------------------------------------------------------------
namespace
{

template <vtkm::IdComponent N>
void CopyTable(
  const vtkm::UInt8* tableIn,
  vtkm::Id numColors,
  vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::UInt8, N>>& tableOut)
{
  tableOut.Allocate(numColors + 4);
  auto portal = tableOut.WritePortal();

  vtkm::Vec<vtkm::UInt8, N> vec;
  vtkm::make_VecC(tableIn + ((numColors + vtkLookupTable::BELOW_RANGE_COLOR_INDEX) * 4), N).CopyInto(vec);
  portal.Set(0, vec);

  for (vtkm::Id i = 0; i < numColors; ++i)
  {
    vtkm::make_VecC(tableIn + (i * 4), N).CopyInto(vec);
    portal.Set(i + 1, vec);
  }

  vtkm::make_VecC(tableIn + ((numColors + vtkLookupTable::REPEATED_LAST_COLOR_INDEX) * 4), N).CopyInto(vec);
  portal.Set(numColors + 1, vec);

  vtkm::make_VecC(tableIn + ((numColors + vtkLookupTable::ABOVE_RANGE_COLOR_INDEX) * 4), N).CopyInto(vec);
  portal.Set(numColors + 2, vec);

  vtkm::make_VecC(tableIn + ((numColors + vtkLookupTable::NAN_COLOR_INDEX) * 4), N).CopyInto(vec);
  portal.Set(numColors + 3, vec);
}

template <typename ColorTableSamples>
void BuildColorTableSamples(
  ColorTableSamples& cts,
  const double tableRange[2],
  vtkm::Id numberOfColors,
  const vtkm::UInt8* table)
{
  cts.SampleRange = vtkm::Range(tableRange[0], tableRange[1]);
  cts.NumberOfSamples = numberOfColors;
  CopyTable(table, numberOfColors, cts.Samples);
}

} // namespace

//------------------------------------------------------------------------------------------------
VTK_ABI_NAMESPACE_BEGIN
void vtkmLookupTable::Build()
{
  this->Superclass::Build();

  const auto* table = this->GetTable()->GetPointer(0);
  BuildColorTableSamples(this->Detail->VtkmColorTables.SamplesRGB, this->GetTableRange(), this->NumberOfColors, table);
  BuildColorTableSamples(this->Detail->VtkmColorTables.SamplesRGBA, this->GetTableRange(), this->NumberOfColors, table);
}
VTK_ABI_NAMESPACE_END

//------------------------------------------------------------------------------------------------
namespace
{

// replicates vtkApplyLogScale
struct LogScaleTransform
{
  vtkm::Range Range, LogRange;

  VTKM_EXEC vtkm::Float64 ApplyLogScale(vtkm::Float64 value) const
  {
    if (this->Range.Min < 0)
    {
      if (value < 0)
      {
        return -vtkm::Log10(-value);
      }
      else if (this->Range.Min > this->Range.Max)
      {
        return this->LogRange.Min;
      }
      else
      {
        return this->LogRange.Max;
      }
    }
    else
    {
      if (value > 0)
      {
        return vtkm::Log10(value);
      }
      else if (this->Range.Min <= this->Range.Max)
      {
        return this->LogRange.Min;
      }
      else
      {
        return this->LogRange.Max;
      }
    }
  }

  template <typename T>
  VTKM_EXEC vtkm::Float64 operator()(const T& value) const
  {
    return this->ApplyLogScale(static_cast<vtkm::Float64>(value));
  }

  VTKM_EXEC vtkm::Float64 operator()(const vtkm::Float32& value) const
  {
    if (vtkm::IsNan(value))
    {
      return value;
    }
    return this->ApplyLogScale(value);
  }

  VTKM_EXEC vtkm::Float64 operator()(const vtkm::Float64& value) const
  {
    if (vtkm::IsNan(value))
    {
      return value;
    }
    return this->ApplyLogScale(value);
  }
};

struct MagnitudeTransform
{
  template <typename T>
  VTKM_EXEC auto operator()(const T& val) const
  {
    return vtkm::Magnitude(val);
  }
};

// TODO: Remove and use ArrayHandleRuntimeVec::AsArrayHandleBasic when its available in VTK
template <typename T>
vtkm::cont::ArrayHandleBasic<T> AsArrayHandleBasic(void* data, vtkm::Id size)
{
  return vtkm::cont::make_ArrayHandle(reinterpret_cast<T*>(data), size, vtkm::CopyFlag::Off);
}

template <typename T, typename S>
void MapScalarsLookupTableImpl(const internal::VtkmTables& vtkmTables, const vtkm::cont::ArrayHandle<T, S>& input, vtkm::UInt8* output, int outputFormat)
{
  switch (outputFormat)
  {
  case VTK_RGB:
    {
      auto op = AsArrayHandleBasic<vtkm::Vec<vtkm::UInt8, 3>>(output, input.GetNumberOfValues());
      vtkm::cont::ColorTableMap(input, vtkmTables.SamplesRGB, op);
      op.SyncControlArray();
      break;
    }
  case VTK_RGBA:
    {
      auto op = AsArrayHandleBasic<vtkm::Vec<vtkm::UInt8, 4>>(output, input.GetNumberOfValues());
      vtkm::cont::ColorTableMap(input, vtkmTables.SamplesRGBA, op);
      op.SyncControlArray();
      break;
    }
  default: // unsupported format, should be checked for and reported earlier
    break;
  }
}

template <typename T, typename S>
void MapScalarsLookupTable(vtkmLookupTable* self, const vtkm::cont::ArrayHandle<T, S>& input, vtkm::UInt8* output, int outputFormat)
{
  if (self->GetScale() == VTK_SCALE_LOG10)
  {
    const double *tableRange = self->GetTableRange();
    double logRange[2];
    self->GetLogRange(tableRange, logRange);

    auto vtkmTables = self->GetVtkmTables();
    vtkmTables.SamplesRGB.SampleRange = vtkm::Range(logRange[0], logRange[1]);
    vtkmTables.SamplesRGBA.SampleRange = vtkm::Range(logRange[0], logRange[1]);

    auto logArray = vtkm::cont::make_ArrayHandleTransform(input, LogScaleTransform{{tableRange[0], tableRange[1]}, {logRange[0], logRange[1]}});
    MapScalarsLookupTableImpl(vtkmTables, logArray, output, outputFormat);
  }
  else
  {
    MapScalarsLookupTableImpl(self->GetVtkmTables(), input, output, outputFormat);
  }

  auto alpha = vtkm::Clamp(self->GetAlpha(), 0.0, 1.0);
  if (alpha < 1.0 && outputFormat == VTK_RGBA) // needs blending with alpha
  {
    auto op = AsArrayHandleBasic<vtkm::Vec<vtkm::UInt8, 4>>(output, input.GetNumberOfValues());
    vtkm::cont::Invoker invoke(vtkm::cont::DeviceAdapterTagAny{});
    invoke(vtkm::worklet::colorconversion::ConvertToRGBA(alpha), op, op);
    op.SyncControlArray();
  }
}

template <typename T, typename S>
void MapScalarsColorsToColors(vtkmLookupTable* self, const vtkm::cont::ArrayHandle<T, S>& input, vtkm::UInt8* output, int outputFormat)
{
  auto range = self->GetRange();
  auto shift = static_cast<vtkm::Float32>(-range[0]);
  auto scale = static_cast<vtkm::Float32>(range[1] - range[0]);
  if (scale * scale > 1e-30)
  {
    scale = 1.0 / scale;
  }
  else
  {
    scale = (scale < 0.0) ? -1e17 : 1e17;
  }
  scale *= 255;

  vtkm::cont::Invoker invoke(vtkm::cont::DeviceAdapterTagAny{});
  auto alpha = vtkm::Clamp(self->GetAlpha(), 0.0, 1.0);

  switch (outputFormat)
  {
  case VTK_RGB:
    {
      auto op = AsArrayHandleBasic<vtkm::Vec<vtkm::UInt8, 3>>(output, input.GetNumberOfValues());
      invoke(vtkm::worklet::colorconversion::ShiftScaleToRGB(shift, scale), input, op);
      op.SyncControlArray();
      break;
    }
  case VTK_RGBA:
    {
      auto op = AsArrayHandleBasic<vtkm::Vec<vtkm::UInt8, 4>>(output, input.GetNumberOfValues());
      invoke(vtkm::worklet::colorconversion::ShiftScaleToRGBA(shift, scale, alpha), input, op);
      op.SyncControlArray();
      break;
    }
  default: // unsupported format, should be checked for and reported earlier
    break;
  }
}

template <typename T, vtkm::IdComponent N>
struct TransformToVec
{
  template <typename PortalType>
  VTKM_EXEC_CONT auto operator()(const vtkm::internal::RecombineVec<PortalType>& rv) const
  {
    vtkm::Vec<T, N> v;
    rv.CopyInto(v);
    return v;
  }
};

template <typename T>
void MapScalarsColorsToColors(vtkmLookupTable* self, const vtkm::cont::ArrayHandleRecombineVec<T>& input, vtkm::UInt8* output, int outputFormat)
{
  assert(input.GetNumberOfComponents() <= 4);

  // ShiftScaleToRGB and ShiftScaleToRGBA do not work with `RecombineVec` type
  switch (input.GetNumberOfComponents())
  {
    case 1:
      MapScalarsColorsToColors(self, input.GetComponentArray(0), output, outputFormat);
      break;
    case 2:
      MapScalarsColorsToColors(self, vtkm::cont::make_ArrayHandleTransform(input, TransformToVec<T, 2>{}), output, outputFormat);
      break;
    case 3:
      MapScalarsColorsToColors(self, vtkm::cont::make_ArrayHandleTransform(input, TransformToVec<T, 3>{}), output, outputFormat);
      break;
    case 4:
      MapScalarsColorsToColors(self, vtkm::cont::make_ArrayHandleTransform(input, TransformToVec<T, 4>{}), output, outputFormat);
      break;
    default:
      break;
  }
}

template <typename Functor, typename... Args>
void ExtractComponentAndCall(const vtkm::cont::UnknownArrayHandle& input,
                             vtkm::IdComponent component,
                             Functor&& f,
                             Args&&... args)
{
  if (component == 0 && 1 == input.GetNumberOfComponents())
  {
    input.CastAndCallForTypes<tovtkm::VTKScalarTypes, VTKM_DEFAULT_STORAGE_LIST>(f, std::forward<Args>(args)...);
    return;
  }

  bool success = false;
  vtkm::ListForEach(
    [&](auto inst) {
      using T = decltype(inst);
      if (!success && input.IsBaseComponentType<T>())
      {
        f(input.ExtractComponent<T>(component, vtkm::CopyFlag::On), std::forward<Args>(args)...);
        success = true;
      }},
    tovtkm::VTKScalarTypes{});

  if (!success)
  {
    std::ostringstream out;
    out << "BaseComponentType of input not found in type list";
    input.PrintSummary(out);
    out << "TypeList: " << vtkm::cont::TypeToString(typeid(tovtkm::VTKScalarTypes)) << "\n";
    throw vtkm::cont::ErrorBadType(out.str());
  }
}

template <typename Functor, typename... Args>
void ExtractArrayFromComponentsAndCall(const vtkm::cont::UnknownArrayHandle& input,
                                       vtkm::IdComponent componentStart,
                                       vtkm::IdComponent numComponents,
                                       Functor&& f,
                                       Args&&... args)
{
  if (componentStart == 0 && numComponents == input.GetNumberOfComponents())
  {
    input.CastAndCallForTypes<VTKM_DEFAULT_TYPE_LIST, VTKM_DEFAULT_STORAGE_LIST>(f, std::forward<Args>(args)...);
    return;
  }

  bool success = false;
  vtkm::ListForEach(
    [&](auto inst) {
      using T = decltype(inst);
      if (!success && input.IsBaseComponentType<T>())
      {
        vtkm::cont::ArrayHandleRecombineVec<T> comps;
        for (vtkm::IdComponent i = 0, c = componentStart; i < numComponents; ++i, ++c)
        {
          comps.AppendComponentArray(input.ExtractComponent<T>(c, vtkm::CopyFlag::On));
        }
        f(comps, std::forward<Args>(args)...);
        success = true;
      }},
    tovtkm::VTKScalarTypes{});

  if (!success)
  {
    std::ostringstream out;
    out << "BaseComponentType of input not found in type list";
    input.PrintSummary(out);
    out << "TypeList: " << vtkm::cont::TypeToString(typeid(tovtkm::VTKScalarTypes)) << "\n";
    throw vtkm::cont::ErrorBadType(out.str());
  }
}

} // namespace

//------------------------------------------------------------------------------------------------
VTK_ABI_NAMESPACE_BEGIN
vtkUnsignedCharArray* vtkmLookupTable::MapScalars(
  vtkAbstractArray* scalars, int colorMode, int component, int outputFormat)
{
  auto* dataArray = vtkArrayDownCast<vtkDataArray>(scalars);
  if (!dataArray)
  {
    vtkLogF(INFO, "scalars is not type vtkDataArray");
    return this->Superclass::MapScalars(scalars, colorMode, component, outputFormat);
  }

  // The following case results in the input array itself being returned as the output.
  // Let the Superclass handle it.
  if ((colorMode == VTK_COLOR_MODE_DEFAULT || colorMode == VTK_COLOR_MODE_DIRECT_SCALARS) &&
      vtkArrayDownCast<vtkUnsignedCharArray>(dataArray) != nullptr &&
      dataArray->GetNumberOfComponents() == 4 &&
      this->Alpha >= 1.0)
  {
    return this->Superclass::MapScalars(scalars, colorMode, component, outputFormat);
  }

  if (dataArray->GetNumberOfTuples() == 0)
  {
    return vtkUnsignedCharArray::New();
  }

  auto* colorArray = vtkUnsignedCharArray::New();
  try
  {
    // we don't currently support IndexedLookup
    if (this->IndexedLookup)
    {
      throw vtkm::cont::ErrorFilterExecution("`IndexedLookup` is not supported");
    }

    // these output formats are currently unsupported
    if (outputFormat == VTK_LUMINANCE || outputFormat == VTK_LUMINANCE_ALPHA)
    {
      throw vtkm::cont::ErrorFilterExecution(
        "only VTK_RGB and VTK_RGBA are supported for outputFormat");
    }

    auto input = tovtkm::Convert(dataArray);

    if ((colorMode == VTK_COLOR_MODE_DEFAULT && input.IsBaseComponentType<vtkm::UInt8>()) ||
        colorMode == VTK_COLOR_MODE_DIRECT_SCALARS)
    {
      // Treat the array values as colors, we don't map through the table but conversion may
      // be needed
      colorArray->SetNumberOfComponents(4);
      colorArray->SetNumberOfTuples(dataArray->GetNumberOfTuples());
      auto output = AsArrayHandleBasic<vtkm::Vec<vtkm::UInt8, 4>>(colorArray->GetPointer(0), colorArray->GetNumberOfTuples());

      input.CastAndCallForTypesWithFloatFallback<VTKM_DEFAULT_TYPE_LIST, VTKM_DEFAULT_STORAGE_LIST>(
        [&](const auto& array){
          vtkm::cont::Invoker invoke(vtkm::cont::DeviceAdapterTagAny{});
          auto worklet = vtkm::worklet::colorconversion::ConvertToRGBA(vtkm::Clamp(this->Alpha, 0.0, 1.0));
          invoke(worklet, array, output);});
      output.SyncControlArray();
    }
    else
    {
      colorArray->SetNumberOfComponents(outputFormat);
      colorArray->SetNumberOfTuples(dataArray->GetNumberOfTuples());

      auto mapScalarsLookupTableFunctor = [&](const auto& array) {
        MapScalarsLookupTable(this, array, colorArray->GetPointer(0), outputFormat); };

      if (component < 0 && input.GetNumberOfComponents() > 1) // vector mode
      {
        int vectorComponent = std::max(0, std::min(this->GetVectorComponent(), input.GetNumberOfComponents() - 1));

        auto vectorMode = this->GetVectorMode();
        if (vectorMode == vtkScalarsToColors::COMPONENT)
        {
          ExtractComponentAndCall(input, vectorComponent, mapScalarsLookupTableFunctor);
        }
        else
        {
          int vectorSize = this->GetVectorSize();
          if (vectorSize == -1) // unsepecified
          {
            vectorComponent = 0;
            vectorSize = input.GetNumberOfComponents();
          }
          vectorSize = std::max(1, std::min(vectorSize, input.GetNumberOfComponents() - vectorComponent));

          if (vectorMode == vtkScalarsToColors::MAGNITUDE)
          {
            if (vectorSize == 1)
            {
              ExtractComponentAndCall(input, vectorComponent, mapScalarsLookupTableFunctor);
            }
            else
            {
              ExtractArrayFromComponentsAndCall(
                input, vectorComponent, vectorSize, [&](const auto& array) {
                  auto magArray = vtkm::cont::make_ArrayHandleTransform(array, MagnitudeTransform{});
                  MapScalarsLookupTable(this, magArray, colorArray->GetPointer(0), outputFormat);});
            }
          }
          else if (vectorMode == vtkScalarsToColors::RGBCOLORS)
          {
            vectorSize = std::min(4, vectorSize);
            ExtractArrayFromComponentsAndCall(
              input, vectorComponent, vectorSize, [&](const auto& array) {
                MapScalarsColorsToColors(this, array, colorArray->GetPointer(0), outputFormat);});
          }
        }
      }
      else
      {
        component = std::max(0, std::min(component, input.GetNumberOfComponents()));
        ExtractComponentAndCall(input, component, mapScalarsLookupTableFunctor);
      }
    }
  }
  catch (const vtkm::cont::Error& e)
  {
    vtkWarningMacro(<< "vtkmLookupTable encountered an error: " << e.GetMessage() << "\n"
                    << "Falling back to vtkLookupTable.");
    return this->Superclass::MapScalars(scalars, colorMode, component, outputFormat);
  }

  return colorArray;
}

void vtkmLookupTable::MapScalarsThroughTable2(void* input, unsigned char* output, int inputDataType,
    int numberOfValues, int inputIncrement, int outputFormat)
{
  vtkLogF(INFO, "vtkmLookupTable::MapScalarsThroughTable2\n");

  // unsupported cases
  if ((this->IndexedLookup) || (outputFormat == VTK_LUMINANCE || outputFormat == VTK_LUMINANCE_ALPHA))
  {
    vtkLogF(INFO, "unsupported case, fallback to vtkLookupTable::MapScalarsThroughTable2\n");
    this->Superclass::MapScalarsThroughTable2(input, output, inputDataType, numberOfValues, inputIncrement, outputFormat);
  }

  switch (inputDataType)
  {
    vtkTemplateMacro(
      auto basicArray = vtkm::cont::make_ArrayHandle(static_cast<VTK_TT*>(input), numberOfValues * inputIncrement, vtkm::CopyFlag::Off);
      vtkm::cont::ArrayHandleStride<VTK_TT> strideArray(basicArray, numberOfValues, inputIncrement, 0);
      MapScalarsLookupTable(this, strideArray, output, outputFormat);
      basicArray.SyncControlArray();
    );
  }
}
VTK_ABI_NAMESPACE_END
