[mlir] Remove redundant shape.cstr_broadcastable canonicalization.
authorTres Popp <tpopp@google.com>
Tue, 15 Sep 2020 16:28:59 +0000 (18:28 +0200)
committerTres Popp <tpopp@google.com>
Thu, 17 Sep 2020 07:01:13 +0000 (09:01 +0200)
These canonicalizations are already handled by folding which will occur
in a superset of situations, so they are being removed.

Differential Revision: https://reviews.llvm.org/D87706

mlir/lib/Dialect/Shape/IR/Shape.cpp

index cd72287..3be53ee 100644 (file)
@@ -399,46 +399,6 @@ LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
     return failure();
   }
 }
-
-// For shapes that were created by some operations, we can obtain partial
-// information on the shapes and sometimes determine if they will be
-// broadcastable with that.
-struct CstrBroadcastablePartialInfo
-    : public OpRewritePattern<CstrBroadcastableOp> {
-  using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(CstrBroadcastableOp op,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<int64_t, 6> lhsShape, rhsShape;
-    if (failed(getShapeVec(op.lhs(), lhsShape)))
-      return failure();
-    if (failed(getShapeVec(op.rhs(), rhsShape)))
-      return failure();
-    if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
-    return success();
-  }
-};
-
-// Scalars are always broadcastable.
-struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> {
-  using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(CstrBroadcastableOp op,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<int64_t, 6> shape;
-    if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0)
-      return failure();
-    if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0)
-      return failure();
-
-    rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
-    return success();
-  }
-};
-
 } // namespace
 
 void CstrBroadcastableOp::getCanonicalizationPatterns(
@@ -446,8 +406,7 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
   // Canonicalization patterns have overlap with the considerations during
   // folding in case additional shape information is inferred at some point that
   // does not result in folding.
-  patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
-                  CstrBroadcastableScalar>(context);
+  patterns.insert<CstrBroadcastableEqOps>(context);
 }
 
 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {