[mlir][OpDSL] Support pointwise ops with rank zero inputs.
authorgysit <gysit@google.com>
Tue, 8 Mar 2022 17:30:06 +0000 (17:30 +0000)
committergysit <gysit@google.com>
Tue, 8 Mar 2022 17:39:47 +0000 (17:39 +0000)
Allow pointwise operations to take rank zero input tensors similarly to scalar inputs. Use an empty indexing map to broadcast rank zero tensors to the iteration domain of the operation.

Depends On D120734

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D120807

mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/python/dialects/linalg/opdsl/emit_fill.py
mlir/test/python/integration/dialects/linalg/opsrun.py
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

index ff5c405..93baef1 100644 (file)
@@ -187,7 +187,11 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
       if arg_def.operand_def.kind == OperandKind.SCALAR:
         indexing_maps.append(scalar_map)
       if arg_def.operand_def.is_tensor():
-        indexing_maps.append(tensor_map)
+        idx = arg_def.operand_def.registered_index
+        if idx < len(ins) and ShapedType(ins[idx].type).rank == 0:
+          indexing_maps.append(scalar_map)
+        else:
+          indexing_maps.append(tensor_map)
     indexing_maps_attr = ArrayAttr.get(
         [AffineMapAttr.get(am) for am in indexing_maps])
 
index 425eeb8..0c98629 100644 (file)
@@ -320,3 +320,18 @@ func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %o
 
 // CHECK-LABEL: @generalize_elemwise_mul
 // CHECK:        = arith.mulf
+
+// -----
+
+// Verifies pointwise ops support rank zero input tensors
+func @generalize_elemwise_rank_zero(%lhs : tensor<f32>, %rhs : tensor<f32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<sub>}
+                              ins(%lhs, %rhs: tensor<f32>, tensor<f32>)
+                              outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_rank_zero
+// CHECK:       linalg.generic
+// CHECK-SAME:  iterator_types = ["parallel", "parallel"]
+// CHECK:        = arith.subf
index 814a6d2..55ca50b 100644 (file)
@@ -15,6 +15,9 @@ T2 = TV.T2
 def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
   O[None] = TypeFn.cast_signed(U, value)
 
+@linalg_structured_op
+def fill_rank_zero_poly(I=TensorDef(T1), O=TensorDef(U, output=True)):
+  O[None] = TypeFn.cast_signed(U, I[None])
 
 with Context() as ctx, Location.unknown():
   module = Module.create()
@@ -25,6 +28,8 @@ with Context() as ctx, Location.unknown():
     # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
     # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
     # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+    # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()>
+    # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 
     # CHECK-LABEL: @test_fill_0d
     # CHECK: linalg.generic
@@ -42,5 +47,13 @@ with Context() as ctx, Location.unknown():
     def test_fill_2d(value, init_result):
       return fill_poly(value, outs=[init_result])
 
+    # CHECK-LABEL: @test_fill_rank_zero_3d
+    # CHECK: linalg.generic
+    # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]]
+    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32))
+    def test_fill_rank_zero_3d(input, init_result):
+      return fill_rank_zero_poly(input, outs=[init_result])
 
 print(module)
index 458780d..279af2b 100644 (file)
@@ -25,19 +25,19 @@ func @main() -> f32 attributes {llvm.emit_c_interface} {
   %v1 = arith.constant 1.0 : f32
   %v2 = arith.constant 2.0 : f32
 
-  %lhs = memref.alloc() : memref<4x8xf32>
+  %lhs = memref.alloc() : memref<f32>
   %rhs = memref.alloc() : memref<4x8xf32>
   %O0 = memref.alloc() : memref<4x8xf32>
   %O1 = memref.alloc() : memref<4x8xf32>
-  linalg.fill(%v1, %lhs) : f32, memref<4x8xf32>
+  linalg.fill(%v1, %lhs) : f32, memref<f32>
   linalg.fill(%v2, %rhs) : f32, memref<4x8xf32>
   linalg.fill(%v0, %O0) : f32, memref<4x8xf32>
   linalg.fill(%v0, %O1) : f32, memref<4x8xf32>
 
   call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
-    (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
+    (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
   call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
-    (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
+    (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
 
   %c0 = arith.constant 0 : index
   %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
@@ -212,14 +212,14 @@ def test_elemwise_builtin():
     with InsertionPoint(module.body):
 
       @builtin.FuncOp.from_py_func(
-          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
           MemRefType.get((4, 8), f32))
       def elemwise_exp_add_on_buffers(lhs, rhs, out):
         linalg.elemwise_unary(lhs, outs=[out])
         linalg.elemwise_binary(out, rhs, outs=[out])
 
       @builtin.FuncOp.from_py_func(
-          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
           MemRefType.get((4, 8), f32))
       def elemwise_log_mul_on_buffers(lhs, rhs, out):
         linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
@@ -251,14 +251,14 @@ def test_elemwise_generic():
     with InsertionPoint(module.body):
 
       @builtin.FuncOp.from_py_func(
-          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
           MemRefType.get((4, 8), f32))
       def elemwise_exp_add_on_buffers(lhs, rhs, out):
         linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
         linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
 
       @builtin.FuncOp.from_py_func(
-          MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
+          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
           MemRefType.get((4, 8), f32))
       def elemwise_log_mul_on_buffers(lhs, rhs, out):
         linalg.elemwise_unary(
index 5cade2a..a696360 100644 (file)
@@ -678,7 +678,7 @@ ArrayAttr {0}::indexing_maps() {{
     getNumParallelLoops(), context);
   SmallVector<AffineMap> indexingMaps;
   for (OpOperand *opOperand : getInputAndOutputOperands())
-    indexingMaps.push_back(isScalar(opOperand) ? scalarMap : tensorMap);
+    indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap);
   return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
 }
 )FMT";