extern "C" {
#endif
+/// Apply the special region builder for the builtin named Linalg op.
+/// Assert that `op` is a builtin named Linalg op.
+MLIR_CAPI_EXPORTED void
+mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);
+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
#ifdef __cplusplus
let dependentDialects = [
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
];
+ let extraClassDeclaration = [{
+ using RegionBuilderFunType = llvm::function_ref<void(Block &, ValueRange)>;
+ RegionBuilderFunType getRegionBuilder(StringRef name) {
+ return namedStructuredOpRegionBuilders.lookup(name);
+ }
+ private:
+ llvm::StringMap<RegionBuilderFunType> namedStructuredOpRegionBuilders;
+ }];
}
// Whether a type is a RangeType.
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Types.h"
+#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
INSTALL_DIR
python
SOURCES
+ DialectLinalg.cpp
MainModule.cpp
IRAffine.cpp
IRAttributes.cpp
--- /dev/null
+//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+#include "mlir-c/Dialect/Linalg.h"
+#include "mlir-c/IR.h"
+
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+
+namespace mlir {
+namespace python {
+
+void populateDialectLinalgSubmodule(py::module &m) {
+ m.def(
+ "fill_builtin_region",
+ [](PyDialectDescriptor &dialect, PyOperation &op) {
+ return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
+ },
+ py::arg("dialect"), py::arg("op"),
+ "Fill the region for `op`, which is assumed to be a builtin named Linalg "
+ "op.");
+}
+
+} // namespace python
+} // namespace mlir
--- /dev/null
+//===- DialectLinalg.h - Linalg dialect submodule of pybind module --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
+#define MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
+
+#include "PybindUtils.h"
+
+namespace mlir {
+namespace python {
+
+void populateDialectLinalgSubmodule(pybind11::module &m);
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
#include "PybindUtils.h"
+#include "DialectLinalg.h"
#include "ExecutionEngine.h"
#include "Globals.h"
#include "IRModule.h"
auto executionEngineModule =
m.def_submodule("execution_engine", "MLIR JIT Execution Engine");
populateExecutionEngineSubmodule(executionEngineModule);
+
+ // Define and populate Linalg submodule.
+ auto dialectsModule = m.def_submodule("dialects");
+ auto linalgModule = dialectsModule.def_submodule("linalg");
+ populateDialectLinalgSubmodule(linalgModule);
}
raise NotImplementedError(
f"Emission of composite linalg ops not supported: {op_configs}")
- # TODO: this file should probably not be called dsl.py but rather is a client
- # of the dsl.py.
- from .... import linalg as linalg_ops
- emit_generic = (emit_generic or
- (not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys()))
+ ctx = ir.Context.current
+ linalgDialect = ctx.get_dialect_descriptor("linalg")
+ fully_qualified_name = 'linalg.' + self.op_name
+ emit_generic = (emit_generic or not ctx.is_registered_operation(fully_qualified_name))
op_config = op_configs[0]
if op_config.structured_op:
from mlir.ir import *
from mlir.dialects import linalg
from mlir.dialects import std
+# TODO: resolve name collision for Linalg functionality that is injected inside
+# the _mlir.dialects.linalg directly via pybind.
+from _mlir.dialects.linalg import fill_builtin_region
from .scalar_expr import *
from .config import *
"emit_named_structured_op",
]
-
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value,
outs: Value):
type_mapping, indexing_maps_attr, iterator_types_attr = \
prepare_common_structured_op(op_config, *ins, outs = outs)
- if not op_class_name in linalg.__dict__.keys():
+ # If we get here, there must exist a builtin class `op_class_name`.
+ ctx = Context.current
+ fully_qualified_name = 'linalg.' + op_name
+ if (not ctx.is_registered_operation(fully_qualified_name) or
+ not op_class_name in linalg.__dict__.keys()):
raise NotImplementedError(
f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
named_op = getattr(linalg, op_class_name)(ins, outs, out_types)
+ linalgDialect = ctx.get_dialect_descriptor("linalg")
+ fill_builtin_region(linalgDialect, named_op.operation)
+
if len(out_arg_defs) == 1:
return named_op.result
else:
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg,
- mlir::linalg::LinalgDialect)
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Apply the special region builder for the builtin named Linalg op.
+/// Assert that `op` is a builtin named Linalg op.
+void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
+ MlirOperation mlirOp) {
+ Operation *op = unwrap(mlirOp);
+ LinalgDialect::RegionBuilderFunType fun =
+ static_cast<LinalgDialect *>(unwrap(linalgDialect))
+ ->getRegionBuilder(op->getName().getStringRef());
+ assert(fun && "Expected a builtin named Linalg op.");
+ assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region");
+ assert(op->getRegion(0).getBlocks().empty() &&
+ "Expected Linalg op with 0 blocks");
+ SmallVector<Type, 8> argTypes;
+ auto linalgOp = cast<LinalgOp>(op);
+ for (auto t : linalgOp.getShapedOperandTypes())
+ argTypes.push_back(getElementTypeOrSelf(t));
+ OpBuilder b(op->getContext());
+ Region ®ion = op->getRegion(0);
+ Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes);
+ // TODO: allow captures.
+ fun(*body, ValueRange{});
+}
+
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
// LinalgDialect
//===----------------------------------------------------------------------===//
+/// Trait to check if T provides a `regionBuilder` method.
+template <typename T, typename... Args>
+using has_region_builder = decltype(T::regionBuilder);
+template <typename T>
+using detect_has_region_builder = llvm::is_detected<has_region_builder, T>;
+
+/// SFINAE helper for single C++ class without a `regionBuilder` method (e.g.
+/// an OpInterface).
+template <typename OpType, typename = std::enable_if_t<
+ !detect_has_region_builder<OpType>::value>>
+void addNamedOpBuilderImpl(
+ llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
+ // Do nothing.
+}
+
+template <typename OpType,
+ typename = std::enable_if_t<detect_has_region_builder<OpType>::value>,
+ typename = void>
+void addNamedOpBuilderImpl(
+ llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
+ map.insert(std::make_pair(
+ OpType::getOperationName(),
+ static_cast<LinalgDialect::RegionBuilderFunType>(OpType::regionBuilder)));
+}
+
+template <typename... OpTypes>
+void addNamedOpBuilders(
+ llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
+ (void)std::initializer_list<int>{0,
+ (addNamedOpBuilderImpl<OpTypes>(map), 0)...};
+}
+
void mlir::linalg::LinalgDialect::initialize() {
addTypes<RangeType>();
addOperations<
#include "mlir/Dialect/Linalg/IR/LinalgSparseOps.cpp.inc"
>();
+ // Fill the Linalg-specific OpName to RegionBuilder map.
+ addNamedOpBuilders<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >(namedStructuredOpRegionBuilders);
+
addInterfaces<LinalgInlinerInterface>();
}
from mlir.dialects import linalg
from mlir.dialects import std
-
def run(f):
print("\nTEST:", f.__name__)
f()
# CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
print(module)
-# CHECK-LABEL: TEST: testNamedStructuredOp
+# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
@run
-def testNamedStructuredOp():
+def testNamedStructuredOpCustomForm():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
RankedTensorType.get((16, 8), f32))
def named_form(lhs, rhs):
init_result = linalg.InitTensorOp([4, 8], f32)
- # CHECK: linalg.matmul
- # TODO: prperly hook up the region.
+ # First check the named form with custom format
+ # CHECK: linalg.matmul
+ # CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
+ # CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
+ # CHECK-SAME: -> tensor<4x8xf32>
+ # CHECK-NEXT: return
return linalg.matmul(lhs, rhs, outs=[init_result.result])
+ print(module)
+
+# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
+@run
+def testNamedStructuredOpGenericForm():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
+ RankedTensorType.get((16, 8), f32))
+ def named_form(lhs, rhs):
+ init_result = linalg.InitTensorOp([4, 8], f32)
+ # CHECK: "linalg.matmul"(%{{.*}})
+ # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
+ # CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32
+ # CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32
+ # CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
+ # CHECK-NEXT: {operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
+ # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
+ return linalg.matmul(lhs, rhs, outs=[init_result.result])
+
+ module.operation.print(print_generic_op_form=True)
+
+# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
+@run
+def testNamedStructuredAsGenericOp():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
RankedTensorType.get((16, 8), f32))
def generic_form(lhs, rhs):