From 3126004a5a8bef0ac079869626b322c2fdbbd655 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 2 Dec 2019 09:56:58 -0800 Subject: [PATCH] [VectorOps] Add legality rules to broadcast PiperOrigin-RevId: 283360101 --- .../mlir/Dialect/VectorOps/VectorOps.td | 19 ++++++++++++++++++- mlir/lib/Dialect/VectorOps/VectorOps.cpp | 13 ++++++++++--- mlir/test/Dialect/VectorOps/invalid.mlir | 14 ++++++++++++++ mlir/test/Dialect/VectorOps/ops.mlir | 8 ++++++-- 4 files changed, 48 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index c78334dd54ac..c75f9fe02312 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -171,7 +171,24 @@ def Vector_BroadcastOp : let summary = "broadcast operation"; let description = [{ Broadcasts the scalar or k-D vector value in the source operand - to a n-D result vector such that the broadcast makes sense. + to a n-D result vector such that the broadcast makes sense, i.e., + the source operand is duplicated to match the given rank and sizes + in the result vector. The legality rules are: + * the source operand must have the same element type as the result type + * a k-D vector can be broadcast to + a n-D vector if + * k <= n, and + * the sizes in the trailing dimensions n-k < i <= n with j=i+k-n + match exactly as s_j = t_i or s_j = 1: + ``` + t_1 x .. t_n-k x t_n-k+1 x .. x t_i x .. x t_n + s_1 x .. x s_j x .. x s_k + + ``` + The source operand is duplicated over all the missing leading dimensions + and streched over the trailing dimensions where the source has a non-equal + dimension of 1. These rules imply that any scalar broadcast (k=0) to any + shaped vector with the same element type is always legal. Examples: ``` diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index fe320b91439b..6086531e3c70 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -386,10 +386,17 @@ static LogicalResult verify(BroadcastOp op) { if (srcVectorType) { const int64_t srcRank = srcVectorType.getRank(); const int64_t dstRank = dstVectorType.getRank(); - // TODO(ajcbik): implement proper rank testing for broadcast; - // this is just a temporary placeholder check. - if (srcRank > dstRank) { + if (srcRank > dstRank) return op.emitOpError("source rank higher than destination rank"); + // Source has an exact match or singleton value for all trailing dimensions + // (all leading dimensions are simply duplicated). + const int64_t lead = dstRank - srcRank; + for (int64_t i = 0; i < srcRank; i++) { + const int64_t srcDim = srcVectorType.getDimSize(i); + const int64_t dstDim = dstVectorType.getDimSize(lead + i); + if (srcDim != 1 && srcDim != dstDim) + return op.emitOpError("dimension mismatch (") + << srcDim << " vs. " << dstDim << ")"; } } return success(); diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index d672b1bf1406..0fbcb56f3882 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -9,6 +9,20 @@ func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) { // ----- +func @broadcast_dim1_mismatch(%arg0: vector<7xf32>) { + // expected-error@+1 {{vector.broadcast' op dimension mismatch (7 vs. 3)}} + %1 = vector.broadcast %arg0 : vector<7xf32> to vector<3xf32> +} + +// ----- + +func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) { + // expected-error@+1 {{vector.broadcast' op dimension mismatch (4 vs. 1)}} + %1 = vector.broadcast %arg0 : vector<4x8xf32> to vector<1x8xf32> +} + +// ----- + func @extract_element_vector_type(%arg0: index) { // expected-error@+1 {{expected vector type}} %1 = vector.extractelement %arg0[] : index diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index d167559ac0c8..3824dfe20e44 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -23,12 +23,16 @@ func @vector_transfer_ops(%arg0: memref) { } // CHECK-LABEL: @vector_broadcast -func @vector_broadcast(%a: f32, %b: vector<16xf32>) -> vector<8x16xf32> { +func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> { // CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32> %0 = vector.broadcast %a : f32 to vector<16xf32> // CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32> %1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32> - return %1 : vector<8x16xf32> + // CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32> + %2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32> + // CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32> + %3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32> + return %3 : vector<8x16xf32> } // CHECK-LABEL: @extractelement -- 2.34.1