}
/// Conditions for elementwise fusion of generic operations.
-static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
- OpOperand *consumerOpOperand) {
+bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
+ auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
+ auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
+
+ // Check producer and consumer are generic ops.
+ if (!producer || !consumer)
+ return false;
+
// Producer and consumer must have tensor semantics.
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
return false;
// Only allow fusing the producer of an input operand for now.
// TODO: allow fusing the producer of an output operand.
- if (!consumer.isInputTensor(consumerOpOperand))
+ if (!consumer.isInputTensor(fusedOperand))
return false;
// Get the consumer index map. The number of results of the consumer index
// map must match the number of loops of the producer.
- AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
+ AffineMap consumerIndexMap = consumer.getTiedIndexingMap(fusedOperand);
if (consumerIndexMap.getNumResults() != producer.getNumLoops())
return false;
- // Currently support only operations with single result.
- if (producer.getNumOutputs() != 1)
- return false;
-
// Finally the index_map for the result must be invertible. For now just
// verify it is a permutation.
AffineMap producerResultIndexMap =
for (auto pair :
llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
Value operand = std::get<0>(pair);
- if (operand == consumerOpOperand->get())
+ if (operand == fusedOperand->get())
continue;
AffineMap operandMap = std::get<1>(pair);
addToCoveredDims(operandMap);
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void
-generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
+generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp,
AffineMap consumerToProducerLoopsMap,
- OpOperand *consumerOpOperand,
- unsigned nloops) {
- auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
- auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
+ OpOperand *fusedOperand, unsigned nloops) {
+ auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
+ auto consumer = cast<GenericOp>(fusedOperand->getOwner());
// Build the region of the fused op.
Block &producerBlock = producer->getRegion(0).front();
Block &consumerBlock = consumer->getRegion(0).front();
}
}
// TODO: allow fusing the producer of an output operand.
- assert(consumer.isInputTensor(consumerOpOperand) &&
+ assert(consumer.isInputTensor(fusedOperand) &&
"expected producer of input operand");
// 3. Consumer input operands up to consumerIdx (exclusive).
for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
- consumerOpOperand->getOperandNumber())) // input assumption.
+ fusedOperand->getOperandNumber())) // input assumption.
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
// Replacing consumerIdx requires getting the cloned, yielded, value from
producerBlock.getArguments().take_front(producer.getNumInputs()))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
- // 4.b. Producer output operand/map that is fused needs to be mapped to the
- // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
- assert(producer->getNumResults() == 1 && "expected single result producer");
- if (producer.isInitTensor(producer.getOutputOperand(0))) {
- BlockArgument bbArg = producerBlock.getArguments()
- .drop_front(producer.getNumInputs())
- // TODO: bbArg index of
- .front();
- mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
- }
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
for (BlockArgument bbArg :
consumerBlock.getArguments()
.take_front(consumer.getNumInputs())
- .drop_front(consumerOpOperand->getOperandNumber() + 1))
+ .drop_front(fusedOperand->getOperandNumber() + 1))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
- // 6. All of consumer's output operands.
+
+ // 6. All of the producer's output operands
+ for (BlockArgument bbArg :
+ producerBlock.getArguments().take_back(producer.getNumOutputs()))
+ mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
+
+ // 7. All of consumer's output operands.
for (BlockArgument bbArg :
consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
- // 7. All of producer's output operands except the one fused.
- // TODO: allow fusion of multi-result producers.
- assert(producer->getNumResults() == 1 && "expected single result producer");
// 8. Clone all producer operations except for the yield and index operations
// to the fused operation.
}
// 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
// forward the yield operand.
- auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
- // TODO: allow fusion of multi-result producers.
- assert(producer->getNumResults() == 1 && "expected single result producer");
- unsigned producerResultNumber = 0;
+ auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
+ unsigned producerResultNumber =
+ fusedOperand->get().cast<OpResult>().getResultNumber();
Value replacement =
- mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
+ mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
+
// Sanity checks, if replacement is not already in the mapper then it must be
// produced outside.
- if (replacement == yieldOp.getOperand(producerResultNumber)) {
+ if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
if (auto bb = replacement.dyn_cast<BlockArgument>())
assert(bb.getOwner() != &producerBlock &&
"yielded block argument must have been mapped");
assert(!producer->isAncestor(replacement.getDefiningOp()) &&
"yielded value must have been mapped");
}
- mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
+ mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
replacement);
// 10. Clone operations from the consumer to the fused op.
- for (auto &op : consumerBlock.getOperations())
+ for (auto &op : consumerBlock.without_terminator())
rewriter.clone(op, mapper);
+ // 11. Include the final yield (which is the remapped values for all the
+ // yield)
+ auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
+ SmallVector<Value> fusedYieldValues;
+ fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
+ consumerYieldOp.getNumOperands());
+ for (auto producerYieldVal : producerYieldOp.getOperands())
+ fusedYieldValues.push_back(mapper.lookupOrDefault(producerYieldVal));
+ for (auto consumerYieldVal : consumerYieldOp.getOperands())
+ fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
+ rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
+
// Sanity checks.
assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
"Ill-formed GenericOp region");
}
-static Optional<SmallVector<Value>>
-fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
- const ControlFusionFn &controlFn,
- PatternRewriter &rewriter) {
- auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
- if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
- !controlFn(producer->getResult(0), *consumerOpOperand))
- return llvm::None;
-
+FailureOr<Operation *>
+mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
+ OpOperand *fusedOperand) {
+ assert(areElementwiseOpsFusable(fusedOperand) &&
+ "expected elementwise operation pre-conditions to pass");
+ auto producerResult = fusedOperand->get().cast<OpResult>();
+ auto producer = cast<GenericOp>(producerResult.getOwner());
+ auto consumer = cast<GenericOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
- assert(consumer.isInputTensor(consumerOpOperand) &&
+ assert(consumer.isInputTensor(fusedOperand) &&
"expected producer of input operand");
// Compute the fused operands list and indexing maps.
- SmallVector<Value> fusedOperands;
+ SmallVector<Value> fusedInputOperands, fusedOutputOperands;
+ SmallVector<Type> fusedResultTypes;
SmallVector<AffineMap> fusedIndexMaps;
- fusedOperands.reserve(producer->getNumOperands() +
- consumer->getNumOperands());
- fusedIndexMaps.reserve(producer->getNumOperands() +
- consumer->getNumOperands());
+ fusedInputOperands.reserve(producer.getNumInputs() + consumer.getNumInputs());
+ fusedOutputOperands.reserve(producer.getNumOutputs() +
+ consumer.getNumOutputs());
+ fusedResultTypes.reserve(producer.getNumOutputs() + consumer.getNumOutputs());
+ fusedIndexMaps.reserve(producer.getNumInputsAndOutputs() +
+ consumer.getNumInputsAndOutputs());
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
SmallVector<OpOperand *>::iterator it =
- llvm::find(consumerInputs, consumerOpOperand);
+ llvm::find(consumerInputs, fusedOperand);
assert(it != consumerInputs.end() && "expected to find the consumer operand");
for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
- fusedOperands.push_back(opOperand->get());
+ fusedInputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
}
// 4. Splice in producer's input operands/maps.
- assert(producer->getNumResults() == 1 && "expected single result producer");
AffineMap producerResultIndexMap =
- producer.getTiedIndexingMap(producer.getOutputOperand(0));
+ producer.getTiedIndexingMapForResult(producerResult);
for (OpOperand *opOperand : producer.getInputOperands()) {
- fusedOperands.push_back(opOperand->get());
+ fusedInputOperands.push_back(opOperand->get());
// Compute indexing maps for the producer args in the fused operation.
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
opOperand, producerResultIndexMap,
- consumer.getTiedIndexingMap(consumerOpOperand));
- fusedIndexMaps.push_back(map);
- }
- // 4.b. Producer output operand/map that is fused needs to be passed if it is
- // an "initTensor" (i.e. its value is actually read).
- assert(producer->getNumResults() == 1 && "expected single result producer");
- if (producer.isInitTensor(producer.getOutputOperand(0))) {
- fusedOperands.push_back(producer.getOutputOperand(0)->get());
- // Compute indexing maps for the producer args in the fused operation.
- AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
- producer.getOutputOperand(0), producerResultIndexMap,
- consumer.getTiedIndexingMap(consumerOpOperand));
+ consumer.getTiedIndexingMap(fusedOperand));
fusedIndexMaps.push_back(map);
}
// 5. Remaining consumer's input operands/maps (drop past index
// `consumerIdx`).
for (OpOperand *opOperand :
llvm::make_range(std::next(it), consumerInputs.end())) {
- fusedOperands.push_back(opOperand->get());
+ fusedInputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
}
- // 6. All of consumer's output operands (skip operands: added by the builder).
- for (OpOperand *opOperand : consumer.getOutputOperands())
+
+ // 6. Collect all of the producer outputs.
+ for (OpOperand *opOperand : producer.getOutputOperands()) {
+ fusedOutputOperands.push_back(opOperand->get());
+ AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+ opOperand, producerResultIndexMap,
+ consumer.getTiedIndexingMap(fusedOperand));
+ fusedIndexMaps.push_back(map);
+ fusedResultTypes.push_back(opOperand->get().getType());
+ }
+
+ // 7. All of consumer's output operands (skip operands: added by the builder).
+ for (OpOperand *opOperand : consumer.getOutputOperands()) {
+ fusedOutputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
- // 7. All of producer's output operands/maps except the one fused.
- // TODO: allow fusion of multi-result producers.
- assert(producer->getNumResults() == 1 && "expected single result producer");
+ fusedResultTypes.push_back(opOperand->get().getType());
+ }
// Generate the fused op.
- SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
auto fusedOp = rewriter.create<GenericOp>(
- consumer.getLoc(), consumer->getResultTypes(),
- /*inputs=*/fusedOperands,
- // TODO: handle outputs.
- consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+ consumer.getLoc(), fusedResultTypes, fusedInputOperands,
+ fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.getIteratorTypes(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
// in the input, but going ahead here would result in verification errors.
// So cleanup and abort.
rewriter.eraseOp(fusedOp);
- return llvm::None;
+ return rewriter.notifyMatchFailure(
+ fusedOp, "fused op failed loop bound computation check");
}
// Construct an AffineMap from consumer loops to producer loops.
// consumer loop -> tensor index
- AffineMap consumerResultIndexMap =
- consumer.getTiedIndexingMap(consumerOpOperand);
+ AffineMap consumerResultIndexMap = consumer.getTiedIndexingMap(fusedOperand);
// tensor index -> producer loop
AffineMap invProducerResultIndexMap =
inversePermutation(producerResultIndexMap);
invProducerResultIndexMap.compose(consumerResultIndexMap);
generateFusedElementwiseOpRegion(rewriter, fusedOp,
- consumerToProducerLoopsMap,
- consumerOpOperand, consumer.getNumLoops());
- return SmallVector<Value>(fusedOp->getResults());
-}
-
-static Optional<SmallVector<Value>>
-fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
- GenericOp producer, const ControlFusionFn &controlFn) {
- if (producer->getNumResults() != 1)
- return llvm::None;
-
- return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
- rewriter);
+ consumerToProducerLoopsMap, fusedOperand,
+ consumer.getNumLoops());
+ return fusedOp.getOperation();
}
namespace {
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
- auto producer =
- dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
- if (!producer || !producer.hasTensorSemantics())
+ if (!areElementwiseOpsFusable(opOperand))
+ continue;
+ if (!controlFn(opOperand))
continue;
- Optional<SmallVector<Value>> fusedOpResults =
- fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
- if (fusedOpResults) {
- rewriter.replaceOp(genericOp, *fusedOpResults);
+
+ FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, opOperand);
+ if (succeeded(fusedOp)) {
+ auto replacements = fusedOp.getValue()->getResults().take_back(
+ genericOp.getNumResults());
+ rewriter.replaceOp(genericOp, replacements);
return success();
}
}
return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
}));
+ // Set insertion point to the generic op.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(genericOp);
+
SmallVector<Value> expandedOpOperands;
expandedOpOperands.reserve(genericOp.getNumInputs());
for (OpOperand *opOperand : genericOp.getInputOperands()) {
SmallVector<Value> resultVals;
for (OpResult opResult : genericOp->getOpResults()) {
int64_t resultNumber = opResult.getResultNumber();
- if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
+ if (resultTypes[resultNumber] != opResult.getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(
genericOp.getTiedIndexingMap(
// - The tensor reshape op is folding.
// - All constraints of fusing with reshape by expansion are met.
if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
- (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
+ (!controlFoldingReshapes(opOperand)))
continue;
Optional<SmallVector<Value>> replacementValues =
LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// Fold only if all constraints of fusing with reshape by expansion are met.
- GenericOp producer = reshapeOp.getSrc().getDefiningOp<GenericOp>();
- if (!producer || producer.getNumOutputs() != 1 ||
- !isFusableWithReshapeByDimExpansion(producer,
- producer.getOutputOperand(0)) ||
- !controlFoldingReshapes(producer->getResult(0),
- reshapeOp->getOpOperand(0)))
- return failure();
+ auto producerResult = reshapeOp.getSrc().dyn_cast<OpResult>();
+ if (!producerResult) {
+ return rewriter.notifyMatchFailure(reshapeOp,
+ "source not produced by an operation");
+ }
+
+ auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
+ if (!producer) {
+ return rewriter.notifyMatchFailure(reshapeOp,
+ "producer not a generic op");
+ }
+
+ if (!isFusableWithReshapeByDimExpansion(
+ producer,
+ producer.getOutputOperand(producerResult.getResultNumber()))) {
+ return rewriter.notifyMatchFailure(
+ reshapeOp, "failed preconditions of fusion with producer generic op");
+ }
+
+ if (!controlFoldingReshapes(&reshapeOp->getOpOperand(0))) {
+ return rewriter.notifyMatchFailure(reshapeOp,
+ "fusion blocked by control function");
+ }
+
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
- producer, reshapeOp, producer.getOutputOperand(0), rewriter);
- if (!replacementValues)
- return failure();
- rewriter.replaceOp(reshapeOp, *replacementValues);
+ producer, reshapeOp,
+ producer.getOutputOperand(producerResult.getResultNumber()), rewriter);
+ if (!replacementValues) {
+ return rewriter.notifyMatchFailure(reshapeOp,
+ "fusion by expansion failed");
+ }
+
+ // Find the replacement for the reshape op. Since the replacements have the
+ // same type as the returns of the original generic op, the consumer reshape
+ // op can be replaced by the source of the collapse_shape op that defines
+ // the replacement.
+ Value reshapeReplacement = (*replacementValues)
+ [reshapeOp.getSrc().cast<OpResult>().getResultNumber()];
+ if (auto collapseOp =
+ reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
+ reshapeReplacement = collapseOp.getSrc();
+ }
+ rewriter.replaceOp(reshapeOp, reshapeReplacement);
+ rewriter.replaceOp(producer, *replacementValues);
return success();
}
getCollapsableIterationSpaceDims(genericOp, opOperand,
reshapeOp.getReassociationIndices());
if (collapsableIterationDims.empty() ||
- !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
+ !controlFoldingReshapes(opOperand)) {
continue;
}
RewritePatternSet patterns(context);
// Add folding with reshape by expansion patterns.
- ControlFusionFn defaultControlFn = [](const OpResult &producer,
- const OpOperand &consumer) {
- return producer.hasOneUse();
+ ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
+ Operation *producer = fusedOperand->get().getDefiningOp();
+ return producer && producer->hasOneUse();
};
// Add elementwise op fusion patterns.
linalg.yield %r : f64
} -> tensor<1x8xf64>
- // CHECK-NEXT: %[[R:.*]] = linalg.generic
+ // CHECK-NEXT: %[[R:.*]]:2 = linalg.generic
// CHECK: bb0(%[[BBA:[0-9a-z]*]]: f64, %[[BBB:[0-9a-z]*]]: i32):
// CHECK-NEXT: %[[A:.*]] = func.call @compute1(%[[BBA]]) : (f64) -> f64
// CHECK-NEXT: %[[B:.*]] = func.call @compute2(%[[A]], %[[BBB]]) : (f64, i32) -> i32
- // CHECK-NEXT: linalg.yield %[[B]] : i32
- // CHECK-NEXT: } -> tensor<1x8xi32>
+ // CHECK-NEXT: linalg.yield %[[A]], %[[B]] : f64, i32
+ // CHECK-NEXT: } -> (tensor<1x8xf64>, tensor<1x8xi32>)
%1 = linalg.generic {
indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>],
iterator_types = ["parallel", "parallel"]}
linalg.yield %r : i32
} -> tensor<1x8xi32>
- // CHECK-NEXT: return %[[R]] : tensor<1x8xi32>
+ // CHECK-NEXT: return %[[R]]#1 : tensor<1x8xi32>
return %1 : tensor<1x8xi32>
}
// -----
-func.func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> {
+func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> {
%c1_i32 = arith.constant 1 : i32
%0 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>],
} -> tensor<5000xi32>
return %2 : tensor<5000xi32>
}
-// CHECK-LABEL: func @illegal_fusion(
-// CHECK: %[[PRODUCER:.+]] = linalg.generic
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[PRODUCER]]
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: func @fusion_different_axes(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5000xi64>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<5000xi32>
+// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [5000] : tensor<5000xi64>
+// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [5000] : tensor<5000xi32>
+// CHECK: %[[RESULT:.+]]:2 = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: outs(%[[INIT0]], %[[INIT1]] :
+// CHECK-NEXT: ^bb0(
+// CHECK-SAME: %[[B0:.+]]: i64
+// CHECK-SAME: %[[B1:.+]]: i32
+// CHECK-DAG: %[[T0:.+]] = linalg.index 0
+// CHECK-DAG: %[[CAST1:.+]] = arith.index_cast %[[T0]] : index to i64
+// CHECK-DAG: %[[CAST2:.+]] = arith.index_cast %[[CAST1]] : i64 to index
+// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[CAST2]]]
+// CHECK: linalg.yield %[[CAST1]], %[[EXTRACT]]
+// CHECK: return %[[RESULT]]#1
// -----
%4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%5 = arith.addf %arg1, %arg2 : f32
- linalg.yield %5 : f32
+ linalg.yield %5 : f32
} -> tensor<?xf32>
return %4 : tensor<?xf32>
}
%7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
%8 = arith.divf %arg1, %arg2 : f32
- linalg.yield %8 : f32
+ linalg.yield %8 : f32
} -> tensor<?x?xf32>
return %7 : tensor<?x?xf32>
}
+
+// -----
+
+#map = affine_map<() -> ()>
+module {
+ func.func @fuse_multi_result_producer(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> tensor<f32> {
+ %0 = linalg.init_tensor [] : tensor<f32>
+ %1 = linalg.init_tensor [] : tensor<f32>
+ %2:2 = linalg.generic {
+ indexing_maps = [#map, #map, #map, #map, #map], iterator_types = []}
+ ins(%arg0, %arg1, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>) outs(%0, %1 : tensor<f32>, tensor<f32>) {
+ ^bb0(%arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32, %arg9: f32):
+ %4 = arith.addf %arg5, %arg6 : f32
+ %5 = arith.addf %4, %arg7 : f32
+ linalg.yield %4, %5 : f32, f32
+ } -> (tensor<f32>, tensor<f32>)
+ %3 = linalg.generic {
+ indexing_maps = [#map, #map, #map], iterator_types = []}
+ ins(%2#1, %arg1 : tensor<f32>, tensor<f32>) outs(%arg4 : tensor<f32>) {
+ ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
+ %4 = arith.addf %arg5, %arg6 : f32
+ %5 = arith.addf %4, %arg6 : f32
+ linalg.yield %5 : f32
+ } -> tensor<f32>
+ return %3 : tensor<f32>
+ }
+}
+// CHECK-LABEL: func.func @fuse_multi_result_producer
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<f32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<f32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
+// CHECK-SAME: outs(%[[INIT]] :
+// CHECK-NEXT: ^bb0
+// CHECK-SAME: %[[B0:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME: %[[B1:[a-zA-Z0-9]+]]: f32
+// CHECK-DAG: %[[T0:.+]] = arith.addf %[[B0]], %[[B1]]
+// CHECK-DAG: %[[T1:.+]] = arith.addf %[[T0]], %[[B1]]
+// CHECK-DAG: %[[T2:.+]] = arith.addf %[[T1]], %[[B1]]
+// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
+// CHECK: linalg.yield %[[T3]] : f32
+// CHECK: return %[[GENERIC]]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor<?xi64>)
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @reshape_as_consumer_permutation_with_multiple_results
+ (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>)
+ -> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
+ %c:2 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d2, d0, d1)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%a, %b : tensor<?x?x?xf32>, tensor<?x?xf32>)
+ outs(%a, %a : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+ ^bb0(%arg0 : f32, %arg1: f32, %s: f32, %t : f32):
+ %1 = arith.addf %arg0, %arg1 : f32
+ linalg.yield %1, %1 : f32, f32
+ } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ %d = tensor.expand_shape %c#0 [[0, 1], [2], [3, 4, 5]]
+ : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+ %e = tensor.expand_shape %c#1 [[0], [1, 2], [3, 4, 5]]
+ : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
+ return %d, %e : tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5, d0, d1, d2, d3, d4)>
+// CHECK: func @reshape_as_consumer_permutation_with_multiple_results
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3, 4], [5]{{\]}}
+// CHECK-DAG: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1, 2], [3]{{\]}}
+// CHECK-DAG: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3, 4, 5]{{\]}}
+// CHECK-DAG: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3, 4, 5]{{\]}}
+// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: ins(%[[RESHAPE0]], %[[RESHAPE1]] :
+// CHECK-SAME: outs(%[[RESHAPE2]], %[[RESHAPE3]] :
+// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1