[mlir][python] Swap shape and element_type order for MemRefType.
authorStella Laurenzo <stellaraccident@gmail.com>
Wed, 20 Jan 2021 00:02:02 +0000 (16:02 -0800)
committerStella Laurenzo <stellaraccident@gmail.com>
Wed, 20 Jan 2021 00:03:19 +0000 (16:03 -0800)
* Matches how all of the other shaped types are declared.
* No super principled reason fro this ordering beyond that it makes the one that was different be like the rest.
* Also matches ordering of things like ndarray, et al.

Reviewed By: ftynse, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D94812

mlir/examples/python/linalg_matmul.py
mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_types.py

index e9be189..0bd3c12 100644 (file)
@@ -31,9 +31,9 @@ def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
 
 
 def build_matmul_buffers_func(func_name, m, k, n, dtype):
-  lhs_type = MemRefType.get(dtype, [m, k])
-  rhs_type = MemRefType.get(dtype, [k, n])
-  result_type = MemRefType.get(dtype, [m, n])
+  lhs_type = MemRefType.get([m, k], dtype)
+  rhs_type = MemRefType.get([k, n], dtype)
+  result_type = MemRefType.get([m, n], dtype)
   # TODO: There should be a one-liner for this.
   func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
   _, entry = FuncOp(func_name, func_type)
@@ -49,8 +49,6 @@ def build_matmul_buffers_func(func_name, m, k, n, dtype):
 
 
 def build_matmul_tensors_func(func_name, m, k, n, dtype):
-  # TODO: MemRefType and TensorTypes should not have inverted dtype/shapes
-  # from each other.
   lhs_type = RankedTensorType.get([m, k], dtype)
   rhs_type = RankedTensorType.get([k, n], dtype)
   result_type = RankedTensorType.get([m, n], dtype)
index 63bdd0c..3c9f79e 100644 (file)
@@ -2832,7 +2832,7 @@ public:
   static void bindDerived(ClassTy &c) {
     c.def_static(
          "get",
-         [](PyType &elementType, std::vector<int64_t> shape,
+         [](std::vector<int64_t> shape, PyType &elementType,
             std::vector<PyAffineMap> layout, unsigned memorySpace,
             DefaultingPyLocation loc) {
            SmallVector<MlirAffineMap> maps;
@@ -2856,7 +2856,7 @@ public:
            }
            return PyMemRefType(elementType.getContext(), t);
          },
-         py::arg("element_type"), py::arg("shape"),
+         py::arg("shape"), py::arg("element_type"),
          py::arg("layout") = py::list(), py::arg("memory_space") = 0,
          py::arg("loc") = py::none(), "Create a memref type")
         .def_property_readonly("layout", &PyMemRefType::getLayout,
index 64b684e..7402c64 100644 (file)
@@ -326,7 +326,7 @@ def testMemRefType():
     f32 = F32Type.get()
     shape = [2, 3]
     loc = Location.unknown()
-    memref = MemRefType.get(f32, shape, memory_space=2)
+    memref = MemRefType.get(shape, f32, memory_space=2)
     # CHECK: memref type: memref<2x3xf32, 2>
     print("memref type:", memref)
     # CHECK: number of affine layout maps: 0
@@ -335,7 +335,7 @@ def testMemRefType():
     print("memory space:", memref.memory_space)
 
     layout = AffineMap.get_permutation([1, 0])
-    memref_layout = MemRefType.get(f32, shape, [layout])
+    memref_layout = MemRefType.get(shape, f32, [layout])
     # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
     print("memref type:", memref_layout)
     assert len(memref_layout.layout) == 1
@@ -346,7 +346,7 @@ def testMemRefType():
 
     none = NoneType.get()
     try:
-      memref_invalid = MemRefType.get(none, shape)
+      memref_invalid = MemRefType.get(shape, none)
     except ValueError as e:
       # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
       # CHECK: or complex type.