}
};
+class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
+ using OpConversionPattern<GetExtentOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ GetExtentOp::Adaptor transformed(operands);
+
+ // Derive shape extent directly from shape origin if possible.
+ // This circumvents the necessity to materialize the shape in memory.
+ if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
+ rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
+ transformed.dim());
+ return success();
+ }
+
+ rewriter.replaceOpWithNewOp<ExtractElementOp>(
+ op, rewriter.getIndexType(), transformed.shape(),
+ ValueRange{transformed.dim()});
+ return success();
+ }
+};
+
class RankOpConverter : public OpConversionPattern<shape::RankOp> {
public:
using OpConversionPattern<shape::RankOp>::OpConversionPattern;
BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>,
ConstSizeOpConverter,
+ GetExtentOpConverter,
RankOpConverter,
ShapeOfOpConversion>(ctx);
// clang-format on
(Shape_SizeToIndexOp $arg),
(replaceWithValue $arg)>;
-// Derive shape extent directly from shape origin if possible.
-// This circumvents the necessity to materialize the shape in memory.
-def GetExtentShapeOfConversion : Pat<
- (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx),
- (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))),
- [],
- (addBenefit 10)>;
-def GetExtentFromExtentTensorConversion : Pattern<
- (Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx),
- [
- (Shape_SizeToIndexOp:$std_idx $idx),
- (ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)),
- (Shape_IndexToSizeOp $std_result)
- ],
- [],
- (addBenefit 10)>;
-