// SPDX-FileCopyrightText: Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
// SPDX-License-Identifier: BSD-3-Clause

#ifndef vtkWebGPUComputePipeline_h
#define vtkWebGPUComputePipeline_h

#include "vtkObject.h"
#include "vtkWGPUContext.h"                    // for requesting device / adapter
#include "vtkWebGPUComputeBuffer.h"            // for compute buffers used by the pipeline
#include "vtkWebGPUComputeRenderBuffer.h"      // for compute render buffers used by the pipeline
#include "vtkWebGPUInternalsBindGroup.h"       // for bind group utilitary methods
#include "vtkWebGPUInternalsBindGroupLayout.h" // for bind group layouts utilitary methods
#include "vtkWebGPUInternalsBuffer.h"          // for internal buffer utils
#include "vtkWebGPUInternalsShaderModule.h"    // for wgpu::ShaderModule

#include <list>
#include <unordered_map>

VTK_ABI_NAMESPACE_BEGIN

class vtkWebGPURenderWindow;
class vtkWebGPURenderer;

/**
 * This class is an abstraction for offloading computation from the CPU onto the GPU using WebGPU
 * compute shaders.
 *
 * The basic usage of a pipeline outside a rendering pipeline is:
 *  - Create a pipeline
 *  - Set its shader source code
 *  - Set its shader entry point
 *  - Create the vtkWebGPUComputeBuffers that contain the data manipulated by the compute shader
 *  - Add the buffers to the pipeline
 *  - Set the number of workgroups
 *  - Dispatch the compute shader
 *  - Update()
 *  - ReadBufferFromGPU() to make results from the GPU available to the CPU
 *
 * Integrated into a rendering pipeline, the only difference in the usage of the class is going to
 * be the creation of the buffers. You will not create the vtkWebGPUComputeBuffer yourself but
 * rather acquire one (or many) by calling AcquirePointAttributeComputeRenderBuffer() on a
 * vtkWebGPURenderer. The returned buffers can then be added to the pipeline with AddRenderBuffer().
 * Other steps are identical.
 */
class VTKRENDERINGWEBGPU_EXPORT vtkWebGPUComputePipeline : public vtkObject
{
public:
  vtkTypeMacro(vtkWebGPUComputePipeline, vtkObject);
  static vtkWebGPUComputePipeline* New();
  void PrintSelf(ostream& os, vtkIndent indent) override;

  /*
   * Callback called when the asynchronous mapping of a buffer is done
   * and data ready to be copied.
   * This callback takes three parameters:
   * - A first void pointer to the data mapped from the GPU ready to be copied
   * - A second void pointer pointing to user data, which can essentially be anything
   *      needed by the callback to copy the data to the CPU
   */
  using MapAsyncCallback = std::function<void(const void*, void*)>;

  ///@{
  /**
   * Set/get the WGSL source of the shader
   */
  vtkGetMacro(ShaderSource, std::string);
  vtkSetMacro(ShaderSource, std::string);
  ///@}

  void SetShaderSourceFromPath(const char* shaderFilePath);

  ///@{
  /**
   * Set/get the entry (name of the function) of the WGSL compute shader
   */
  vtkGetMacro(ShaderEntryPoint, std::string);
  vtkSetMacro(ShaderEntryPoint, std::string);
  ///@}

  ///@{
  /**
   * Set/get the label of the compute pipeline.
   * This label will be printed along with error/warning logs
   * to help with debugging
   */
  vtkGetMacro(Label, std::string);
  void SetLabel(const std::string& label);
  ///@}

  /**
   * Adds a buffer to the pipeline and uploads its data to the device.
   *
   * Returns the index of the buffer that can for example be used as input to the
   * ReadBufferFromGPU function
   */
  int AddBuffer(vtkSmartPointer<vtkWebGPUComputeBuffer> buffer);

  /**
   * Adds a render buffer to the pipeline. A render buffer can be obtained from
   * vtkWebGPUPolyDataMapper::AcquirePointXXXXRenderBuffer() or
   * vtkWebGPUPolyDataMapper::AcquireCellXXXXRenderBuffer()
   */
  void AddRenderBuffer(vtkSmartPointer<vtkWebGPUComputeRenderBuffer> renderBuffer);

  /**
   * Resizes a buffer of the pipeline.
   *
   * @warning: After the resize, the data of the buffer is undefined and should be updated by a call
   * to UpdateBufferData()
   */
  void ResizeBuffer(int bufferIndex, vtkIdType newByteSize);

  /*
   * This function maps the buffer, making it accessible to the CPU. This is
   * an asynchronous operation, meaning that the given callback will be called
   * when the mapping is done.
   *
   * The buffer data can then be read from the callback and stored
   * in a buffer passed in via the userdata pointer for example
   */
  void ReadBufferFromGPU(
    int bufferIndex, vtkWebGPUComputePipeline::MapAsyncCallback callback, void* userdata);

  /**
   * Updates the data of a buffer.
   * The given data is expected to be at most the size of the buffer.
   * If N bytes are given to update but the buffer size is > N, only the first N bytes
   * will be updated, the rest will remain unchanged.
   * The data is immediately available to the GPU (no call to Update() is necessary for this call to
   * take effect)
   *
   * @note: This method can be used even if the buffer was initially configured with std::vector
   * data and the given data can safely be destroyed directly after calling this function.
   *
   */
  template <typename T>
  void UpdateBufferData(int bufferIndex, const std::vector<T>& newData)
  {
    if (!CheckBufferIndex(bufferIndex, std::string("UpdataBufferData")))
      return;

    vtkWebGPUComputeBuffer* buffer = this->Buffers[bufferIndex];
    vtkIdType byteSize = buffer->GetByteSize();
    vtkIdType givenSize = newData.size() * sizeof(T);

    if (givenSize > byteSize)
    {
      vtkLog(ERROR,
        "std::vector data given to UpdateBufferData with index "
          << bufferIndex << " is too big. " << givenSize
          << "bytes were given but the buffer is only " << byteSize
          << " bytes long. No data was updated by this call.");

      return;
    }

    wgpu::Buffer wgpuBuffer = this->WGPUBuffers[bufferIndex];
    this->Device.GetQueue().WriteBuffer(wgpuBuffer, 0, newData.data(), newData.size() * sizeof(T));
  }

  /**
   * Similar to the overload without offset of this function.
   * The offset is used to determine where in the buffer to reupload data.
   * Useful when only a portion of the buffer needs to be reuploaded.
   */
  template <typename T>
  void UpdateBufferData(int bufferIndex, vtkIdType byteOffset, const std::vector<T>& data)
  {
    if (!CheckBufferIndex(bufferIndex, std::string("UpdataBufferData with offset")))
      return;

    vtkWebGPUComputeBuffer* buffer = this->Buffers[bufferIndex];
    vtkIdType byteSize = buffer->GetByteSize();
    vtkIdType givenSize = data.size() * sizeof(T);

    if (givenSize + byteOffset > byteSize)
    {
      vtkLog(ERROR,
        "std::vector data given to UpdateBufferData with index "
          << bufferIndex << " and offset " << byteOffset << " is too big. " << givenSize
          << "bytes and offset " << byteOffset << " were given but the buffer is only " << byteSize
          << " bytes long. No data was updated by this call.");

      return;
    }

    wgpu::Buffer wgpuBuffer = this->WGPUBuffers[bufferIndex];
    this->Device.GetQueue().WriteBuffer(
      wgpuBuffer, byteOffset, data.data(), data.size() * sizeof(T));
  }

  /**
   * Updates the data of a buffer with a vtkDataArray.
   * The given data is expected to be at most the size of the buffer.
   * If N bytes are given to update but the buffer size is > N, only the first N bytes
   * will be updated, the rest will remain unchanged.
   * The data is immediately available to the GPU (no call to Update() is necessary for this call
   * to take effect
   *
   * @note: This method can be used even if the buffer was initially configured with std::vector
   * data and the given data can safely be destroyed directly after calling this function.
   */
  void UpdateBufferData(int bufferIndex, vtkDataArray* newData);

  /**
   * Similar to the overload without offset of this function.
   * The offset is used to determine where in the buffer to reupload data.
   * Useful when only a portion of the buffer needs to be reuploaded.
   */
  void UpdateBufferData(int bufferIndex, vtkIdType byteOffset, vtkDataArray* newData);

  ///@{
  /*
   * Set/get the number of workgroups in each dimension that are used by each Dispatch() call.
   */
  void SetWorkgroups(int groupsX, int groupsY, int groupsZ);
  ///@}

  /**
   * Dispatch the compute shader with (X, Y, Z) = (groupX, groupsY, groupZ) groups
   */
  void Dispatch();

  /**
   * Executes WebGPU commands and callbacks. This method needs to be called at some point to allow
   * for the execution of WebGPU commands that have been submitted so far. A call to Dispatch() or
   * ReadBufferFromGPU() without a call to Update() will have no effect. Calling Dispatch() and then
   * ReadBufferFromGPU() and then Update() is valid. You do not need to call Update() after every
   * pipeline call. It can be called only once "at the end".
   */
  void Update();

protected:
  vtkWebGPUComputePipeline();

private:
  friend class vtkWebGPURenderWindow;
  friend class vtkWebGPURenderer;

  /**
   * Constructor that initializes the device and adapter
   */
  vtkWebGPUComputePipeline(const vtkWebGPUComputePipeline&) = delete;
  void operator=(const vtkWebGPUComputePipeline&) = delete;

  /**
   * Given a buffer, create the associated bind group layout entry
   * that will be used when creating the bind group layouts
   */
  void AddBindGroupLayoutEntry(
    uint32_t bindGroup, uint32_t binding, vtkWebGPUComputeBuffer::BufferMode mode);

  /**
   * Given a buffer, create the associated bind group entry
   * that will be used when creating the bind groups
   */
  void AddBindGroupEntry(wgpu::Buffer buffer, uint32_t bindGroup, uint32_t binding,
    vtkWebGPUComputeBuffer::BufferMode mode, uint32_t offset);

  /**
   * Initializes the adapter of the compute pipeline
   */
  void CreateAdapter();

  /*
   * Sets the adapter. Useful to reuse an already existing adapter.
   */
  void SetAdapter(wgpu::Adapter adapter) { this->Adapter = adapter; }

  /*
   * Sets the device. Useful to reuse an already existing device.
   */
  void SetDevice(wgpu::Device device) { this->Device = device; }

  /**
   * Initializes the device of the compute pipeline
   */
  void CreateDevice();

  /**
   * Compiles the shader source given into a WGPU shader module
   */
  void CreateShaderModule();

  /**
   * Creates all the bind groups and bind group layouts of this compute pipeline from the buffers
   * that have been added so far.
   */
  void CreateBindGroupsAndLayouts();

  /**
   * Creates the bind group layout of a given list of buffers (that must all belong to the same bind
   * group)
   */
  static wgpu::BindGroupLayout CreateBindGroupLayout(
    const wgpu::Device& device, const std::vector<wgpu::BindGroupLayoutEntry>& layoutEntries);

  /**
   * Creates the bind group entries given a list of buffers
   */
  std::vector<wgpu::BindGroupEntry> CreateBindGroupEntries(
    const std::vector<vtkWebGPUComputeBuffer*>& buffers);

  /**
   * Checks if a given index is suitable for indexing this->Buffers. Logs an error if the index is
   * negative or greater than the number of buffer of the pipeline. The callerFunctionName parameter
   * is using to give more information on what function used an invalid buffer index
   *
   * Returns true if the buffer index is valid, false if it's not.
   */
  bool CheckBufferIndex(int bufferIndex, const std::string& callerFunctionName);

  /**
   * Makes some various (and obvious) checks to ensure that the buffer is ready to be created.
   *
   * Returns true if the buffer is correct.
   * If the buffer is incorrect, returns false and logs the error with the ERROR verbosity
   */
  bool CheckBufferCorrectness(vtkWebGPUComputeBuffer* buffer, const char* bufferLabel);

  /**
   * Destroys and recreates a buffer with the given newByteSize
   * Only the wgpu::Buffer object is recreated so the binding/group of the group doesn't change
   */
  void RecreateBuffer(int bufferIndex, vtkIdType newByteSize);

  /**
   * After recreating a wgpu::Buffer, the bind group entry (and the bind group) will need to be
   * updated. This
   */
  void RecreateBufferBindGroup(int bufferIndex);

  /**
   * Binds the buffer to the pipeline at the WebGPU level.
   * To use once the buffer has been properly set up.
   */
  void SetupRenderBuffer(vtkWebGPUComputeRenderBuffer* renderBuffer);

  /**
   * Creates the compute pipeline that will be used to dispatch the compute shader
   */
  void CreateComputePipeline();

  /**
   * Creates the compute pipeline layout associated with the bind group layouts of this compute
   * pipeline
   *
   * @warning: The bind group layouts must have been created by CreateBindGroups() prior to calling
   * this function
   */
  wgpu::PipelineLayout CreateComputePipelineLayout();

  /**
   * Creates and returns a command encoder
   */
  wgpu::CommandEncoder CreateCommandEncoder();

  /**
   * Creates a compute pass encoder from a command encoder
   */
  wgpu::ComputePassEncoder CreateComputePassEncoder(const wgpu::CommandEncoder& commandEncoder);

  /**
   * Encodes the compute pass and dispatches the workgroups
   *
   * @warning: The bind groups and the compute pipeline must have been created prior to calling this
   * function
   */
  void DispatchComputePass(unsigned int groupsX, unsigned int groupsY, unsigned int groupsZ);

  /**
   * Finishes the encoding of a command encoder and submits the resulting command buffer
   * to the queue
   */
  void SubmitCommandEncoderToQueue(const wgpu::CommandEncoder& commandEncoder);

  /**
   * Internal method used to convert the user friendly BufferMode to the internal enum
   * wgpu::BufferUsage
   */
  static wgpu::BufferUsage ComputeBufferModeToBufferUsage(vtkWebGPUComputeBuffer::BufferMode mode);

  /**
   * Internal method used to convert the user friendly BufferMode to the internal enum
   * wgpu::BufferBindingType
   */
  static wgpu::BufferBindingType ComputeBufferModeToBufferBindingType(
    vtkWebGPUComputeBuffer::BufferMode mode);

  bool Initialized = false;

  wgpu::Adapter Adapter = nullptr;
  wgpu::Device Device = nullptr;
  wgpu::ShaderModule ShaderModule;
  // A list of the bind group index in which bind groups are stored in this->BindGroups. If
  // BindGroupsOrder[0] = 1, this means that this->BindGroups[0] correspond to the bind group of
  // index 1 (@group(1) in WGSL).
  // This list is going to be useful when we're going to want to update a bind group (after resizing
  // a buffer for example when we need to find which bind group in the list to recreate).
  std::vector<int> BindGroupsOrder;
  // List of the bind groups, used to set the bind groups of the pipeline at each dispatch
  std::vector<wgpu::BindGroup> BindGroups;
  // Maps a bind group index to to the list of bind group entries for this group. These
  // entries will be used at the creation of the bind groups
  std::unordered_map<int, std::vector<wgpu::BindGroupEntry>> BindGroupEntries;
  std::vector<wgpu::BindGroupLayout> BindGroupLayouts;
  // Maps a bind group index to to the list of bind group layout entries for this group.
  // These layout entries will be used at the creation of the bind group layouts
  std::unordered_map<int, std::vector<wgpu::BindGroupLayoutEntry>> BindGroupLayoutEntries;
  wgpu::ComputePipeline ComputePipeline;

  std::vector<vtkSmartPointer<vtkWebGPUComputeBuffer>> Buffers;
  std::vector<wgpu::Buffer> WGPUBuffers;

  /**
   * Render buffers use already existing wgpu buffers (those of poly data mappers for example) and
   * thus need to be handled differently
   */
  std::vector<vtkSmartPointer<vtkWebGPUComputeRenderBuffer>> RenderBuffers;

  std::string ShaderSource;
  std::string ShaderEntryPoint;

  // How many groups to launch when dispatching the compute
  unsigned int GroupsX = 0, GroupsY = 0, GroupsZ = 0;

  // Label used for debugging
  std::string Label = "VTK Compute pipeline";
  // Label used for the wgpu compute pipeline of this VTK compute pipeline
  std::string WGPUComputePipelineLabel = "WebGPU compute pipeline of \"VTK Compute pipeline\"";
  // Label used for the wgpu command encoders created and used by this VTK compute pipeline
  std::string WGPUCommandEncoderLabel = "WebGPU command encoder of \"VTK Compute pipeline\"";
};

VTK_ABI_NAMESPACE_END

#endif
