Support retrieving the splat value from DenseElementsAttrs in Python
authorAdam Paszke <apaszke@google.com>
Tue, 21 Mar 2023 15:26:06 +0000 (08:26 -0700)
committerJacques Pienaar <jpienaar@google.com>
Tue, 21 Mar 2023 15:43:17 +0000 (08:43 -0700)
This is especially convenient when trying to resize the splat.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D146510

mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/test/python/ir/array_attributes.py

index c59a54b..40598ec 100644 (file)
@@ -777,6 +777,16 @@ public:
                                [](PyDenseElementsAttribute &self) -> bool {
                                  return mlirDenseElementsAttrIsSplat(self);
                                })
+        .def("get_splat_value",
+             [](PyDenseElementsAttribute &self) -> PyAttribute {
+               if (!mlirDenseElementsAttrIsSplat(self)) {
+                 throw SetPyError(
+                     PyExc_ValueError,
+                     "get_splat_value called on a non-splat attribute");
+               }
+               return PyAttribute(self.getContext(),
+                                  mlirDenseElementsAttrGetSplatValue(self));
+             })
         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
   }
 
index b618802..c1f1633 100644 (file)
@@ -43,6 +43,7 @@ def testGetDenseElementsSplatInt():
     print(attr)
     # CHECK: is_splat: True
     print("is_splat:", attr.is_splat)
+    assert attr.get_splat_value() == element
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
@@ -55,6 +56,7 @@ def testGetDenseElementsSplatFloat():
     attr = DenseElementsAttr.get_splat(shaped_type, element)
     # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
     print(attr)
+    assert attr.get_splat_value() == element
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatErrors