}];
}
+def Vector_BroadcastOp :
+ Vector_Op<"broadcast", [NoSideEffect,
+ PredOpTrait<"source operand and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 0>>,
+ PredOpTrait<"dest operand and result have same type",
+ TCresIsSameAsOpBase<0, 1>>]>,
+ Arguments<(ins AnyType:$source, AnyVector:$dest)>,
+ Results<(outs AnyVector:$vector)> {
+ let summary = "broadcast operation";
+ let description = [{
+ Broadcasts the scalar or k-D vector value in the source to the n-D
+ destination vector of a proper shape such that the broadcast makes sense.
+
+ Examples:
+ ```
+ %0 = constant 0.0 : f32
+ %1 = vector.broadcast %0, %x : f32 into vector<16xf32>
+ %2 = vector.broadcast %1, %y : vector<16xf32> into vector<4x16xf32>
+ ```
+ }];
+ let extraClassDeclaration = [{
+ Type getSourceType() { return source()->getType(); }
+ VectorType getDestVectorType() {
+ return dest()->getType().cast<VectorType>();
+ }
+ }];
+}
+
def Vector_ExtractElementOp :
Vector_Op<"extractelement", [NoSideEffect,
PredOpTrait<"operand and result have same element type",
}
//===----------------------------------------------------------------------===//
+// BroadcastOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, BroadcastOp op) {
+ p << op.getOperationName() << " " << *op.source() << ", " << *op.dest();
+ p << " : " << op.getSourceType();
+ p << " into " << op.getDestVectorType();
+}
+
+static LogicalResult verify(BroadcastOp op) {
+ VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
+ VectorType dstVectorType = op.getDestVectorType();
+ // Scalar to vector broadcast is always valid. A vector
+ // to vector broadcast needs some additional checking.
+ if (srcVectorType) {
+ const int64_t srcRank = srcVectorType.getRank();
+ const int64_t dstRank = dstVectorType.getRank();
+ // TODO(ajcbik): implement proper rank testing for broadcast;
+ // this is just a temporary placeholder check.
+ if (srcRank > dstRank) {
+ return op.emitOpError("source rank higher than destination rank");
+ }
+ }
+ return success();
+}
+
+static ParseResult parseBroadcastOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType source, dest;
+ Type sourceType;
+ VectorType destType;
+ return failure(parser.parseOperand(source) || parser.parseComma() ||
+ parser.parseOperand(dest) ||
+ parser.parseColonType(sourceType) ||
+ parser.parseKeywordType("into", destType) ||
+ parser.resolveOperand(source, sourceType, result.operands) ||
+ parser.resolveOperand(dest, destType, result.operands) ||
+ parser.addTypeToList(destType, result.types));
+}
+
+//===----------------------------------------------------------------------===//
// InsertElementOp
//===----------------------------------------------------------------------===//
// -----
+func @broadcast_rank_too_high(%arg0: vector<4x4xf32>, %arg1: vector<4xf32>) {
+ // expected-error@+1 {{source rank higher than destination rank}}
+ %2 = vector.broadcast %arg0, %arg1 : vector<4x4xf32> into vector<4xf32>
+}
+
+// -----
+
func @extract_element_vector_type(%arg0: index) {
// expected-error@+1 {{expected vector type}}
%1 = vector.extractelement %arg0[] : index
return
}
+// CHECK-LABEL: @vector_broadcast
+func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>) {
+ // CHECK: vector.broadcast %{{.*}}, %{{.*}} : f32 into vector<16xf32>
+ %0 = vector.broadcast %a, %b : f32 into vector<16xf32>
+ // CHECK-NEXT: vector.broadcast %{{.*}}, %{{.*}} : vector<16xf32> into vector<8x16xf32>
+ %1 = vector.broadcast %b, %c : vector<16xf32> into vector<8x16xf32>
+ return
+}
+
// CHECK-LABEL: @extractelement
func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
// CHECK: vector.extractelement {{.*}}[3 : i32] : vector<4x8x16xf32>