(Shape_SizeToIndexOp $arg),
(replaceWithValue $arg)>;
+// Derive shape extent directly from shape origin if possible.
+// This circumvents the necessity to materialize the shape in memory.
+def GetExtentShapeOfConversion : Pat<
+ (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx),
+ (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))),
+ [],
+ (addBenefit 10)>;
%rank = shape.rank %shape
return %rank : !shape.size
}
+
+// -----
+
+// Express `get_extent` as `std.dim` when it relies directly on the outcome of a
+// `shape_of` operation.
+// CHECK-LABEL: @get_extent_shape_of
+// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
+func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
+ -> !shape.size {
+ // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
+ // CHECK: return %[[RESULT]] : index
+ %shape = shape.shape_of %arg : tensor<2x3xf32>
+ %result = shape.get_extent %shape, %idx
+ return %result : !shape.size
+}
+