[mlir] Generalize broadcastable trait operands
authorJacques Pienaar <jpienaar@google.com>
Sat, 11 Jan 2020 17:42:18 +0000 (09:42 -0800)
committerJacques Pienaar <jpienaar@google.com>
Mon, 20 Jan 2020 21:02:14 +0000 (13:02 -0800)
Summary:
Generalize broadcastable trait to variadic operands. Update the
documentation that still talked about element type as part of
broadcastable trait (that bug was already fixed). Also rename
Broadcastable to ResultBroadcastableShape to be more explicit that the
trait affects the result shape (it is possible for op to allow
broadcastable operands but not have result shape that is broadcast
compatible with operands).

Doing some intermediate work to have getBroadcastedType take an optional
elementType as input and use that if specified, instead of the common
element type of type1 and type2 in this function.

Differential Revision: https://reviews.llvm.org/D72559

mlir/docs/Traits.md
mlir/include/mlir/Dialect/Traits.h
mlir/include/mlir/IR/OpBase.td
mlir/lib/Dialect/Traits.cpp
mlir/test/Dialect/traits.mlir
mlir/test/lib/TestDialect/TestOps.td
mlir/tools/mlir-tblgen/RewriterGen.cpp

index b233f9bef66dd7d4955f7f9f04f7ef7055f48785..63614f8faa0208935d339b3b47363855c139c3c1 100644 (file)
@@ -137,20 +137,20 @@ section goes as follows:
 
 ### Broadcastable
 
-*   `OpTrait::BroadcastableTwoOperandsOneResult` -- `Broadcastable`
+*   `OpTrait::ResultsBroadcastableShape` -- `ResultsBroadcastableShape`
 
-This trait provides the API for operations that are known to have
+This trait adds the property that the operation is known to have
 [broadcast-compatible](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-operand and result types. Specifically, starting from the most varying
-dimension, each dimension pair of the two operands' types should either be the
-same or one of them is one. Also, the result type should have the corresponding
+operands and its result types' shape is the broadcast compatible with the shape
+of the broadcasted operands. Specifically, starting from the most varying
+dimension, each dimension pair of the two operands' shapes should either be the
+same or one of them is one. Also, the result shape should have the corresponding
 dimension equal to the larger one, if known. Shapes are checked partially if
 ranks or dimensions are not known. For example, an op with `tensor<?x2xf32>` and
 `tensor<2xf32>` as operand types and `tensor<3x2xf32>` as the result type is
 broadcast-compatible.
 
-Ths trait assumes the op has two operands and one result, and it asserts if the
-pre-condition is not satisfied.
+This trait requires that the operands are either vector or tensor types.
 
 ### Commutative
 
index 87c8e662a65521eeb187d11d0f8df54016866114..a9c1b9a4f2f9500f77cc134ed6083930c9dfddfd 100644 (file)
@@ -51,23 +51,26 @@ bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
 /// either of the input types has dynamic shape. Returns null type if the two
 /// given types are not broadcast-compatible.
-Type getBroadcastedType(Type type1, Type type2);
+///
+/// elementType, if specified, will be used as the element type of the
+/// broadcasted result type. Otherwise it is required that the element type of
+/// type1 and type2 is the same and this element type will be used as the
+/// resultant element type.
+Type getBroadcastedType(Type type1, Type type2, Type elementType = nullptr);
+
 } // namespace util
 
-/// This class provides the API for ops that are known to have broadcast-
-/// compatible operand and result types. Specifically,  starting from the
-/// most varying dimension, each dimension pair of the two operands' types
-/// should either be the same or one of them is one. Also, the result type
-/// should have the corresponding dimension equal to the larger one, if known.
-/// Shapes are checked partially if ranks or dimensions are not known. For
-/// example, an op with tensor<? x 2 x f32> and tensor <2 x f32> as operand
-/// types and tensor<3 x 2 x f32> as the result type is broadcast-compatible.
-///
-/// Ths trait assumes the op has two operands and one result, and it asserts
-/// if the pre-condition is not satisfied.
+/// Trait for ops that are known to have broadcast compatible operands and
+/// result types. Specifically, starting from the most varying dimension, each
+/// dimension pair of the operands' shapes should either be the same or one
+/// of them is one. Also, the results's shapes should have the corresponding
+/// dimension equal to the larger one, if known. Shapes are checked partially if
+/// ranks or dimensions are not known. For example, an op with tensor<?x2xf32>
+/// and tensor<2xf32> as operand types and tensor<5x3x2xi16> as the result
+/// type has broadcast compatible operands ns result types.
 template <typename ConcreteType>
-class BroadcastableTwoOperandsOneResult
-    : public TraitBase<ConcreteType, BroadcastableTwoOperandsOneResult> {
+class ResultsBroadcastableShape
+    : public TraitBase<ConcreteType, ResultsBroadcastableShape> {
 public:
   static LogicalResult verifyTrait(Operation *op) {
     return impl::verifyCompatibleOperandBroadcast(op);
index 4420ffec1a8495641fd8d95c331ecc916c23df8a..775123362bbff46be9c5aabfa0ad6140b659af48 100644 (file)
@@ -1327,7 +1327,10 @@ class PredOpTrait<string descr, Pred pred> : OpTrait {
 }
 
 // Op supports operand broadcast behavior.
-def Broadcastable  : NativeOpTrait<"BroadcastableTwoOperandsOneResult">;
+def ResultsBroadcastableShape :
+  NativeOpTrait<"ResultsBroadcastableShape">;
+// TODO: Alias of the above, remove post integrate.
+def Broadcastable : NativeOpTrait<"ResultsBroadcastableShape">;
 // X op Y == Y op X
 def Commutative  : NativeOpTrait<"IsCommutative">;
 // Op behaves like a function.
index 11aea0936f20d0647e458f91d4c77f5b7c7f51d5..6de66132937f077d62ea2ce1acb35ca53d6bc717 100644 (file)
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Traits.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
@@ -80,25 +81,27 @@ static ArrayRef<int64_t> getShape(Type type) {
 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
 /// either of the input types has dynamic shape. Returns null type if the two
 /// given types are not broadcast-compatible.
-Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
-  // Returns the scalar type out of the given type.
-  auto getScalarType = [](Type type) -> Type {
-    if (auto shapedType = type.dyn_cast<ShapedType>())
-      return shapedType.getElementType();
-    return type;
-  };
-
-  // Make sure underlying scalar type is the same.
-  auto scalarType = getScalarType(type1);
-  if (scalarType != getScalarType(type2))
-    return {};
+///
+/// elementType, if specified, will be used as the element type of the
+/// broadcasted result type. Otherwise it is required that the element type of
+/// type1 and type2 is the same and this element type will be used as the
+/// resultant element type.
+Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
+                                       Type elementType) {
+  // If the elementType is not specified, then the use the common element type
+  // of the inputs or fail if there is no common element type.
+  if (!elementType) {
+    elementType = getElementTypeOrSelf(type1);
+    if (elementType != getElementTypeOrSelf(type2))
+      return {};
+  }
 
   // If one of the types is unranked tensor, then the other type shouldn't be
   // vector and the result should have unranked tensor type.
   if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
     if (type1.isa<VectorType>() || type2.isa<VectorType>())
       return {};
-    return UnrankedTensorType::get(scalarType);
+    return UnrankedTensorType::get(elementType);
   }
 
   // Returns the type kind if the given type is a vector or ranked tensor type.
@@ -132,16 +135,18 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
 
   // Compose the final broadcasted type
   if (resultCompositeKind == StandardTypes::Vector)
-    return VectorType::get(resultShape, scalarType);
+    return VectorType::get(resultShape, elementType);
   if (resultCompositeKind == StandardTypes::RankedTensor)
-    return RankedTensorType::get(resultShape, scalarType);
-  return scalarType;
+    return RankedTensorType::get(resultShape, elementType);
+  return elementType;
 }
 
-/// Returns true if the given types has both vector types and tensor types.
-static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
-  return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
-         llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
+/// Returns a tuple corresponding to whether range has tensor or vector type.
+template <typename iterator_range>
+static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
+  return std::make_tuple(
+      llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }),
+      llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
 }
 
 static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
@@ -157,55 +162,57 @@ static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
   return true;
 }
 
-LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
-  assert(op->getNumOperands() == 2 &&
-         "only support broadcast check on two operands");
-  assert(op->getNumResults() == 1 &&
-         "only support broadcast check on one result");
-
-  auto type1 = op->getOperand(0).getType();
-  auto type2 = op->getOperand(1).getType();
-  auto retType = op->getResult(0).getType();
+static std::string getShapeString(ArrayRef<int64_t> shape) {
+  // TODO: should replace with printing shape more uniformly across here and
+  // when in type.
+  return formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end()));
+}
 
-  // We forbid broadcasting vector and tensor.
-  if (hasBothVectorAndTensorType({type1, type2, retType}))
+LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
+  // Ensure broadcasting only tensor or only vector types.
+  auto operandsHasTensorVectorType =
+      hasTensorOrVectorType(op->getOperandTypes());
+  auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
+  if ((std::get<0>(operandsHasTensorVectorType) ||
+       std::get<0>(resultsHasTensorVectorType)) &&
+      (std::get<1>(operandsHasTensorVectorType) ||
+       std::get<1>(resultsHasTensorVectorType)))
     return op->emitError("cannot broadcast vector with tensor");
 
-  if (retType.isa<UnrankedTensorType>())
-    return success();
-
-  bool isUnranked1 = type1.isa<UnrankedTensorType>();
-  bool isUnranked2 = type2.isa<UnrankedTensorType>();
+  auto rankedOperands = make_filter_range(
+      op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
 
-  // If both operands are unranked, then all result shapes are possible.
-  if (isUnranked1 && isUnranked2)
+  // If all operands are unranked, then all result shapes are possible.
+  if (rankedOperands.empty())
     return success();
 
-  // If one of the operands is unranked, then the known dimensions in the result
-  // should be compatible with the other shaped operand.
-  if (isUnranked1 || isUnranked2) {
-    // Result should have higher rank than the shaped operand's rank and then
-    // the result's trailing dimensions should be compatible with the operand
-    // shape.
-    ArrayRef<int64_t> shape = getShape(!isUnranked1 ? type1 : type2);
-    ArrayRef<int64_t> actualSuffix = getShape(retType).take_back(shape.size());
-    if (!areCompatibleShapes(actualSuffix, shape))
-      return op->emitOpError()
-             << "result type " << retType
-             << " has shape incompatible with a ranked operand type";
-    return success();
+  // Compute broadcasted shape of operands (which requires that operands are
+  // broadcast compatible). The results need to be broadcast compatible with
+  // this result shape.
+  SmallVector<int64_t, 4> resultShape;
+  (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
+                                  resultShape);
+  for (auto other : make_early_inc_range(rankedOperands)) {
+    SmallVector<int64_t, 4> temp = resultShape;
+    if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
+      return op->emitOpError("operands don't have broadcast-compatible shapes");
   }
 
-  // If both operands are shaped, then the computed broadcasted shape should be
-  // compatible with the result shape.
-  SmallVector<int64_t, 4> resultShape;
-  if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
-    return op->emitOpError("operands don't have broadcast-compatible shapes");
+  auto rankedResults = make_filter_range(
+      op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
 
-  if (!areCompatibleShapes(resultShape, getShape(retType)))
-    return op->emitOpError() << "result type " << retType
-                             << " does not have shape compatible with the one "
-                                "computed from the operand types";
+  // If all of the results are unranked then no further verfication.
+  if (rankedResults.empty())
+    return success();
 
+  for (auto type : rankedResults) {
+    ArrayRef<int64_t> actualSuffix =
+        getShape(type).take_back(resultShape.size());
+    if (!areCompatibleShapes(actualSuffix, resultShape))
+      return op->emitOpError()
+             << "result type " << getShapeString(getShape(type))
+             << " not broadcast compatible with broadcasted operands's shapes "
+             << getShapeString(resultShape);
+  }
   return success();
 }
index 6ef10407768ff71808a96355fec0c86dce391fd9..aaea63d143617656a586c7d7c03dfbe6cb2d21b2 100644 (file)
@@ -78,7 +78,7 @@ func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x3xi32>) -> tens
 // Check incompatible result type with known dimension
 func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32> {
 ^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>):
-  // expected-error @+1 {{does not have shape compatible with the one computed}}
+  // expected-error @+1 {{op result type '4x3x3' not broadcast compatible with broadcasted operands's shapes '4x3x2'}}
   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32>
   return %0 : tensor<4x3x3xi32>
 }
@@ -88,7 +88,7 @@ func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tens
 // Check incompatible result type with known dimension
 func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32> {
 ^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
-  // expected-error @+1 {{does not have shape compatible with the one computed}}
+  // expected-error @+1 {{op result type '8x7x6x1' not broadcast compatible with broadcasted operands's shapes '8x7x6x5'}}
   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32>
   return %0 : tensor<8x7x6x1xi32>
 }
@@ -123,7 +123,7 @@ func @broadcast_tensor_tensor_tensor(tensor<*xi32>, tensor<*xi32>) -> tensor<2xi
 // Unranked operand and compatible ranked result
 func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> {
 ^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>):
-  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32>
+  %0 = "test.broadcastable"(%arg0, %arg0, %arg1) : (tensor<3x2xi32>, tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32>
   return %0 : tensor<4x3x2xi32>
 }
 
@@ -131,7 +131,7 @@ func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<4
 
 func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32> {
 ^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>):
-  // expected-error @+1 {{shape incompatible with a ranked operand type}}
+  // expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '3x2'}}
   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32>
   return %0 : tensor<2xi32>
 }
index 546d21dc699875b04e62f40e7143f305feeb7135..ea729fec8c633e8fd66c79954618eacb9db40650 100644 (file)
@@ -376,8 +376,8 @@ def IfFirstOperandIsNoneThenSoIsSecond :
   let arguments = (ins AnyType:$x, AnyType:$y);
 }
 
-def BroadcastableOp : TEST_Op<"broadcastable", [Broadcastable]> {
-  let arguments = (ins AnyTensor, AnyTensor);
+def BroadcastableOp : TEST_Op<"broadcastable", [ResultsBroadcastableShape]> {
+  let arguments = (ins Variadic<AnyTensor>);
   let results = (outs AnyTensor);
 }
 
index 5c5b04b0c1b359e606fdf2383de41ac85294ecbf..0c058ffbb90099d32555d3cc101e16f848c0ac15 100644 (file)
@@ -781,11 +781,17 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     return resultValue;
   }
 
-  bool isBroadcastable =
-      resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult");
+  // TODO: Remove once broadcastable has been updated. This query here is not
+  // really about broadcastable or not, it is about which build method to invoke
+  // and that requires knowledge of whether ODS generated a builder that need
+  // not take return types. That knowledge should be captured in one place
+  // rather than duplicated.
+  bool isResultsBroadcastableShape =
+      resultOp.getTrait("OpTrait::ResultsBroadcastableShape");
   bool usePartialResults = valuePackName != resultValue;
 
-  if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) {
+  if (isResultsBroadcastableShape || usePartialResults || depth > 0 ||
+      resultIndex < 0) {
     // For these cases (broadcastable ops, op results used both as auxiliary
     // values and replacement values, ops in nested patterns, auxiliary ops), we
     // still need to supply the result types when building the op. But because