// SPDX-FileCopyrightText: Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
// SPDX-License-Identifier: BSD-3-Clause

#include "vtkWebGPUComputePipeline.h"
#include "vtkWebGPUInternalsCallbacks.h"
#include "vtkWebGPUInternalsComputeBuffer.h"
#include "vtksys/FStream.hxx"

VTK_ABI_NAMESPACE_BEGIN

vtkStandardNewMacro(vtkWebGPUComputePipeline);

/**
 * Structure used to pass data to the asynchronous callback of wgpu::Buffer.MapAsync()
 */
struct InternalMapBufferAsyncData
{
  // Buffer currently being mapped
  wgpu::Buffer buffer = nullptr;
  // Label of the buffer currently being mapped. Used for printing errors
  std::string bufferLabel;
  // Size of the buffer being mapped in bytes
  vtkIdType byteSize = -1;

  // The callback given by the user that will be called once the buffer is mapped. The user will
  // usually use their callback to copy the data from the mapped buffer into a CPU-side buffer that
  // will use the result of the compute shader in the rest of the application
  vtkWebGPUComputePipeline::MapAsyncCallback userCallback;
  // Userdata passed to userCallback. This is typically the structure that contains the CPU-side
  // buffer into which the data of the mapped buffer will be copied
  void* userdata = nullptr;
};

//------------------------------------------------------------------------------
vtkWebGPUComputePipeline::vtkWebGPUComputePipeline()
{
  this->CreateAdapter();
  this->CreateDevice();
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::PrintSelf(ostream& os, vtkIndent indent)
{
  this->Superclass::PrintSelf(os, indent);

  os << indent << "Initialized: " << this->Initialized << std::endl;

  os << indent << "Adapter: " << this->Adapter.Get() << std::endl;
  os << indent << "Device: " << this->Device.Get() << std::endl;
  os << indent << "ShaderModule: " << this->ShaderModule.Get() << std::endl;

  os << indent << this->BindGroups.size() << " binds groups: " << std::endl;
  for (const wgpu::BindGroup& bindGroup : this->BindGroups)
  {
    os << indent << "\t- " << bindGroup.Get() << std::endl;
  }

  os << indent << this->BindGroupEntries.size() << " binds group entries: " << std::endl;
  for (const auto& bindGroupEntry : this->BindGroupEntries)
  {
    os << indent << "\t Bind group " << bindGroupEntry.first << std::endl;
    os << indent << "\t (binding/buffer/offset/size)" << std::endl;
    for (wgpu::BindGroupEntry entry : bindGroupEntry.second)
    {
      os << indent << "\t- " << entry.binding << " / " << entry.buffer.Get() << " / "
         << entry.offset << " / " << entry.size << std::endl;
    }
  }

  os << indent << this->BindGroupLayouts.size() << " bind group layouts:" << std::endl;
  for (const wgpu::BindGroupLayout& bindGroupLayout : this->BindGroupLayouts)
  {
    os << indent << "\t- " << bindGroupLayout.Get() << std::endl;
  }

  os << indent << this->BindGroupLayoutEntries.size()
     << " binds group layouts entries: " << std::endl;
  for (const auto& bindLayoutGroupEntry : this->BindGroupLayoutEntries)
  {
    os << indent << "\t Bind group layout " << bindLayoutGroupEntry.first << std::endl;
    os << indent << "\t (binding/buffer type/visibility)" << std::endl;
    for (wgpu::BindGroupLayoutEntry entry : bindLayoutGroupEntry.second)
    {
      os << indent << "\t- " << entry.binding << " / " << static_cast<uint32_t>(entry.buffer.type)
         << " / " << static_cast<uint32_t>(entry.visibility) << std::endl;
    }
  }

  os << indent << "WGPU Compute pipeline: " << this->ComputePipeline.Get() << std::endl;

  os << indent << this->Buffers.size() << "buffers: " << std::endl;
  for (vtkWebGPUComputeBuffer* buffer : this->Buffers)
  {
    os << indent << "\t- " << buffer << std::endl;
  }

  os << indent << this->WGPUBuffers.size() << "WGPU Buffers:" << std::endl;
  for (wgpu::Buffer buffer : this->WGPUBuffers)
  {
    os << indent << "\t- " << buffer.Get() << std::endl;
  }

  /**
   * Render buffers use already existing wgpu buffers (those of poly data mappers for example) and
   * thus need to be handled differently
   */
  std::vector<vtkSmartPointer<vtkWebGPUComputeRenderBuffer>> RenderBuffers;

  std::string ShaderSource;
  std::string ShaderEntryPoint;

  // How many groups to launch when dispatching the compute
  unsigned int GroupsX = 0, GroupsY = 0, GroupsZ = 0;

  // Label used for debugging
  std::string Label = "VTK Compute pipeline";
  // Label used for the wgpu compute pipeline of this VTK compute pipeline
  std::string WGPUComputePipelineLabel = "WebGPU compute pipeline of \"VTK Compute pipeline\"";
  // Label used for the wgpu command encoders created and used by this VTK compute pipeline
  std::string WGPUCommandEncoderLabel = "WebGPU command encoder of \"VTK Compute pipeline\"";
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::SetShaderSourceFromPath(const char* shaderFilePath)
{
  if (!vtksys::SystemTools::FileExists(shaderFilePath))
  {
    vtkLogF(ERROR, "Given shader file path '%s' doesn't exist", shaderFilePath);

    return;
  }

  vtksys::ifstream inputFileStream(shaderFilePath);
  assert(inputFileStream);
  std::string source(
    (std::istreambuf_iterator<char>(inputFileStream)), std::istreambuf_iterator<char>());

  this->SetShaderSource(source);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::SetLabel(const std::string& label)
{
  this->Label = label;

  this->WGPUComputePipelineLabel = std::string("Compute pipeline of \"" + this->Label + "\"");
  this->WGPUCommandEncoderLabel = std::string("Command encoder of \"" + this->Label + "\"");
}

int vtkWebGPUComputePipeline::AddBuffer(vtkSmartPointer<vtkWebGPUComputeBuffer> buffer)
{
  // Giving the buffer a default label if it doesn't have one already
  if (buffer->GetLabel().empty())
  {
    buffer->SetLabel("Buffer " + std::to_string(this->Buffers.size()));
  }

  std::string bufferLabel = buffer->GetLabel();

  bool bufferCorrect = this->CheckBufferCorrectness(buffer, bufferLabel.c_str());
  if (!bufferCorrect)
  {
    return -1;
  }

  wgpu::Buffer wgpuBuffer =
    vtkWebGPUInternalsBuffer::CreateABuffer(this->Device, buffer->GetByteSize(),
      ComputeBufferModeToBufferUsage(buffer->GetMode()), false, bufferLabel.c_str());

  // Uploading from std::vector or vtkDataArray if one of the two is present
  if (buffer->GetDataPointer() != nullptr)
  {
    this->Device.GetQueue().WriteBuffer(
      wgpuBuffer, 0, buffer->GetDataPointer(), buffer->GetByteSize());
  }
  else if (buffer->GetDataArray() != nullptr)
  {
    vtkWebGPUInternalsComputeBuffer::UploadFromDataArray(
      this->Device, wgpuBuffer, buffer->GetDataArray());
  }

  // Adding the buffer to the lists
  this->Buffers.push_back(buffer);
  this->WGPUBuffers.push_back(wgpuBuffer);

  // Creating the layout entry and the bind group entry for this buffer. These entries will be used
  // later when creating the bind groups / bind group layouts
  this->AddBindGroupLayoutEntry(buffer->GetGroup(), buffer->GetBinding(), buffer->GetMode());
  this->AddBindGroupEntry(
    wgpuBuffer, buffer->GetGroup(), buffer->GetBinding(), buffer->GetMode(), 0);

  // Returning the index of the buffer
  return this->Buffers.size() - 1;
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::AddRenderBuffer(
  vtkSmartPointer<vtkWebGPUComputeRenderBuffer> renderBuffer)
{
  renderBuffer->SetAssociatedPipeline(this);

  this->Buffers.push_back(renderBuffer);
  this->WGPUBuffers.push_back(renderBuffer->GetWGPUBuffer());
  this->RenderBuffers.push_back(renderBuffer);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::ResizeBuffer(int bufferIndex, vtkIdType newByteSize)
{
  if (!CheckBufferIndex(bufferIndex, std::string("ResizeBuffer")))
  {
    return;
  }

  vtkWebGPUComputeBuffer* buffer = this->Buffers[bufferIndex];

  RecreateBuffer(bufferIndex, newByteSize);
  RecreateBufferBindGroup(bufferIndex);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::RecreateBuffer(int bufferIndex, vtkIdType newByteSize)
{
  vtkWebGPUComputeBuffer* buffer = this->Buffers[bufferIndex];

  // Updating the byte size
  buffer->SetByteSize(newByteSize);
  const char* bufferLabel = buffer->GetLabel().c_str();
  wgpu::BufferUsage bufferUsage = ComputeBufferModeToBufferUsage(buffer->GetMode());

  // Recreating the buffer
  this->WGPUBuffers[bufferIndex] = vtkWebGPUInternalsBuffer::CreateABuffer(
    this->Device, newByteSize, bufferUsage, false, bufferLabel);
}

void vtkWebGPUComputePipeline::RecreateBufferBindGroup(int bufferIndex)
{
  vtkWebGPUComputeBuffer* buffer = this->Buffers[bufferIndex];

  // We also need to recreate the bind group entry (and the bind group below) that corresponded to
  // this buffer.
  // We first need to find the bind group entry that corresponded to this buffer
  std::vector<wgpu::BindGroupEntry>& bgEntries = this->BindGroupEntries[buffer->GetGroup()];
  for (wgpu::BindGroupEntry& entry : bgEntries)
  {
    // We only need to check the binding because we already retrieved all the entries that
    // correspond to the group of the buffer
    if (entry.binding == buffer->GetBinding())
    {
      // Replacing the buffer by the one we just recreated
      entry.buffer = this->WGPUBuffers[bufferIndex];

      break;
    }
  }

  // Finding which bind group is the one to recreate
  int bindGroupIndex = -1;
  for (int i = 0; i < this->BindGroupsOrder.size(); i++)
  {
    if (this->BindGroupsOrder[i] == buffer->GetGroup())
    {
      bindGroupIndex = i;

      break;
    }
  }

  if (bindGroupIndex == -1)
  {
    // We couldn't find the bind group, something went wrong
    vtkLog(ERROR,
      "Unable to find the bind group to which the buffer of index" << bufferIndex << " belongs.");

    return;
  }

  // Recreating the right bind group
  this->BindGroups[bindGroupIndex] = vtkWebGPUInternalsBindGroup::MakeBindGroup(
    this->Device, this->BindGroupLayouts[bindGroupIndex], bgEntries);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::AddBindGroupLayoutEntry(
  uint32_t bindGroup, uint32_t binding, vtkWebGPUComputeBuffer::BufferMode mode)
{
  wgpu::BufferBindingType bindingType =
    vtkWebGPUComputePipeline::ComputeBufferModeToBufferBindingType(mode);

  vtkWebGPUInternalsBindGroupLayout::LayoutEntryInitializationHelper bglEntry{ binding,
    wgpu::ShaderStage::Compute, bindingType };

  this->BindGroupLayoutEntries[bindGroup].push_back(bglEntry);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::AddBindGroupEntry(wgpu::Buffer wgpuBuffer, uint32_t bindGroup,
  uint32_t binding, vtkWebGPUComputeBuffer::BufferMode mode, uint32_t offset)
{
  wgpu::BufferBindingType bindingType =
    vtkWebGPUComputePipeline::ComputeBufferModeToBufferBindingType(mode);

  vtkWebGPUInternalsBindGroup::BindingInitializationHelper bgEntry{ binding, wgpuBuffer, offset };

  this->BindGroupEntries[bindGroup].push_back(bgEntry.GetAsBinding());
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::ReadBufferFromGPU(
  int bufferIndex, vtkWebGPUComputePipeline::MapAsyncCallback callback, void* userdata)
{
  // We need a buffer that will hold the mapped data.
  // We cannot directly map the output buffer of the compute shader because
  // wgpu::BufferUsage::Storage is incompatible with wgpu::BufferUsage::MapRead. This is a
  // restriction of WebGPU. This means that we have to create a new buffer with the MapRead flag
  // that is not a Storage buffer, copy the storage buffer that we actually want to this new buffer
  // (that has the MapRead usage flag) and then map this buffer to the CPU.
  vtkIdType byteSize = this->Buffers[bufferIndex]->GetByteSize();
  wgpu::Buffer mappedBuffer = vtkWebGPUInternalsBuffer::CreateABuffer(this->Device, byteSize,
    wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, false, nullptr);

  // If we were to allocate this callbackData locally on the stack, it would be destroyed when going
  // out of scope (at the end of this function). The callback, called asynchronously would then be
  // refering to data that has been destroyed (since it was allocated locally). This is why we're
  // allocating it dynamically with a new
  InternalMapBufferAsyncData* internalCallbackData = new InternalMapBufferAsyncData;
  internalCallbackData->buffer = mappedBuffer;
  internalCallbackData->bufferLabel = this->Label;
  internalCallbackData->byteSize = byteSize;
  internalCallbackData->userCallback = callback;
  internalCallbackData->userdata = userdata;

  wgpu::CommandEncoder commandEncoder = this->CreateCommandEncoder();
  commandEncoder.CopyBufferToBuffer(
    this->WGPUBuffers[bufferIndex], 0, internalCallbackData->buffer, 0, byteSize);
  this->SubmitCommandEncoderToQueue(commandEncoder);

  auto internalCallback = [](WGPUBufferMapAsyncStatus status, void* wgpuUserData)
  {
    InternalMapBufferAsyncData* callbackData =
      reinterpret_cast<InternalMapBufferAsyncData*>(wgpuUserData);

    if (status == WGPUBufferMapAsyncStatus::WGPUBufferMapAsyncStatus_Success)
    {
      const void* mappedRange = callbackData->buffer.GetConstMappedRange(0, callbackData->byteSize);
      callbackData->userCallback(mappedRange, callbackData->userdata);

      callbackData->buffer.Unmap();
      // Freeing the callbackData structure as it was dynamically allocated
      delete callbackData;
    }
    else
    {
      vtkLogF(WARNING, "Could not map buffer '%s' with error status: %d",
        callbackData->bufferLabel.length() > 0 ? callbackData->bufferLabel.c_str() : "(nolabel)",
        status);

      delete callbackData;
    }
  };

  internalCallbackData->buffer.MapAsync(
    wgpu::MapMode::Read, 0, byteSize, internalCallback, internalCallbackData);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::UpdateBufferData(int bufferIndex, vtkDataArray* newData)
{
  if (!CheckBufferIndex(bufferIndex, std::string("UpdateBufferData")))
  {
    return;
  }

  vtkWebGPUComputeBuffer* buffer = this->Buffers[bufferIndex];
  vtkIdType byteSize = buffer->GetByteSize();
  vtkIdType givenSize = newData->GetNumberOfValues() * newData->GetDataTypeSize();

  if (givenSize > byteSize)
  {
    vtkLog(ERROR,
      "std::vector data given to UpdateBufferData with index "
        << bufferIndex << " is too big. " << givenSize << "bytes were given but the buffer is only "
        << byteSize << " bytes long. No data was updated by this call.");

    return;
  }

  wgpu::Buffer wgpuBuffer = this->WGPUBuffers[bufferIndex];

  vtkWebGPUInternalsComputeBuffer::UploadFromDataArray(this->Device, wgpuBuffer, newData);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::UpdateBufferData(
  int bufferIndex, vtkIdType byteOffset, vtkDataArray* newData)
{
  if (!CheckBufferIndex(bufferIndex, std::string("UpdateBufferData with offset")))
  {
    return;
  }

  vtkWebGPUComputeBuffer* buffer = this->Buffers[bufferIndex];
  vtkIdType byteSize = buffer->GetByteSize();
  vtkIdType givenSize = newData->GetNumberOfValues() * newData->GetDataTypeSize();

  if (givenSize + byteOffset > byteSize)
  {
    vtkLog(ERROR,
      "vtkDataArray data given to UpdateBufferData with index "
        << bufferIndex << " and offset " << byteOffset << " is too big. " << givenSize
        << "bytes and offset " << byteOffset << " were given but the buffer is only " << byteSize
        << " bytes long. No data was updated by this call.");

    return;
  }

  wgpu::Buffer wgpuBuffer = this->WGPUBuffers[bufferIndex];

  vtkWebGPUInternalsComputeBuffer::UploadFromDataArray(
    this->Device, wgpuBuffer, byteOffset, newData);
}

//------------------------------------------------------------------------------
bool vtkWebGPUComputePipeline::CheckBufferIndex(
  int bufferIndex, const std::string& callerFunctionName)
{
  if (bufferIndex < 0)
  {
    vtkLog(ERROR,
      "Negative bufferIndex given to "
        << callerFunctionName << ". Make sure to use an index that was returned by AddBuffer().");

    return false;
  }
  else if (bufferIndex >= this->Buffers.size())
  {
    vtkLog(ERROR,
      "Invalid bufferIndex given to "
        << callerFunctionName << ". Index was '" << bufferIndex << "' while there are "
        << this->Buffers.size()
        << " available buffers. Make sure to use an index that was returned by AddBuffer().");

    return false;
  }

  return true;
}

//------------------------------------------------------------------------------
bool vtkWebGPUComputePipeline::CheckBufferCorrectness(
  vtkWebGPUComputeBuffer* buffer, const char* bufferLabel)
{
  if (buffer->GetGroup() == -1)
  {
    vtkLogF(
      ERROR, "The group of the buffer with label \"%s\" hasn't been initialized", bufferLabel);
    return false;
  }
  else if (buffer->GetBinding() == -1)
  {
    vtkLogF(
      ERROR, "The binding of the buffer with label \"%s\" hasn't been initialized", bufferLabel);
    return false;
  }
  else if (buffer->GetByteSize() == 0)
  {
    vtkLogF(ERROR, "The buffer with label \"%s\" has a size of 0. Did you forget to set its size?",
      bufferLabel);
    return false;
  }
  else
  {
    // Checking that the buffer isn't already used
    for (vtkWebGPUComputeBuffer* existingBuffer : this->Buffers)
    {
      if (buffer->GetBinding() == existingBuffer->GetBinding() &&
        buffer->GetGroup() == existingBuffer->GetGroup())
      {
        vtkLog(ERROR,
          "The buffer with label" << bufferLabel << " is bound to binding " << buffer->GetBinding()
                                  << " but that binding is already used by buffer with label \""
                                  << buffer->GetLabel().c_str() << "\" in bind group "
                                  << buffer->GetGroup());

        return false;
      }
    }
  }

  return true;
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::SetWorkgroups(int groupsX, int groupsY, int groupsZ)
{
  this->GroupsX = groupsX;
  this->GroupsY = groupsY;
  this->GroupsZ = groupsZ;
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::Dispatch()
{
  if (!this->Initialized)
  {
    this->CreateShaderModule();
    this->CreateBindGroupsAndLayouts();
    this->CreateComputePipeline();

    this->Initialized = true;
  }

  this->DispatchComputePass(this->GroupsX, this->GroupsY, this->GroupsZ);
}

void vtkWebGPUComputePipeline::Update()
{
  // Waiting for the compute pipeline to complete all its work

  bool workDone = false;

  // clang-format off
  this->Device.GetQueue().OnSubmittedWorkDone([](WGPUQueueWorkDoneStatus, void* userdata)
  { 
    *static_cast<bool*>(userdata) = true; 
  }, &workDone);
  // clang-format on

  while (!workDone)
  {
    vtkWGPUContext::WaitABit();
  }
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::CreateAdapter()
{
  if (this->Adapter != nullptr)
  {
    // The adapter already exists, it must have been given by SetAdapter()
    return;
  }

#if defined(__APPLE__)
  wgpu::BackendType backendType = wgpu::BackendType::Metal;
#elif defined(_WIN32)
  wgpu::BackendType backendType = wgpu::BackendType::D3D12;
#else
  wgpu::BackendType backendType = wgpu::BackendType::Undefined;
#endif

  wgpu::RequestAdapterOptions adapterOptions;
  adapterOptions.backendType = backendType;
  adapterOptions.powerPreference = wgpu::PowerPreference::HighPerformance;
  this->Adapter = vtkWGPUContext::RequestAdapter(adapterOptions);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::CreateDevice()
{
  if (this->Device != nullptr)
  {
    // The device already exists, it must have been given by SetDevice()
    return;
  }

  wgpu::DeviceDescriptor deviceDescriptor;
  deviceDescriptor.nextInChain = nullptr;
  deviceDescriptor.deviceLostCallback = &vtkWebGPUInternalsCallbacks::DeviceLostCallback;
  deviceDescriptor.label = this->Label.c_str();
  this->Device = vtkWGPUContext::RequestDevice(this->Adapter, deviceDescriptor);
  this->Device.SetUncapturedErrorCallback(
    &vtkWebGPUInternalsCallbacks::UncapturedErrorCallback, nullptr);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::CreateShaderModule()
{
  this->ShaderModule =
    vtkWebGPUInternalsShaderModule::CreateFromWGSL(this->Device, this->ShaderSource);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::CreateBindGroupsAndLayouts()
{
  this->BindGroupLayouts.clear();
  this->BindGroups.clear();

  for (const auto& mapEntry : this->BindGroupLayoutEntries)
  {
    int bindGroup = mapEntry.first;

    const std::vector<wgpu::BindGroupLayoutEntry>& bglEntries =
      this->BindGroupLayoutEntries[bindGroup];
    const std::vector<wgpu::BindGroupEntry>& bgEntries = this->BindGroupEntries[bindGroup];

    this->BindGroupLayouts.push_back(CreateBindGroupLayout(this->Device, bglEntries));
    this->BindGroupsOrder.push_back(bindGroup);
    this->BindGroups.push_back(
      vtkWebGPUInternalsBindGroup::MakeBindGroup(this->Device, BindGroupLayouts.back(), bgEntries));
  }
}

//------------------------------------------------------------------------------
wgpu::BindGroupLayout vtkWebGPUComputePipeline::CreateBindGroupLayout(
  const wgpu::Device& device, const std::vector<wgpu::BindGroupLayoutEntry>& layoutEntries)
{
  wgpu::BindGroupLayout bgl =
    vtkWebGPUInternalsBindGroupLayout::MakeBindGroupLayout(device, layoutEntries);
  return bgl;
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::SetupRenderBuffer(vtkWebGPUComputeRenderBuffer* renderBuffer)
{
  this->AddBindGroupLayoutEntry(
    renderBuffer->GetGroup(), renderBuffer->GetBinding(), renderBuffer->GetMode());
  this->AddBindGroupEntry(renderBuffer->GetWGPUBuffer(), renderBuffer->GetGroup(),
    renderBuffer->GetBinding(), renderBuffer->GetMode(), 0);

  std::vector<unsigned int> uniformData = { renderBuffer->GetRenderBufferOffset(),
    renderBuffer->GetRenderBufferElementCount() };
  vtkNew<vtkWebGPUComputeBuffer> offsetSizeUniform;
  offsetSizeUniform->SetMode(vtkWebGPUComputeBuffer::BufferMode::UNIFORM_BUFFER);
  offsetSizeUniform->SetGroup(renderBuffer->GetRenderUniformsGroup());
  offsetSizeUniform->SetBinding(renderBuffer->GetRenderUniformsBinding());
  offsetSizeUniform->SetData(uniformData);

  this->AddBuffer(offsetSizeUniform);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::CreateComputePipeline()
{
  wgpu::ComputePipelineDescriptor computePipelineDescriptor;
  computePipelineDescriptor.compute.constantCount = 0;
  computePipelineDescriptor.compute.constants = nullptr;
  computePipelineDescriptor.compute.entryPoint = this->ShaderEntryPoint.c_str();
  computePipelineDescriptor.compute.module = this->ShaderModule;
  computePipelineDescriptor.compute.nextInChain = nullptr;
  computePipelineDescriptor.label = this->WGPUComputePipelineLabel.c_str();
  computePipelineDescriptor.layout = CreateComputePipelineLayout();

  this->ComputePipeline = this->Device.CreateComputePipeline(&computePipelineDescriptor);
}

//------------------------------------------------------------------------------
wgpu::PipelineLayout vtkWebGPUComputePipeline::CreateComputePipelineLayout()
{
  wgpu::PipelineLayoutDescriptor computePipelineLayoutDescriptor;
  computePipelineLayoutDescriptor.bindGroupLayoutCount = this->BindGroupLayouts.size();
  computePipelineLayoutDescriptor.bindGroupLayouts = this->BindGroupLayouts.data();
  computePipelineLayoutDescriptor.nextInChain = nullptr;

  return this->Device.CreatePipelineLayout(&computePipelineLayoutDescriptor);
}

//------------------------------------------------------------------------------
wgpu::CommandEncoder vtkWebGPUComputePipeline::CreateCommandEncoder()
{
  wgpu::CommandEncoderDescriptor commandEncoderDescriptor;
  commandEncoderDescriptor.label = this->WGPUCommandEncoderLabel.c_str();

  return this->Device.CreateCommandEncoder(&commandEncoderDescriptor);
}

//------------------------------------------------------------------------------
wgpu::ComputePassEncoder vtkWebGPUComputePipeline::CreateComputePassEncoder(
  const wgpu::CommandEncoder& commandEncoder)
{
  wgpu::ComputePassDescriptor computePassDescriptor;
  computePassDescriptor.nextInChain = nullptr;
  computePassDescriptor.timestampWrites = 0;
  return commandEncoder.BeginComputePass(&computePassDescriptor);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::DispatchComputePass(
  unsigned int groupsX, unsigned int groupsY, unsigned int groupsZ)
{
  if (groupsX * groupsY * groupsZ == 0)
  {
    vtkLogF(ERROR,
      "Invalid number of workgroups when dispatching compute pipeline \"%s\". Work groups sizes "
      "(X, Y, Z) were: (%d, %d, %d) but no dimensions can be 0.",
      this->Label.c_str(), groupsX, groupsY, groupsZ);

    return;
  }

  wgpu::CommandEncoder commandEncoder = this->CreateCommandEncoder();

  wgpu::ComputePassEncoder computePassEncoder = CreateComputePassEncoder(commandEncoder);
  computePassEncoder.SetPipeline(this->ComputePipeline);
  for (int bindGroupIndex = 0; bindGroupIndex < this->BindGroups.size(); bindGroupIndex++)
  {
    computePassEncoder.SetBindGroup(bindGroupIndex, this->BindGroups[bindGroupIndex], 0, nullptr);
  }
  computePassEncoder.DispatchWorkgroups(groupsX, groupsY, groupsZ);
  computePassEncoder.End();

  this->SubmitCommandEncoderToQueue(commandEncoder);
}

//------------------------------------------------------------------------------
void vtkWebGPUComputePipeline::SubmitCommandEncoderToQueue(
  const wgpu::CommandEncoder& commandEncoder)
{
  wgpu::CommandBuffer commandBuffer = commandEncoder.Finish();
  this->Device.GetQueue().Submit(1, &commandBuffer);
}

//------------------------------------------------------------------------------
wgpu::BufferUsage vtkWebGPUComputePipeline::ComputeBufferModeToBufferUsage(
  vtkWebGPUComputeBuffer::BufferMode mode)
{
  switch (mode)
  {
    case vtkWebGPUComputeBuffer::BufferMode::READ_ONLY_COMPUTE_STORAGE:
    case vtkWebGPUComputeBuffer::READ_WRITE_COMPUTE_STORAGE:
      return wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Storage;

    case vtkWebGPUComputeBuffer::BufferMode::READ_WRITE_MAP_COMPUTE_STORAGE:
      return wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Storage;

    case vtkWebGPUComputeBuffer::BufferMode::UNIFORM_BUFFER:
      return wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform;

    default:
      return wgpu::BufferUsage::None;
  }
}

//------------------------------------------------------------------------------
wgpu::BufferBindingType vtkWebGPUComputePipeline::ComputeBufferModeToBufferBindingType(
  vtkWebGPUComputeBuffer::BufferMode mode)
{
  switch (mode)
  {
    case vtkWebGPUComputeBuffer::BufferMode::READ_ONLY_COMPUTE_STORAGE:
      return wgpu::BufferBindingType::ReadOnlyStorage;

    case vtkWebGPUComputeBuffer::BufferMode::READ_WRITE_COMPUTE_STORAGE:
    case vtkWebGPUComputeBuffer::BufferMode::READ_WRITE_MAP_COMPUTE_STORAGE:
      return wgpu::BufferBindingType::Storage;

    case vtkWebGPUComputeBuffer::BufferMode::UNIFORM_BUFFER:
      return wgpu::BufferBindingType::Uniform;

    default:
      return wgpu::BufferBindingType::Undefined;
  }
}

VTK_ABI_NAMESPACE_END
