[mlir][python] expose the shape property of shaped types
authorAlex Zinenko <zinenko@google.com>
Tue, 2 Nov 2021 15:44:37 +0000 (16:44 +0100)
committerAlex Zinenko <zinenko@google.com>
Wed, 3 Nov 2021 09:49:12 +0000 (10:49 +0100)
This has been missing in the original definition of shaped types.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D113025

mlir/lib/Bindings/Python/IRTypes.cpp
mlir/test/python/ir/builtin_types.py

index 1cfd799..89fdb1f 100644 (file)
@@ -284,6 +284,19 @@ public:
         },
         "Returns whether the given value is used as a placeholder for dynamic "
         "strides and offsets in shaped types.");
+    c.def_property_readonly(
+        "shape",
+        [](PyShapedType &self) {
+          self.requireHasRank();
+
+          std::vector<int64_t> shape;
+          int64_t rank = mlirShapedTypeGetRank(self);
+          shape.reserve(rank);
+          for (int64_t i = 0; i < rank; ++i)
+            shape.push_back(mlirShapedTypeGetDimSize(self, i));
+          return shape;
+        },
+        "Returns the shape of the ranked shaped type as a list of integers.");
   }
 
 private:
index c5b32e8..7d881b9 100644 (file)
@@ -315,6 +315,9 @@ def testRankedTensorType():
     # Encoding should be None.
     assert RankedTensorType.get(shape, f32).encoding is None
 
+    tensor = RankedTensorType.get(shape, f32)
+    assert tensor.shape == shape
+
 
 # CHECK-LABEL: TEST: testUnrankedTensorType
 @run
@@ -396,6 +399,8 @@ def testMemRefType():
     else:
       print("Exception not produced")
 
+    assert memref.shape == shape
+
 
 # CHECK-LABEL: TEST: testUnrankedMemRefType
 @run