[VectorOps] Add a BroadcastOp to the VectorOps dialect
authorAart Bik <ajcbik@google.com>
Tue, 26 Nov 2019 22:43:03 +0000 (14:43 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 26 Nov 2019 22:43:31 +0000 (14:43 -0800)
PiperOrigin-RevId: 282643305

mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/Dialect/VectorOps/ops.mlir

index a887a3e..34c2fa9 100644 (file)
@@ -162,6 +162,34 @@ def Vector_ContractionOp :
   }];
 }
 
+def Vector_BroadcastOp :
+  Vector_Op<"broadcast", [NoSideEffect,
+     PredOpTrait<"source operand and result have same element type",
+                 TCresVTEtIsSameAsOpBase<0, 0>>,
+     PredOpTrait<"dest operand and result have same type",
+                 TCresIsSameAsOpBase<0, 1>>]>,
+    Arguments<(ins AnyType:$source, AnyVector:$dest)>,
+    Results<(outs AnyVector:$vector)> {
+  let summary = "broadcast operation";
+  let description = [{
+    Broadcasts the scalar or k-D vector value in the source to the n-D
+    destination vector of a proper shape such that the broadcast makes sense.
+
+    Examples:
+    ```
+      %0 = constant 0.0 : f32
+      %1 = vector.broadcast %0, %x : f32 into vector<16xf32>
+      %2 = vector.broadcast %1, %y : vector<16xf32> into vector<4x16xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    Type getSourceType() { return source()->getType(); }
+    VectorType getDestVectorType() {
+      return dest()->getType().cast<VectorType>();
+    }
+  }];
+}
+
 def Vector_ExtractElementOp :
   Vector_Op<"extractelement", [NoSideEffect,
      PredOpTrait<"operand and result have same element type",
index b73b771..d09fd0f 100644 (file)
@@ -369,6 +369,47 @@ static LogicalResult verify(ExtractElementOp op) {
 }
 
 //===----------------------------------------------------------------------===//
+// BroadcastOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, BroadcastOp op) {
+  p << op.getOperationName() << " " << *op.source() << ", " << *op.dest();
+  p << " : " << op.getSourceType();
+  p << " into " << op.getDestVectorType();
+}
+
+static LogicalResult verify(BroadcastOp op) {
+  VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
+  VectorType dstVectorType = op.getDestVectorType();
+  // Scalar to vector broadcast is always valid. A vector
+  // to vector broadcast needs some additional checking.
+  if (srcVectorType) {
+    const int64_t srcRank = srcVectorType.getRank();
+    const int64_t dstRank = dstVectorType.getRank();
+    // TODO(ajcbik): implement proper rank testing for broadcast;
+    // this is just a temporary placeholder check.
+    if (srcRank > dstRank) {
+      return op.emitOpError("source rank higher than destination rank");
+    }
+  }
+  return success();
+}
+
+static ParseResult parseBroadcastOp(OpAsmParser &parser,
+                                    OperationState &result) {
+  OpAsmParser::OperandType source, dest;
+  Type sourceType;
+  VectorType destType;
+  return failure(parser.parseOperand(source) || parser.parseComma() ||
+                 parser.parseOperand(dest) ||
+                 parser.parseColonType(sourceType) ||
+                 parser.parseKeywordType("into", destType) ||
+                 parser.resolveOperand(source, sourceType, result.operands) ||
+                 parser.resolveOperand(dest, destType, result.operands) ||
+                 parser.addTypeToList(destType, result.types));
+}
+
+//===----------------------------------------------------------------------===//
 // InsertElementOp
 //===----------------------------------------------------------------------===//
 
index 60d5774..92e956e 100644 (file)
@@ -2,6 +2,13 @@
 
 // -----
 
+func @broadcast_rank_too_high(%arg0: vector<4x4xf32>, %arg1: vector<4xf32>) {
+  // expected-error@+1 {{source rank higher than destination rank}}
+  %2 = vector.broadcast %arg0, %arg1 : vector<4x4xf32> into vector<4xf32>
+}
+
+// -----
+
 func @extract_element_vector_type(%arg0: index) {
   // expected-error@+1 {{expected vector type}}
   %1 = vector.extractelement %arg0[] : index
index a2b1ac3..51dbc4f 100644 (file)
@@ -22,6 +22,15 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
   return
 }
 
+// CHECK-LABEL: @vector_broadcast
+func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>) {
+  //      CHECK: vector.broadcast %{{.*}}, %{{.*}} : f32 into vector<16xf32>
+  %0 = vector.broadcast %a, %b : f32 into vector<16xf32>
+  //      CHECK-NEXT: vector.broadcast %{{.*}}, %{{.*}} : vector<16xf32> into vector<8x16xf32>
+  %1 = vector.broadcast %b, %c : vector<16xf32> into vector<8x16xf32>
+  return
+}
+
 // CHECK-LABEL: @extractelement
 func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
   //      CHECK: vector.extractelement {{.*}}[3 : i32] : vector<4x8x16xf32>