[mlir][Python][Linalg] Add support for captures in body builder.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 13 Apr 2021 06:25:47 +0000 (06:25 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 16 Apr 2021 08:47:26 +0000 (08:47 +0000)
When Linalg named ops support was added, captures were omitted
from the body builder. This revision adds support for captures
which allows us to write FillOp in a more idiomatic fashion using
the _linalg_ops_ext mixin support.

This raises an issue in the generation of `_linalg_ops_gen.py` where
```
  @property
  def result(self):
    return self.operation.results[0] if len(self.operation.results) > 1 else None
```.
The condition should be `== 1`.

This will be fixed in a separate commit.

Differential Revision: https://reviews.llvm.org/D100363

mlir/include/mlir-c/Dialect/Linalg.h
mlir/lib/Bindings/Python/DialectLinalg.cpp
mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py
mlir/lib/CAPI/Dialect/Linalg.cpp
mlir/test/Bindings/Python/dialects/linalg/ops.py

index 06f15f0..6e20eec 100644 (file)
@@ -1,11 +1,11 @@
-//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect --------*- C -*-===//
+//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect -------*- C -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
 // Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-//===----------------------------------------------------------------------===//
+//===---------------------------------------------------------------------===//
 
 #ifndef MLIR_C_DIALECT_LINALG_H
 #define MLIR_C_DIALECT_LINALG_H
@@ -18,9 +18,11 @@ extern "C" {
 #endif
 
 /// Apply the special region builder for the builtin named Linalg op.
+/// The list of `capture` MlirValue is passed as-is to the region builder.
 /// Assert that `op` is a builtin named Linalg op.
 MLIR_CAPI_EXPORTED void
-mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);
+mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op,
+                                   intptr_t n, MlirValue const *mlirCaptures);
 
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
 
index e4ef694..849a003 100644 (file)
@@ -22,10 +22,15 @@ namespace python {
 void populateDialectLinalgSubmodule(py::module &m) {
   m.def(
       "fill_builtin_region",
-      [](PyDialectDescriptor &dialect, PyOperation &op) {
-        return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
+      [](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) {
+        llvm::SmallVector<MlirValue, 4> mlirOperands;
+        mlirOperands.reserve(captures.size());
+        for (auto v : captures)
+          mlirOperands.push_back(py::cast<PyValue *>(v)->get());
+        mlirLinalgFillBuiltinNamedOpRegion(
+            dialect.get(), op.get(), mlirOperands.size(), mlirOperands.data());
       },
-      py::arg("dialect"), py::arg("op"),
+      py::arg("dialect"), py::arg("op"), py::arg("captures") = py::list(),
       "Fill the region for `op`, which is assumed to be a builtin named Linalg "
       "op.");
 }
index d787943..4714e69 100644 (file)
@@ -5,6 +5,47 @@
 from typing import Optional, Sequence, Union
 from ..ir import *
 from ._ods_common import get_default_loc_context
+# TODO: resolve name collision for Linalg functionality that is injected inside
+# the _mlir.dialects.linalg directly via pybind.
+from _mlir.dialects.linalg import fill_builtin_region
+
+
+def isa(cls : Type, ty : Type):
+  try:
+    cls(ty)
+    return True
+  except ValueError:
+    return False
+
+
+class FillOp:
+  """Extends the linalg.fill op."""
+
+  def __init__(self,
+               output: Value,
+               value: Value,
+               *,
+               loc=None,
+               ip=None):
+    results = []
+    if isa(RankedTensorType, output.type):
+      results = [output.type]
+    op = self.build_generic(results=results,
+                            operands=[output, value],
+                            attributes=None,
+                            loc=loc,
+                            ip=ip)
+    OpView.__init__(self, op)
+    linalgDialect = Context.current.get_dialect_descriptor("linalg")
+    fill_builtin_region(linalgDialect, self.operation, [value])
+    # TODO: self.result is None. When len(results) == 1 we expect it to be
+    # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
+    # in the generator of _linalg_ops_gen.py where we have:
+    # ```
+    # def result(self):
+    #   return self.operation.results[0] \
+    #     if len(self.operation.results) > 1 else None
+    # ```
 
 
 class InitTensorOp:
index 1c50aa6..6f6e090 100644 (file)
@@ -8,6 +8,7 @@
 
 #include "mlir-c/Dialect/Linalg.h"
 #include "mlir/CAPI/Registration.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 
 using namespace mlir;
@@ -16,8 +17,14 @@ using namespace mlir::linalg;
 /// Apply the special region builder for the builtin named Linalg op.
 /// Assert that `op` is a builtin named Linalg op.
 void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
-                                        MlirOperation mlirOp) {
+                                        MlirOperation mlirOp, intptr_t n,
+                                        MlirValue const *mlirCaptures) {
   Operation *op = unwrap(mlirOp);
+  SmallVector<Value> captures;
+  captures.reserve(n);
+  for (unsigned idx = 0; idx < n; ++idx)
+    captures.push_back(unwrap(mlirCaptures[idx]));
+
   LinalgDialect::RegionBuilderFunType fun =
       static_cast<LinalgDialect *>(unwrap(linalgDialect))
           ->getRegionBuilder(op->getName().getStringRef());
@@ -25,15 +32,18 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
   assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region");
   assert(op->getRegion(0).getBlocks().empty() &&
          "Expected Linalg op with 0 blocks");
+
   SmallVector<Type, 8> argTypes;
   auto linalgOp = cast<LinalgOp>(op);
   for (auto t : linalgOp.getShapedOperandTypes())
     argTypes.push_back(getElementTypeOrSelf(t));
+
   OpBuilder b(op->getContext());
   Region &region = op->getRegion(0);
   Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes);
-  // TODO: allow captures.
-  fun(*body, ValueRange{});
+  b.setInsertionPointToStart(body);
+  mlir::edsc::ScopedContext scope(b, op->getLoc());
+  fun(*body, captures);
 }
 
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
index afcb582..f153ecb 100644 (file)
@@ -38,6 +38,40 @@ def testInitTensor():
 
   print(module)
 
+# CHECK-LABEL: TEST: testFill
+@run
+def testFill():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    with InsertionPoint(module.body):
+      # CHECK-LABEL: func @fill_tensor
+      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
+      #  CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32
+      #  CHECK-NEXT: %[[RES:.*]] = linalg.fill(%[[OUT]], %[[CST]]) : tensor<12x?xf32>, f32 -> tensor<12x?xf32>
+      #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
+      @builtin.FuncOp.from_py_func(
+          RankedTensorType.get((12, -1), f32))
+      def fill_tensor(out):
+        zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
+        # TODO: FillOp.result is None. When len(results) == 1 we expect it to
+        # be results[0] as per _linalg_ops_gen.py. This seems like an
+        # orthogonal bug in the generator of _linalg_ops_gen.py.
+        return linalg.FillOp(output=out, value=zero).results[0]
+
+      # CHECK-LABEL: func @fill_buffer
+      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
+      #  CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32
+      #  CHECK-NEXT: linalg.fill(%[[OUT]], %[[CST]]) : memref<12x?xf32>, f32
+      #  CHECK-NEXT: return
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((12, -1), f32))
+      def fill_buffer(out):
+        zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
+        linalg.FillOp(output=out, value=zero)
+
+  print(module)
+
 
 # CHECK-LABEL: TEST: testStructuredOpOnTensors
 @run