From: KareemErgawy-TomTom Date: Tue, 16 Feb 2021 06:42:41 +0000 (+0100) Subject: [MLIR][LinAlg] Start detensoring implementation. X-Git-Tag: llvmorg-14-init~14325 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=67e0d58de4d338add132838810db70218f1064d8;p=platform%2Fupstream%2Fllvm.git [MLIR][LinAlg] Start detensoring implementation. This commit is the first baby step towards detensoring in linalg-on-tensors. Detensoring is the process through which a tensor value is convereted to one or potentially more primitive value(s). During this process, operations with such detensored operands are also converted to an equivalen form that works on primitives. The detensoring process is driven by linalg-on-tensor ops. In particular, a linalg-on-tensor op is checked to see whether *all* its operands can be detensored. If so, those operands are converted to thier primitive counterparts and the linalg op is replaced by an equivalent op that takes those new primitive values as operands. This works towards handling github/google/iree#1159. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D96271 --- diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 5d68328..7d93dd0 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -59,6 +59,10 @@ void populateElementwiseToLinalgConversionPatterns( /// operations. std::unique_ptr> createLinalgGeneralizationPass(); +/// Create a pass to convert Linalg operations to equivalent operations that +/// work on primitive types, if possible. +std::unique_ptr createLinalgDetensorizePass(); + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index a20289a..e51d08d 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -136,4 +136,28 @@ def LinalgGeneralization : FunctionPass<"linalg-generalize-named-ops"> { let dependentDialects = ["linalg::LinalgDialect"]; } +def LinalgDetensorize : FunctionPass<"linalg-detensorize"> { + let summary = "Detensorize linalg ops"; + let constructor = "mlir::createLinalgDetensorizePass()"; + let dependentDialects = []; + + let description = [{ + Detensoring is the process through which a tensor value is convereted to one + or potentially more primitive value(s). During this process, operations with + such detensored operands are also converted to an equivalent form that works + on primitives. + + The detensoring process is driven by linalg-on-tensor ops. In particular, a + linalg-on-tensor op is checked to see whether *all* its operands can be + detensored. If so, those operands are converted to their primitive + counterparts and the linalg op is replaced by an equivalent op that takes + those new primitive values as operands. Therefore, the detensoring process + can be divided into 2 main logical phases: + + 1. Detect/match an op that can be detensored. + 2. Detensor the operands of the op and replace it with a primitive + equivalent. + }]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index d988e24..1469371 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Bufferize.cpp CodegenStrategy.cpp + Detensorize.cpp DropUnitDims.cpp ElementwiseToLinalg.cpp Fusion.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp new file mode 100644 index 0000000..2e2e3b9 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -0,0 +1,173 @@ +//===- Detensorize.cpp - Linalg transformations as patterns ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include +#include + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +/// Defines the criteria a TensorType must follow in order to be considered +/// "detensorable". +/// +/// NOTE: For now, only 0-D are supported. +/// +/// Returns true if tensorType can be detensored. +bool canBeDetensored(TensorType tensorType) { + return tensorType.hasRank() && tensorType.getRank() == 0; +} + +/// A conversion patttern for detensoring `linalg.generic` ops. +class DetensorizeGenericOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(GenericOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Block *originalBlock = op->getBlock(); + + // Gather some information about the op before inling its region. + Block *opEntryBlock = &*op.region().begin(); + YieldOp yieldOp = dyn_cast(op.region().back().getTerminator()); + + // Split the op's region before the op. This way, we have a clear insertion + // point in which the op can be inlined. + Block *newBlock = originalBlock->splitBlock(op); + rewriter.inlineRegionBefore(op.region(), newBlock); + // Now that op's region is inlined, the operands of its YieldOp are mapped + // to the materialized target values. Therefore, we can replace the op's + // uses with those of its YielOp's operands. + rewriter.replaceOp(op, yieldOp->getOperands()); + + // No need for these intermediate blocks, merge them into 1. + rewriter.mergeBlocks(opEntryBlock, originalBlock, operands); + rewriter.mergeBlocks(newBlock, originalBlock, {}); + + rewriter.eraseOp(&*Block::iterator(yieldOp)); + + return success(); + } +}; + +class DetensorizeTypeConverter : public TypeConverter { +public: + DetensorizeTypeConverter() { + addConversion([](Type type) { return type; }); + + // A TensorType that can be detensored, is converted to the underlying + // element type. + addConversion([](TensorType tensorType) -> Type { + if (canBeDetensored(tensorType)) + return tensorType.getElementType(); + + return tensorType; + }); + + // A tensor value is detensoried by extracting its element(s). + addTargetMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, inputs[0], ValueRange{}); + }); + + // A detensored value is converted back by creating a new tensor from its + // element(s). + addSourceMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { + auto createNewTensorOp = builder.create( + loc, inputs[0].getType(), inputs[0]); + + // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to + // a tensor instead. + return builder.create( + loc, type, createNewTensorOp, ArrayRef{}); + }); + } +}; + +/// Canonicalizes the pattern of the form +/// +/// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> +/// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into +/// tensor +/// %extracted_element = tensor.extract %reshaped_tensor[] : tensor +/// +/// to just %element. +struct ExtractFromReshapeFromElements + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + if (extract.indices().size() != 0) + return failure(); + + auto tensorReshape = extract.tensor().getDefiningOp(); + if (tensorReshape == nullptr) + return failure(); + + auto tensorFromElements = + tensorReshape.getOperand() + .getDefiningOp(); + if (tensorFromElements == nullptr) + return failure(); + + rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); + return success(); + } +}; + +/// @see LinalgDetensorize in Linalg/Passes.td for more details. +struct LinalgDetensorize : public LinalgDetensorizeBase { + void runOnFunction() override { + auto *context = &getContext(); + DetensorizeTypeConverter typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); + target.addLegalDialect(); + target.addDynamicallyLegalOp([&](GenericOp op) { + // If any of the operands or results cannot be detensored, the op is + // considered legal and won't be detensored. + return llvm::any_of( + op.getShapedOperandTypes(), [](ShapedType shapedType) { + assert(shapedType.isa()); + return !canBeDetensored(shapedType.cast()); + }); + }); + + patterns.insert(typeConverter, context); + + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) + signalPassFailure(); + + OwningRewritePatternList canonPatterns; + canonPatterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(getFunction(), + std::move(canonPatterns)))) + signalPassFailure(); + + // TODO Properly handle control flow within function boundaries. + } +}; +} // namespace + +std::unique_ptr mlir::createLinalgDetensorizePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorized_0d.mlir new file mode 100644 index 0000000..e35a34f --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorized_0d.mlir @@ -0,0 +1,107 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s + +#map = affine_map<() -> ()> + +func @detensor_simple(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor + return %1: tensor +} +// CHECK-LABEL: func @detensor_simple +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]] + +func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor + + %3 = linalg.init_tensor [] : tensor + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %1 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = mulf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor + + %6 = linalg.init_tensor [] : tensor + %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%1, %4 : tensor, tensor) + outs(%6 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = divf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor + + return %7: tensor +} +// CHECK-LABEL: func @detensor_op_sequence +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]] +// CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]] +// CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]] + +func @detensor_multiple_ops(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + %3 = mulf %2, %arg4 : f32 + linalg.yield %3 : f32 + } -> tensor + return %1: tensor +} +// CHECK-LABEL: func @detensor_multiple_ops +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK: %[[detensored_res2:.*]] = mulf %[[detensored_res]], %[[arg2_val]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]] + +func @detensor_foreign_op(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = "foreign.do_something"(%arg3, %arg4) {} : (f32, f32) -> f32 + linalg.yield %2 : f32 + } -> tensor + return %1: tensor +} +// CHECK-LABEL: func @detensor_foreign_op +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]]) +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]]