Commit b57dc5d2 authored by Robert Maynard's avatar Robert Maynard

Update ArrayHandleVirtual to handle PrepareForInPlace.

parent 3c624614
......@@ -73,10 +73,10 @@ private:
vtkm::cont::DeviceAdapterId devId) const;
void TransferPortalForOutput(vtkm::cont::internal::TransferInfoArray& payload,
OutputMode mode,
vtkm::Id numberOfValues,
vtkm::cont::DeviceAdapterId devId);
vtkm::cont::ArrayHandle<T, S> Handle;
};
......
......@@ -70,12 +70,24 @@ struct PortalWrapperToDevice
bool operator()(DeviceAdapterTag device,
Handle&& handle,
vtkm::Id numberOfValues,
vtkm::cont::internal::TransferInfoArray& payload) const
vtkm::cont::internal::TransferInfoArray& payload,
vtkm::cont::StorageVirtual::OutputMode mode) const
{
auto portal = handle.PrepareForOutput(numberOfValues, device);
using DerivedPortal = vtkm::ArrayPortalWrapper<decltype(portal)>;
vtkm::cont::detail::TransferToDevice<DerivedPortal> transfer;
return transfer(device, payload, portal);
using ACCESS_MODE = vtkm::cont::StorageVirtual::OutputMode;
if (mode == ACCESS_MODE::WRITE)
{
auto portal = handle.PrepareForOutput(numberOfValues, device);
using DerivedPortal = vtkm::ArrayPortalWrapper<decltype(portal)>;
vtkm::cont::detail::TransferToDevice<DerivedPortal> transfer;
return transfer(device, payload, portal);
}
else
{
auto portal = handle.PrepareForInPlace(device);
using DerivedPortal = vtkm::ArrayPortalWrapper<decltype(portal)>;
vtkm::cont::detail::TransferToDevice<DerivedPortal> transfer;
return transfer(device, payload, portal);
}
}
};
}
......@@ -129,11 +141,12 @@ void StorageAny<T, S>::TransferPortalForInput(vtkm::cont::internal::TransferInfo
template <typename T, typename S>
void StorageAny<T, S>::TransferPortalForOutput(vtkm::cont::internal::TransferInfoArray& payload,
vtkm::cont::StorageVirtual::OutputMode mode,
vtkm::Id numberOfValues,
vtkm::cont::DeviceAdapterId devId)
{
vtkm::cont::TryExecuteOnDevice(
devId, detail::PortalWrapperToDevice(), this->Handle, numberOfValues, payload);
devId, detail::PortalWrapperToDevice(), this->Handle, numberOfValues, payload, mode);
}
}
} // namespace vtkm::cont
......
......@@ -239,7 +239,7 @@ public:
vtkm::ArrayPortalRef<T> PrepareForInPlace(vtkm::cont::DeviceAdapterId devId)
{
return make_ArrayPortalRef(
static_cast<const vtkm::ArrayPortalVirtual<T>*>(this->Storage->PrepareForInput(devId)),
static_cast<const vtkm::ArrayPortalVirtual<T>*>(this->Storage->PrepareForInPlace(devId)),
this->GetNumberOfValues());
}
......
......@@ -122,7 +122,6 @@ Storage<void, ::vtkm::cont::StorageTagVirtual>::PrepareForInput(
//need to re-transfer the execution information
auto* payload = this->DeviceTransferState.get();
this->TransferPortalForInput(*payload, devId);
this->HostUpToDate = false;
this->DeviceUpToDate = true;
}
return this->DeviceTransferState->devicePtr();
......@@ -145,7 +144,33 @@ Storage<void, ::vtkm::cont::StorageTagVirtual>::PrepareForOutput(vtkm::Id number
const bool needsUpload = !(this->DeviceTransferState->valid(devId) && this->DeviceUpToDate);
if (needsUpload)
{
this->TransferPortalForOutput(*(this->DeviceTransferState), numberOfValues, devId);
this->TransferPortalForOutput(
*(this->DeviceTransferState), OutputMode::WRITE, numberOfValues, devId);
this->HostUpToDate = false;
this->DeviceUpToDate = true;
}
return this->DeviceTransferState->devicePtr();
}
//--------------------------------------------------------------------
const vtkm::internal::PortalVirtualBase*
Storage<void, ::vtkm::cont::StorageTagVirtual>::PrepareForInPlace(vtkm::cont::DeviceAdapterId devId)
{
if (devId == vtkm::cont::DeviceAdapterTagUndefined())
{
throw vtkm::cont::ErrorBadValue("device should not be VTKM_DEVICE_ADAPTER_UNDEFINED");
}
if (devId == vtkm::cont::DeviceAdapterTagError())
{
throw vtkm::cont::ErrorBadValue("device should not be VTKM_DEVICE_ADAPTER_ERROR");
}
const bool needsUpload = !(this->DeviceTransferState->valid(devId) && this->DeviceUpToDate);
if (needsUpload)
{
vtkm::Id numberOfValues = this->GetNumberOfValues();
this->TransferPortalForOutput(
*(this->DeviceTransferState), OutputMode::READ_WRITE, numberOfValues, devId);
this->HostUpToDate = false;
this->DeviceUpToDate = true;
}
......@@ -178,9 +203,6 @@ Storage<void, ::vtkm::cont::StorageTagVirtual>::GetPortalConstControl() const
vtkm::cont::internal::TransferInfoArray* payload = this->DeviceTransferState.get();
this->ControlPortalForInput(*payload);
}
//We need to mark the device out of date after the conditions
//as they can modify the state of the device
this->DeviceUpToDate = false;
this->HostUpToDate = true;
return this->DeviceTransferState->hostPtr();
}
......@@ -195,16 +217,18 @@ DeviceAdapterId Storage<void, ::vtkm::cont::StorageTagVirtual>::GetDeviceAdapter
void Storage<void, ::vtkm::cont::StorageTagVirtual>::ControlPortalForOutput(
vtkm::cont::internal::TransferInfoArray&)
{
throw vtkm::cont::ErrorBadValue("StorageTagVirtual by default doesn't support output.");
throw vtkm::cont::ErrorBadValue(
"StorageTagVirtual by default doesn't support control side writes.");
}
//--------------------------------------------------------------------
void Storage<void, ::vtkm::cont::StorageTagVirtual>::TransferPortalForOutput(
vtkm::cont::internal::TransferInfoArray&,
OutputMode,
vtkm::Id,
vtkm::cont::DeviceAdapterId)
{
throw vtkm::cont::ErrorBadValue("StorageTagVirtual by default doesn't support output.");
throw vtkm::cont::ErrorBadValue("StorageTagVirtual by default doesn't support exec side writes.");
}
}
}
......
......@@ -127,6 +127,8 @@ public:
const vtkm::internal::PortalVirtualBase* PrepareForOutput(vtkm::Id numberOfValues,
vtkm::cont::DeviceAdapterId devId);
const vtkm::internal::PortalVirtualBase* PrepareForInPlace(vtkm::cont::DeviceAdapterId devId);
//This needs to cause a host side sync!
//This needs to work before we execute on a device
const vtkm::internal::PortalVirtualBase* GetPortalControl();
......@@ -140,6 +142,13 @@ public:
/// returned.
DeviceAdapterId GetDeviceAdapterId() const noexcept;
enum struct OutputMode
{
WRITE,
READ_WRITE
};
private:
//Memory management routines
// virtual void DoAllocate(vtkm::Id numberOfValues) = 0;
......@@ -153,9 +162,12 @@ private:
//Portal routines
virtual void ControlPortalForInput(vtkm::cont::internal::TransferInfoArray& payload) const = 0;
virtual void ControlPortalForOutput(vtkm::cont::internal::TransferInfoArray& payload);
virtual void TransferPortalForInput(vtkm::cont::internal::TransferInfoArray& payload,
vtkm::cont::DeviceAdapterId devId) const = 0;
virtual void TransferPortalForOutput(vtkm::cont::internal::TransferInfoArray& payload,
OutputMode mode,
vtkm::Id numberOfValues,
vtkm::cont::DeviceAdapterId devId);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment