/*methodBody=*/"",
/*defaultImplementation=*/[{
auto range = this->getOperation()->getOperands();
- return {range.begin() + getNumInputsAndOutputBuffers(),
- range.begin() + getNumInputsAndOutputs()};
+ auto base = range.begin() + getNumInputsAndOutputBuffers();
+ return {base, base + $_op.getNumInitTensors()};
}]
>,
InterfaceMethod<
/// allow transformations like tiling to just use the values when cloning
/// `linalgOp`.
SmallVector<Value, 4> getAssumedNonShapedOperands() {
- unsigned numShapedOperands = getNumInputsAndOutputs();
+ unsigned numShapedOperands = getNumShapedOperands();
unsigned nExtraOperands =
getOperation()->getNumOperands() - numShapedOperands;
SmallVector<Value, 4> res;
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/FormatVariadic.h"
};
} // namespace
+namespace {
+// Deduplicate redundant args of a linalg op.
+// An arg is redundant if it has the same Value and indexing map as another.
+struct DeduplicateInputs : public RewritePattern {
+ DeduplicateInputs(PatternBenefit benefit = 1)
+ : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // This pattern reduces the number of arguments of an op, which breaks
+ // the invariants of semantically charged named ops.
+ if (!isa<GenericOp, IndexedGenericOp>(op))
+ return failure();
+ auto linalgOp = cast<LinalgOp>(op);
+
+ // Associate each input to an equivalent "canonical" input that has the same
+ // Value and indexing map.
+ //
+ // In the non-duplicate case, input `i` will have canonical input `i`. But
+ // in the case of duplicated inputs, the canonical input could be some other
+ // input `< i`. That is, a later input will have some earlier input as its
+ // canonical input.
+ llvm::SmallDenseMap<std::pair<Value, AffineMap>, int> canonicalInput;
+ // For later remapping tasks like deduplicating payload block arguments,
+ // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
+ // convenient.
+ SmallVector<int, 6> canonicalInputIndices;
+ for (int i = 0, e = linalgOp.getNumInputs(); i != e; i++) {
+ Value input = linalgOp.getInput(i);
+ AffineMap indexingMap = linalgOp.getInputIndexingMap(i);
+ // STL-like maps have a convenient behavior for our use case here. In the
+ // case of duplicate keys, the insertion is rejected, and the returned
+ // iterator gives access to the value already in the map.
+ auto pair = canonicalInput.insert({{input, indexingMap}, i});
+ canonicalInputIndices.push_back(pair.first->second);
+ }
+
+ // If there are no duplicate args, then bail out.
+ if (canonicalInput.size() == linalgOp.getNumInputs())
+ return failure();
+
+ // The operands for the newly canonicalized op.
+ SmallVector<Value, 6> newOperands;
+ for (auto v : llvm::enumerate(linalgOp.getInputs()))
+ if (canonicalInputIndices[v.index()] == static_cast<int>(v.index()))
+ newOperands.push_back(v.value());
+ llvm::append_range(newOperands, linalgOp.getOutputBuffers());
+ llvm::append_range(newOperands, linalgOp.getInitTensors());
+ llvm::append_range(newOperands, linalgOp.getAssumedNonShapedOperands());
+
+ // Clone the old op with new operands.
+ Operation *newOp = linalgOp.clone(rewriter, op->getLoc(),
+ op->getResultTypes(), newOperands);
+ auto newLinalgOp = cast<LinalgOp>(newOp);
+
+ // Repair the indexing maps by filtering out the ones that have been
+ // eliminated.
+ SmallVector<AffineMap, 6> newIndexingMaps;
+ for (int i = 0, e = newLinalgOp.getNumInputs(); i != e; i++)
+ if (canonicalInputIndices[i] == i)
+ newIndexingMaps.push_back(newLinalgOp.getIndexingMap(i));
+ for (int i = 0, e = newLinalgOp.getNumOutputs(); i != e; i++)
+ newIndexingMaps.push_back(newLinalgOp.getOutputIndexingMap(i));
+ newOp->setAttr("indexing_maps",
+ rewriter.getAffineMapArrayAttr(newIndexingMaps));
+
+ // Set the number of inputs to the new value. The `clone` call above kept
+ // the value from the original op.
+ newLinalgOp.setNumInputs(canonicalInput.size());
+
+ // linalg.indexed_generic payloads have additional arguments prepended to
+ // the block arg list. The number of such args is one per dimension of the
+ // iteration space.
+ int bbArgBaseOffset = 0;
+ if (isa<IndexedGenericOp>(op))
+ bbArgBaseOffset = newIndexingMaps[0].getNumInputs();
+
+ // Repair the payload entry block by RAUW'ing redundant arguments and
+ // erasing them.
+ Block &payload = newOp->getRegion(0).front();
+ for (int i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
+ // Iterate in reverse, so that we erase later args first, preventing the
+ // argument list from shifting unexpectedly and invalidating all our
+ // indices.
+ int reversed = e - i - 1;
+ int canonicalIndex = canonicalInputIndices[reversed];
+ if (canonicalInputIndices[reversed] == reversed)
+ continue;
+ payload.getArgument(bbArgBaseOffset + reversed)
+ .replaceAllUsesWith(
+ payload.getArgument(bbArgBaseOffset + canonicalIndex));
+ payload.eraseArgument(bbArgBaseOffset + reversed);
+ }
+
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+} // namespace
+
#define CANONICALIZERS_AND_FOLDERS(XXX) \
void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
MLIRContext *context) { \
results.insert<EraseDeadLinalgOp>(); \
results.insert<FoldTensorCastOp>(); \
+ results.insert<DeduplicateInputs>(); \
} \
\
LogicalResult XXX::fold(ArrayRef<Attribute>, \
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+
+// Test case: Most basic case. Adding a vector to itself.
+
+#map = affine_map<(d0) -> (d0)>
+
+// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @basic
+func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK: linalg.generic{{.*}}[#[[$MAP]], #[[$MAP]]]
+ // CHECK: ^bb0(%[[BBARG:.*]]: f32):
+ // CHECK: addf %[[BBARG]], %[[BBARG]]
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg0 : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// Test case: Different indexing maps mean that args are not redundant, despite
+// being the same Value.
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-LABEL: @distinct_affine_maps
+func @distinct_affine_maps(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: linalg.generic{{.*}}[#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]]
+ %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+// Test case: Check rewriting mechanics for mixed redundant and
+// non-redundant args.
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-LABEL: @mixed_redundant_non_redundant
+func @mixed_redundant_non_redundant(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: linalg.generic{{.*}}[#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]]
+ // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %[[BBARG1:.*]]: f32):
+ // CHECK: "test.elementwise_mappable"(%[[BBARG0]], %[[BBARG1]], %[[BBARG0]])
+ %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %1 = "test.elementwise_mappable"(%arg1, %arg2, %arg3) : (f32, f32, f32) -> f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+// Test case: Check rewriting mechanics for multiple different redundant args.
+
+#map = affine_map<(d0) -> (d0)>
+
+// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @multiple_different_redundant_args
+func @multiple_different_redundant_args(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK: linalg.generic{{.*}}[#[[$MAP]], #[[$MAP]], #[[$MAP]]]
+ // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %[[BBARG1:.*]]: f32):
+ // CHECK: "test.elementwise_mappable"(%[[BBARG0]], %[[BBARG1]], %[[BBARG0]], %[[BBARG1]])
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1, %arg0, %arg1 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32):
+ %1 = "test.elementwise_mappable"(%arg2, %arg3, %arg4, %arg5) : (f32, f32, f32, f32) -> f32
+ linalg.yield %1 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// Test case: linalg.indexed_generic.
+// Other than the payload argument handling, everything else is the same.
+
+#map = affine_map<(d0) -> (d0)>
+
+// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @indexed_generic
+func @indexed_generic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK: linalg.indexed_generic
+ // CHECK: ^bb0(%{{.*}}: index, %[[BBARG:.*]]: f32):
+ // CHECK: addf %[[BBARG]], %[[BBARG]]
+ %0 = linalg.indexed_generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg0 : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%index: index, %arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}