[mlir][Shape] Generalize cstr_broadcastable folding for n-ary broadcasts
authorBenjamin Kramer <benny.kra@googlemail.com>
Tue, 16 Feb 2021 18:08:34 +0000 (19:08 +0100)
committerBenjamin Kramer <benny.kra@googlemail.com>
Wed, 17 Feb 2021 10:44:52 +0000 (11:44 +0100)
This is still fairly tricky code, but I tried to untangle it a bit.

Differential Revision: https://reviews.llvm.org/D96800

mlir/include/mlir/Dialect/Traits.h
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Traits.cpp
mlir/test/Dialect/Shape/canonicalize.mlir

index aecceaa..c51cadf 100644 (file)
@@ -47,7 +47,7 @@ namespace util {
 bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
                          SmallVectorImpl<int64_t> &resultShape);
 
-/// Returns true if a broadcast between the 2 shapes is guaranteed to be
+/// Returns true if a broadcast between n shapes is guaranteed to be
 /// successful and not result in an error. False does not guarantee that the
 /// shapes are not broadcastable; it might guarantee that they are not
 /// broadcastable or it might mean that this function does not have enough
@@ -59,6 +59,7 @@ bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
 /// dimension, while this function will return false because it's possible for
 /// both shapes to have a dimension greater than 1 and different which would
 /// fail to broadcast.
+bool staticallyKnownBroadcastable(ArrayRef<SmallVector<int64_t, 6>> shapes);
 bool staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
                                   ArrayRef<int64_t> shape2);
 
index 058c0c5..b1199fb 100644 (file)
@@ -490,38 +490,48 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
   patterns.insert<CstrBroadcastableEqOps>(context);
 }
 
-OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
-  // TODO: Add folding for the nary case
-  if (operands.size() != 2)
-    return nullptr;
+// Return true if there is exactly one attribute not representing a scalar
+// broadcast.
+static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
+  bool nonScalarSeen = false;
+  for (Attribute a : attributes) {
+    if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
+      if (nonScalarSeen)
+        return false;
+      nonScalarSeen = true;
+    }
+  }
+  return true;
+}
 
-  // Both operands are not needed if one is a scalar.
-  if (operands[0] &&
-      operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
-    return BoolAttr::get(getContext(), true);
-  if (operands[1] &&
-      operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
+OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+  // No broadcasting is needed if all operands but one are scalar.
+  if (hasAtMostSingleNonScalar(operands))
     return BoolAttr::get(getContext(), true);
 
-  if (operands[0] && operands[1]) {
-    auto lhsShape = llvm::to_vector<6>(
-        operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
-    auto rhsShape = llvm::to_vector<6>(
-        operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
-    SmallVector<int64_t, 6> resultShape;
-    if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
-      return BoolAttr::get(getContext(), true);
-  }
+  if ([&] {
+        SmallVector<SmallVector<int64_t, 6>, 6> extents;
+        for (const auto &operand : operands) {
+          if (!operand)
+            return false;
+          extents.push_back(llvm::to_vector<6>(
+              operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
+        }
+        return OpTrait::util::staticallyKnownBroadcastable(extents);
+      }())
+    return BoolAttr::get(getContext(), true);
 
   // Lastly, see if folding can be completed based on what constraints are known
   // on the input shapes.
-  SmallVector<int64_t, 6> lhsShape, rhsShape;
-  if (failed(getShapeVec(shapes()[0], lhsShape)))
-    return nullptr;
-  if (failed(getShapeVec(shapes()[1], rhsShape)))
-    return nullptr;
-
-  if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
+  if ([&] {
+        SmallVector<SmallVector<int64_t, 6>, 6> extents;
+        for (const auto &shape : shapes()) {
+          extents.emplace_back();
+          if (failed(getShapeVec(shape, extents.back())))
+            return false;
+        }
+        return OpTrait::util::staticallyKnownBroadcastable(extents);
+      }())
     return BoolAttr::get(getContext(), true);
 
   // Because a failing witness result here represents an eventual assertion
index b7d1bc8..50f2036 100644 (file)
@@ -15,19 +15,45 @@ using namespace mlir;
 
 bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
                                                  ArrayRef<int64_t> shape2) {
-  // Two dimensions are compatible when
-  //   1. they are defined and equal, or
-  //   2. one of them is 1
-  return llvm::all_of(llvm::zip(llvm::reverse(shape1), llvm::reverse(shape2)),
-                      [](auto dimensions) {
-                        auto dim1 = std::get<0>(dimensions);
-                        auto dim2 = std::get<1>(dimensions);
-                        if (dim1 == 1 || dim2 == 1)
-                          return true;
-                        if (dim1 == dim2 && !ShapedType::isDynamic(dim1))
-                          return true;
-                        return false;
-                      });
+  SmallVector<SmallVector<int64_t, 6>, 2> extents;
+  extents.emplace_back(shape1.begin(), shape1.end());
+  extents.emplace_back(shape2.begin(), shape2.end());
+  return staticallyKnownBroadcastable(extents);
+}
+
+bool OpTrait::util::staticallyKnownBroadcastable(
+    ArrayRef<SmallVector<int64_t, 6>> shapes) {
+  assert(!shapes.empty() && "Expected at least one shape");
+  size_t maxRank = shapes[0].size();
+  for (size_t i = 1; i != shapes.size(); ++i)
+    maxRank = std::max(maxRank, shapes[i].size());
+
+  // We look backwards through every column of `shapes`.
+  for (size_t i = 0; i != maxRank; ++i) {
+    bool seenDynamic = false;
+    Optional<int64_t> nonOneDim;
+    for (ArrayRef<int64_t> extent : shapes) {
+      int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
+
+      if (dim == 1)
+        continue;
+
+      // Dimensions are compatible when
+      //.  1. One is dynamic, the rest are 1
+      if (ShapedType::isDynamic(dim)) {
+        if (seenDynamic || nonOneDim)
+          return false;
+        seenDynamic = true;
+      }
+
+      //   2. All are 1 or a specific constant.
+      if (nonOneDim && dim != *nonOneDim)
+        return false;
+
+      nonOneDim = dim;
+    }
+  }
+  return true;
 }
 
 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
index ba7e479..3a04c16 100644 (file)
@@ -601,6 +601,92 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
 }
 
 // -----
+// Fold ternary broadcastable
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [8, 1] : !shape.shape
+  %cs1 = shape.const_shape [1, 8] : !shape.shape
+  %cs2 = shape.const_shape [1, 1] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// Fold ternary broadcastable with dynamic ranks
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [8, 1] : !shape.shape
+  %cs1 = shape.const_shape [1, -1] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs0, %cs1 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// One scalar and one non-scalar and one unknown cannot be broadcasted at compile time
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK: shape.cstr_broadcastable
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [8, 1] : !shape.shape
+  %cs1 = shape.const_shape [1, 8] : !shape.shape
+  %cs2 = shape.const_shape [1, -1] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// One scalar and two unknowns cannot be broadcasted at compile time
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK: shape.cstr_broadcastable
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [8, 1] : !shape.shape
+  %cs1 = shape.const_shape [1, -1] : !shape.shape
+  %cs2 = shape.const_shape [1, -1] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// Broadcastable with scalars and a non-scalar can be constant folded
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs0, %arg0 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+// One scalar and one non-scalar and one unknown cannot be folded.
+// CHECK-LABEL: func @f
+func @f(%arg0 : !shape.shape) {
+  // CHECK: shape.cstr_broadcastable
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %cs0 = shape.const_shape [] : !shape.shape
+  %cs1 = shape.const_shape [2] : !shape.shape
+  %0 = shape.cstr_broadcastable %cs0, %cs1, %arg0 : !shape.shape, !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
 
 // Fold `rank` based on constant shape.
 // CHECK-LABEL: @fold_rank