/*=========================================================================

  Program:   Visualization Toolkit

  Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
  All rights reserved.
  See Copyright.txt or http://www.kitware.com/Copyright.htm for details.

     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notice for more information.

=========================================================================*/
#include "vtkDirect3DShaderProgram.h"
#include "vtkObjectFactory.h"

#include "vtkDirect3DConstantBufferObject.h"
#include "vtkDirect3DShader.h"
#include "vtkMatrix3x3.h"
#include "vtkMatrix4x4.h"
#include "vtkDirect3DRenderWindow.h"
#include "vtkDirect3DShaderCache.h"
#include "vtkTypeTraits.h"

#include <d3dcompiler.h>

# include <sstream>

#pragma comment (lib, "D3DCompiler.lib")


typedef std::map<const char *, int, vtkDirect3DShaderProgram::cmp_str>::iterator IterT;

vtkStandardNewMacro(vtkDirect3DShaderProgram)

vtkCxxSetObjectMacro(vtkDirect3DShaderProgram,VertexShader,vtkDirect3DShader)
vtkCxxSetObjectMacro(vtkDirect3DShaderProgram,FragmentShader,vtkDirect3DShader)
vtkCxxSetObjectMacro(vtkDirect3DShaderProgram,GeometryShader,vtkDirect3DShader)

vtkDirect3DShaderProgram::vtkDirect3DShaderProgram()
{
  this->VertexShader = vtkDirect3DShader::New();
  this->VertexShader->SetType(vtkDirect3DShader::Vertex);
  this->FragmentShader = vtkDirect3DShader::New();
  this->FragmentShader->SetType(vtkDirect3DShader::Fragment);
  this->GeometryShader = vtkDirect3DShader::New();
  this->GeometryShader->SetType(vtkDirect3DShader::Geometry);

  this->Compiled = false;
  this->NumberOfOutputs = 0;
  this->VertexShaderHandle = 0;
  this->VertexShaderBlob = nullptr;

  this->FragmentShaderHandle = 0;
  this->FragmentShaderBlob = nullptr;

//  this->GeometryShaderHandle = 0;
  this->Linked = false;
  this->Bound = false;
}

vtkDirect3DShaderProgram::~vtkDirect3DShaderProgram()
{
  this->ClearMaps();
  if (this->VertexShader)
  {
    this->VertexShader->Delete();
    this->VertexShader = NULL;
  }
  if (this->FragmentShader)
  {
    this->FragmentShader->Delete();
    this->FragmentShader = NULL;
  }
  if (this->GeometryShader)
  {
    this->GeometryShader->Delete();
    this->GeometryShader = NULL;
  }
  if (this->VertexShaderBlob)
  {
    this->VertexShaderBlob->Release();
    this->VertexShaderBlob = nullptr;
  }
  if (this->FragmentShaderBlob)
  {
    this->FragmentShaderBlob->Release();
    this->FragmentShaderBlob = nullptr;
  }
}

// Process the string, and return a version with replacements.
bool vtkDirect3DShaderProgram::Substitute(std::string &source, const std::string &search,
             const std::string &replace, bool all)
{
  std::string::size_type pos = 0;
  bool replaced = false;
  while ((pos = source.find(search, pos)) != std::string::npos)
  {
    source.replace(pos, search.length(), replace);
    if (!all)
    {
      return true;
    }
    pos += replace.length();
    replaced = true;
  }
  return replaced;
}

bool vtkDirect3DShaderProgram::AttachShader(const vtkDirect3DShader *shader)
{
  if (shader->GetType() == vtkDirect3DShader::Unknown)
  {
    this->Error = "Shader object is of type Unknown and cannot be used.";
    return false;
  }

  this->Linked = false;
  return true;
}

bool vtkDirect3DShaderProgram::DetachShader(const vtkDirect3DShader *)
{
  return false;
}

void vtkDirect3DShaderProgram::ClearMaps()
{
  for (IterT i = this->UniformLocs.begin(); i != this->UniformLocs.end(); i++)
  {
    free(const_cast<char *>(i->first));
  }
  this->UniformLocs.clear();
  for (IterT i = this->AttributeLocs.begin(); i != this->AttributeLocs.end(); i++)
  {
    free(const_cast<char *>(i->first));
  }
  this->AttributeLocs.clear();
}

bool vtkDirect3DShaderProgram::Link()
{
  if (this->Linked)
  {
    return true;
  }

  // clear out the list of uniforms used
  this->ClearMaps();

  this->Linked = true;
  return true;
}

bool vtkDirect3DShaderProgram::Bind()
{
  if (!this->Linked && !this->Link())
  {
    return false;
  }

  this->Context->GetImmediateContext()->VSSetShader( this->VertexShaderHandle, nullptr, 0 );
  this->Context->GetImmediateContext()->PSSetShader( this->FragmentShaderHandle, nullptr, 0 );
  this->Bound = true;
  return true;
}

void vtkDirect3DShaderProgram::BindConstantBuffers()
{
  this->VertexShader->BindConstantBuffer();
  this->FragmentShader->BindConstantBuffer();
}

// return 0 if there is an issue
int vtkDirect3DShaderProgram::CompileShader()
{
 HRESULT hr = S_OK;

  DWORD dwShaderFlags = D3DCOMPILE_ENABLE_STRICTNESS;
#ifdef _DEBUG
  // Set the D3DCOMPILE_DEBUG flag to embed debug information in the shaders.
  // Setting this flag improves the shader debugging experience, but still allows
  // the shaders to be optimized and to run exactly the way they will run in
  // the release configuration of this program.
  dwShaderFlags |= D3DCOMPILE_DEBUG;

  // Disable optimizations to further improve shader debugging
  dwShaderFlags |= D3DCOMPILE_SKIP_OPTIMIZATION;
#endif

  this->VertexShader->SetContext(this->Context);
  ID3DBlob* pErrorBlob = nullptr;
  hr = D3DCompile(this->VertexShader->GetSource().c_str(),
         this->VertexShader->GetSource().size(), NULL, NULL, NULL,
         "VS",
         "vs_4_0",
         dwShaderFlags, 0, &this->VertexShaderBlob, &pErrorBlob );

  if( FAILED(hr) )
  {
    if( pErrorBlob )
    {
      OutputDebugStringA( reinterpret_cast<const char*>( pErrorBlob->GetBufferPointer() ) );
      pErrorBlob->Release();
    }
    return 0;
  }

  // Create the vertex shader
  hr = this->Context->GetD3DDevice()->CreateVertexShader(
    this->VertexShaderBlob->GetBufferPointer(),
    this->VertexShaderBlob->GetBufferSize(),
    nullptr, &this->VertexShaderHandle );
  if( FAILED( hr ) )
  {
    this->VertexShaderBlob->Release();
    return 0;
  }

  this->FragmentShader->SetContext(this->Context);
  hr = D3DCompile(this->FragmentShader->GetSource().c_str(),
         this->FragmentShader->GetSource().size(), NULL, NULL, NULL,
         "FS",
         "ps_4_0",
         dwShaderFlags, 0, &this->FragmentShaderBlob, &pErrorBlob );

  if( FAILED(hr) )
  {
    if( pErrorBlob )
    {
      OutputDebugStringA( reinterpret_cast<const char*>( pErrorBlob->GetBufferPointer() ) );
      pErrorBlob->Release();
    }

    return 0;
  }

  // Create the fragment shader
  hr = this->Context->GetD3DDevice()->CreatePixelShader(
    this->FragmentShaderBlob->GetBufferPointer(),
    this->FragmentShaderBlob->GetBufferSize(),
    nullptr, &this->FragmentShaderHandle );
  if( FAILED( hr ) )
  {
    this->FragmentShaderBlob->Release();
    return 0;
  }

  if (!this->AttachShader(this->GetVertexShader()))
  {
    vtkErrorMacro(<< this->GetError());
    return 0;
  }
  if (!this->AttachShader(this->GetFragmentShader()))
  {
    vtkErrorMacro(<< this->GetError());
    return 0;
  }

  if (!this->Link())
  {
    vtkErrorMacro(<< "Links failed: " << this->GetError());
    return 0;
  }

  this->Compiled = true;
  return 1;
}

void vtkDirect3DShaderProgram::Release()
{
  this->Bound = false;
}

void vtkDirect3DShaderProgram::ReleaseGraphicsResources(vtkWindow *win)
{
  this->Release();

  if (this->Compiled)
  {
    if( this->VertexShaderHandle )
    {
      this->VertexShaderHandle->Release();
      this->VertexShaderHandle = NULL;
    }
    if( this->FragmentShaderHandle )
    {
      this->FragmentShaderHandle->Release();
      this->FragmentShaderHandle = NULL;
    }
    if (this->VertexShaderBlob)
    {
      this->VertexShaderBlob->Release();
      this->VertexShaderBlob = nullptr;
    }
    if (this->FragmentShaderBlob)
    {
      this->FragmentShaderBlob->Release();
      this->FragmentShaderBlob = nullptr;
    }
    this->Compiled = false;
  }

  vtkDirect3DRenderWindow *renWin = vtkDirect3DRenderWindow::SafeDownCast(win);
  if (renWin && renWin->GetShaderCache()->GetLastShaderBound() == this)
  {
    renWin->GetShaderCache()->ClearLastShaderBound();
  }

  this->Linked = false;
}

// ----------------------------------------------------------------------------
void vtkDirect3DShaderProgram::PrintSelf(ostream& os, vtkIndent indent)
{
  this->Superclass::PrintSelf(os,indent);
}
