From 9be8219f60e1bbddaebdb271a55ecfb867078899 Mon Sep 17 00:00:00 2001 From: Quentin Colombet Date: Mon, 3 Jul 2023 18:20:32 +0200 Subject: [PATCH] [mlir][Linalg] Add an interface to decompose complex ops This patch adds an interface, named AggregatedOpInterface, that decomposes complex operations into simpler ones. For now, make the interface specific to Linalg because although the concept is general, the way to materialize it needs some maturing. Use that interface with the softmax operator. Differential Revision: https://reviews.llvm.org/D154363 --- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 30 ++++ mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td | 2 + .../Linalg/TransformOps/LinalgTransformOps.td | 27 ++++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 170 +++++++++++++++++++++ .../Linalg/TransformOps/LinalgTransformOps.cpp | 32 ++++ .../Dialect/Linalg/transform-op-decompose.mlir | 49 ++++++ 6 files changed, 310 insertions(+) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 1aba722..edc393b 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -897,4 +897,34 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let verifyWithRegions = 1; } +def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> { + let description = [{ + Interface for decomposing aggregated operations into a sequence of simpler + ops. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Method to decompose the operation into simpler operations. + + On success, this method returns one `Value` per result in the + original operation. + The order of the returned values must match the order of the + original values. + In other words, the returned vector can be used directly with + `RewriterBase::replaceOp(this, returnedValues)`. + }], + /*retType=*/"FailureOr>", + /*methodName=*/"decomposeOperation", + /*args=*/(ins + "OpBuilder &":$b), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return {}; + }] + > + ]; +} + #endif // LINALG_IR_LINALGINTERFACES diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index eb68890..fdb0430 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -14,6 +14,7 @@ #define LINALG_OPS include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -93,6 +94,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax", [DestinationStyleOpInterface, PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods { + let description = [{ + TODO + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + let assemblyFormat = + "$target attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} +//===----------------------------------------------------------------------===// // RewriteInDestinationPassingStyleOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index a40618f..464a30a 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2323,6 +2323,176 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b, .reifyResultShapes(b, reifiedReturnShapes); } +// Helper functions for softmax decomposition. +// @{ + +// Helper function to produce the iterator types (reduction or parallel) and +// affine maps for the iterators used in the decomposition of softmax. +// This method creates: +// If allParallel == true: +// - iterator type: {parallel, ..., parallel} +// - affine maps: +// -- identity with inputRank dimensions. +// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN), +// where N == inputRank. +// +// If allParallel == false: +// - iterator type at dim(i) == parallel for i != \p dim and +// dim(dim) == reduction. +// - affine map: +// -- identity with inputRank dimensions. +// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN), +// where N == inputRank. +static std::tuple, SmallVector> +computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, + int64_t dim, bool allParallel = false) { + SmallVector iteratorTypes(inputRank, + utils::IteratorType::parallel); + if (!allParallel) + iteratorTypes[dim] = utils::IteratorType::reduction; + MLIRContext *ctxt = builder.getContext(); + auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt); + SmallVector affineExprs; + for (int i = 0; i < inputRank; i++) { + if (i != dim) + affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt)); + } + auto reductionMap = + AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt); + SmallVector indexingMaps{identityMap, reductionMap}; + return std::make_tuple(iteratorTypes, indexingMaps); +} + +// Helper function to produce a linalg.generic that computes a reduction on +// dimension \p dim with the operation type \p T. +template +static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, + int64_t dim) { + auto inputType = cast(input.getType()); + ArrayRef inputShape = inputType.getShape(); + int64_t inputRank = inputShape.size(); + auto [iteratorTypes, indexingMaps] = + computeIteratorTypesAndIndexingMaps(builder, inputRank, dim); + assert(indexingMaps.size() == 2 && + "We should have two maps: 1 for the input, 1 for the output"); + assert(indexingMaps[0].isIdentity() && "input map should be identity"); + + auto genericOp = builder.create( + loc, output.getType(), input, output, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value result = b.create(loc, args[0], args[1]); + b.create(loc, result); + }); + return genericOp.getResult(0); +} + +/// Produce a linalg generic that computes the second step of the softmax +/// decomposition: res = exp(input - max), where \p max is the max of \p input +/// on dimension \p dim. +static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, + Value max, Value output, int64_t dim) { + auto inputType = cast(input.getType()); + ArrayRef inputShape = inputType.getShape(); + int64_t inputRank = inputShape.size(); + auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps( + builder, inputRank, dim, /*allParallel=*/true); + assert(indexingMaps.size() == 2 && "We should have one map for each input"); + assert(indexingMaps[0].isIdentity() && "input map should be identity"); + // Add the affine map for the output argument. + indexingMaps.push_back(indexingMaps[0]); + auto genericOp = builder.create( + loc, input.getType(), ValueRange{input, max}, output, indexingMaps, + iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + Value diff = b.create(loc, args[0], args[1]); + Value result = b.create(loc, diff); + b.create(loc, result); + }); + return genericOp.getResult(0); +} + +/// Produce a linalg generic that computes the final step of the softmax +/// decomposition. +/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) { +/// yield n / d +/// } +static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, + Value denominator, Value output, int64_t dim) { + auto inputType = cast(numerator.getType()); + ArrayRef inputShape = inputType.getShape(); + int64_t inputRank = inputShape.size(); + auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps( + builder, inputRank, dim, /*allParallel=*/true); + assert(indexingMaps.size() == 2 && + "We should have one map for each input (2)"); + assert(indexingMaps[0].isIdentity() && "Numerator map should be identity"); + // Add the affine map for the output tensor. + indexingMaps.push_back(indexingMaps[0]); + auto genericOp = builder.create( + loc, numerator.getType(), ValueRange{numerator, denominator}, output, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value result = b.create(loc, args[0], args[1]); + b.create(loc, result); + }); + return genericOp.getResult(0); +} +// @} End helper functions for softmax decomposition. + +/// Given an N-dimensional tensor x, this method converts +/// softmax(x) to the following sequence of operations: +/// +/// 1. Compute the max of x along dimension d. This results +/// in a N-1 dimensional tensor m. +/// m = max(x, dim = d) +/// +/// 2. Subtract a broadcasted m from x and exponentiate. This results in +/// a N dimensional tensor z. +/// z = exp(x - m) +/// +/// 3. Compute the sum of z along dimension d. This results in +/// a N-1 dimensional tensor l. +/// l = sum(z, dim = d) +/// +/// 4. Divide z and l. This gives the N-dimensional softmax. +/// softmax = z / l +/// +FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(*this); + Location loc = getLoc(); + Value input = getInput(); + ShapedType inputType = getInputOperandType(); + Type elementType = inputType.getElementType(); + int64_t reductionDim = getDimension(); + SmallVector dims = tensor::getMixedSizes(b, loc, input); + Value outputNd = b.create(loc, dims, elementType); + dims.erase(dims.begin() + reductionDim); + // Step 1: Compute max along dim. + Value output = b.create(loc, dims, elementType); + Value neutralForMaxF = + arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc); + Value neutralForMaxFInit = + b.create(loc, Value{neutralForMaxF}, output).result(); + Value max = + reduce(b, loc, input, neutralForMaxFInit, reductionDim); + + // Step 2: Subtract max from input and exponentiate. + Value numerator = + buildSubAndExpOp(b, loc, input, max, outputNd, reductionDim); + + // Step 3: Compute sum along dim. + Value zero = + arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc); + Value zeroInit = b.create(loc, Value{zero}, output).result(); + Value denominator = + reduce(b, loc, numerator, zeroInit, reductionDim); + + // Step 4: Compute softmax. + Value result = + buildDivOp(b, loc, numerator, denominator, outputNd, reductionDim); + return SmallVector{result}; +} + //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index a51050b..f6e0f27 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -336,6 +336,38 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter, } //===----------------------------------------------------------------------===// +// DecomposeInterfaceOp +//===----------------------------------------------------------------------===// + +// Decompose the target operation if it implements the AggregatedOpInterface. +// Push the decomposed operations (the ones that replaces the values produced by +// \p target) in the `results`. +DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + auto decomposableOp = dyn_cast(target); + if (!decomposableOp) { + failed(rewriter.notifyMatchFailure(target, + "payload is not a decomposable op")); + return emitDefaultSilenceableFailure(target); + } + + FailureOr> maybeNewResults = + decomposableOp.decomposeOperation(rewriter); + if (failed(maybeNewResults)) + return emitDefaultSilenceableFailure(target); + + rewriter.replaceOp(decomposableOp, *maybeNewResults); + for (Value val : *maybeNewResults) { + Operation *definition = val.getDefiningOp(); + if (definition) + results.push_back(definition); + } + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// // EliminateLinalgOpAnchoredEmptyTensorsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir index 052992f..30c2ab8 100644 --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -1,5 +1,8 @@ // RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + // CHECK-LABEL: @conv_2d_nhwc_hwcf // CHECK-SAME: %[[ARG0:.+]]: tensor, // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32> @@ -199,8 +202,54 @@ func.func @pooling_nchw_max(%input: tensor, %filter: tensor<1x?xf32 return %0 : tensor } +func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { + %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> + return %1 : tensor<2x16x32xf32> +} + +// CHECK-LABEL: func.func @softmax( +//CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { +// CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32> +// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32> +// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000 : f32 +// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> +// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", +// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) { +// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32): +// CHECK: %[[D8:.+]] = arith.maxf %[[IN]], %[[OUT]] : f32 +// CHECK: linalg.yield %[[D8]] : f32 +// CHECK: } -> tensor<2x16xf32> +// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types = +// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>) +// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) { +// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32): +// CHECK: %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32 +// CHECK: %[[D9:.+]] = math.exp %[[D8]] : f32 +// CHECK: linalg.yield %[[D9]] : f32 +// CHECK: } -> tensor<2x16x32xf32> +// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> +// CHECK: %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", +// CHECK-SAME: "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) { +// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32): +// CHECK: %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32 +// CHECK: linalg.yield %[[D8]] : f32 +// CHECK: } -> tensor<2x16xf32> +// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types = +// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>) +// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) { +// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32): +// CHECK: %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32 +// CHECK: linalg.yield %[[D8]] : f32 +// CHECK: } -> tensor<2x16x32xf32> +// CHECK: return %[[D7]] : tensor<2x16x32xf32> +// CHECK: } + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.decompose %0 : (!transform.any_op) -> !transform.any_op + + %2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op } -- 2.7.4