From f345f7e30bd3a8e15052f5669c1977aa088e468f Mon Sep 17 00:00:00 2001 From: gysit Date: Tue, 8 Mar 2022 17:30:06 +0000 Subject: [PATCH] [mlir][OpDSL] Support pointwise ops with rank zero inputs. Allow pointwise operations to take rank zero input tensors similarly to scalar inputs. Use an empty indexing map to broadcast rank zero tensors to the iteration domain of the operation. Depends On D120734 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120807 --- mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py | 6 +++++- .../Dialect/Linalg/generalize-named-polymorphic-ops.mlir | 15 +++++++++++++++ mlir/test/python/dialects/linalg/opdsl/emit_fill.py | 13 +++++++++++++ mlir/test/python/integration/dialects/linalg/opsrun.py | 16 ++++++++-------- .../mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp | 2 +- 5 files changed, 42 insertions(+), 10 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index ff5c405..93baef1 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -187,7 +187,11 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, if arg_def.operand_def.kind == OperandKind.SCALAR: indexing_maps.append(scalar_map) if arg_def.operand_def.is_tensor(): - indexing_maps.append(tensor_map) + idx = arg_def.operand_def.registered_index + if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: + indexing_maps.append(scalar_map) + else: + indexing_maps.append(tensor_map) indexing_maps_attr = ArrayAttr.get( [AffineMapAttr.get(am) for am in indexing_maps]) diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir index 425eeb8..0c98629 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -320,3 +320,18 @@ func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %o // CHECK-LABEL: @generalize_elemwise_mul // CHECK: = arith.mulf + +// ----- + +// Verifies pointwise ops support rank zero input tensors +func @generalize_elemwise_rank_zero(%lhs : tensor, %rhs : tensor, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = linalg.elemwise_binary {fun = #linalg.binary_fn} + ins(%lhs, %rhs: tensor, tensor) + outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> + return %0: tensor<4x8xf32> +} + +// CHECK-LABEL: @generalize_elemwise_rank_zero +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK: = arith.subf diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py index 814a6d2..55ca50b 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py @@ -15,6 +15,9 @@ T2 = TV.T2 def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)): O[None] = TypeFn.cast_signed(U, value) +@linalg_structured_op +def fill_rank_zero_poly(I=TensorDef(T1), O=TensorDef(U, output=True)): + O[None] = TypeFn.cast_signed(U, I[None]) with Context() as ctx, Location.unknown(): module = Module.create() @@ -25,6 +28,8 @@ with Context() as ctx, Location.unknown(): # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()> # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()> # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> + # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()> + # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> # CHECK-LABEL: @test_fill_0d # CHECK: linalg.generic @@ -42,5 +47,13 @@ with Context() as ctx, Location.unknown(): def test_fill_2d(value, init_result): return fill_poly(value, outs=[init_result]) + # CHECK-LABEL: @test_fill_rank_zero_3d + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] + @builtin.FuncOp.from_py_func( + RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32)) + def test_fill_rank_zero_3d(input, init_result): + return fill_rank_zero_poly(input, outs=[init_result]) print(module) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py index 458780d..279af2b 100644 --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -25,19 +25,19 @@ func @main() -> f32 attributes {llvm.emit_c_interface} { %v1 = arith.constant 1.0 : f32 %v2 = arith.constant 2.0 : f32 - %lhs = memref.alloc() : memref<4x8xf32> + %lhs = memref.alloc() : memref %rhs = memref.alloc() : memref<4x8xf32> %O0 = memref.alloc() : memref<4x8xf32> %O1 = memref.alloc() : memref<4x8xf32> - linalg.fill(%v1, %lhs) : f32, memref<4x8xf32> + linalg.fill(%v1, %lhs) : f32, memref linalg.fill(%v2, %rhs) : f32, memref<4x8xf32> linalg.fill(%v0, %O0) : f32, memref<4x8xf32> linalg.fill(%v0, %O1) : f32, memref<4x8xf32> call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) : - (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> () + (memref, memref<4x8xf32>, memref<4x8xf32>) -> () call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) : - (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> () + (memref, memref<4x8xf32>, memref<4x8xf32>) -> () %c0 = arith.constant 0 : index %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32> @@ -212,14 +212,14 @@ def test_elemwise_builtin(): with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( - MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((), f32), MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32)) def elemwise_exp_add_on_buffers(lhs, rhs, out): linalg.elemwise_unary(lhs, outs=[out]) linalg.elemwise_binary(out, rhs, outs=[out]) @builtin.FuncOp.from_py_func( - MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((), f32), MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32)) def elemwise_log_mul_on_buffers(lhs, rhs, out): linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log) @@ -251,14 +251,14 @@ def test_elemwise_generic(): with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( - MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((), f32), MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32)) def elemwise_exp_add_on_buffers(lhs, rhs, out): linalg.elemwise_unary(lhs, outs=[out], emit_generic=True) linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True) @builtin.FuncOp.from_py_func( - MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((), f32), MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32)) def elemwise_log_mul_on_buffers(lhs, rhs, out): linalg.elemwise_unary( diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 5cade2a..a696360 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -678,7 +678,7 @@ ArrayAttr {0}::indexing_maps() {{ getNumParallelLoops(), context); SmallVector indexingMaps; for (OpOperand *opOperand : getInputAndOutputOperands()) - indexingMaps.push_back(isScalar(opOperand) ? scalarMap : tensorMap); + indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap); return Builder(getContext()).getAffineMapArrayAttr(indexingMaps); } )FMT"; -- 2.7.4