From: Sean Silva Date: Wed, 28 Oct 2020 20:25:48 +0000 (-0700) Subject: [mlir] Add pass to convert elementwise ops to linalg. X-Git-Tag: llvmorg-13-init~6500 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=53a0d45db6d0f33dfbb724c99ce2560ae25473c2;p=platform%2Fupstream%2Fllvm.git [mlir] Add pass to convert elementwise ops to linalg. This patch converts elementwise ops on tensors to linalg.generic ops with the same elementwise op in the payload (except rewritten to operate on scalars, obviously). This is a great form for later fusion to clean up. E.g. ``` // Compute: %arg0 + %arg1 - %arg2 func @f(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = addf %arg0, %arg1 : tensor %1 = subf %0, %arg2 : tensor return %1 : tensor } ``` Running this through `mlir-opt -convert-std-to-linalg -linalg-fusion-for-tensor-ops` we get: ``` func @f(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0, #map0], iterator_types = ["parallel"]} ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors %1 = addf %arg3, %arg4 : f32 %2 = subf %1, %arg5 : f32 linalg.yield %2 : f32 } -> tensor return %0 : tensor } ``` So the elementwise ops on tensors have nicely collapsed into a single linalg.generic, which is the form we want for further transformations. Differential Revision: https://reviews.llvm.org/D90354 --- diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 24570d3..50aec73 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -16,6 +16,8 @@ #include "mlir/Pass/Pass.h" namespace mlir { +std::unique_ptr> createConvertElementwiseToLinalgPass(); + std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); std::unique_ptr createLinalgFusionOfTensorOpsPass(); @@ -48,6 +50,11 @@ std::unique_ptr> createConvertLinalgToAffineLoopsPass(); /// buffers instead. std::unique_ptr> createLinalgBufferizePass(); +/// Populate patterns that convert `ElementwiseMappable` ops to linalg +/// parallel loops. +void populateElementwiseToLinalgConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx); + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 7446ca8..9162543 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -11,6 +11,17 @@ include "mlir/Pass/PassBase.td" +def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> { + let summary = "Convert ElementwiseMappable ops to linalg"; + let description = [{ + Convert ops with the `ElementwiseMappable` trait to linalg parallel loops. + + This pass only converts ops that operate on ranked tensors. + }]; + let constructor = "mlir::createConvertElementwiseToLinalgPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir new file mode 100644 index 0000000..d26b8f7 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @main() { + %a = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32> + %b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32> + + %addf = addf %a, %b : tensor<3xf32> + %addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32> + call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> () + // CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data = + // CHECK-NEXT: [11, 22, 33] + + return +} + +func @print_memref_f32(%ptr : tensor<*xf32>) diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 88242c1..73df73e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Bufferize.cpp CodegenStrategy.cpp DropUnitDims.cpp + ElementwiseToLinalg.cpp Fusion.cpp FusionOnTensors.cpp Hoisting.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp new file mode 100644 index 0000000..a0e5d74 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -0,0 +1,98 @@ +//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Passes.h" + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { + if (!op->hasTrait()) + return false; + + // TODO: The conversion pattern can be made to work for `any_of` here, but + // it's more complex as it requires tracking which operands are scalars. + return llvm::all_of(op->getOperandTypes(), + [](Type type) { return type.isa(); }); +} + +namespace { +struct ConvertStdElementwiseOpOnRankedTensors : public RewritePattern { + ConvertStdElementwiseOpOnRankedTensors() + : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + if (!isElementwiseMappableOpOnRankedTensors(op)) + return rewriter.notifyMatchFailure( + op, "requires elementwise op on ranked tensors"); + + auto rank = op->getResult(0).getType().cast().getRank(); + SmallVector indexingMaps( + op->getNumResults() + op->getNumOperands(), + rewriter.getMultiDimIdentityMap(rank)); + SmallVector iteratorTypes(rank, + getParallelIteratorTypeName()); + rewriter.replaceOpWithNewOp( + op, /*resultTensorTypes=*/op->getResultTypes(), + /*inputs=*/op->getOperands(), + /*outputBuffers=*/ValueRange(), + /*initTensors=*/ValueRange(), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + /*bodyBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { + OperationState state(loc, op->getName()); + state.addAttributes(op->getAttrs()); + state.addOperands(regionArgs); + auto resultTypes = llvm::to_vector<6>( + llvm::map_range(op->getResultTypes(), [](Type type) { + return type.cast().getElementType(); + })); + state.addTypes(resultTypes); + auto *scalarOp = builder.createOperation(state); + builder.create(loc, scalarOp->getResults()); + }); + return success(); + } +}; +} // namespace + +void mlir::populateElementwiseToLinalgConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *) { + patterns.insert(); +} + +namespace { +class ConvertElementwiseToLinalgPass + : public ConvertElementwiseToLinalgBase { + + void runOnFunction() final { + auto func = getOperation(); + auto *context = &getContext(); + ConversionTarget target(*context); + OwningRewritePatternList patterns; + + populateElementwiseToLinalgConversionPatterns(patterns, context); + target.markUnknownOpDynamicallyLegal([](Operation *op) { + return !isElementwiseMappableOpOnRankedTensors(op); + }); + + if (failed(applyPartialConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::createConvertElementwiseToLinalgPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir new file mode 100644 index 0000000..7ea78fe --- /dev/null +++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt -convert-elementwise-to-linalg -split-input-file %s | FileCheck %s + +// In-depth checking of the linalg.generic op for a very trivial case. +// CHECK: #map = affine_map<() -> ()> +// CHECK-LABEL: func @addf_rank0 +func @addf_rank0(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%{{.*}}, %{{.*}} : tensor, tensor) { + // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): + // CHECK: %[[YIELD:.*]] = addf %[[LHS]], %[[RHS]] : f32 + // CHECK: linalg.yield %[[YIELD]] : f32 + // CHECK: } -> tensor + %0 = addf %arg0, %arg1 : tensor + return %0 : tensor +} + +// ----- + +// Check indexing maps and iterator types for the rank > 0 case. +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @addf_rank1 +func @addf_rank1(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: linalg.generic{{.*}}indexing_maps = [#map, #map, #map], iterator_types = ["parallel"] + %0 = addf %arg0, %arg1 : tensor + return %0 : tensor +} + +// ----- + +// Check a unary op. +// CHECK-LABEL: func @exp +func @exp(%arg0: tensor) -> tensor { + // CHECK: linalg.generic + // CHECK: ^bb0(%[[SCALAR:.*]]: f32): + // CHECK: %[[YIELD:.*]] = exp %[[SCALAR]] : f32 + // CHECK: linalg.yield %[[YIELD]] : f32 + %0 = exp %arg0 : tensor + return %0 : tensor +} + +// ----- + +// Check a case with varying operand types. +// CHECK-LABEL: func @select +func @select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: linalg.generic + // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32): + // CHECK: select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32 + %0 = select %arg0, %arg1, %arg2 : tensor, tensor + return %0 : tensor +} + +// ----- + +// Spot-check an op that requires copying attributes properly to the created scalar op. +// CHECK-LABEL: func @cmpf( +func @cmpf(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: cmpf "olt", %{{.*}}, %{{.*}} : f32 + %0 = cmpf "olt", %arg0, %arg1 : tensor + return %0 : tensor +}