/// operations.
std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
+/// Create a pass to convert Linalg operations to equivalent operations that
+/// work on primitive types, if possible.
+std::unique_ptr<Pass> createLinalgDetensorizePass();
+
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgDetensorize : FunctionPass<"linalg-detensorize"> {
+ let summary = "Detensorize linalg ops";
+ let constructor = "mlir::createLinalgDetensorizePass()";
+ let dependentDialects = [];
+
+ let description = [{
+ Detensoring is the process through which a tensor value is convereted to one
+ or potentially more primitive value(s). During this process, operations with
+ such detensored operands are also converted to an equivalent form that works
+ on primitives.
+
+ The detensoring process is driven by linalg-on-tensor ops. In particular, a
+ linalg-on-tensor op is checked to see whether *all* its operands can be
+ detensored. If so, those operands are converted to their primitive
+ counterparts and the linalg op is replaced by an equivalent op that takes
+ those new primitive values as operands. Therefore, the detensoring process
+ can be divided into 2 main logical phases:
+
+ 1. Detect/match an op that can be detensored.
+ 2. Detensor the operands of the op and replace it with a primitive
+ equivalent.
+ }];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
add_mlir_dialect_library(MLIRLinalgTransforms
Bufferize.cpp
CodegenStrategy.cpp
+ Detensorize.cpp
DropUnitDims.cpp
ElementwiseToLinalg.cpp
Fusion.cpp
--- /dev/null
+//===- Detensorize.cpp - Linalg transformations as patterns ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <iterator>
+#include <memory>
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Defines the criteria a TensorType must follow in order to be considered
+/// "detensorable".
+///
+/// NOTE: For now, only 0-D are supported.
+///
+/// Returns true if tensorType can be detensored.
+bool canBeDetensored(TensorType tensorType) {
+ return tensorType.hasRank() && tensorType.getRank() == 0;
+}
+
+/// A conversion patttern for detensoring `linalg.generic` ops.
+class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(GenericOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Block *originalBlock = op->getBlock();
+
+ // Gather some information about the op before inling its region.
+ Block *opEntryBlock = &*op.region().begin();
+ YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
+
+ // Split the op's region before the op. This way, we have a clear insertion
+ // point in which the op can be inlined.
+ Block *newBlock = originalBlock->splitBlock(op);
+ rewriter.inlineRegionBefore(op.region(), newBlock);
+ // Now that op's region is inlined, the operands of its YieldOp are mapped
+ // to the materialized target values. Therefore, we can replace the op's
+ // uses with those of its YielOp's operands.
+ rewriter.replaceOp(op, yieldOp->getOperands());
+
+ // No need for these intermediate blocks, merge them into 1.
+ rewriter.mergeBlocks(opEntryBlock, originalBlock, operands);
+ rewriter.mergeBlocks(newBlock, originalBlock, {});
+
+ rewriter.eraseOp(&*Block::iterator(yieldOp));
+
+ return success();
+ }
+};
+
+class DetensorizeTypeConverter : public TypeConverter {
+public:
+ DetensorizeTypeConverter() {
+ addConversion([](Type type) { return type; });
+
+ // A TensorType that can be detensored, is converted to the underlying
+ // element type.
+ addConversion([](TensorType tensorType) -> Type {
+ if (canBeDetensored(tensorType))
+ return tensorType.getElementType();
+
+ return tensorType;
+ });
+
+ // A tensor value is detensoried by extracting its element(s).
+ addTargetMaterialization([](OpBuilder &builder, Type type,
+ ValueRange inputs, Location loc) -> Value {
+ return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
+ });
+
+ // A detensored value is converted back by creating a new tensor from its
+ // element(s).
+ addSourceMaterialization([](OpBuilder &builder, Type type,
+ ValueRange inputs, Location loc) -> Value {
+ auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
+ loc, inputs[0].getType(), inputs[0]);
+
+ // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
+ // a tensor<dtype> instead.
+ return builder.create<linalg::TensorReshapeOp>(
+ loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
+ });
+ }
+};
+
+/// Canonicalizes the pattern of the form
+///
+/// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
+/// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into
+/// tensor<i32>
+/// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
+///
+/// to just %element.
+struct ExtractFromReshapeFromElements
+ : public OpRewritePattern<tensor::ExtractOp> {
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+ PatternRewriter &rewriter) const final {
+ if (extract.indices().size() != 0)
+ return failure();
+
+ auto tensorReshape = extract.tensor().getDefiningOp<TensorReshapeOp>();
+ if (tensorReshape == nullptr)
+ return failure();
+
+ auto tensorFromElements =
+ tensorReshape.getOperand()
+ .getDefiningOp<mlir::tensor::FromElementsOp>();
+ if (tensorFromElements == nullptr)
+ return failure();
+
+ rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
+ return success();
+ }
+};
+
+/// @see LinalgDetensorize in Linalg/Passes.td for more details.
+struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
+ void runOnFunction() override {
+ auto *context = &getContext();
+ DetensorizeTypeConverter typeConverter;
+ OwningRewritePatternList patterns;
+ ConversionTarget target(*context);
+
+ target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+ target.addLegalDialect<linalg::LinalgDialect>();
+ target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
+ // If any of the operands or results cannot be detensored, the op is
+ // considered legal and won't be detensored.
+ return llvm::any_of(
+ op.getShapedOperandTypes(), [](ShapedType shapedType) {
+ assert(shapedType.isa<TensorType>());
+ return !canBeDetensored(shapedType.cast<TensorType>());
+ });
+ });
+
+ patterns.insert<DetensorizeGenericOp>(typeConverter, context);
+
+ if (failed(
+ applyPartialConversion(getFunction(), target, std::move(patterns))))
+ signalPassFailure();
+
+ OwningRewritePatternList canonPatterns;
+ canonPatterns.insert<ExtractFromReshapeFromElements>(context);
+ if (failed(applyPatternsAndFoldGreedily(getFunction(),
+ std::move(canonPatterns))))
+ signalPassFailure();
+
+ // TODO Properly handle control flow within function boundaries.
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
+ return std::make_unique<LinalgDetensorize>();
+}
--- /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+
+#map = affine_map<() -> ()>
+
+func @detensor_simple(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
+ %0 = linalg.init_tensor [] : tensor<f32>
+ %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
+ outs(%0 : tensor<f32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %2 = addf %arg3, %arg4 : f32
+ linalg.yield %2 : f32
+ } -> tensor<f32>
+ return %1: tensor<f32>
+}
+// CHECK-LABEL: func @detensor_simple
+// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
+// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
+// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
+// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
+// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
+// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK: return %[[reshaped_tensor_res]]
+
+func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
+ %0 = linalg.init_tensor [] : tensor<f32>
+ %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
+ outs(%0 : tensor<f32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %2 = addf %arg3, %arg4 : f32
+ linalg.yield %2 : f32
+ } -> tensor<f32>
+
+ %3 = linalg.init_tensor [] : tensor<f32>
+ %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%arg1, %1 : tensor<f32>, tensor<f32>)
+ outs(%3 : tensor<f32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %5 = mulf %arg3, %arg4 : f32
+ linalg.yield %5 : f32
+ } -> tensor<f32>
+
+ %6 = linalg.init_tensor [] : tensor<f32>
+ %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%1, %4 : tensor<f32>, tensor<f32>)
+ outs(%6 : tensor<f32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %5 = divf %arg3, %arg4 : f32
+ linalg.yield %5 : f32
+ } -> tensor<f32>
+
+ return %7: tensor<f32>
+}
+// CHECK-LABEL: func @detensor_op_sequence
+// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
+// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
+// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
+// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
+// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]]
+// CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]]
+// CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]]
+// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
+// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK: return %[[reshaped_tensor_res]]
+
+func @detensor_multiple_ops(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
+ %0 = linalg.init_tensor [] : tensor<f32>
+ %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
+ outs(%0 : tensor<f32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %2 = addf %arg3, %arg4 : f32
+ %3 = mulf %2, %arg4 : f32
+ linalg.yield %3 : f32
+ } -> tensor<f32>
+ return %1: tensor<f32>
+}
+// CHECK-LABEL: func @detensor_multiple_ops
+// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
+// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
+// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
+// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
+// CHECK: %[[detensored_res2:.*]] = mulf %[[detensored_res]], %[[arg2_val]]
+// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]]
+// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK: return %[[reshaped_tensor_res]]
+
+func @detensor_foreign_op(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
+ %0 = linalg.init_tensor [] : tensor<f32>
+ %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
+ outs(%0 : tensor<f32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %2 = "foreign.do_something"(%arg3, %arg4) {} : (f32, f32) -> f32
+ linalg.yield %2 : f32
+ } -> tensor<f32>
+ return %1: tensor<f32>
+}
+// CHECK-LABEL: func @detensor_foreign_op
+// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
+// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
+// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
+// CHECK: %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]])
+// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
+// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK: return %[[reshaped_tensor_res]]