* Matches how all of the other shaped types are declared.
* No super principled reason fro this ordering beyond that it makes the one that was different be like the rest.
* Also matches ordering of things like ndarray, et al.
Reviewed By: ftynse, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D94812
def build_matmul_buffers_func(func_name, m, k, n, dtype):
- lhs_type = MemRefType.get(dtype, [m, k])
- rhs_type = MemRefType.get(dtype, [k, n])
- result_type = MemRefType.get(dtype, [m, n])
+ lhs_type = MemRefType.get([m, k], dtype)
+ rhs_type = MemRefType.get([k, n], dtype)
+ result_type = MemRefType.get([m, n], dtype)
# TODO: There should be a one-liner for this.
func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
_, entry = FuncOp(func_name, func_type)
def build_matmul_tensors_func(func_name, m, k, n, dtype):
- # TODO: MemRefType and TensorTypes should not have inverted dtype/shapes
- # from each other.
lhs_type = RankedTensorType.get([m, k], dtype)
rhs_type = RankedTensorType.get([k, n], dtype)
result_type = RankedTensorType.get([m, n], dtype)
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](PyType &elementType, std::vector<int64_t> shape,
+ [](std::vector<int64_t> shape, PyType &elementType,
std::vector<PyAffineMap> layout, unsigned memorySpace,
DefaultingPyLocation loc) {
SmallVector<MlirAffineMap> maps;
}
return PyMemRefType(elementType.getContext(), t);
},
- py::arg("element_type"), py::arg("shape"),
+ py::arg("shape"), py::arg("element_type"),
py::arg("layout") = py::list(), py::arg("memory_space") = 0,
py::arg("loc") = py::none(), "Create a memref type")
.def_property_readonly("layout", &PyMemRefType::getLayout,
f32 = F32Type.get()
shape = [2, 3]
loc = Location.unknown()
- memref = MemRefType.get(f32, shape, memory_space=2)
+ memref = MemRefType.get(shape, f32, memory_space=2)
# CHECK: memref type: memref<2x3xf32, 2>
print("memref type:", memref)
# CHECK: number of affine layout maps: 0
print("memory space:", memref.memory_space)
layout = AffineMap.get_permutation([1, 0])
- memref_layout = MemRefType.get(f32, shape, [layout])
+ memref_layout = MemRefType.get(shape, f32, [layout])
# CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
print("memref type:", memref_layout)
assert len(memref_layout.layout) == 1
none = NoneType.get()
try:
- memref_invalid = MemRefType.get(none, shape)
+ memref_invalid = MemRefType.get(shape, none)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.