broadcastable output shape possible for the given inputs.
}];
- let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs,
- OptionalAttr<StrAttr>:$error);
+ let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
+ Shape_ShapeOrExtentTensorType:$rhs,
+ OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeType:$result);
- let assemblyFormat = "$lhs `,` $rhs attr-dict";
+ let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
let hasFolder = 1;
}
If the shape represents an error, this op's behavior is undefined.
}];
- let arguments = (ins Shape_ShapeType:$input);
+ let arguments = (ins Shape_ShapeOrExtentTensorType:$input);
let results = (outs IndexTensor:$result);
- let assemblyFormat = "attr-dict $input `:` type($result)";
+ let assemblyFormat = "attr-dict $input `:` type($input) `->` type($result)";
let hasFolder = 1;
}
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
let results = (outs Shape_SizeOrIndexType:$result);
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &result, Value shape">,
+ ];
+
let assemblyFormat = "$shape `:` type($shape) `->` type($result) attr-dict";
let hasFolder = 1;
let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict";
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &result, Value arg">
+ ];
+
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
}
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
let summary = "Casts between index types of the shape and standard dialect";
let description = [{
- Converts a `shape.size` to a standard index.
- This operation and its inverse, `index_to_size`, facilitate index conversion
- between the standard and the shape dialect.
- The behavior is undefined for unknown and invalid arguments.
+ Converts a `shape.size` to a standard index. This operation and its
+ inverse, `index_to_size`, facilitate index conversion between the standard
+ and the shape dialect. The behavior is undefined for unknown and invalid
+ arguments.
}];
- let arguments = (ins Shape_SizeType:$arg);
+ let arguments = (outs Shape_SizeOrIndexType:$arg);
let results = (outs Index:$result);
- let assemblyFormat = "$arg attr-dict";
+ let assemblyFormat = "$arg attr-dict `:` type($arg)";
let hasFolder = 1;
let hasCanonicalizer = 1;
- `index` is in the range [-rank(operand),rank(operand)]
}];
- let arguments = (ins Shape_ShapeType:$operand, I32:$index);
+ let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, I32:$index);
let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
let hasFolder = 1;
}
// TODO: Move the code below and witnesses to a different file.
def Shape_AnyOp : Shape_Op<"any", [Commutative,
- NoSideEffect,
- SameOperandsAndResultType]> {
+ NoSideEffect]> {
let summary = "Return any combination of the input shapes";
let description = [{
This operation takes multiple input shapes or extent tensors and returns
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs);
let results = (outs Shape_ShapeOrExtentTensorType:$result);
- let assemblyFormat = "$inputs `:` type($result) attr-dict";
let hasFolder = 1;
}
return builder.getIndexAttr(product.getLimitedValue());
}
+void NumElementsOp::build(OpBuilder &builder, OperationState &result,
+ Value shape) {
+ if (shape.getType().isa<ShapedType>()) {
+ auto type = builder.getIndexType();
+ return build(builder, result, type, shape);
+ }
+ auto type = SizeType::get(builder.getContext());
+ return build(builder, result, type, shape);
+}
+
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
return builder.getIndexTensorAttr(type.getShape());
}
+void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
+ if (arg.getType().isa<ShapedType>()) {
+ auto type = RankedTensorType::get({ShapedType::kDynamicSize},
+ builder.getIndexType());
+ return ShapeOfOp::build(builder, result, type, arg);
+ }
+ auto type = ShapeType::get(builder.getContext());
+ return ShapeOfOp::build(builder, result, type, arg);
+}
+
+namespace {
+struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+ using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(shape::ShapeOfOp op,
+ PatternRewriter &rewriter) const override {
+ if (!op.arg().getType().isa<ShapedType>())
+ return failure();
+ if (op.getType().isa<ShapedType>())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
+ return success();
+ }
+};
+} // namespace
+
+void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context) {
+ patterns.insert<ShapeOfWithTensor>(context);
+}
+
//===----------------------------------------------------------------------===//
// SizeToIndexOp
//===----------------------------------------------------------------------===//
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
- // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
%shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
return
}
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
// CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
- // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
%shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
return
}
%b : tensor<?xindex>,
%c : tensor<?xindex>) -> tensor<?xindex> {
// CHECK: return %[[A]] : tensor<?xindex>
- %result = shape.any %a, %b, %c : tensor<?xindex>
+ %result = "shape.any"(%a, %b, %c) : (tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
return %result : tensor<?xindex>
}
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> {
// CHECK: return %[[A]] : tensor<?xindex>
- %result = shape.any %a : tensor<?xindex>
+ %result = "shape.any"(%a) : (tensor<?xindex>) -> tensor<?xindex>
return %result : tensor<?xindex>
}
// CHECK: shape.const_shape [7, 2] : !shape.shape
%0 = shape.const_shape [1, 2] : !shape.shape
%1 = shape.const_shape [7, 1] : !shape.shape
- %2 = shape.broadcast %0, %1
+ %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
return %2 : !shape.shape
}
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0
%0 = shape.const_shape [] : !shape.shape
- %1 = shape.broadcast %arg0, %0
+ %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape
return %1 : !shape.shape
}
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0
%0 = shape.const_shape [] : !shape.shape
- %1 = shape.broadcast %0, %arg0
+ %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape
return %1 : !shape.shape
}
// CHECK: return %[[CST]]
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
- %2 = shape.broadcast %0, %1
+ %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
return %2 : !shape.shape
}
// CHECK: shape.broadcast
%0 = shape.const_shape [2] : !shape.shape
%1 = shape.const_shape [7] : !shape.shape
- %2 = shape.broadcast %0, %1
+ %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
return %2 : !shape.shape
}
func @f() -> tensor<2xindex> {
// CHECK: constant dense<[0, 1]> : tensor<2xindex>
%cs = shape.const_shape [0, 1] : !shape.shape
- %0 = shape.to_extent_tensor %cs : tensor<2xindex>
+ %0 = shape.to_extent_tensor %cs : !shape.shape -> tensor<2xindex>
return %0 : tensor<2xindex>
}
// CHECK-NOT: shape.index_cast
%cs = shape.const_size 123
// CHECK: constant 123 : index
- %ci = shape.size_to_index %cs
+ %ci = shape.size_to_index %cs : !shape.size
return %ci : index
}
%cs0 = shape.index_to_size %ci0
// CHECK: %[[CI:.*]] = constant 123 : index
// CHECK-NEXT: return %[[CI]] : index
- %ci1 = shape.size_to_index %cs0
+ %ci1 = shape.size_to_index %cs0 : !shape.size
return %ci1 : index
}
// CHECK-LABEL: func @nonfoldable_size_to_index
func @nonfoldable_size_to_index(%cs : !shape.size) -> index {
// CHECK: shape.size_to_index
- %ci = shape.size_to_index %cs
+ %ci = shape.size_to_index %cs : !shape.size
return %ci : index
}
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape
// CHECK-NEXT: return %[[CS]]
%0 = shape.const_shape [2, 3, 4] : !shape.shape
- %1 = shape.any %0, %arg : !shape.shape
+ %1 = "shape.any"(%0, %arg) : (!shape.shape, !shape.shape) -> !shape.shape
return %1 : !shape.shape
}
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
// CHECK-NEXT: return %[[CS]] : tensor<?xindex>
%0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
- %1 = shape.any %0, %arg : tensor<?xindex>
+ %1 = "shape.any"(%0, %arg) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
return %1 : tensor<?xindex>
}
// Folding of any with partially constant operands is not yet implemented.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
- // CHECK-NEXT: %[[CS:.*]] = shape.any
+ // CHECK-NEXT: %[[CS:.*]] = "shape.any"
// CHECK-NEXT: return %[[CS]]
- %1 = shape.any %arg0, %arg1 : !shape.shape
+ %1 = "shape.any"(%arg0, %arg1) : (!shape.shape, !shape.shape) -> !shape.shape
return %1 : !shape.shape
}
func @index_to_size_to_index(%index : index) -> index {
// CHECK: return %[[IDX]] : index
%size = shape.index_to_size %index
- %result = shape.size_to_index %size
+ %result = shape.size_to_index %size : !shape.size
return %result : index
}
// CHECK-SAME: (%[[SIZE:.*]]: !shape.size) -> !shape.size
func @size_to_index_to_size(%size : !shape.size) -> !shape.size {
// CHECK: return %[[SIZE]] : !shape.size
- %idx = shape.size_to_index %size
+ %idx = shape.size_to_index %size : !shape.size
%result = shape.index_to_size %idx
return %result : !shape.size
}
func @test_broadcast_fixed() {
%0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape
- %2 = shape.broadcast %0, %1
+ %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
%w3 = shape.const_witness false
%w4 = shape.assuming_all %w0, %w1, %w2, %w3
shape.assuming %w4 -> !shape.shape {
- %2 = shape.any %0, %1 : !shape.shape
+ %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
shape.assuming_yield %2 : !shape.shape
}
return
}
func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> {
- %0 = shape.to_extent_tensor %arg : tensor<3xindex>
+ %0 = shape.to_extent_tensor %arg : !shape.shape -> tensor<3xindex>
return %0 : tensor<3xindex>
}
func @any() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape
%1 = shape.const_shape [4, 5, 6] : !shape.shape
- %2 = shape.any %0, %1 : !shape.shape
+ %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
%4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
- %5 = shape.any %3, %4 : tensor<?xindex>
+ %5 = "shape.any"(%3, %4) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
return
}