From 9c27fa3821dc5c04f5710e64411815893de160ce Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Wed, 9 Jun 2021 06:31:53 +0000 Subject: [PATCH] [mlir][linalg] Prepare fusion on tensors for scalar operands. Adapt fusion on tensors to support structured ops taking scalar operands. Differential Revision: https://reviews.llvm.org/D103889 --- .../Dialect/Linalg/Transforms/FusionOnTensors.cpp | 29 ++++++++++++---------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index d4dbb5a..f65a0fa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -701,24 +701,27 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, })); SmallVector expandedOpOperands; + expandedOpOperands.reserve(genericOp.getNumInputs()); for (OpOperand *opOperand : genericOp.getInputOperands()) { if (opOperand == fusableOpOperand) { expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() : collapsingReshapeOp.src()); continue; } - AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); - RankedTensorType expandedOperandType = - getExpandedType(opOperand->get().getType().cast(), - indexingMap, expansionInfo); - if (expandedOperandType != opOperand->get().getType()) { - // Reshape the operand to get the right type. - SmallVector reassociation = - getReassociationForExpansion(indexingMap, expansionInfo); - expandedOpOperands.push_back(rewriter.create( - genericOp.getLoc(), expandedOperandType, opOperand->get(), - reassociation)); - continue; + if (genericOp.isInputTensor(opOperand)) { + AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + RankedTensorType expandedOperandType = + getExpandedType(opOperand->get().getType().cast(), + indexingMap, expansionInfo); + if (expandedOperandType != opOperand->get().getType()) { + // Reshape the operand to get the right type. + SmallVector reassociation = + getReassociationForExpansion(indexingMap, expansionInfo); + expandedOpOperands.push_back(rewriter.create( + genericOp.getLoc(), expandedOperandType, opOperand->get(), + reassociation)); + continue; + } } expandedOpOperands.push_back(opOperand->get()); } @@ -1035,7 +1038,7 @@ public: LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : genericOp.getInputOperands()) { + for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { TensorCollapseShapeOp reshapeOp = opOperand->get().getDefiningOp(); if (!reshapeOp) -- 2.7.4