Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
the proper position. Degenerates to an element type in the 0-D case.
- Example:
+ Examples:
+ ```
%1 = vector.extractelement %0[3]: vector<4x8x16xf32>
%2 = vector.extractelement %0[3, 3, 3]: vector<4x8x16xf32>
+ ```
}];
let extraClassDeclaration = [{
VectorType getVectorType() {
}
}];
}
+def OuterProductOp :
+ Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
+ Arguments<(ins AnyVector:$lhs, AnyVector:$rhs)>,
+ Results<(outs AnyVector)> {
+ let summary = "outerproduct operation";
+ let description = [{
+ Takes 2 1-D vectors and returns the 2-D vector containing the outer product.
+ Example:
+ ```
+ %2 = vector.extractelement %0, %1: vector<4xf32>, vector<8xf32>
+ return %2: vector<4x8xf32>
+ ```
+ }];
+ let extraClassDeclaration = [{
+ VectorType getOperandVectorTypeLHS() {
+ return lhs()->getType().cast<VectorType>();
+ }
+ VectorType getOperandVectorTypeRHS() {
+ return rhs()->getType().cast<VectorType>();
+ }
+ VectorType getVectorType() {
+ return getResult()->getType().cast<VectorType>();
+ }
+ }];
+}
#endif // VECTOR_OPS
}
return success();
}
+//===----------------------------------------------------------------------===//
+// OuterProductOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, OuterProductOp op) {
+ *p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
+ *p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
+}
+
+static ParseResult parseOuterProductOp(OpAsmParser *parser,
+ OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
+ Type t0, t1;
+ if (parser->parseOperandList(operandsInfo) || parser->parseColonType(t0) ||
+ parser->parseComma() || parser->parseType(t1))
+ return failure();
+ VectorType v0 = t0.dyn_cast<VectorType>();
+ VectorType v1 = t1.dyn_cast<VectorType>();
+ if (!v0 || !v1)
+ return parser->emitError(parser->getNameLoc(), "expected 2 vector types");
+ VectorType resType = VectorType::get({v0.getDimSize(0), v1.getDimSize(0)},
+ v0.getElementType());
+ return failure(parser->resolveOperands(operandsInfo, {t0, t1},
+ parser->getCurrentLocation(),
+ result->operands) ||
+ parser->addTypeToList(resType, result->types));
+}
+static LogicalResult verify(OuterProductOp op) {
+ VectorType v1 = op.getOperandVectorTypeLHS(),
+ v2 = op.getOperandVectorTypeRHS(), res = op.getVectorType();
+ if (v1.getRank() != 1)
+ return op.emitOpError("expected 1-d vector for operand #1");
+ if (v2.getRank() != 1)
+ return op.emitOpError("expected 1-d vector for operand #2");
+ if (res.getRank() != 2)
+ return op.emitOpError("expected 2-d vector result");
+ if (v1.getDimSize(0) != res.getDimSize(0))
+ return op.emitOpError(
+ "expected first operand dim to match first result dim");
+ if (v2.getDimSize(0) != res.getDimSize(1))
+ return op.emitOpError(
+ "expected second operand dim to match second result dim");
+ return success();
+}
//===----------------------------------------------------------------------===//
// VectorTransferReadOp
//===----------------------------------------------------------------------===//
func @position_empty(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected non-empty position attribute}}
%1 = vector.extractelement %arg0[] : vector<4x8x16xf32>
- return
}
// -----
func @position_rank_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute of rank smaller than vector}}
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32>
- return
}
// -----
func @position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}}
%1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
- return
}
// -----
func @position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}}
%1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
- return
+}
+
+// -----
+
+// CHECK-LABEL: outerproduct_non_vector_operand
+func @outerproduct_non_vector_operand(%arg0: f32) {
+ // expected-error@+1 {{expected 2 vector types}}
+ %1 = vector.outerproduct %arg0, %arg0 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: outerproduct_operand_1
+func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
+ // expected-error@+1 {{expected 1-d vector for operand #1}}
+ %1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: outerproduct_operand_2
+func @outerproduct_operand_2(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
+ // expected-error@+1 {{expected 1-d vector for operand #2}}
+ %1 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<4x8xf32>
}
// CHECK-NEXT: vector.extractelement {{.*}}[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
%3 = vector.extractelement %arg0[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: outerproduct
+func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
+ // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
+ %0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
+ return %0 : vector<4x8xf32>
+}