[mlir] Add pass to convert elementwise ops to linalg.
authorSean Silva <silvasean@google.com>
Wed, 28 Oct 2020 20:25:48 +0000 (13:25 -0700)
committerSean Silva <silvasean@google.com>
Tue, 10 Nov 2020 21:44:44 +0000 (13:44 -0800)
This patch converts elementwise ops on tensors to linalg.generic ops
with the same elementwise op in the payload (except rewritten to
operate on scalars, obviously). This is a great form for later fusion to
clean up.

E.g.

```
// Compute: %arg0 + %arg1 - %arg2
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
  %0 = addf %arg0, %arg1 : tensor<?xf32>
  %1 = subf %0, %arg2 : tensor<?xf32>
  return %1 : tensor<?xf32>
}
```

Running this through
`mlir-opt -convert-std-to-linalg -linalg-fusion-for-tensor-ops` we get:

```
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
  %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0, #map0], iterator_types = ["parallel"]} ins(%arg0, %arg1, %arg2 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
    %1 = addf %arg3, %arg4 : f32
    %2 = subf %1, %arg5 : f32
    linalg.yield %2 : f32
  } -> tensor<?xf32>
  return %0 : tensor<?xf32>
}
```

So the elementwise ops on tensors have nicely collapsed into a single
linalg.generic, which is the form we want for further transformations.

Differential Revision: https://reviews.llvm.org/D90354

mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir [new file with mode: 0644]
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp [new file with mode: 0644]
mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir [new file with mode: 0644]

index 24570d3..50aec73 100644 (file)
@@ -16,6 +16,8 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+std::unique_ptr<OperationPass<FuncOp>> createConvertElementwiseToLinalgPass();
+
 std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
 
 std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
@@ -48,6 +50,11 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
 /// buffers instead.
 std::unique_ptr<OperationPass<ModuleOp>> createLinalgBufferizePass();
 
+/// Populate patterns that convert `ElementwiseMappable` ops to linalg
+/// parallel loops.
+void populateElementwiseToLinalgConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx);
+
 /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
 /// producer (consumer) generic operation by expanding the dimensionality of the
 /// loop in the generic op.
index 7446ca8..9162543 100644 (file)
 
 include "mlir/Pass/PassBase.td"
 
+def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> {
+  let summary = "Convert ElementwiseMappable ops to linalg";
+  let description = [{
+    Convert ops with the `ElementwiseMappable` trait to linalg parallel loops.
+
+    This pass only converts ops that operate on ranked tensors.
+  }];
+  let constructor = "mlir::createConvertElementwiseToLinalgPass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
   let summary = "Remove unit-extent dimension in Linalg ops on tensors";
   let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
new file mode 100644 (file)
index 0000000..d26b8f7
--- /dev/null
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @main() {
+  %a = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32>
+  %b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32>
+
+  %addf = addf %a, %b : tensor<3xf32>
+  %addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32>
+  call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> ()
+  // CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data =
+  // CHECK-NEXT: [11,  22,  33]
+
+  return
+}
+
+func @print_memref_f32(%ptr : tensor<*xf32>)
index 88242c1..73df73e 100644 (file)
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Bufferize.cpp
   CodegenStrategy.cpp
   DropUnitDims.cpp
+  ElementwiseToLinalg.cpp
   Fusion.cpp
   FusionOnTensors.cpp
   Hoisting.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
new file mode 100644 (file)
index 0000000..a0e5d74
--- /dev/null
@@ -0,0 +1,98 @@
+//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
+  if (!op->hasTrait<OpTrait::ElementwiseMappable>())
+    return false;
+
+  // TODO: The conversion pattern can be made to work for `any_of` here, but
+  // it's more complex as it requires tracking which operands are scalars.
+  return llvm::all_of(op->getOperandTypes(),
+                      [](Type type) { return type.isa<RankedTensorType>(); });
+}
+
+namespace {
+struct ConvertStdElementwiseOpOnRankedTensors : public RewritePattern {
+  ConvertStdElementwiseOpOnRankedTensors()
+      : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const final {
+    if (!isElementwiseMappableOpOnRankedTensors(op))
+      return rewriter.notifyMatchFailure(
+          op, "requires elementwise op on ranked tensors");
+
+    auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
+    SmallVector<AffineMap, 3> indexingMaps(
+        op->getNumResults() + op->getNumOperands(),
+        rewriter.getMultiDimIdentityMap(rank));
+    SmallVector<StringRef, 6> iteratorTypes(rank,
+                                            getParallelIteratorTypeName());
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        op, /*resultTensorTypes=*/op->getResultTypes(),
+        /*inputs=*/op->getOperands(),
+        /*outputBuffers=*/ValueRange(),
+        /*initTensors=*/ValueRange(),
+        /*indexingMaps=*/indexingMaps,
+        /*iteratorTypes=*/iteratorTypes,
+        /*bodyBuilder=*/
+        [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
+          OperationState state(loc, op->getName());
+          state.addAttributes(op->getAttrs());
+          state.addOperands(regionArgs);
+          auto resultTypes = llvm::to_vector<6>(
+              llvm::map_range(op->getResultTypes(), [](Type type) {
+                return type.cast<TensorType>().getElementType();
+              }));
+          state.addTypes(resultTypes);
+          auto *scalarOp = builder.createOperation(state);
+          builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
+        });
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateElementwiseToLinalgConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *) {
+  patterns.insert<ConvertStdElementwiseOpOnRankedTensors>();
+}
+
+namespace {
+class ConvertElementwiseToLinalgPass
+    : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
+
+  void runOnFunction() final {
+    auto func = getOperation();
+    auto *context = &getContext();
+    ConversionTarget target(*context);
+    OwningRewritePatternList patterns;
+
+    populateElementwiseToLinalgConversionPatterns(patterns, context);
+    target.markUnknownOpDynamicallyLegal([](Operation *op) {
+      return !isElementwiseMappableOpOnRankedTensors(op);
+    });
+
+    if (failed(applyPartialConversion(func, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createConvertElementwiseToLinalgPass() {
+  return std::make_unique<ConvertElementwiseToLinalgPass>();
+}
diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
new file mode 100644 (file)
index 0000000..7ea78fe
--- /dev/null
@@ -0,0 +1,60 @@
+// RUN: mlir-opt -convert-elementwise-to-linalg -split-input-file %s | FileCheck %s
+
+// In-depth checking of the linalg.generic op for a very trivial case.
+// CHECK: #map = affine_map<() -> ()>
+// CHECK-LABEL:   func @addf_rank0
+func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
+  // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%{{.*}}, %{{.*}} : tensor<f32>, tensor<f32>) {
+  // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+  // CHECK:   %[[YIELD:.*]] = addf %[[LHS]], %[[RHS]] : f32
+  // CHECK:   linalg.yield %[[YIELD]] : f32
+  // CHECK: } -> tensor<f32>
+  %0 = addf %arg0, %arg1 : tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// Check indexing maps and iterator types for the rank > 0 case.
+// CHECK: #map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @addf_rank1
+func @addf_rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  // CHECK: linalg.generic{{.*}}indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]
+  %0 = addf %arg0, %arg1 : tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// Check a unary op.
+// CHECK-LABEL: func @exp
+func @exp(%arg0: tensor<f32>) -> tensor<f32> {
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[SCALAR:.*]]: f32):
+  // CHECK:   %[[YIELD:.*]] = exp %[[SCALAR]] : f32
+  // CHECK:   linalg.yield %[[YIELD]] : f32
+  %0 = exp %arg0 : tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// Check a case with varying operand types.
+// CHECK-LABEL: func @select
+func @select(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32):
+  // CHECK:   select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32
+  %0 = select %arg0, %arg1, %arg2 : tensor<i1>, tensor<i32>
+  return %0 : tensor<i32>
+}
+
+// -----
+
+// Spot-check an op that requires copying attributes properly to the created scalar op.
+// CHECK-LABEL: func @cmpf(
+func @cmpf(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<i1> {
+  // CHECK: cmpf "olt", %{{.*}}, %{{.*}} : f32
+  %0 = cmpf "olt", %arg0, %arg1 : tensor<f32>
+  return %0 : tensor<i1>
+}