[mlir][Vector] Add a Broadcast::createBroadcastOp helper
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 30 Nov 2022 12:20:18 +0000 (04:20 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 30 Nov 2022 13:32:14 +0000 (05:32 -0800)
This helper handles non trivial cases of broadcast + optional transpose creation
that should not leak to the outside world.

Differential Revision: https://reviews.llvm.org/D139003

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/test-create-broadcast.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

index 0158257..ff7b79b 100644 (file)
@@ -449,6 +449,23 @@ def Vector_BroadcastOp :
     /// Return the dimensions of the result vector that were formerly ones in the
     /// source tensor and thus correspond to "dim-1" broadcasting.
     llvm::SetVector<int64_t> computeBroadcastedUnitDims();
+
+    /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the 
+    /// `broadcastedDims` dimensions in the dstShape are broadcasted.
+    /// This requires (and asserts) that the broadcast is free of dim-1 
+    /// broadcasting.
+    /// Since vector.broadcast only allows expanding leading dimensions, an extra
+    /// vector.transpose may be inserted to make the broadcast possible.
+    /// `value`, `dstShape` and `broadcastedDims` must be properly specified or 
+    /// the helper will assert. This means:
+    ///   1. `dstShape` must not be empty.
+    ///   2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
+    ///   2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
+    //       must match the `value` shape.
+    static Value createOrFoldBroadcastOp(
+      OpBuilder &b, Value value,
+      ArrayRef<int64_t> dstShape,
+      const llvm::SetVector<int64_t> &broadcastedDims);
   }];
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
   let hasFolder = 1;
index c4af2d8..b36206c 100644 (file)
@@ -1725,13 +1725,9 @@ Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
 
 /// Return the dimensions of the result vector that were formerly ones in the
 /// source tensor and thus correspond to "dim-1" broadcasting.
-llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
-  VectorType srcVectorType = getSourceType().dyn_cast<VectorType>();
-  // Scalar broadcast is without any unit dim broadcast.
-  if (!srcVectorType)
-    return {};
-  ArrayRef<int64_t> srcShape = srcVectorType.getShape();
-  ArrayRef<int64_t> dstShape = getVectorType().getShape();
+static llvm::SetVector<int64_t>
+computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
+                           ArrayRef<int64_t> dstShape) {
   int64_t rankDiff = dstShape.size() - srcShape.size();
   int64_t dstDim = rankDiff;
   llvm::SetVector<int64_t> res;
@@ -1745,6 +1741,129 @@ llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
   return res;
 }
 
+llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
+  // Scalar broadcast is without any unit dim broadcast.
+  auto srcVectorType = getSourceType().dyn_cast<VectorType>();
+  if (!srcVectorType)
+    return {};
+  return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
+                                      getVectorType().getShape());
+}
+
+static bool allBitsSet(llvm::SmallBitVector &bv, int64_t lb, int64_t ub) {
+  for (int64_t i = lb; i < ub; ++i)
+    if (!bv.test(i))
+      return false;
+  return true;
+}
+
+/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
+/// `broadcastedDims` dimensions in the dstShape are broadcasted.
+/// This requires (and asserts) that the broadcast is free of dim-1
+/// broadcasting.
+/// Since vector.broadcast only allows expanding leading dimensions, an extra
+/// vector.transpose may be inserted to make the broadcast possible.
+/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
+/// the helper will assert. This means:
+///   1. `dstShape` must not be empty.
+///   2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
+///   2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
+//       must match the `value` shape.
+Value BroadcastOp::createOrFoldBroadcastOp(
+    OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
+    const llvm::SetVector<int64_t> &broadcastedDims) {
+  assert(!dstShape.empty() && "unexpected empty dst shape");
+
+  // Well-formedness check.
+  SmallVector<int64_t> checkShape;
+  for (int i = 0, e = dstShape.size(); i < e; ++i) {
+    if (broadcastedDims.contains(i))
+      continue;
+    checkShape.push_back(dstShape[i]);
+  }
+  assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
+         "ill-formed broadcastedDims contains values not confined to "
+         "destVectorShape");
+
+  Location loc = value.getLoc();
+  Type elementType = getElementTypeOrSelf(value.getType());
+  VectorType srcVectorType = value.getType().dyn_cast<VectorType>();
+  VectorType dstVectorType = VectorType::get(dstShape, elementType);
+
+  // Step 2. If scalar -> dstShape broadcast, just do it.
+  if (!srcVectorType) {
+    assert(checkShape.empty() &&
+           "ill-formed createOrFoldBroadcastOp arguments");
+    return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
+  }
+
+  assert(srcVectorType.getShape().equals(checkShape) &&
+         "ill-formed createOrFoldBroadcastOp arguments");
+
+  // Step 3. Since vector.broadcast only allows creating leading dims,
+  //   vector -> dstShape broadcast may require a transpose.
+  // Traverse the dims in order and construct:
+  //   1. The leading entries of the broadcastShape that is guaranteed to be
+  //      achievable by a simple broadcast.
+  //   2. The induced permutation for the subsequent vector.transpose that will
+  //      bring us from `broadcastShape` back to he desired `dstShape`.
+  // If the induced permutation is not the identity, create a vector.transpose.
+  SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
+  broadcastShape.reserve(dstShape.size());
+  // Consider the example:
+  //   srcShape     = 2x4
+  //   dstShape     = 1x2x3x4x5
+  //   broadcastedDims = [0, 2, 4]
+  //
+  // We want to build:
+  //   broadcastShape  = 1x3x5x2x4
+  //   permutation     = [0, 2, 4,                 1, 3]
+  //                      ---V---           -----V-----
+  //            leading broadcast part      src shape part
+  //
+  // Note that the trailing dims of broadcastShape are exactly the srcShape
+  // by construction.
+  // nextSrcShapeDim is used to keep track of where in the permutation the
+  // "src shape part" occurs.
+  int64_t nextSrcShapeDim = broadcastedDims.size();
+  for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
+    if (broadcastedDims.contains(i)) {
+      // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
+      // bring it to the head of the broadcastShape.
+      // It will need to be permuted back from `broadcastShape.size() - 1` into
+      // position `i`.
+      broadcastShape.push_back(dstShape[i]);
+      permutation[i] = broadcastShape.size() - 1;
+    } else {
+      // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
+      // shape and needs to be permuted into position `i`.
+      // Don't touch `broadcastShape` here, the whole srcShape will be
+      // appended after.
+      permutation[i] = nextSrcShapeDim++;
+    }
+  }
+  // 3.c. Append the srcShape.
+  llvm::append_range(broadcastShape, srcVectorType.getShape());
+
+  // Ensure there are no dim-1 broadcasts.
+  assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
+             .empty() &&
+         "unexpected dim-1 broadcast");
+
+  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
+  assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
+             vector::BroadcastableToResult::Success &&
+         "must be broadcastable");
+  Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
+  // Step 4. If we find any dimension that indeed needs to be permuted,
+  // immediately return a new vector.transpose.
+  for (int64_t i = 0, e = permutation.size(); i < e; ++i)
+    if (permutation[i] != i)
+      return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
+  // Otherwise return res.
+  return res;
+}
+
 BroadcastableToResult
 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
                                 std::pair<int, int> *mismatchingDims) {
diff --git a/mlir/test/Dialect/Vector/test-create-broadcast.mlir b/mlir/test/Dialect/Vector/test-create-broadcast.mlir
new file mode 100644 (file)
index 0000000..f7af184
--- /dev/null
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s --test-create-vector-broadcast --allow-unregistered-dialect --split-input-file | FileCheck %s
+
+func.func @foo(%a : f32) -> vector<1x2xf32> {
+  %0 = "test_create_broadcast"(%a) {broadcast_dims = array<i64: 0, 1>} : (f32) -> vector<1x2xf32>
+  // CHECK: vector.broadcast {{.*}} : f32 to vector<1x2xf32>
+  // CHECK-NOT: vector.transpose
+  return %0:  vector<1x2xf32>
+}
+
+// -----
+
+func.func @foo(%a : vector<2x2xf32>) -> vector<2x2x3xf32> {
+  %0 = "test_create_broadcast"(%a) {broadcast_dims = array<i64: 2>} 
+    : (vector<2x2xf32>) -> vector<2x2x3xf32>
+  // CHECK: vector.broadcast {{.*}} : vector<2x2xf32> to vector<3x2x2xf32>
+  // CHECK: vector.transpose {{.*}}, [1, 2, 0] : vector<3x2x2xf32> to vector<2x2x3xf32>
+  return %0: vector<2x2x3xf32>
+}
+
+// -----
+
+func.func @foo(%a : vector<3x3xf32>) -> vector<4x3x3xf32> {
+  %0 = "test_create_broadcast"(%a) {broadcast_dims = array<i64: 0>} 
+    : (vector<3x3xf32>) -> vector<4x3x3xf32>
+  // CHECK: vector.broadcast {{.*}} : vector<3x3xf32> to vector<4x3x3xf32>
+  // CHECK-NOT: vector.transpose
+  return %0: vector<4x3x3xf32>
+}
+
+// -----
+
+func.func @foo(%a : vector<2x4xf32>) -> vector<1x2x3x4x5xf32> {
+  %0 = "test_create_broadcast"(%a) {broadcast_dims = array<i64: 0, 2, 4>} 
+    : (vector<2x4xf32>) -> vector<1x2x3x4x5xf32>
+  // CHECK: vector.broadcast {{.*}} : vector<2x4xf32> to vector<1x3x5x2x4xf32>
+  // CHECK: vector.transpose {{.*}}, [0, 3, 1, 4, 2] : vector<1x3x5x2x4xf32> to vector<1x2x3x4x5xf32>
+  return %0: vector<1x2x3x4x5xf32>
+}
index 1bd40e7..00bd07a 100644 (file)
@@ -820,6 +820,38 @@ struct TestVectorExtractStridedSliceLowering
   }
 };
 
+struct TestCreateVectorBroadcast
+    : public PassWrapper<TestCreateVectorBroadcast,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast)
+
+  StringRef getArgument() const final { return "test-create-vector-broadcast"; }
+  StringRef getDescription() const final {
+    return "Test optimization transformations for transfer ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    getOperation()->walk([](Operation *op) {
+      if (op->getName().getStringRef() != "test_create_broadcast")
+        return;
+      auto targetShape =
+          op->getResult(0).getType().cast<VectorType>().getShape();
+      auto arrayAttr =
+          op->getAttr("broadcast_dims").cast<DenseI64ArrayAttr>().asArrayRef();
+      llvm::SetVector<int64_t> broadcastedDims;
+      broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end());
+      OpBuilder b(op);
+      Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp(
+          b, op->getOperand(0), targetShape, broadcastedDims);
+      op->getResult(0).replaceAllUsesWith(bcast);
+      op->erase();
+    });
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -856,6 +888,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorDistribution>();
 
   PassRegistration<TestVectorExtractStridedSliceLowering>();
+
+  PassRegistration<TestCreateVectorBroadcast>();
 }
 } // namespace test
 } // namespace mlir