[VectorOps] Refine BroadcastOp in VectorOps dialect
authorAart Bik <ajcbik@google.com>
Wed, 27 Nov 2019 03:52:02 +0000 (19:52 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 27 Nov 2019 03:52:38 +0000 (19:52 -0800)
Since second argument is always fully overwritten and
shape is define in "to" clause, it is not needed.
Also renamed "into" to "to" now that arg is dropped.

PiperOrigin-RevId: 282686475

mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/Dialect/VectorOps/ops.mlir

index 34c2fa97e5369c477c6f29ab3486bb66812bba87..c78334dd54acf64617d57ba2868dd270af824f72 100644 (file)
@@ -165,27 +165,25 @@ def Vector_ContractionOp :
 def Vector_BroadcastOp :
   Vector_Op<"broadcast", [NoSideEffect,
      PredOpTrait<"source operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>,
-     PredOpTrait<"dest operand and result have same type",
-                 TCresIsSameAsOpBase<0, 1>>]>,
-    Arguments<(ins AnyType:$source, AnyVector:$dest)>,
+                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+    Arguments<(ins AnyType:$source)>,
     Results<(outs AnyVector:$vector)> {
   let summary = "broadcast operation";
   let description = [{
-    Broadcasts the scalar or k-D vector value in the source to the n-D
-    destination vector of a proper shape such that the broadcast makes sense.
+    Broadcasts the scalar or k-D vector value in the source operand
+    to a n-D result vector such that the broadcast makes sense.
 
     Examples:
     ```
       %0 = constant 0.0 : f32
-      %1 = vector.broadcast %0, %x : f32 into vector<16xf32>
-      %2 = vector.broadcast %1, %y : vector<16xf32> into vector<4x16xf32>
+      %1 = vector.broadcast %0 : f32 to vector<16xf32>
+      %2 = vector.broadcast %1 : vector<16xf32> to vector<4x16xf32>
     ```
   }];
   let extraClassDeclaration = [{
     Type getSourceType() { return source()->getType(); }
-    VectorType getDestVectorType() {
-      return dest()->getType().cast<VectorType>();
+    VectorType getVectorType() {
+      return vector()->getType().cast<VectorType>();
     }
   }];
 }
index d09fd0fc2f2d3673a52e0759a18ea264d9bc513c..fe320b91439b7b1feb4a90eb09ddd5464038671e 100644 (file)
@@ -373,14 +373,14 @@ static LogicalResult verify(ExtractElementOp op) {
 //===----------------------------------------------------------------------===//
 
 static void print(OpAsmPrinter &p, BroadcastOp op) {
-  p << op.getOperationName() << " " << *op.source() << ", " << *op.dest();
+  p << op.getOperationName() << " " << *op.source();
   p << " : " << op.getSourceType();
-  p << " into " << op.getDestVectorType();
+  p << " to " << op.getVectorType();
 }
 
 static LogicalResult verify(BroadcastOp op) {
   VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
-  VectorType dstVectorType = op.getDestVectorType();
+  VectorType dstVectorType = op.getVectorType();
   // Scalar to vector broadcast is always valid. A vector
   // to vector broadcast needs some additional checking.
   if (srcVectorType) {
@@ -397,16 +397,14 @@ static LogicalResult verify(BroadcastOp op) {
 
 static ParseResult parseBroadcastOp(OpAsmParser &parser,
                                     OperationState &result) {
-  OpAsmParser::OperandType source, dest;
+  OpAsmParser::OperandType source;
   Type sourceType;
-  VectorType destType;
-  return failure(parser.parseOperand(source) || parser.parseComma() ||
-                 parser.parseOperand(dest) ||
+  VectorType vectorType;
+  return failure(parser.parseOperand(source) ||
                  parser.parseColonType(sourceType) ||
-                 parser.parseKeywordType("into", destType) ||
+                 parser.parseKeywordType("to", vectorType) ||
                  parser.resolveOperand(source, sourceType, result.operands) ||
-                 parser.resolveOperand(dest, destType, result.operands) ||
-                 parser.addTypeToList(destType, result.types));
+                 parser.addTypeToList(vectorType, result.types));
 }
 
 //===----------------------------------------------------------------------===//
index 92e956ef29a13950cbbf75c27030aac6e4a27599..d672b1bf140634697b77cc50258f279f3120ccf5 100644 (file)
@@ -2,9 +2,9 @@
 
 // -----
 
-func @broadcast_rank_too_high(%arg0: vector<4x4xf32>, %arg1: vector<4xf32>) {
+func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) {
   // expected-error@+1 {{source rank higher than destination rank}}
-  %2 = vector.broadcast %arg0, %arg1 : vector<4x4xf32> into vector<4xf32>
+  %1 = vector.broadcast %arg0 : vector<4x4xf32> to vector<4xf32>
 }
 
 // -----
index 51dbc4f043592ba5e17c4cc99aabf57edaca2fcf..d167559ac0c86ca3c1106ac20f878816b67db9f9 100644 (file)
@@ -23,12 +23,12 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
 }
 
 // CHECK-LABEL: @vector_broadcast
-func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>) {
-  //      CHECK: vector.broadcast %{{.*}}, %{{.*}} : f32 into vector<16xf32>
-  %0 = vector.broadcast %a, %b : f32 into vector<16xf32>
-  //      CHECK-NEXT: vector.broadcast %{{.*}}, %{{.*}} : vector<16xf32> into vector<8x16xf32>
-  %1 = vector.broadcast %b, %c : vector<16xf32> into vector<8x16xf32>
-  return
+func @vector_broadcast(%a: f32, %b: vector<16xf32>) -> vector<8x16xf32> {
+  //      CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
+  %0 = vector.broadcast %a : f32 to vector<16xf32>
+  //      CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
+  %1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32>
+  return %1 : vector<8x16xf32>
 }
 
 // CHECK-LABEL: @extractelement