[mlir][linalg] Prepare fusion on tensors for scalar operands.
authorTobias Gysi <gysit@google.com>
Wed, 9 Jun 2021 06:31:53 +0000 (06:31 +0000)
committerTobias Gysi <gysit@google.com>
Wed, 9 Jun 2021 07:09:46 +0000 (07:09 +0000)
Adapt fusion on tensors to support structured ops taking scalar operands.

Differential Revision: https://reviews.llvm.org/D103889

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

index d4dbb5a..f65a0fa 100644 (file)
@@ -701,24 +701,27 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
       }));
 
   SmallVector<Value> expandedOpOperands;
+  expandedOpOperands.reserve(genericOp.getNumInputs());
   for (OpOperand *opOperand : genericOp.getInputOperands()) {
     if (opOperand == fusableOpOperand) {
       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
                                                : collapsingReshapeOp.src());
       continue;
     }
-    AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
-    RankedTensorType expandedOperandType =
-        getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
-                        indexingMap, expansionInfo);
-    if (expandedOperandType != opOperand->get().getType()) {
-      // Reshape the operand to get the right type.
-      SmallVector<ReassociationIndices> reassociation =
-          getReassociationForExpansion(indexingMap, expansionInfo);
-      expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
-          genericOp.getLoc(), expandedOperandType, opOperand->get(),
-          reassociation));
-      continue;
+    if (genericOp.isInputTensor(opOperand)) {
+      AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+      RankedTensorType expandedOperandType =
+          getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
+                          indexingMap, expansionInfo);
+      if (expandedOperandType != opOperand->get().getType()) {
+        // Reshape the operand to get the right type.
+        SmallVector<ReassociationIndices> reassociation =
+            getReassociationForExpansion(indexingMap, expansionInfo);
+        expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
+            genericOp.getLoc(), expandedOperandType, opOperand->get(),
+            reassociation));
+        continue;
+      }
     }
     expandedOpOperands.push_back(opOperand->get());
   }
@@ -1035,7 +1038,7 @@ public:
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    for (OpOperand *opOperand : genericOp.getInputOperands()) {
+    for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
       TensorCollapseShapeOp reshapeOp =
           opOperand->get().getDefiningOp<TensorCollapseShapeOp>();
       if (!reshapeOp)