[mlir][python] Fix MemRefType IsAFunction in Python bindings
authorAlex Zinenko <zinenko@google.com>
Thu, 14 Oct 2021 09:33:28 +0000 (11:33 +0200)
committerAlex Zinenko <zinenko@google.com>
Thu, 14 Oct 2021 11:12:37 +0000 (13:12 +0200)
MemRefType was using a wrong `isa` function in the bindings code, which
could lead to invalid IR being constructed. Also run the verifier in
memref dialect tests.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D111784

mlir/lib/Bindings/Python/IRTypes.cpp
mlir/python/mlir/dialects/_memref_ops_ext.py
mlir/test/python/dialects/memref.py

index 568cca160a5951465d1e9d69658bb77f1f9d9986..fd9f3efe7405f8f075294074b7d3e7db580a96f7 100644 (file)
@@ -406,7 +406,7 @@ class PyMemRefLayoutMapList;
 /// Ranked MemRef Type subclass - MemRefType.
 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
 public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
   static constexpr const char *pyClassName = "MemRefType";
   using PyConcreteType::PyConcreteType;
 
index cb25ef105d73f7ab90c776afc10c666f575fd140..9cc22a21c6283fb3276ea7f5dd252d87d428ba3a 100644 (file)
@@ -33,5 +33,5 @@ class LoadOp:
     memref_resolved = _get_op_result_or_value(memref)
     indices_resolved = [] if indices is None else _get_op_results_or_values(
         indices)
-    return_type = memref_resolved.type
+    return_type = MemRefType(memref_resolved.type).element_type
     super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)
index e421f9b2fde953f0177d13f4ef10f545de9ae40d..f2eda0a620610f1ae60db20cb670729c8211b149 100644 (file)
@@ -71,3 +71,4 @@ def testCustomBuidlers():
     # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
     # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
     print(module)
+    assert module.operation.verify()