[mlir] Insert tensor.cast only when needed when folding tensor.cast into extract_slice.
authorAlexander Belyaev <pifon@google.com>
Mon, 27 Feb 2023 14:16:27 +0000 (15:16 +0100)
committerAlexander Belyaev <pifon@google.com>
Mon, 27 Feb 2023 14:18:01 +0000 (15:18 +0100)
Differential Revision: https://reviews.llvm.org/D144868

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

index b6359a9..b31ef3b 100644 (file)
@@ -1829,7 +1829,7 @@ public:
         }))
       return failure();
 
-    auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
+    auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
     if (!castOp)
       return failure();
 
@@ -1837,17 +1837,20 @@ public:
       return failure();
 
     /// Deduce the type of the result to use for the canonicalized operation.
+    Location loc = sliceOp.getLoc();
+    auto sliceOpType = sliceOp.getType();
     RankedTensorType resultType =
         ExtractSliceOp::inferCanonicalRankReducedResultType(
-            sliceOp.getType().getRank(), sliceOp.getSourceType(),
+            sliceOpType.getRank(), sliceOp.getSourceType(),
             sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
             sliceOp.getMixedStrides());
-    Value newSlice = rewriter.create<ExtractSliceOp>(
-        sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
+    Value newResult = rewriter.create<ExtractSliceOp>(
+        loc, resultType, castOp.getSource(), sliceOp.getOffsets(),
         sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
         sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
-                                                newSlice);
+    if (newResult.getType() != sliceOpType)
+      newResult = rewriter.create<CastOp>(loc, sliceOpType, newResult);
+    rewriter.replaceOp(sliceOp, newResult);
     return success();
   }
 };