[mlir][linalg] Prepare drop unit dims for scalar operands.
authorTobias Gysi <gysit@google.com>
Fri, 11 Jun 2021 12:53:21 +0000 (12:53 +0000)
committerTobias Gysi <gysit@google.com>
Fri, 11 Jun 2021 13:18:06 +0000 (13:18 +0000)
Adapt drop unit dims for structured ops taking scalar operands.

Differential Revision: https://reviews.llvm.org/D103890

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

index fb59907..102dbdb 100644 (file)
@@ -249,7 +249,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
 };
 
 struct UnitExtentReplacementInfo {
-  RankedTensorType type;
+  Type type;
   AffineMap indexMap;
   ArrayAttr reassociation;
 };
@@ -271,10 +271,10 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
   ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
-  SmallVector<AffineExpr, 2> reassociations;
-  SmallVector<Attribute, 4> reassociationMaps;
-  SmallVector<AffineExpr, 4> newIndexExprs;
-  SmallVector<int64_t, 4> newShape;
+  SmallVector<AffineExpr> reassociations;
+  SmallVector<Attribute> reassociationMaps;
+  SmallVector<AffineExpr> newIndexExprs;
+  SmallVector<int64_t> newShape;
 
   int64_t origRank = genericOp.getRank(opOperand);
   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
@@ -282,7 +282,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
     return shape[dim] == 1 && exprs[dim] == zeroExpr;
   };
 
-  unsigned dim = 0;
+  int64_t dim = 0;
   // Fold dimensions that are unit-extent at the beginning of the tensor.
   while (dim < origRank && isUnitExtent(dim))
     reassociations.push_back(getAffineDimExpr(dim++, context));
@@ -300,12 +300,16 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
     reassociations.clear();
     ++dim;
   }
-  UnitExtentReplacementInfo info = {
-      RankedTensorType::get(newShape,
-                            getElementTypeOrSelf(opOperand->get().getType())),
-      AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(),
-                     newIndexExprs, context),
-      ArrayAttr::get(context, reassociationMaps)};
+  // Compute the tensor or scalar replacement type.
+  Type elementType = getElementTypeOrSelf(opOperand->get().getType());
+  Type replacementType = elementType == opOperand->get().getType()
+                             ? elementType
+                             : RankedTensorType::get(newShape, elementType);
+  UnitExtentReplacementInfo info = {replacementType,
+                                    AffineMap::get(indexingMap.getNumDims(),
+                                                   indexingMap.getNumSymbols(),
+                                                   newIndexExprs, context),
+                                    ArrayAttr::get(context, reassociationMaps)};
   return info;
 }
 
@@ -331,13 +335,14 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
     MLIRContext *context = rewriter.getContext();
     Location loc = genericOp.getLoc();
 
-    SmallVector<AffineMap, 4> newIndexingMaps;
-    SmallVector<ArrayAttr, 4> reassociationMaps;
-    SmallVector<ShapedType, 4> newInputOutputTypes;
+    SmallVector<AffineMap> newIndexingMaps;
+    SmallVector<ArrayAttr> reassociationMaps;
+    SmallVector<Type> newInputOutputTypes;
     bool doCanonicalization = false;
 
     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
-      auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
+      UnitExtentReplacementInfo replacementInfo =
+          replaceUnitExtents(genericOp, opOperand, context);
       reassociationMaps.push_back(replacementInfo.reassociation);
       newIndexingMaps.push_back(replacementInfo.indexMap);
       newInputOutputTypes.push_back(replacementInfo.type);