-//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect --------*- C -*-===//
+//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect -------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-//===----------------------------------------------------------------------===//
+//===---------------------------------------------------------------------===//
#ifndef MLIR_C_DIALECT_LINALG_H
#define MLIR_C_DIALECT_LINALG_H
#endif
/// Apply the special region builder for the builtin named Linalg op.
+/// The list of `capture` MlirValue is passed as-is to the region builder.
/// Assert that `op` is a builtin named Linalg op.
MLIR_CAPI_EXPORTED void
-mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);
+mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op,
+ intptr_t n, MlirValue const *mlirCaptures);
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
void populateDialectLinalgSubmodule(py::module &m) {
m.def(
"fill_builtin_region",
- [](PyDialectDescriptor &dialect, PyOperation &op) {
- return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
+ [](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) {
+ llvm::SmallVector<MlirValue, 4> mlirOperands;
+ mlirOperands.reserve(captures.size());
+ for (auto v : captures)
+ mlirOperands.push_back(py::cast<PyValue *>(v)->get());
+ mlirLinalgFillBuiltinNamedOpRegion(
+ dialect.get(), op.get(), mlirOperands.size(), mlirOperands.data());
},
- py::arg("dialect"), py::arg("op"),
+ py::arg("dialect"), py::arg("op"), py::arg("captures") = py::list(),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
}
from typing import Optional, Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context
+# TODO: resolve name collision for Linalg functionality that is injected inside
+# the _mlir.dialects.linalg directly via pybind.
+from _mlir.dialects.linalg import fill_builtin_region
+
+
+def isa(cls : Type, ty : Type):
+ try:
+ cls(ty)
+ return True
+ except ValueError:
+ return False
+
+
+class FillOp:
+ """Extends the linalg.fill op."""
+
+ def __init__(self,
+ output: Value,
+ value: Value,
+ *,
+ loc=None,
+ ip=None):
+ results = []
+ if isa(RankedTensorType, output.type):
+ results = [output.type]
+ op = self.build_generic(results=results,
+ operands=[output, value],
+ attributes=None,
+ loc=loc,
+ ip=ip)
+ OpView.__init__(self, op)
+ linalgDialect = Context.current.get_dialect_descriptor("linalg")
+ fill_builtin_region(linalgDialect, self.operation, [value])
+ # TODO: self.result is None. When len(results) == 1 we expect it to be
+ # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
+ # in the generator of _linalg_ops_gen.py where we have:
+ # ```
+ # def result(self):
+ # return self.operation.results[0] \
+ # if len(self.operation.results) > 1 else None
+ # ```
class InitTensorOp:
#include "mlir-c/Dialect/Linalg.h"
#include "mlir/CAPI/Registration.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
using namespace mlir;
/// Apply the special region builder for the builtin named Linalg op.
/// Assert that `op` is a builtin named Linalg op.
void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
- MlirOperation mlirOp) {
+ MlirOperation mlirOp, intptr_t n,
+ MlirValue const *mlirCaptures) {
Operation *op = unwrap(mlirOp);
+ SmallVector<Value> captures;
+ captures.reserve(n);
+ for (unsigned idx = 0; idx < n; ++idx)
+ captures.push_back(unwrap(mlirCaptures[idx]));
+
LinalgDialect::RegionBuilderFunType fun =
static_cast<LinalgDialect *>(unwrap(linalgDialect))
->getRegionBuilder(op->getName().getStringRef());
assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region");
assert(op->getRegion(0).getBlocks().empty() &&
"Expected Linalg op with 0 blocks");
+
SmallVector<Type, 8> argTypes;
auto linalgOp = cast<LinalgOp>(op);
for (auto t : linalgOp.getShapedOperandTypes())
argTypes.push_back(getElementTypeOrSelf(t));
+
OpBuilder b(op->getContext());
Region ®ion = op->getRegion(0);
Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes);
- // TODO: allow captures.
- fun(*body, ValueRange{});
+ b.setInsertionPointToStart(body);
+ mlir::edsc::ScopedContext scope(b, op->getLoc());
+ fun(*body, captures);
}
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
print(module)
+# CHECK-LABEL: TEST: testFill
+@run
+def testFill():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ # CHECK-LABEL: func @fill_tensor
+ # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
+ # CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32
+ # CHECK-NEXT: %[[RES:.*]] = linalg.fill(%[[OUT]], %[[CST]]) : tensor<12x?xf32>, f32 -> tensor<12x?xf32>
+ # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((12, -1), f32))
+ def fill_tensor(out):
+ zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
+ # TODO: FillOp.result is None. When len(results) == 1 we expect it to
+ # be results[0] as per _linalg_ops_gen.py. This seems like an
+ # orthogonal bug in the generator of _linalg_ops_gen.py.
+ return linalg.FillOp(output=out, value=zero).results[0]
+
+ # CHECK-LABEL: func @fill_buffer
+ # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
+ # CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32
+ # CHECK-NEXT: linalg.fill(%[[OUT]], %[[CST]]) : memref<12x?xf32>, f32
+ # CHECK-NEXT: return
+ @builtin.FuncOp.from_py_func(
+ MemRefType.get((12, -1), f32))
+ def fill_buffer(out):
+ zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
+ linalg.FillOp(output=out, value=zero)
+
+ print(module)
+
# CHECK-LABEL: TEST: testStructuredOpOnTensors
@run