def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
let summary = "Returns the broadcasted output shape of two inputs";
let description = [{
- Computes the broadcasted output shape following:
- 1. If any inputs are unranked, output is unranked;
- 2. Else the input array with number of dimensions smaller than the max
- input dimension, has 1’s prepended to its shapes and the output shape is
- calculated as follows:
-
- output[i] = lhs[i] if lhs[i] == rhs[i] or rhs[i] is unknown/undefined
- = rhs[i] if lhs[i] is unknown/undefined
- = lhs[i] if rhs[i] == 1
- = rhs[i] if lhs[i] == 1
- = error if lhs[i] != rhs[i]
-
- Op has an optional string attribute for the error case where there is no
- broadcastable output shape possible for the given inputs.
-
- Op may also return an ExtentTensor, but this should only be done when this
- is statically guaranteed to never fail, either because of a dependency on a
- cstr_broadcastable operation or other details of the construction of the
- program.
+ Returns the broadcasted shape for two input shapes or extent tensors. Both
+ operands can be of type `shape.shape` or `tensor<?xindex>`. The result is of
+ type `shape.shape` and, if both operands are tensors, may be of type
+ `tensor<?xindex>`.
+
+ If the two operand shapes are of different rank the smaller one is padded
+ with 1's from the left. The resulting broadcasted shape is then defined as
+
+ result[i] = lhs[i] if lhs[i] == rhs[i]
+ = lhs[i] if rhs[i] == 1
+ = rhs[i] if lhs[i] == 1.
+
+ In case the resulting shape is undefined, i.e. if corresponding extents are
+ different from each other but none is 1, the result is an error shape.
+ Likewise error values are propagated if any of the operands holds an error
+ value. If the result type is an extent tensor (and can therefore not hold
+ the error value) the behavior may be undefined. The optional string
+ attribute can be used to describe the error case.
}];
let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeOrExtentTensorType:$result);
- let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)";
+ let assemblyFormat = [{
+ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+ }];
+ let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
let hasFolder = 1;
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
// -----
+// Basic case including extent tensors.
+// CHECK-LABEL: @broadcast
+func @broadcast() -> tensor<?xindex> {
+ // CHECK: shape.const_shape [7, 2] : tensor<?xindex>
+ %0 = shape.const_shape [1, 2] : tensor<?xindex>
+ %1 = shape.const_shape [7, 1] : tensor<?xindex>
+ %2 = shape.broadcast %0, %1
+ : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+ return %2 : tensor<?xindex>
+}
+
+// -----
+
+// Basic case including extent tensors.
+// CHECK-LABEL: @broadcast
+func @broadcast() -> !shape.shape {
+ // CHECK: shape.const_shape [7, 2] : !shape.shape
+ %0 = shape.const_shape [1, 2] : tensor<?xindex>
+ %1 = shape.const_shape [7, 1] : tensor<?xindex>
+ %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> !shape.shape
+ return %2 : !shape.shape
+}
+
+// -----
+
// Rhs is a scalar.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) -> !shape.shape {
// -----
-func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor<?xindex> {
+func @broadcast(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor<?xindex> {
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
- %result = shape.broadcast %arg0, %arg1 : !shape.shape, !shape.shape -> tensor<?xindex>
+ %result = shape.broadcast %arg0, %arg1
+ : !shape.shape, !shape.shape -> tensor<?xindex>
return %result : tensor<?xindex>
}
// -----
-func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : tensor<?xindex>) -> tensor<?xindex> {
+func @broadcast(%arg0 : !shape.shape, %arg1 : tensor<?xindex>) -> tensor<?xindex> {
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
- %result = shape.broadcast %arg0, %arg1 : !shape.shape, tensor<?xindex> -> tensor<?xindex>
+ %result = shape.broadcast %arg0, %arg1
+ : !shape.shape, tensor<?xindex> -> tensor<?xindex>
return %result : tensor<?xindex>
}