[mlir][Linalg] Add an interface to decompose complex ops
authorQuentin Colombet <quentin.colombet@gmail.com>
Mon, 3 Jul 2023 16:20:32 +0000 (18:20 +0200)
committerQuentin Colombet <quentin.colombet@gmail.com>
Tue, 18 Jul 2023 17:06:36 +0000 (19:06 +0200)
This patch adds an interface, named AggregatedOpInterface, that decomposes
complex operations into simpler ones.

For now, make the interface specific to Linalg because although the concept
is general, the way to materialize it needs some maturing.

Use that interface with the softmax operator.

Differential Revision: https://reviews.llvm.org/D154363

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-decompose.mlir

index 1aba722..edc393b 100644 (file)
@@ -897,4 +897,34 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
   let verifyWithRegions = 1;
 }
 
+def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
+  let description = [{
+    Interface for decomposing aggregated operations into a sequence of simpler
+    ops.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to decompose the operation into simpler operations.
+
+          On success, this method returns one `Value` per result in the
+          original operation.
+          The order of the returned values must match the order of the
+          original values.
+          In other words, the returned vector can be used directly with
+          `RewriterBase::replaceOp(this, returnedValues)`.
+        }],
+        /*retType=*/"FailureOr<SmallVector<Value>>",
+        /*methodName=*/"decomposeOperation",
+        /*args=*/(ins
+            "OpBuilder &":$b),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return {};
+        }]
+      >
+  ];
+}
+
 #endif // LINALG_IR_LINALGINTERFACES
index eb68890..fdb0430 100644 (file)
@@ -14,6 +14,7 @@
 #define LINALG_OPS
 
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -93,6 +94,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
     [DestinationStyleOpInterface,
      PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+     DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",
index f1510f6..5d1f2d8 100644 (file)
@@ -1200,6 +1200,33 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
 }
 
 //===----------------------------------------------------------------------===//
+// DecomposeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformOpInterface,
+     TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    TODO
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+  let assemblyFormat =
+      "$target attr-dict `:` functional-type(operands, results)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+//===----------------------------------------------------------------------===//
 // RewriteInDestinationPassingStyleOp.
 //===----------------------------------------------------------------------===//
 
index a40618f..464a30a 100644 (file)
@@ -2323,6 +2323,176 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
       .reifyResultShapes(b, reifiedReturnShapes);
 }
 
+// Helper functions for softmax decomposition.
+// @{
+
+// Helper function to produce the iterator types (reduction or parallel) and
+// affine maps for the iterators used in the decomposition of softmax.
+// This method creates:
+// If allParallel == true:
+// - iterator type: {parallel, ..., parallel}
+// - affine maps:
+// -- identity with inputRank dimensions.
+// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
+//    where N == inputRank.
+//
+// If allParallel == false:
+// - iterator type at dim(i) == parallel for i != \p dim and
+//   dim(dim) == reduction.
+// - affine map:
+// -- identity with inputRank dimensions.
+// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
+//    where N == inputRank.
+static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
+computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
+                                    int64_t dim, bool allParallel = false) {
+  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
+                                                 utils::IteratorType::parallel);
+  if (!allParallel)
+    iteratorTypes[dim] = utils::IteratorType::reduction;
+  MLIRContext *ctxt = builder.getContext();
+  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
+  SmallVector<AffineExpr, 2> affineExprs;
+  for (int i = 0; i < inputRank; i++) {
+    if (i != dim)
+      affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
+  }
+  auto reductionMap =
+      AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
+  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
+  return std::make_tuple(iteratorTypes, indexingMaps);
+}
+
+// Helper function to produce a linalg.generic that computes a reduction on
+// dimension \p dim with the operation type \p T.
+template <typename T>
+static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
+                    int64_t dim) {
+  auto inputType = cast<ShapedType>(input.getType());
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputRank = inputShape.size();
+  auto [iteratorTypes, indexingMaps] =
+      computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
+  assert(indexingMaps.size() == 2 &&
+         "We should have two maps: 1 for the input, 1 for the output");
+  assert(indexingMaps[0].isIdentity() && "input map should be identity");
+
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, output.getType(), input, output, indexingMaps, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value result = b.create<T>(loc, args[0], args[1]);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return genericOp.getResult(0);
+}
+
+/// Produce a linalg generic that computes the second step of the softmax
+/// decomposition: res = exp(input - max), where \p max is the max of \p input
+/// on dimension \p dim.
+static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
+                              Value max, Value output, int64_t dim) {
+  auto inputType = cast<ShapedType>(input.getType());
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputRank = inputShape.size();
+  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
+      builder, inputRank, dim, /*allParallel=*/true);
+  assert(indexingMaps.size() == 2 && "We should have one map for each input");
+  assert(indexingMaps[0].isIdentity() && "input map should be identity");
+  // Add the affine map for the output argument.
+  indexingMaps.push_back(indexingMaps[0]);
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
+      iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
+        Value result = b.create<math::ExpOp>(loc, diff);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return genericOp.getResult(0);
+}
+
+/// Produce a linalg generic that computes the final step of the softmax
+/// decomposition.
+/// \returns  linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
+///   yield  n / d
+/// }
+static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
+                        Value denominator, Value output, int64_t dim) {
+  auto inputType = cast<ShapedType>(numerator.getType());
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputRank = inputShape.size();
+  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
+      builder, inputRank, dim, /*allParallel=*/true);
+  assert(indexingMaps.size() == 2 &&
+         "We should have one map for each input (2)");
+  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
+  // Add the affine map for the output tensor.
+  indexingMaps.push_back(indexingMaps[0]);
+  auto genericOp = builder.create<linalg::GenericOp>(
+      loc, numerator.getType(), ValueRange{numerator, denominator}, output,
+      indexingMaps, iteratorTypes,
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
+        b.create<linalg::YieldOp>(loc, result);
+      });
+  return genericOp.getResult(0);
+}
+// @} End helper functions for softmax decomposition.
+
+/// Given an N-dimensional tensor x, this method converts
+/// softmax(x) to the following sequence of operations:
+///
+/// 1. Compute the max of x along dimension d. This results
+///    in a N-1 dimensional tensor m.
+///    m = max(x, dim = d)
+///
+/// 2. Subtract a broadcasted m from x and exponentiate. This results in
+///    a N dimensional tensor z.
+///    z = exp(x - m)
+///
+/// 3. Compute the sum of z along dimension d. This results in
+///    a N-1 dimensional tensor l.
+///    l = sum(z, dim = d)
+///
+/// 4. Divide z and l. This gives the N-dimensional softmax.
+///    softmax = z / l
+///
+FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPoint(*this);
+  Location loc = getLoc();
+  Value input = getInput();
+  ShapedType inputType = getInputOperandType();
+  Type elementType = inputType.getElementType();
+  int64_t reductionDim = getDimension();
+  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
+  Value outputNd = b.create<tensor::EmptyOp>(loc, dims, elementType);
+  dims.erase(dims.begin() + reductionDim);
+  // Step 1: Compute max along dim.
+  Value output = b.create<tensor::EmptyOp>(loc, dims, elementType);
+  Value neutralForMaxF =
+      arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc);
+  Value neutralForMaxFInit =
+      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, output).result();
+  Value max =
+      reduce<arith::MaxFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+
+  // Step 2: Subtract max from input and exponentiate.
+  Value numerator =
+      buildSubAndExpOp(b, loc, input, max, outputNd, reductionDim);
+
+  // Step 3: Compute sum along dim.
+  Value zero =
+      arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc);
+  Value zeroInit = b.create<linalg::FillOp>(loc, Value{zero}, output).result();
+  Value denominator =
+      reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+
+  // Step 4: Compute softmax.
+  Value result =
+      buildDivOp(b, loc, numerator, denominator, outputNd, reductionDim);
+  return SmallVector<Value>{result};
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
index a51050b..f6e0f27 100644 (file)
@@ -336,6 +336,38 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
 }
 
 //===----------------------------------------------------------------------===//
+// DecomposeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+// Decompose the target operation if it implements the AggregatedOpInterface.
+// Push the decomposed operations (the ones that replaces the values produced by
+// \p target) in the `results`.
+DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+  if (!decomposableOp) {
+    failed(rewriter.notifyMatchFailure(target,
+                                       "payload is not a decomposable op"));
+    return emitDefaultSilenceableFailure(target);
+  }
+
+  FailureOr<SmallVector<Value>> maybeNewResults =
+      decomposableOp.decomposeOperation(rewriter);
+  if (failed(maybeNewResults))
+    return emitDefaultSilenceableFailure(target);
+
+  rewriter.replaceOp(decomposableOp, *maybeNewResults);
+  for (Value val : *maybeNewResults) {
+    Operation *definition = val.getDefiningOp();
+    if (definition)
+      results.push_back(definition);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
 // EliminateLinalgOpAnchoredEmptyTensorsOp
 //===----------------------------------------------------------------------===//
 
index 052992f..30c2ab8 100644 (file)
@@ -1,5 +1,8 @@
 // RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
 
+// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
 // CHECK-LABEL: @conv_2d_nhwc_hwcf
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
 // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
@@ -199,8 +202,54 @@ func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32
   return %0 : tensor<?x?x1x?xf32>
 }
 
+func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL:      func.func @softmax(
+//CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-DAG:        %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK-DAG:        %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0xFF800000 : f32
+// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK:        %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8:.+]] = arith.maxf %[[IN]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16xf32>
+// CHECK:        %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
+// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK-SAME:     outs(%[[D0]] : tensor<2x16x32xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
+// CHECK:          %[[D9:.+]] = math.exp %[[D8]] : f32
+// CHECK:          linalg.yield %[[D9]] : f32
+// CHECK:        } -> tensor<2x16x32xf32>
+// CHECK:        %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK:        %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
+// CHECK-SAME:     "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16xf32>
+// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
+// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK-SAME:     outs(%[[D0]] : tensor<2x16x32xf32>) {
+// CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:          %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
+// CHECK:        } -> tensor<2x16x32xf32>
+// CHECK:        return %[[D7]] : tensor<2x16x32xf32>
+// CHECK:      }
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1 = transform.structured.decompose %0 : (!transform.any_op) -> !transform.any_op
+
+  %2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op
 }