[mlir][Linalg][Python] Create the body of builtin named Linalg ops
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 30 Mar 2021 11:41:41 +0000 (11:41 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 31 Mar 2021 07:58:32 +0000 (07:58 +0000)
This revision adds support to properly add the body of registered
builtin named linalg ops.
At this time, indexing_map and iterator_type support is still
missing so the op is not executable yet.

Differential Revision: https://reviews.llvm.org/D99578

12 files changed:
mlir/include/mlir-c/Dialect/Linalg.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
mlir/lib/Bindings/Python/CMakeLists.txt
mlir/lib/Bindings/Python/DialectLinalg.cpp [new file with mode: 0644]
mlir/lib/Bindings/Python/DialectLinalg.h [new file with mode: 0644]
mlir/lib/Bindings/Python/MainModule.cpp
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/lib/CAPI/Dialect/Linalg.cpp
mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
mlir/test/Bindings/Python/dialects/linalg/ops.py

index be73a5c..06f15f0 100644 (file)
 extern "C" {
 #endif
 
+/// Apply the special region builder for the builtin named Linalg op.
+/// Assert that `op` is a builtin named Linalg op.
+MLIR_CAPI_EXPORTED void
+mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);
+
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
 
 #ifdef __cplusplus
index 5a906ff..007cb6d 100644 (file)
@@ -37,6 +37,14 @@ def Linalg_Dialect : Dialect {
   let dependentDialects = [
     "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
   ];
+  let extraClassDeclaration = [{
+    using RegionBuilderFunType = llvm::function_ref<void(Block &, ValueRange)>;
+    RegionBuilderFunType getRegionBuilder(StringRef name) {
+      return namedStructuredOpRegionBuilders.lookup(name);
+    }
+    private:
+      llvm::StringMap<RegionBuilderFunType> namedStructuredOpRegionBuilders;
+  }];
 }
 
 // Whether a type is a RangeType.
index 71ac601..d94e43b 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Types.h"
+#include "llvm/ADT/StringMap.h"
 
 #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
 
index 43d6275..39192cc 100644 (file)
@@ -69,6 +69,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
   INSTALL_DIR
     python
   SOURCES
+    DialectLinalg.cpp
     MainModule.cpp
     IRAffine.cpp
     IRAttributes.cpp
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
new file mode 100644 (file)
index 0000000..e4ef694
--- /dev/null
@@ -0,0 +1,34 @@
+//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+#include "mlir-c/Dialect/Linalg.h"
+#include "mlir-c/IR.h"
+
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+
+namespace mlir {
+namespace python {
+
+void populateDialectLinalgSubmodule(py::module &m) {
+  m.def(
+      "fill_builtin_region",
+      [](PyDialectDescriptor &dialect, PyOperation &op) {
+        return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
+      },
+      py::arg("dialect"), py::arg("op"),
+      "Fill the region for `op`, which is assumed to be a builtin named Linalg "
+      "op.");
+}
+
+} // namespace python
+} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.h b/mlir/lib/Bindings/Python/DialectLinalg.h
new file mode 100644 (file)
index 0000000..3735dbf
--- /dev/null
@@ -0,0 +1,22 @@
+//===- DialectLinalg.h - Linalg dialect submodule of pybind module --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
+#define MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
+
+#include "PybindUtils.h"
+
+namespace mlir {
+namespace python {
+
+void populateDialectLinalgSubmodule(pybind11::module &m);
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
index 5fe0401..79128f2 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "PybindUtils.h"
 
+#include "DialectLinalg.h"
 #include "ExecutionEngine.h"
 #include "Globals.h"
 #include "IRModule.h"
@@ -225,4 +226,9 @@ PYBIND11_MODULE(_mlir, m) {
   auto executionEngineModule =
       m.def_submodule("execution_engine", "MLIR JIT Execution Engine");
   populateExecutionEngineSubmodule(executionEngineModule);
+
+  // Define and populate Linalg submodule.
+  auto dialectsModule = m.def_submodule("dialects");
+  auto linalgModule = dialectsModule.def_submodule("linalg");
+  populateDialectLinalgSubmodule(linalgModule);
 }
index d6dc989..002ae51 100644 (file)
@@ -61,11 +61,10 @@ class DefinedOpCallable:
       raise NotImplementedError(
           f"Emission of composite linalg ops not supported: {op_configs}")
 
-    # TODO: this file should probably not be called dsl.py but rather is a client
-    # of the dsl.py.
-    from .... import linalg as linalg_ops
-    emit_generic = (emit_generic or 
-      (not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys()))
+    ctx = ir.Context.current
+    linalgDialect = ctx.get_dialect_descriptor("linalg")
+    fully_qualified_name = 'linalg.' + self.op_name
+    emit_generic = (emit_generic or not ctx.is_registered_operation(fully_qualified_name))
 
     op_config = op_configs[0]
     if op_config.structured_op:
index e8e7eb5..2395a42 100644 (file)
@@ -7,6 +7,9 @@ from typing import Dict, Sequence
 from mlir.ir import *
 from mlir.dialects import linalg
 from mlir.dialects import std
+# TODO: resolve name collision for Linalg functionality that is injected inside
+# the _mlir.dialects.linalg directly via pybind.
+from _mlir.dialects.linalg import fill_builtin_region
 
 from .scalar_expr import *
 from .config import *
@@ -16,7 +19,6 @@ __all__ = [
     "emit_named_structured_op",
 ]
 
-
 def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
                                  *ins: Value,
                                  outs: Value):
@@ -97,11 +99,18 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
   type_mapping, indexing_maps_attr, iterator_types_attr =   \
      prepare_common_structured_op(op_config, *ins, outs = outs)
 
-  if not op_class_name in linalg.__dict__.keys():
+  # If we get here, there must exist a builtin class `op_class_name`.
+  ctx = Context.current
+  fully_qualified_name = 'linalg.' + op_name
+  if (not ctx.is_registered_operation(fully_qualified_name) or
+      not op_class_name in linalg.__dict__.keys()):
     raise NotImplementedError(
         f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
 
   named_op = getattr(linalg, op_class_name)(ins, outs, out_types)
+  linalgDialect = ctx.get_dialect_descriptor("linalg")
+  fill_builtin_region(linalgDialect, named_op.operation)
+
   if len(out_arg_defs) == 1:
     return named_op.result
   else:
index da6fd48..1c50aa6 100644 (file)
 #include "mlir/CAPI/Registration.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 
-MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg,
-                                      mlir::linalg::LinalgDialect)
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Apply the special region builder for the builtin named Linalg op.
+/// Assert that `op` is a builtin named Linalg op.
+void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
+                                        MlirOperation mlirOp) {
+  Operation *op = unwrap(mlirOp);
+  LinalgDialect::RegionBuilderFunType fun =
+      static_cast<LinalgDialect *>(unwrap(linalgDialect))
+          ->getRegionBuilder(op->getName().getStringRef());
+  assert(fun && "Expected a builtin named Linalg op.");
+  assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region");
+  assert(op->getRegion(0).getBlocks().empty() &&
+         "Expected Linalg op with 0 blocks");
+  SmallVector<Type, 8> argTypes;
+  auto linalgOp = cast<LinalgOp>(op);
+  for (auto t : linalgOp.getShapedOperandTypes())
+    argTypes.push_back(getElementTypeOrSelf(t));
+  OpBuilder b(op->getContext());
+  Region &region = op->getRegion(0);
+  Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes);
+  // TODO: allow captures.
+  fun(*body, ValueRange{});
+}
+
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
index 8cd2d4f..2288f73 100644 (file)
@@ -57,6 +57,38 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
 // LinalgDialect
 //===----------------------------------------------------------------------===//
 
+/// Trait to check if T provides a `regionBuilder` method.
+template <typename T, typename... Args>
+using has_region_builder = decltype(T::regionBuilder);
+template <typename T>
+using detect_has_region_builder = llvm::is_detected<has_region_builder, T>;
+
+/// SFINAE helper for single C++ class without a `regionBuilder` method (e.g.
+/// an OpInterface).
+template <typename OpType, typename = std::enable_if_t<
+                               !detect_has_region_builder<OpType>::value>>
+void addNamedOpBuilderImpl(
+    llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
+  // Do nothing.
+}
+
+template <typename OpType,
+          typename = std::enable_if_t<detect_has_region_builder<OpType>::value>,
+          typename = void>
+void addNamedOpBuilderImpl(
+    llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
+  map.insert(std::make_pair(
+      OpType::getOperationName(),
+      static_cast<LinalgDialect::RegionBuilderFunType>(OpType::regionBuilder)));
+}
+
+template <typename... OpTypes>
+void addNamedOpBuilders(
+    llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
+  (void)std::initializer_list<int>{0,
+                                   (addNamedOpBuilderImpl<OpTypes>(map), 0)...};
+}
+
 void mlir::linalg::LinalgDialect::initialize() {
   addTypes<RangeType>();
   addOperations<
@@ -72,6 +104,12 @@ void mlir::linalg::LinalgDialect::initialize() {
 #include "mlir/Dialect/Linalg/IR/LinalgSparseOps.cpp.inc"
       >();
 
+  // Fill the Linalg-specific OpName to RegionBuilder map.
+  addNamedOpBuilders<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+      >(namedStructuredOpRegionBuilders);
+
   addInterfaces<LinalgInlinerInterface>();
 }
 
index 8f2eb06..489aa63 100644 (file)
@@ -5,7 +5,6 @@ from mlir.dialects import builtin
 from mlir.dialects import linalg
 from mlir.dialects import std
 
-
 def run(f):
   print("\nTEST:", f.__name__)
   f()
@@ -82,9 +81,9 @@ def testStructuredOpOnBuffers():
   # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
   print(module)
 
-# CHECK-LABEL: TEST: testNamedStructuredOp
+# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
 @run
-def testNamedStructuredOp():
+def testNamedStructuredOpCustomForm():
   with Context() as ctx, Location.unknown():
     module = Module.create()
     f32 = F32Type.get()
@@ -93,10 +92,45 @@ def testNamedStructuredOp():
                                    RankedTensorType.get((16, 8), f32))
       def named_form(lhs, rhs):
         init_result = linalg.InitTensorOp([4, 8], f32)
-        # CHECK: linalg.matmul
-        # TODO: prperly hook up the region.
+        # First check the named form with custom format
+        #      CHECK: linalg.matmul
+        # CHECK-SAME:    ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
+        # CHECK-SAME:   outs(%{{.*}} : tensor<4x8xf32>)
+        # CHECK-SAME:   -> tensor<4x8xf32>
+        # CHECK-NEXT: return
         return linalg.matmul(lhs, rhs, outs=[init_result.result])
 
+  print(module)
+
+# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
+@run
+def testNamedStructuredOpGenericForm():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    with InsertionPoint(module.body):
+      @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
+                                   RankedTensorType.get((16, 8), f32))
+      def named_form(lhs, rhs):
+        init_result = linalg.InitTensorOp([4, 8], f32)
+        #      CHECK: "linalg.matmul"(%{{.*}})
+        # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
+        # CHECK-NEXT:    std.mulf{{.*}} (f32, f32) -> f32
+        # CHECK-NEXT:    std.addf{{.*}} (f32, f32) -> f32
+        # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
+        # CHECK-NEXT:    {operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : 
+        # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
+        return linalg.matmul(lhs, rhs, outs=[init_result.result])
+
+  module.operation.print(print_generic_op_form=True)
+
+# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
+@run
+def testNamedStructuredAsGenericOp():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    with InsertionPoint(module.body):
       @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
                                    RankedTensorType.get((16, 8), f32))
       def generic_form(lhs, rhs):