TestingTaskTiling.h 15.5 KB
Newer Older
1 2 3 4
//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
5
//
6 7 8 9 10 11 12 13
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//============================================================================

#include <vtkm/StaticAssert.h>

#include <vtkm/cont/DeviceAdapterAlgorithm.h>
14
#include <vtkm/testing/Testing.h>
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

#include <vtkm/exec/FunctorBase.h>
#include <vtkm/exec/arg/BasicArg.h>
#include <vtkm/exec/arg/Fetch.h>
#include <vtkm/exec/arg/ThreadIndicesBasic.h>

#include <vtkm/internal/FunctionInterface.h>
#include <vtkm/internal/Invocation.h>

#include <algorithm>
#include <vector>

namespace vtkm
{
namespace exec
{
namespace internal
{
namespace testing
{

struct TestExecObject
{
  VTKM_EXEC_CONT
  TestExecObject()
    : Values(nullptr)
  {
  }

  VTKM_EXEC_CONT
  TestExecObject(std::vector<vtkm::Id>& values)
    : Values(&values[0])
  {
  }

  VTKM_EXEC_CONT
  TestExecObject(const TestExecObject& other) { Values = other.Values; }

  vtkm::Id* Values;
};

struct MyOutputToInputMapPortal
{
58
  using ValueType = vtkm::Id;
59 60 61 62 63 64
  VTKM_EXEC_CONT
  vtkm::Id Get(vtkm::Id index) const { return index; }
};

struct MyVisitArrayPortal
{
65
  using ValueType = vtkm::IdComponent;
66 67 68
  vtkm::IdComponent Get(vtkm::Id) const { return 1; }
};

69 70 71 72 73 74 75
struct MyThreadToOutputMapPortal
{
  using ValueType = vtkm::Id;
  VTKM_EXEC_CONT
  vtkm::Id Get(vtkm::Id index) const { return index; }
};

76 77 78 79 80 81 82 83 84 85
struct TestFetchTagInput
{
};
struct TestFetchTagOutput
{
};

// Missing TransportTag, but we are not testing that so we can leave it out.
struct TestControlSignatureTagInput
{
86
  using FetchTag = TestFetchTagInput;
87 88 89
};
struct TestControlSignatureTagOutput
{
90
  using FetchTag = TestFetchTagOutput;
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
};
}
}
}
}

namespace vtkm
{
namespace exec
{
namespace arg
{

using namespace vtkm::exec::internal::testing;

template <>
107 108 109 110
struct Fetch<TestFetchTagInput,
             vtkm::exec::arg::AspectTagDefault,
             vtkm::exec::arg::ThreadIndicesBasic,
             TestExecObject>
111
{
112
  using ValueType = vtkm::Id;
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128

  VTKM_EXEC
  ValueType Load(const vtkm::exec::arg::ThreadIndicesBasic& indices,
                 const TestExecObject& execObject) const
  {
    return execObject.Values[indices.GetInputIndex()] + 10 * indices.GetInputIndex();
  }

  VTKM_EXEC
  void Store(const vtkm::exec::arg::ThreadIndicesBasic&, const TestExecObject&, ValueType) const
  {
    // No-op
  }
};

template <>
129 130 131 132
struct Fetch<TestFetchTagOutput,
             vtkm::exec::arg::AspectTagDefault,
             vtkm::exec::arg::ThreadIndicesBasic,
             TestExecObject>
133
{
134
  using ValueType = vtkm::Id;
135 136 137 138 139 140 141 142 143

  VTKM_EXEC
  ValueType Load(const vtkm::exec::arg::ThreadIndicesBasic&, const TestExecObject&) const
  {
    // No-op
    return ValueType();
  }

  VTKM_EXEC
144 145
  void Store(const vtkm::exec::arg::ThreadIndicesBasic& indices,
             const TestExecObject& execObject,
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
             ValueType value) const
  {
    execObject.Values[indices.GetOutputIndex()] = value + 20 * indices.GetOutputIndex();
  }
};
}
}
} // vtkm::exec::arg

namespace vtkm
{
namespace exec
{
namespace internal
{
namespace testing
{

164
using TestControlSignature = void(TestControlSignatureTagInput, TestControlSignatureTagOutput);
165
using TestControlInterface = vtkm::internal::FunctionInterface<TestControlSignature>;
166

167
using TestExecutionSignature1 = void(vtkm::exec::arg::BasicArg<1>, vtkm::exec::arg::BasicArg<2>);
168
using TestExecutionInterface1 = vtkm::internal::FunctionInterface<TestExecutionSignature1>;
169

170
using TestExecutionSignature2 = vtkm::exec::arg::BasicArg<2>(vtkm::exec::arg::BasicArg<1>);
171
using TestExecutionInterface2 = vtkm::internal::FunctionInterface<TestExecutionSignature2>;
172

173 174
using ExecutionParameterInterface =
  vtkm::internal::FunctionInterface<void(TestExecObject, TestExecObject)>;
175

176 177 178 179 180
using InvocationType1 = vtkm::internal::Invocation<ExecutionParameterInterface,
                                                   TestControlInterface,
                                                   TestExecutionInterface1,
                                                   1,
                                                   MyOutputToInputMapPortal,
181 182
                                                   MyVisitArrayPortal,
                                                   MyThreadToOutputMapPortal>;
183 184 185 186 187 188

using InvocationType2 = vtkm::internal::Invocation<ExecutionParameterInterface,
                                                   TestControlInterface,
                                                   TestExecutionInterface2,
                                                   1,
                                                   MyOutputToInputMapPortal,
189 190
                                                   MyVisitArrayPortal,
                                                   MyThreadToOutputMapPortal>;
191 192 193 194 195 196 197 198 199 200

// Not a full worklet, but provides operators that we expect in a worklet.
struct TestWorkletProxy : vtkm::exec::FunctorBase
{
  VTKM_EXEC
  void operator()(vtkm::Id input, vtkm::Id& output) const { output = input + 100; }

  VTKM_EXEC
  vtkm::Id operator()(vtkm::Id input) const { return input + 200; }

201 202
  template <typename OutToInArrayType,
            typename VisitArrayType,
203
            typename ThreadToOutputArrayType,
204
            typename InputDomainType,
205 206
            typename G>
  VTKM_EXEC vtkm::exec::arg::ThreadIndicesBasic GetThreadIndices(
207 208 209
    const vtkm::Id& threadIndex,
    const OutToInArrayType& outToIn,
    const VisitArrayType& visit,
210
    const ThreadToOutputArrayType& threadToOut,
211 212
    const InputDomainType&,
    const G& globalThreadIndexOffset) const
213
  {
214
    const vtkm::Id outIndex = threadToOut.Get(threadIndex);
215
    return vtkm::exec::arg::ThreadIndicesBasic(
216
      threadIndex, outToIn.Get(outIndex), visit.Get(outIndex), outIndex, globalThreadIndexOffset);
217 218
  }

219 220
  template <typename OutToInArrayType,
            typename VisitArrayType,
221
            typename ThreadToOutArrayType,
222
            typename InputDomainType,
223 224
            typename G>
  VTKM_EXEC vtkm::exec::arg::ThreadIndicesBasic GetThreadIndices(
225 226 227
    const vtkm::Id3& threadIndex,
    const OutToInArrayType& outToIn,
    const VisitArrayType& visit,
228
    const ThreadToOutArrayType& threadToOut,
229 230
    const InputDomainType&,
    const G& globalThreadIndexOffset) const
231
  {
232 233 234 235 236 237 238
    const vtkm::Id flatThreadIndex = vtkm::Dot(threadIndex, vtkm::Id3(1, 8, 64));
    const vtkm::Id outIndex = threadToOut.Get(flatThreadIndex);
    return vtkm::exec::arg::ThreadIndicesBasic(flatThreadIndex,
                                               outToIn.Get(outIndex),
                                               visit.Get(outIndex),
                                               outIndex,
                                               globalThreadIndexOffset);
239 240 241 242 243 244 245 246 247 248 249
  }
};

#define ERROR_MESSAGE "Expected worklet error."

// Not a full worklet, but provides operators that we expect in a worklet.
struct TestWorkletErrorProxy : vtkm::exec::FunctorBase
{
  VTKM_EXEC
  void operator()(vtkm::Id, vtkm::Id) const { this->RaiseError(ERROR_MESSAGE); }

250 251
  template <typename OutToInArrayType,
            typename VisitArrayType,
252
            typename ThreadToOutArrayType,
253
            typename InputDomainType,
254 255
            typename G>
  VTKM_EXEC vtkm::exec::arg::ThreadIndicesBasic GetThreadIndices(
256 257 258
    const vtkm::Id& threadIndex,
    const OutToInArrayType& outToIn,
    const VisitArrayType& visit,
259
    const ThreadToOutArrayType& threadToOut,
260 261
    const InputDomainType&,
    const G& globalThreadIndexOffset) const
262
  {
263
    const vtkm::Id outIndex = threadToOut.Get(threadIndex);
264
    return vtkm::exec::arg::ThreadIndicesBasic(
265
      threadIndex, outToIn.Get(outIndex), visit.Get(outIndex), outIndex, globalThreadIndexOffset);
266 267
  }

268 269
  template <typename OutToInArrayType,
            typename VisitArrayType,
270
            typename ThreadToOutputArrayType,
271
            typename InputDomainType,
272 273
            typename G>
  VTKM_EXEC vtkm::exec::arg::ThreadIndicesBasic GetThreadIndices(
274 275 276
    const vtkm::Id3& threadIndex,
    const OutToInArrayType& outToIn,
    const VisitArrayType& visit,
277
    const ThreadToOutputArrayType& threadToOutput,
278 279
    const InputDomainType&,
    const G& globalThreadIndexOffset) const
280
  {
281
    const vtkm::Id index = vtkm::Dot(threadIndex, vtkm::Id3(1, 8, 64));
282 283 284 285 286 287
    const vtkm::Id outputIndex = threadToOutput.Get(index);
    return vtkm::exec::arg::ThreadIndicesBasic(index,
                                               outToIn.Get(outputIndex),
                                               visit.Get(outputIndex),
                                               outputIndex,
                                               globalThreadIndexOffset);
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
  }
};

template <typename DeviceAdapter>
void Test1DNormalTaskTilingInvoke()
{

  std::cout << "Testing TaskTiling1D." << std::endl;

  std::vector<vtkm::Id> inputTestValues(100, 5);
  std::vector<vtkm::Id> outputTestValues(100, static_cast<vtkm::Id>(0xDEADDEAD));
  vtkm::internal::FunctionInterface<void(TestExecObject, TestExecObject)> execObjects =
    vtkm::internal::make_FunctionInterface<void>(TestExecObject(inputTestValues),
                                                 TestExecObject(outputTestValues));

  std::cout << "  Try void return." << std::endl;
  TestWorkletProxy worklet;
  InvocationType1 invocation1(execObjects);

  using TaskTypes = typename vtkm::cont::DeviceTaskTypes<DeviceAdapter>;
  auto task1 = TaskTypes::MakeTask(worklet, invocation1, vtkm::Id());

  vtkm::exec::internal::ErrorMessageBuffer errorMessage(nullptr, 0);
  task1.SetErrorMessageBuffer(errorMessage);

  task1(0, 90);
  task1(90, 99);
  task1(99, 100); //verify single value ranges work

  for (std::size_t i = 0; i < 100; ++i)
  {
    VTKM_TEST_ASSERT(inputTestValues[i] == 5, "Input value changed.");
    VTKM_TEST_ASSERT(outputTestValues[i] ==
                       inputTestValues[i] + 100 + (30 * static_cast<vtkm::Id>(i)),
                     "Output value not set right.");
  }

  std::cout << "  Try return value." << std::endl;
  std::fill(inputTestValues.begin(), inputTestValues.end(), 6);
  std::fill(outputTestValues.begin(), outputTestValues.end(), static_cast<vtkm::Id>(0xDEADDEAD));

  InvocationType2 invocation2(execObjects);

  using TaskTypes = typename vtkm::cont::DeviceTaskTypes<DeviceAdapter>;
  auto task2 = TaskTypes::MakeTask(worklet, invocation2, vtkm::Id());

  task2.SetErrorMessageBuffer(errorMessage);

  task2(0, 0); //verify zero value ranges work
  task2(0, 90);
  task2(90, 100);

  task2(0, 100); //verify that you can invoke worklets multiple times

  for (std::size_t i = 0; i < 100; ++i)
  {
    VTKM_TEST_ASSERT(inputTestValues[i] == 6, "Input value changed.");
    VTKM_TEST_ASSERT(outputTestValues[i] ==
                       inputTestValues[i] + 200 + (30 * static_cast<vtkm::Id>(i)),
                     "Output value not set right.");
  }
}

template <typename DeviceAdapter>
void Test1DErrorTaskTilingInvoke()
{

  std::cout << "Testing TaskTiling1D with an error raised in the worklet." << std::endl;

  std::vector<vtkm::Id> inputTestValues(100, 5);
  std::vector<vtkm::Id> outputTestValues(100, static_cast<vtkm::Id>(0xDEADDEAD));

  TestExecObject arg1(inputTestValues);
  TestExecObject arg2(outputTestValues);

  vtkm::internal::FunctionInterface<void(TestExecObject, TestExecObject)> execObjects =
    vtkm::internal::make_FunctionInterface<void>(arg1, arg2);

  TestWorkletErrorProxy worklet;
  InvocationType1 invocation(execObjects);

  using TaskTypes = typename vtkm::cont::DeviceTaskTypes<DeviceAdapter>;
  auto task = TaskTypes::MakeTask(worklet, invocation, vtkm::Id());

  char message[1024];
  message[0] = '\0';
  vtkm::exec::internal::ErrorMessageBuffer errorMessage(message, 1024);
  task.SetErrorMessageBuffer(errorMessage);

  task(0, 100);

  VTKM_TEST_ASSERT(errorMessage.IsErrorRaised(), "Error not raised correctly.");
  VTKM_TEST_ASSERT(message == std::string(ERROR_MESSAGE), "Got wrong error message.");
}

template <typename DeviceAdapter>
void Test3DNormalTaskTilingInvoke()
{
  std::cout << "Testing TaskTiling3D." << std::endl;

  std::vector<vtkm::Id> inputTestValues((8 * 8 * 8), 5);
  std::vector<vtkm::Id> outputTestValues((8 * 8 * 8), static_cast<vtkm::Id>(0xDEADDEAD));
  vtkm::internal::FunctionInterface<void(TestExecObject, TestExecObject)> execObjects =
    vtkm::internal::make_FunctionInterface<void>(TestExecObject(inputTestValues),
                                                 TestExecObject(outputTestValues));

  std::cout << "  Try void return." << std::endl;

  TestWorkletProxy worklet;
  InvocationType1 invocation1(execObjects);

  using TaskTypes = typename vtkm::cont::DeviceTaskTypes<DeviceAdapter>;
  auto task1 = TaskTypes::MakeTask(worklet, invocation1, vtkm::Id3());
  for (vtkm::Id k = 0; k < 8; ++k)
  {
    for (vtkm::Id j = 0; j < 8; j += 2)
    {
      //verify that order is not required
      task1(0, 8, j + 1, k);
      task1(0, 8, j, k);
    }
  }

  for (std::size_t i = 0; i < (8 * 8 * 8); ++i)
  {
    VTKM_TEST_ASSERT(inputTestValues[i] == 5, "Input value changed.");
    VTKM_TEST_ASSERT(outputTestValues[i] ==
                       inputTestValues[i] + 100 + (30 * static_cast<vtkm::Id>(i)),
                     "Output value not set right.");
  }

  std::cout << "  Try return value." << std::endl;
  std::fill(inputTestValues.begin(), inputTestValues.end(), 6);
  std::fill(outputTestValues.begin(), outputTestValues.end(), static_cast<vtkm::Id>(0xDEADDEAD));

  InvocationType2 invocation2(execObjects);
  using TaskTypes = typename vtkm::cont::DeviceTaskTypes<DeviceAdapter>;
  auto task2 = TaskTypes::MakeTask(worklet, invocation2, vtkm::Id3());

  //verify that linear order of values being processed is not presumed
  for (vtkm::Id i = 0; i < 8; ++i)
  {
    for (vtkm::Id j = 0; j < 8; ++j)
    {
      for (vtkm::Id k = 0; k < 8; ++k)
      {
        task2(i, i + 1, j, k);
      }
    }
  }

  for (std::size_t i = 0; i < (8 * 8 * 8); ++i)
  {
    VTKM_TEST_ASSERT(inputTestValues[i] == 6, "Input value changed.");
    VTKM_TEST_ASSERT(outputTestValues[i] ==
                       inputTestValues[i] + 200 + (30 * static_cast<vtkm::Id>(i)),
                     "Output value not set right.");
  }
}

template <typename DeviceAdapter>
void Test3DErrorTaskTilingInvoke()
{
  std::cout << "Testing TaskTiling3D with an error raised in the worklet." << std::endl;

  std::vector<vtkm::Id> inputTestValues((8 * 8 * 8), 5);
  std::vector<vtkm::Id> outputTestValues((8 * 8 * 8), static_cast<vtkm::Id>(0xDEADDEAD));
  vtkm::internal::FunctionInterface<void(TestExecObject, TestExecObject)> execObjects =
    vtkm::internal::make_FunctionInterface<void>(TestExecObject(inputTestValues),
                                                 TestExecObject(outputTestValues));

  TestWorkletErrorProxy worklet;
  InvocationType1 invocation(execObjects);

  using TaskTypes = typename vtkm::cont::DeviceTaskTypes<DeviceAdapter>;
  auto task1 = TaskTypes::MakeTask(worklet, invocation, vtkm::Id3());

  char message[1024];
  message[0] = '\0';
  vtkm::exec::internal::ErrorMessageBuffer errorMessage(message, 1024);
  task1.SetErrorMessageBuffer(errorMessage);

  for (vtkm::Id k = 0; k < 8; ++k)
  {
    for (vtkm::Id j = 0; j < 8; ++j)
    {
      task1(0, 8, j, k);
    }
  }

  VTKM_TEST_ASSERT(errorMessage.IsErrorRaised(), "Error not raised correctly.");
  VTKM_TEST_ASSERT(message == std::string(ERROR_MESSAGE), "Got wrong error message.");
}

template <typename DeviceAdapter>
void TestTaskTiling()
{
  Test1DNormalTaskTilingInvoke<DeviceAdapter>();
  Test1DErrorTaskTilingInvoke<DeviceAdapter>();

  Test3DNormalTaskTilingInvoke<DeviceAdapter>();
  Test3DErrorTaskTilingInvoke<DeviceAdapter>();
}
}
}
}
}