From 3097720b69aca1d4c2e1434ee1e9827ce2e39ed6 Mon Sep 17 00:00:00 2001
From: Julien Fausty <julien.fausty@kitware.com>
Date: Wed, 21 Jun 2023 15:23:10 +0200
Subject: [PATCH] vtkMPICommunicator: add blocking probe methods

---
 Parallel/MPI/vtkMPICommunicator.cxx | 150 ++++++++++++++++++++++++++++
 Parallel/MPI/vtkMPICommunicator.h   |  28 ++++++
 2 files changed, 178 insertions(+)

diff --git a/Parallel/MPI/vtkMPICommunicator.cxx b/Parallel/MPI/vtkMPICommunicator.cxx
index f4439eea9d2..d97ac1b32fe 100644
--- a/Parallel/MPI/vtkMPICommunicator.cxx
+++ b/Parallel/MPI/vtkMPICommunicator.cxx
@@ -487,6 +487,70 @@ int vtkMPICommunicatorIprobe(int source, int tag, int* flag, int* actualSource,
   return retVal;
 }
 
+//------------------------------------------------------------------------------
+int vtkMPICommunicatorProbe(
+  int source, int tag, int* actualSource, MPI_Datatype datatype, int* size, MPI_Comm* handle)
+{
+  if (source == vtkMultiProcessController::ANY_SOURCE)
+  {
+    source = MPI_ANY_SOURCE;
+  }
+  MPI_Status status;
+  int retVal = MPI_Probe(source, tag, *handle, &status);
+  if (retVal == MPI_SUCCESS)
+  {
+    if (actualSource)
+    {
+      *actualSource = status.MPI_SOURCE;
+    }
+    if (size)
+    {
+      return MPI_Get_count(&status, datatype, size);
+    }
+  }
+  return retVal;
+}
+
+//------------------------------------------------------------------------------
+int vtkMPICommunicatorProbe(int source, int tag, int* actualSource, MPI_Datatype datatype,
+  vtkTypeInt64* size, MPI_Comm* handle)
+{
+  if (source == vtkMultiProcessController::ANY_SOURCE)
+  {
+    source = MPI_ANY_SOURCE;
+  }
+  MPI_Status status;
+  int retVal = MPI_Probe(source, tag, *handle, &status);
+  if (retVal == MPI_SUCCESS)
+  {
+    if (actualSource)
+    {
+      *actualSource = status.MPI_SOURCE;
+    }
+    if (size)
+    {
+#ifdef VTKMPI_64BIT_LENGTH
+      MPI_Count countSize = 0;
+      retVal = MPI_Get_count_c(&status, datatype, &countSize);
+      if (retVal == MPI_SUCCESS)
+      {
+        *size = countSize;
+      }
+      return retVal;
+#else
+      int intSize = 0;
+      retVal = MPI_Get_count(&status, datatype, &intSize);
+      if (retVal == MPI_SUCCESS)
+      {
+        *size = intSize;
+      }
+      return retVal;
+#endif
+    }
+  }
+  return retVal;
+}
+
 //------------------------------------------------------------------------------
 // Method for converting an MPI operation to a
 // vtkMultiProcessController::Operation.
@@ -2022,4 +2086,90 @@ int vtkMPICommunicator::Iprobe(
     source, tag, flag, actualSource, MPI_DOUBLE, size, this->MPIComm->Handle));
 }
 
+//------------------------------------------------------------------------------
+int vtkMPICommunicator::Probe(int source, int tag, int* actualSource)
+{
+  return CheckForMPIError(vtkMPICommunicatorProbe(
+    source, tag, actualSource, MPI_INT, (vtkIdType*)nullptr, this->MPIComm->Handle));
+}
+
+//------------------------------------------------------------------------------
+int vtkMPICommunicator::Probe(
+  int source, int tag, int* actualSource, int* vtkNotUsed(type), int* size)
+{
+  return CheckForMPIError(
+    vtkMPICommunicatorProbe(source, tag, actualSource, MPI_INT, size, this->MPIComm->Handle));
+}
+
+//------------------------------------------------------------------------------
+int vtkMPICommunicator::Probe(
+  int source, int tag, int* actualSource, unsigned long* vtkNotUsed(type), int* size)
+{
+  return CheckForMPIError(vtkMPICommunicatorProbe(
+    source, tag, actualSource, MPI_UNSIGNED_LONG, size, this->MPIComm->Handle));
+}
+
+//------------------------------------------------------------------------------
+int vtkMPICommunicator::Probe(
+  int source, int tag, int* actualSource, const char* vtkNotUsed(type), int* size)
+{
+  return CheckForMPIError(
+    vtkMPICommunicatorProbe(source, tag, actualSource, MPI_CHAR, size, this->MPIComm->Handle));
+}
+
+//------------------------------------------------------------------------------
+int vtkMPICommunicator::Probe(
+  int source, int tag, int* actualSource, float* vtkNotUsed(type), int* size)
+{
+  return CheckForMPIError(
+    vtkMPICommunicatorProbe(source, tag, actualSource, MPI_FLOAT, size, this->MPIComm->Handle));
+}
+
+//------------------------------------------------------------------------------
+int vtkMPICommunicator::Probe(
+  int source, int tag, int* actualSource, double* vtkNotUsed(type), int* size)
+{
+  return CheckForMPIError(
+    vtkMPICommunicatorProbe(source, tag, actualSource, MPI_DOUBLE, size, this->MPIComm->Handle));
+}
+//------------------------------------------------------------------------------
+int vtkMPICommunicator::Probe(
+  int source, int tag, int* actualSource, int* vtkNotUsed(type), vtkTypeInt64* size)
+{
+  return CheckForMPIError(
+    vtkMPICommunicatorProbe(source, tag, actualSource, MPI_INT, size, this->MPIComm->Handle));
+}
+
+//------------------------------------------------------------------------------
+int vtkMPICommunicator::Probe(
+  int source, int tag, int* actualSource, unsigned long* vtkNotUsed(type), vtkTypeInt64* size)
+{
+  return CheckForMPIError(vtkMPICommunicatorProbe(
+    source, tag, actualSource, MPI_UNSIGNED_LONG, size, this->MPIComm->Handle));
+}
+
+//------------------------------------------------------------------------------
+int vtkMPICommunicator::Probe(
+  int source, int tag, int* actualSource, const char* vtkNotUsed(type), vtkTypeInt64* size)
+{
+  return CheckForMPIError(
+    vtkMPICommunicatorProbe(source, tag, actualSource, MPI_CHAR, size, this->MPIComm->Handle));
+}
+
+//------------------------------------------------------------------------------
+int vtkMPICommunicator::Probe(
+  int source, int tag, int* actualSource, float* vtkNotUsed(type), vtkTypeInt64* size)
+{
+  return CheckForMPIError(
+    vtkMPICommunicatorProbe(source, tag, actualSource, MPI_FLOAT, size, this->MPIComm->Handle));
+}
+
+//------------------------------------------------------------------------------
+int vtkMPICommunicator::Probe(
+  int source, int tag, int* actualSource, double* vtkNotUsed(type), vtkTypeInt64* size)
+{
+  return CheckForMPIError(
+    vtkMPICommunicatorProbe(source, tag, actualSource, MPI_DOUBLE, size, this->MPIComm->Handle));
+}
+
 VTK_ABI_NAMESPACE_END
diff --git a/Parallel/MPI/vtkMPICommunicator.h b/Parallel/MPI/vtkMPICommunicator.h
index d47b973f8b9..999667cc973 100644
--- a/Parallel/MPI/vtkMPICommunicator.h
+++ b/Parallel/MPI/vtkMPICommunicator.h
@@ -227,6 +227,34 @@ public:
   int Iprobe(int source, int tag, int* flag, int* actualSource, double* type, vtkTypeInt64* size);
   ///@}
 
+  /**
+   * Check if this communicator implements a probe operation (always true for MPI communicator)
+   */
+  bool CanProbe() override { return true; };
+
+  ///@{
+  /**
+   * Blocking test for a message.  Inputs are: source -- the source rank
+   * or ANY_SOURCE; tag -- the tag value.  Outputs are:
+   * actualSource -- the rank sending the message (useful if ANY_SOURCE is used)
+   * if actualSource isn't nullptr; size -- the length of the message in
+   * bytes if flag is true (only set if size isn't nullptr). The return
+   * value is 1 for success and 0 otherwise.
+   */
+  int Probe(int source, int tag, int* actualSource) override;
+  int Probe(int source, int tag, int* actualSource, int* type, int* size);
+  int Probe(int source, int tag, int* actualSource, unsigned long* type, int* size);
+  int Probe(int source, int tag, int* actualSource, const char* type, int* size);
+  int Probe(int source, int tag, int* actualSource, float* type, int* size);
+  int Probe(int source, int tag, int* actualSource, double* type, int* size);
+
+  int Probe(int source, int tag, int* actualSource, int* type, vtkTypeInt64* size);
+  int Probe(int source, int tag, int* actualSource, unsigned long* type, vtkTypeInt64* size);
+  int Probe(int source, int tag, int* actualSource, const char* type, vtkTypeInt64* size);
+  int Probe(int source, int tag, int* actualSource, float* type, vtkTypeInt64* size);
+  int Probe(int source, int tag, int* actualSource, double* type, vtkTypeInt64* size);
+  ///@}
+
   /**
    * Given the request objects of a set of non-blocking operations
    * (send and/or receive) this method blocks until all requests are complete.
-- 
GitLab