From 9dc351782b5c29f23175dce7933e08440068b1b1 Mon Sep 17 00:00:00 2001
From: Dan Lipsa <dan.lipsa@kitware.com>
Date: Fri, 31 Jan 2025 18:29:33 -0500
Subject: [PATCH] Move constants at class level and refactor _update_accessor

---
 .../Python/vtkmodules/util/xarray_support.py  | 139 +++++++++---------
 1 file changed, 73 insertions(+), 66 deletions(-)

diff --git a/Wrapping/Python/vtkmodules/util/xarray_support.py b/Wrapping/Python/vtkmodules/util/xarray_support.py
index d0e16bad204..8613df31f24 100644
--- a/Wrapping/Python/vtkmodules/util/xarray_support.py
+++ b/Wrapping/Python/vtkmodules/util/xarray_support.py
@@ -39,9 +39,71 @@ class VtkAccessor:
 
 
 class vtkXArrayCFReader(VTKPythonAlgorithmBase):
+    '''Reads data from a file using the XArray readers and then connects
+    the XArray data to the vtkNetCDFCFREader (using zero-copy when
+    possible). At the moment, data is copied for coordinates (because
+    they are converted to double in the reader) and for certain data
+    that is subset either in XArray or in VTK.  Lazy loading in XArray
+    is respected, that is data is accessed only when it is needed.
+    Time is passed to VTK either as an int64 for datetime64 or
+    timedelta64, or as a double (using cftime.toordinal) for cftime.
     '''
 
-    '''
+    _FORWARD_GET = {
+        "CanReadFile",
+        "GetAccessor",
+        "GetAllDimensions",
+
+        "GetNumberOfVariableArrays",
+        "GetAllVariableArrayNames",
+        "GetVariableArrayName",
+        "GetVariableArrayStatus",
+
+        "GetTimeDimensionName",
+        "GetLatitudeDimensionName",
+        "GetLongitudeDimensionName",
+        "GetVerticalDimensionName",
+
+        "GetOutput",
+        "GetOutputType",
+        "GetSphericalCoordinates",
+
+        "GetReplaceFillValueWithNan",
+
+        "GetVariableDimensions",
+        "GetVerticalBias",
+        "GetVerticalScale",
+        "PrintSelf",
+    }
+    _FORWARD_SET = {
+        "SetDimensions",
+
+        "SetTimeDimensionName",
+        "SetLatitudeDimensionName",
+        "SetLongitudeDimensionName",
+        "SetVerticalDimensionName",
+
+
+        "SetSphericalCoordinates",
+        "SphericalCoordinatesOn",
+        "SphericalCoordinatesOff",
+
+        "SetReplaceFillValueWithNan",
+        "ReplaceFillValueWithNanOn",
+        "ReplaceFillValueWithNanOff",
+
+        "SetOutputType",
+        "SetOutputTypeToAutomatic",
+        "SetOutputTypeToImage",
+        "SetOutputTypeToRectilinear",
+        "SetOutputTypeToStructured",
+        "SetOutputTypeToUnstructured",
+
+        "SetVariableArrayStatus",
+        "SetVerticalBias",
+        "SetVerticalScale",
+        "UpdateMetaData",
+    }
 
     def __init__(self):
         VTKPythonAlgorithmBase.__init__(
@@ -59,65 +121,10 @@ class vtkXArrayCFReader(VTKPythonAlgorithmBase):
         self._ndarray_cftime_toordinal = np.frompyfunc(vtkXArrayCFReader._cftime_toordinal, 1, 1)
         # reference to contiguous arrays so that they are not dealocated
         self._arrays = {}
-        self._forward_get = {
-            "CanReadFile",
-            "GetAccessor",
-            "GetAllDimensions",
-
-            "GetNumberOfVariableArrays",
-            "GetAllVariableArrayNames",
-            "GetVariableArrayName",
-            "GetVariableArrayStatus",
-
-            "GetTimeDimensionName",
-            "GetLatitudeDimensionName",
-            "GetLongitudeDimensionName",
-            "GetVerticalDimensionName",
-
-            "GetOutput",
-            "GetOutputType",
-            "GetSphericalCoordinates",
-
-            "GetReplaceFillValueWithNan",
-
-            "GetVariableDimensions",
-            "GetVerticalBias",
-            "GetVerticalScale",
-            "PrintSelf",
-        }
-        self._forward_set = {
-            "SetDimensions",
-
-            "SetTimeDimensionName",
-            "SetLatitudeDimensionName",
-            "SetLongitudeDimensionName",
-            "SetVerticalDimensionName",
-
-
-            "SetSphericalCoordinates",
-            "SphericalCoordinatesOn",
-            "SphericalCoordinatesOff",
-
-            "SetReplaceFillValueWithNan",
-            "ReplaceFillValueWithNanOn",
-            "ReplaceFillValueWithNanOff",
-
-            "SetOutputType",
-            "SetOutputTypeToAutomatic",
-            "SetOutputTypeToImage",
-            "SetOutputTypeToRectilinear",
-            "SetOutputTypeToStructured",
-            "SetOutputTypeToUnstructured",
-
-            "SetVariableArrayStatus",
-            "SetVerticalBias",
-            "SetVerticalScale",
-            "UpdateMetaData",
-        }
 
     def __getattr__(self, name):
-        in_set = name in self._forward_set
-        in_get = name in self._forward_get
+        in_set = name in self._FORWARD_SET
+        in_get = name in self._FORWARD_GET
         if in_set or in_get:
             if in_set:
                 self.Modified()
@@ -145,10 +152,7 @@ class vtkXArrayCFReader(VTKPythonAlgorithmBase):
 
     def SetXArray(self, dsxr):
         self._dsxr = dsxr
-        accessor, timename = self._get_accessor()
-        self._reader.SetAccessor(accessor)
-        if timename:
-            self._reader.SetTimeDimensionName(timename)
+        self._update_accessor()
         self.Modified()
 
     def GetXArray(self):
@@ -162,10 +166,7 @@ class vtkXArrayCFReader(VTKPythonAlgorithmBase):
                 self._dsxr = tree[self._node].to_dataset()
             else:
                 self._dsxr = xr.open_dataset(self._filename)
-            accessor, timename = self._get_accessor()
-            self._reader.SetAccessor(accessor)
-            if timename:
-                self._reader.SetTimeDimensionName(timename)
+            self._update_accessor()
         self._reader.UpdateDataObject()
         roi = self._reader.GetOutputInformation(0)
         if roi.Has(vtkDataObject.DATA_OBJECT()):
@@ -273,6 +274,12 @@ class vtkXArrayCFReader(VTKPythonAlgorithmBase):
             "Could not find a suitable NetCDF type for %s" % (str(numpy_array_type))
         )
 
+    def _update_accessor(self):
+        accessor, timename = self._get_accessor()
+        self._reader.SetAccessor(accessor)
+        if timename:
+            self._reader.SetTimeDimensionName(timename)
+
     def _get_accessor(self):
         acclog = logging.getLogger("_get_accessor_")
         acclog.setLevel(logging.WARNING)
-- 
GitLab