`shape.get_extent` now accepts extent tensors `tensor<?xindex>` as an argument.
Differential Revision: https://reviews.llvm.org/D84158
}
def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
- let summary = "Gets the specified extent from a shape";
+ let summary = "Gets the specified extent from a shape or extent tensor";
let description = [{
- Gets the extent indexed by `dim` from `shape`.
- If the shape is an error, it returns an error size.
+ Gets the extent indexed by `dim` from the `shape` operand. If the shape is
+ an error then it returns an error size.
}];
- let arguments = (ins
- Shape_ShapeType:$shape,
- Shape_SizeType:$dim
- );
+ let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
+ Shape_SizeType:$dim);
let results = (outs Shape_SizeType:$extent);
- let assemblyFormat = "$shape `,` $dim attr-dict";
+ let assemblyFormat = "$shape `,` $dim `:` type($shape) attr-dict";
let builders = [
// Builder that allows passing a constant dimension as a simple integer.
// CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
// CHECK: return %[[RESULT]] : index
%shape = shape.shape_of %arg : tensor<2x3xf32>
- %result = shape.get_extent %shape, %idx
+ %result = shape.get_extent %shape, %idx : !shape.shape
return %result : !shape.size
}
// CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
// CHECK: return %[[RESULT]] : index
%shape = shape.from_extent_tensor %extents : tensor<?xindex>
- %result = shape.get_extent %shape, %idx
+ %result = shape.get_extent %shape, %idx : !shape.shape
return %result : !shape.size
}
// CHECK-LABEL: func @basic
func @basic() -> !shape.size {
// CHECK: shape.const_size 2
- %0 = shape.const_shape [0, 1, 2] : !shape.shape
+ %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
%c2 = shape.const_size 2
- %1 = shape.get_extent %0, %c2
+ %1 = shape.get_extent %0, %c2 : tensor<?xindex>
return %1 : !shape.size
}
func @out_of_bounds() -> !shape.size {
// CHECK: shape.const_shape
// CHECK: shape.get_extent
- %0 = shape.const_shape [0, 1, 2] : !shape.shape
+ %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
%c3 = shape.const_size 3
- %1 = shape.get_extent %0, %c3
+ %1 = shape.get_extent %0, %c3 : tensor<?xindex>
return %1 : !shape.size
}
// Should not fold.
// CHECK-LABEL: func @not_const
-func @not_const(%arg0: !shape.shape) -> !shape.size {
+func @not_const(%arg0: tensor<?xindex>) -> !shape.size {
// CHECK: shape.get_extent
%c3 = shape.const_size 3
- %0 = shape.get_extent %arg0, %c3
+ %0 = shape.get_extent %arg0, %c3 : tensor<?xindex>
return %0 : !shape.size
}
%result = shape.shape_eq %a, %b : tensor<?xindex>, !shape.shape
return %result : i1
}
+
+func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
+ %c0 = shape.const_size 0
+ %result = shape.get_extent %arg, %c0 : !shape.shape
+ return %result : !shape.size
+}
+
+func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
+ %c0 = shape.const_size 0
+ %result = shape.get_extent %arg, %c0 : tensor<?xindex>
+ return %result : !shape.size
+}