let hasCanonicalizer = 1;
}
-def Shape_JoinOp : Shape_Op<"join",
+def Shape_MaxOp : Shape_Op<"max",
+ [Commutative, NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Elementwise maximum";
+ let description = [{
+ Computes the elementwise maximum of two sizes or shapes with equal ranks.
+ If either operand is an error, then an error will be propagated to the
+ result. If the input types mismatch or the ranks do not match, then the
+ result is an error.
+ }];
+
+ let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
+ let results = (outs Shape_ShapeOrSizeType:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+ }];
+
+ let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ // Returns when two result types are compatible for this op; method used by
+ // InferTypeOpInterface
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
+}
+
+def Shape_MeetOp : Shape_Op<"meet",
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the least general shape.shape of its operands";
let description = [{
contradictory requirements. E.g., using pseudo code
```
- shape.join([*], [*]) -> [*]
- shape.join([*], [1, ?]) -> [1, ?]
- shape.join([1, 2], [1, ?]) -> [1, 2]
- shape.join([*], [1, 2]) -> [1, 2]
- shape.join([], []) -> []
- shape.join([], [*]) -> []
- shape.join([], [?, ?]) -> [invalid]
- shape.join([1, ?], [2, ?, ?]) -> [invalid]
+ shape.meet([*], [*]) -> [*]
+ shape.meet([*], [1, ?]) -> [1, ?]
+ shape.meet([1, 2], [1, ?]) -> [1, 2]
+ shape.meet([*], [1, 2]) -> [1, 2]
+ shape.meet([], []) -> []
+ shape.meet([], [*]) -> []
+ shape.meet([], [?, ?]) -> [invalid]
+ shape.meet([1, ?], [2, ?, ?]) -> [invalid]
```
- `shape.join` also allows specifying an optional error string, that may be
+ `shape.meet` also allows specifying an optional error string, that may be
used to return an error to the user upon mismatch of dimensions.
```mlir
- %c = shape.join %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
+ %c = shape.meet %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
```
}];
}];
}
-def Shape_MaxOp : Shape_Op<"max",
- [Commutative, NoSideEffect,
- DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
- let summary = "Elementwise maximum";
- let description = [{
- Computes the elementwise maximum of two sizes or shapes with equal ranks.
- If either operand is an error, then an error will be propagated to the
- result. If the input types mismatch or the ranks do not match, then the
- result is an error.
- }];
-
- let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
- let results = (outs Shape_ShapeOrSizeType:$result);
-
- let assemblyFormat = [{
- $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
- }];
-
- let hasFolder = 1;
-
- let extraClassDeclaration = [{
- // Returns when two result types are compatible for this op; method used by
- // InferTypeOpInterface
- static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
- }];
-}
-
def Shape_MinOp : Shape_Op<"min",
[Commutative, NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
func @test_shape_any_fixed() {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape
- %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_unknown() {
%0 = shape.const_shape [4, -1, 92] : !shape.shape
%1 = shape.const_shape [-1, 57, 92] : !shape.shape
- %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_fixed_mismatch() {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.const_shape [2, 57, 92] : !shape.shape
- %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
%0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
%1 = shape.shape_of %b : !shape.value_shape -> !shape.shape
- %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
+ %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
return %2 : !shape.shape
}
func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
- %2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
+ %2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}
func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
- %2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
+ %2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}
func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 5
%1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
- %2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
+ %2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}
func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 9
%1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
- %2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
+ %2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}