let arguments = (ins IndexAttr:$value);
let results = (outs Shape_SizeType:$result);
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &result, int64_t value">
+ ];
+
let assemblyFormat = "$value attr-dict";
let hasFolder = 1;
}
let assemblyFormat = "attr-dict $shape";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
// ConstSizeOp
//===----------------------------------------------------------------------===//
+void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
+ int64_t value) {
+ build(builder, result, builder.getIndexAttr(value));
+}
+
OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
void ConstSizeOp::getAsmResultNames(
return builder.getIndexAttr(rank);
}
+/// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
+/// Constant folding fails in cases where only the rank is constant, not the
+/// shape itself.
+/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
+///
+/// Example:
+///
+/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
+/// %rank = shape.rank %shape
+///
+/// becomes
+///
+/// %rank = shape.const_size 3
+
+namespace {
+struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
+ using OpRewritePattern<RankOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(RankOp op,
+ PatternRewriter &rewriter) const override {
+ auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
+ if (!shapeOfOp)
+ return failure();
+ auto rankedTensorType =
+ shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
+ if (!rankedTensorType)
+ return failure();
+ int64_t rank = rankedTensorType.getRank();
+ rewriter.replaceOpWithNewOp<ConstSizeOp>(op.getOperation(), rank);
+ return success();
+ }
+};
+} // namespace
+
+void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context) {
+ patterns.insert<RankShapeOfCanonicalizationPattern>(context);
+}
+
//===----------------------------------------------------------------------===//
// NumElementsOp
//===----------------------------------------------------------------------===//
%rank = shape.rank %shape
return %rank : !shape.size
}
+
+// -----
+
+// Canonicalize `rank` when shape is derived from ranked tensor.
+// CHECK-LABEL: @canonicalize_rank
+func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
+// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
+// CHECK-DAG: return %[[RESULT]] : !shape.size
+%shape = shape.shape_of %arg : tensor<1x2x?xf32>
+%rank = shape.rank %shape
+return %rank : !shape.size
+}
+
+// -----
+
+// Do not canonicalize `rank` when shape is derived from unranked tensor.
+// CHECK-LABEL: @dont_canonicalize_rank
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size
+func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
+// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32>
+// CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
+// CHECK-DAG: return %[[SIZE]] : !shape.size
+%shape = shape.shape_of %arg : tensor<*xf32>
+%rank = shape.rank %shape
+return %rank : !shape.size
+}