From 6983cf3a57aa6d8619eb39e1625eed5340ba05c7 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Fri, 31 Jul 2020 14:17:31 +0000 Subject: [PATCH] [MLIR][Shape] Allow unsafe `shape.broadcast` In a context in which `shape.broadcast` is known not to produce an error value, we want it to operate solely on extent tensors. The operation's behavior is then undefined in the error case as the result type cannot hold this value. Differential Revision: https://reviews.llvm.org/D84933 --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 42 ++++++++++++++------------ mlir/test/Dialect/Shape/canonicalize.mlir | 25 +++++++++++++++ mlir/test/Dialect/Shape/invalid.mlir | 10 +++--- 3 files changed, 53 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 72e392b..bc7b6048 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -49,25 +49,24 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> { def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> { let summary = "Returns the broadcasted output shape of two inputs"; let description = [{ - Computes the broadcasted output shape following: - 1. If any inputs are unranked, output is unranked; - 2. Else the input array with number of dimensions smaller than the max - input dimension, has 1’s prepended to its shapes and the output shape is - calculated as follows: - - output[i] = lhs[i] if lhs[i] == rhs[i] or rhs[i] is unknown/undefined - = rhs[i] if lhs[i] is unknown/undefined - = lhs[i] if rhs[i] == 1 - = rhs[i] if lhs[i] == 1 - = error if lhs[i] != rhs[i] - - Op has an optional string attribute for the error case where there is no - broadcastable output shape possible for the given inputs. - - Op may also return an ExtentTensor, but this should only be done when this - is statically guaranteed to never fail, either because of a dependency on a - cstr_broadcastable operation or other details of the construction of the - program. + Returns the broadcasted shape for two input shapes or extent tensors. Both + operands can be of type `shape.shape` or `tensor`. The result is of + type `shape.shape` and, if both operands are tensors, may be of type + `tensor`. + + If the two operand shapes are of different rank the smaller one is padded + with 1's from the left. The resulting broadcasted shape is then defined as + + result[i] = lhs[i] if lhs[i] == rhs[i] + = lhs[i] if rhs[i] == 1 + = rhs[i] if lhs[i] == 1. + + In case the resulting shape is undefined, i.e. if corresponding extents are + different from each other but none is 1, the result is an error shape. + Likewise error values are propagated if any of the operands holds an error + value. If the result type is an extent tensor (and can therefore not hold + the error value) the behavior may be undefined. The optional string + attribute can be used to describe the error case. }]; let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, @@ -75,8 +74,11 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> { OptionalAttr:$error); let results = (outs Shape_ShapeOrExtentTensorType:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; + let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let hasFolder = 1; let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index e18ff14..21c5a68 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -60,6 +60,31 @@ func @f() -> !shape.shape { // ----- +// Basic case including extent tensors. +// CHECK-LABEL: @broadcast +func @broadcast() -> tensor { + // CHECK: shape.const_shape [7, 2] : tensor + %0 = shape.const_shape [1, 2] : tensor + %1 = shape.const_shape [7, 1] : tensor + %2 = shape.broadcast %0, %1 + : tensor, tensor -> tensor + return %2 : tensor +} + +// ----- + +// Basic case including extent tensors. +// CHECK-LABEL: @broadcast +func @broadcast() -> !shape.shape { + // CHECK: shape.const_shape [7, 2] : !shape.shape + %0 = shape.const_shape [1, 2] : tensor + %1 = shape.const_shape [7, 1] : tensor + %2 = shape.broadcast %0, %1 : tensor, tensor -> !shape.shape + return %2 : !shape.shape +} + +// ----- + // Rhs is a scalar. // CHECK-LABEL: func @f func @f(%arg0 : !shape.shape) -> !shape.shape { diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir index 448bd84..eb0ae5a 100644 --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -138,17 +138,19 @@ func @add(%lhs : !shape.size, %rhs : index) -> index { // ----- -func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor { +func @broadcast(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor { // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}} - %result = shape.broadcast %arg0, %arg1 : !shape.shape, !shape.shape -> tensor + %result = shape.broadcast %arg0, %arg1 + : !shape.shape, !shape.shape -> tensor return %result : tensor } // ----- -func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : tensor) -> tensor { +func @broadcast(%arg0 : !shape.shape, %arg1 : tensor) -> tensor { // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}} - %result = shape.broadcast %arg0, %arg1 : !shape.shape, tensor -> tensor + %result = shape.broadcast %arg0, %arg1 + : !shape.shape, tensor -> tensor return %result : tensor } -- 2.7.4