Extend vector.outerproduct with an optional 3rd argument
authorNicolas Vasilache <ntv@google.com>
Fri, 16 Aug 2019 10:52:56 +0000 (03:52 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 16 Aug 2019 10:53:26 +0000 (03:53 -0700)
This CL adds an optional third argument to the vector.outerproduct instruction.
When such a third argument is specified, it is added to the result of the outerproduct and  is lowered to FMA intrinsic when the lowering supports it.

In the future, we can add an attribute on the `vector.outerproduct` instruction to modify the operations for which to emit code (e.g. "+/*", "max/+", "min/+", "log/exp" ...).

This CL additionally performs minor cleanups in the vector lowering and adds tests to improve coverage.

This has been independently verified to result in proper fma instructions for haswell as follows.

Input:
```
func @outerproduct_add(%arg0: vector<17xf32>, %arg1: vector<8xf32>, %arg2: vector<17x8xf32>) -> vector<17x8xf32> {
  %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<17xf32>, vector<8xf32>
  return %2 : vector<17x8xf32>
}
}
```

Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```

Output:
```
outerproduct_add:                       # @outerproduct_add
# %bb.0:
        ...
        vmovaps 112(%rbp), %ymm8
        vbroadcastss    %xmm0, %ymm0
        ...
        vbroadcastss    64(%rbp), %ymm15
        vfmadd213ps     144(%rbp), %ymm8, %ymm0 # ymm0 = (ymm8 * ymm0) + mem
        ...
        vfmadd213ps     400(%rbp), %ymm8, %ymm9 # ymm9 = (ymm8 * ymm9) + mem
        ...
```
PiperOrigin-RevId: 263743359

mlir/include/mlir/Conversion/VectorToLLVM/VectorToLLVM.h
mlir/include/mlir/VectorOps/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
mlir/lib/VectorOps/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/Dialect/VectorOps/ops.mlir

index 39b7ee2..7334c67 100644 (file)
 #define MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
 
 namespace mlir {
+class LLVMTypeConverter;
 class ModulePassBase;
+class OwningRewritePatternList;
 
+/// Collect a set of patterns to convert from the Vector dialect to LLVM.
+void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                            OwningRewritePatternList &patterns);
+
+/// Create a pass to convert vector operations to the LLVMIR dialect.
 ModulePassBase *createLowerVectorToLLVMPass();
 } // namespace mlir
 
index 962e53b..e6f543f 100644 (file)
@@ -72,17 +72,25 @@ def ExtractElementOp :
 }
 def OuterProductOp :
   Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
-    Arguments<(ins AnyVector:$lhs, AnyVector:$rhs)>,
+    Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>,
     Results<(outs AnyVector)> {
-  let summary = "outerproduct operation";
+  let summary = "vector outerproduct with optional fused add";
   let description = [{
     Takes 2 1-D vectors and returns the 2-D vector containing the outer product.
 
-    Example:
-    ```
+    An optional extra 2-D vector argument may be specified in which case the
+    operation returns the sum of the outer product and the extra vector. When
+    lowered to the LLVMIR dialect, this form emits `llvm.fmuladd`, which can
+    lower to actual `fma` instructions in LLVM.
+
+    Examples
+
       %2 = vector.extractelement %0, %1: vector<4xf32>, vector<8xf32>
       return %2: vector<4x8xf32>
-    ```
+
+      %3 = vector.extractelement %0, %1, %2:
+        vector<4xf32>, vector<8xf32>, vector<4x8xf32>
+      return %3: vector<4x8xf32>
   }];
   let extraClassDeclaration = [{
     VectorType getOperandVectorTypeLHS() {
@@ -91,6 +99,10 @@ def OuterProductOp :
     VectorType getOperandVectorTypeRHS() {
       return rhs()->getType().cast<VectorType>();
     }
+    VectorType getOperandVectorTypeACC() {
+      return (llvm::size(acc()) == 0) ? VectorType() :
+        (*acc().begin())->getType().cast<VectorType>();
+    }
     VectorType getVectorType() {
       return getResult()->getType().cast<VectorType>();
     }
index bf90edb..1e4b8ca 100644 (file)
@@ -79,11 +79,8 @@ public:
     auto positionArrayAttr = extractOp.position();
     // One-shot extraction of vector from array (only requires extractvalue).
     if (resultType.isa<VectorType>()) {
-      Value *extracted =
-          rewriter
-              .create<LLVM::ExtractValueOp>(loc, llvmResultType,
-                                            adaptor.vector(), positionArrayAttr)
-              .getResult();
+      Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
+          loc, llvmResultType, adaptor.vector(), positionArrayAttr);
       rewriter.replaceOp(op, extracted);
       return matchSuccess();
     }
@@ -92,29 +89,24 @@ public:
     auto *context = op->getContext();
     Value *extracted = adaptor.vector();
     auto positionAttrs = positionArrayAttr.getValue();
-    auto indexType = rewriter.getIndexType();
+    auto i32Type = rewriter.getIntegerType(32);
     if (positionAttrs.size() > 1) {
       auto nDVectorType = vectorType;
       auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
                                             nDVectorType.getElementType());
       auto nMinusOnePositionAttrs =
           ArrayAttr::get(positionAttrs.drop_back(), context);
-      extracted = rewriter
-                      .create<LLVM::ExtractValueOp>(
-                          loc, lowering.convertType(oneDVectorType), extracted,
-                          nMinusOnePositionAttrs)
-                      .getResult();
+      extracted = rewriter.create<LLVM::ExtractValueOp>(
+          loc, lowering.convertType(oneDVectorType), extracted,
+          nMinusOnePositionAttrs);
     }
 
     // Remaining extraction of element from 1-D LLVM vector
     auto position = positionAttrs.back().cast<IntegerAttr>();
-    auto constant = rewriter
-                        .create<LLVM::ConstantOp>(
-                            loc, lowering.convertType(indexType), position)
-                        .getResult();
+    auto constant = rewriter.create<LLVM::ConstantOp>(
+        loc, lowering.convertType(i32Type), position);
     extracted =
-        rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant)
-            .getResult();
+        rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
     rewriter.replaceOp(op, extracted);
 
     return matchSuccess();
@@ -134,32 +126,38 @@ public:
     auto loc = op->getLoc();
     auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
     auto *ctx = op->getContext();
-    auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
-    auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
-    auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements();
-    auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements();
+    auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
+    auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
+    auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
+    auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
     auto llvmArrayOfVectType = lowering.convertType(
         cast<vector::OuterProductOp>(op).getResult()->getType());
-    Value *desc =
-        rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult();
-    for (unsigned i = 0, e = rankV1; i < e; ++i) {
-      // Emit the following pattern:
-      //   vec(a[i]) * b -> llvmStructOfVectType[i]
-      Value *a = adaptor.lhs(), *b = adaptor.rhs();
-      // shufflevector explicitly requires i32 /
-      auto attr = rewriter.getI32IntegerAttr(i);
-      SmallVector<Attribute, 4> broadcastAttr(rankV2, attr);
-      auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx);
-      auto *broadcasted =
-          rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr)
-              .getResult();
-      auto *multiplied =
-          rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult();
-      desc = rewriter
-                 .create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc,
-                                              multiplied,
-                                              positionAttr(rewriter, i))
-                 .getResult();
+    Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
+    Value *a = adaptor.lhs(), *b = adaptor.rhs();
+    Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
+    SmallVector<Value *, 8> lhs, accs;
+    lhs.reserve(rankLHS);
+    accs.reserve(rankLHS);
+    for (unsigned d = 0, e = rankLHS; d < e; ++d) {
+      // shufflevector explicitly requires i32.
+      auto attr = rewriter.getI32IntegerAttr(d);
+      SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
+      auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
+      Value *aD = nullptr, *accD = nullptr;
+      // 1. Broadcast the element a[d] into vector aD.
+      aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
+      // 2. If acc is present, extract 1-d vector acc[d] into accD.
+      if (acc)
+        accD = rewriter.create<LLVM::ExtractValueOp>(loc, vRHS, acc,
+                                                     positionAttr(rewriter, d));
+      // 3. Compute aD outer b (plus accD, if relevant).
+      Value *aOuterbD =
+          accD ? rewriter.create<LLVM::fmuladd>(loc, vRHS, aD, b, accD)
+                     .getResult()
+               : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
+      // 4. Insert as value `d` in the descriptor.
+      desc = rewriter.create<LLVM::InsertValueOp>(
+          loc, llvmArrayOfVectType, desc, aOuterbD, positionAttr(rewriter, d));
     }
     rewriter.replaceOp(op, desc);
     return matchSuccess();
@@ -167,12 +165,10 @@ public:
 };
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
-static void
-populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                       OwningRewritePatternList &patterns,
-                                       MLIRContext *ctx) {
+void mlir::populateVectorToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
-      ctx, converter);
+      converter.getDialect()->getContext(), converter);
 }
 
 namespace {
@@ -185,7 +181,7 @@ void LowerVectorToLLVMPass::runOnModule() {
   // Convert to the LLVM IR dialect using the converter defined above.
   OwningRewritePatternList patterns;
   LLVMTypeConverter converter(&getContext());
-  populateVectorToLLVMConversionPatterns(converter, patterns, &getContext());
+  populateVectorToLLVMConversionPatterns(converter, patterns);
   populateStdToLLVMConversionPatterns(converter, patterns);
 
   ConversionTarget target(getContext());
index 38267af..0bd552e 100644 (file)
@@ -116,45 +116,54 @@ static LogicalResult verify(ExtractElementOp op) {
 
 static void print(OpAsmPrinter *p, OuterProductOp op) {
   *p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
+  if (llvm::size(op.acc()) > 0)
+    *p << ", " << **op.acc().begin();
   *p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
 }
 
 static ParseResult parseOuterProductOp(OpAsmParser *parser,
                                        OperationState *result) {
-  SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
-  Type t0, t1;
-  if (parser->parseOperandList(operandsInfo) || parser->parseColonType(t0) ||
-      parser->parseComma() || parser->parseType(t1))
+  SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
+  Type tLHS, tRHS;
+  if (parser->parseOperandList(operandsInfo) || parser->parseColonType(tLHS) ||
+      parser->parseComma() || parser->parseType(tRHS))
     return failure();
-  VectorType v0 = t0.dyn_cast<VectorType>();
-  VectorType v1 = t1.dyn_cast<VectorType>();
-  if (!v0 || !v1)
+  if (operandsInfo.size() < 2)
+    return parser->emitError(parser->getNameLoc(),
+                             "expected at least 2 operands");
+  VectorType vLHS = tLHS.dyn_cast<VectorType>();
+  VectorType vRHS = tRHS.dyn_cast<VectorType>();
+  if (!vLHS || !vRHS)
     return parser->emitError(parser->getNameLoc(), "expected 2 vector types");
-  VectorType resType = VectorType::get({v0.getDimSize(0), v1.getDimSize(0)},
-                                       v0.getElementType());
-  return failure(parser->resolveOperands(operandsInfo, {t0, t1},
-                                         parser->getCurrentLocation(),
-                                         result->operands) ||
-                 parser->addTypeToList(resType, result->types));
+  VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
+                                       vLHS.getElementType());
+  return failure(
+      parser->resolveOperand(operandsInfo[0], tLHS, result->operands) ||
+      parser->resolveOperand(operandsInfo[1], tRHS, result->operands) ||
+      (operandsInfo.size() > 2 &&
+       parser->resolveOperand(operandsInfo[2], resType, result->operands)) ||
+      parser->addTypeToList(resType, result->types));
 }
 
 static LogicalResult verify(OuterProductOp op) {
-  VectorType v1 = op.getOperandVectorTypeLHS(),
-             v2 = op.getOperandVectorTypeRHS(), res = op.getVectorType();
-  if (v1.getRank() != 1)
+  VectorType vLHS = op.getOperandVectorTypeLHS(),
+             vRHS = op.getOperandVectorTypeRHS(),
+             vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
+  if (vLHS.getRank() != 1)
     return op.emitOpError("expected 1-d vector for operand #1");
-  if (v2.getRank() != 1)
+  if (vRHS.getRank() != 1)
     return op.emitOpError("expected 1-d vector for operand #2");
-  if (res.getRank() != 2)
+  if (vRES.getRank() != 2)
     return op.emitOpError("expected 2-d vector result");
-  if (v1.getDimSize(0) != res.getDimSize(0))
-    return op.emitOpError(
-        "expected first operand dim to match first result dim");
-  if (v2.getDimSize(0) != res.getDimSize(1))
-    return op.emitOpError(
-        "expected second operand dim to match second result dim");
+  if (vLHS.getDimSize(0) != vRES.getDimSize(0))
+    return op.emitOpError("expected #1 operand dim to match result dim #1");
+  if (vRHS.getDimSize(0) != vRES.getDimSize(1))
+    return op.emitOpError("expected #2 operand dim to match result dim #2");
+  if (vACC && vACC != vRES)
+    return op.emitOpError("expected operand #3 of same type as result type");
   return success();
 }
+
 //===----------------------------------------------------------------------===//
 // VectorTransferReadOp
 //===----------------------------------------------------------------------===//
index f582de1..532a4c2 100644 (file)
@@ -1,33 +1,49 @@
 // RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s
 
-func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
-  %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
-  %3 = vector.extractelement %2[0 : i32]: vector<4x8xf32>
-  return %3 : vector<8xf32>
+func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> {
+  %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
+  return %2 : vector<2x3xf32>
 }
-// CHECK-LABEL: vec_1d
-//       CHECK:   llvm.undef : !llvm<"[4 x <8 x float>]">
-//     CHECK-5:   llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-//       CHECK:   llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
-//       CHECK:   llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
-//       CHECK:   llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x <8 x float>]">
-//       CHECK:   llvm.return {{.*}} : !llvm<"<8 x float>">
+//    CHECK-LABEL: outerproduct
+//          CHECK:   llvm.undef : !llvm<"[2 x <3 x float>]">
+//          CHECK:   llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+//          CHECK:   llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>">
+//          CHECK:   llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
+//          CHECK:   llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+//          CHECK:   llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>">
+//          CHECK:   llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
+//          CHECK:   llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
 
-func @vec_2d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
-  %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
-  return %2 : vector<4x8xf32>
+func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+  %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32>
+  return %2 : vector<2x3xf32>
 }
-// CHECK-LABEL: vec_2d
-//       CHECK:   llvm.undef : !llvm<"[4 x <8 x float>]">
-//     CHECK-4:   llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
-//       CHECK:   llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
-//       CHECK:   llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
-//       CHECK:   llvm.return {{.*}} : !llvm<"[4 x <8 x float>]">
+//    CHECK-LABEL: outerproduct_add
+//          CHECK:   llvm.undef : !llvm<"[2 x <3 x float>]">
+//          CHECK:   llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+//          CHECK:   llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
+//          CHECK:   "llvm.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+//          CHECK:   llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
+//          CHECK:   llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
+//          CHECK:   llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
+//          CHECK:   "llvm.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+//          CHECK:   llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
+//          CHECK:   llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
 
-func @vec_3d(%arg0: vector<4x8x16xf32>) -> vector<8x16xf32> {
-  %0 = vector.extractelement %arg0[0 : i32]: vector<4x8x16xf32>
-  return %0 : vector<8x16xf32>
+func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
+  %0 = vector.extractelement %arg0[0 : i32]: vector<4x3x16xf32>
+  return %0 : vector<3x16xf32>
 }
-// CHECK-LABEL: vec_3d
-//       CHECK:   llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
-//       CHECK:   llvm.return %{{.*}} : !llvm<"[8 x <16 x float>]">
\ No newline at end of file
+// CHECK-LABEL: extract_vec_2d_from_vec_3d
+//       CHECK:   llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
+//       CHECK:   llvm.return %{{.*}} : !llvm<"[3 x <16 x float>]">
+
+func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
+  %0 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32]: vector<4x3x16xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: extract_element_from_vec_3d
+//       CHECK:   llvm.extractvalue %{{.*}}[0 : i32, 0 : i32] : !llvm<"[4 x [3 x <16 x float>]]">
+//       CHECK:   llvm.constant(0 : i32) : !llvm.i32
+//       CHECK:   llvm.extractelement %{{.*}}, %{{.*}} : !llvm<"<16 x float>">
+//       CHECK:   llvm.return %{{.*}} : !llvm.float
\ No newline at end of file
index 7917f14..ca339e7 100644 (file)
@@ -2,39 +2,54 @@
 
 // -----
 
-// CHECK-LABEL: position_empty
-func @position_empty(%arg0: vector<4x8x16xf32>) {
+func @extract_element_vector_type(%arg0: index) {
+  // expected-error@+1 {{expected vector type}}
+  %1 = vector.extractelement %arg0[] : index
+}
+
+// -----
+
+func @extractelement_position_empty(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{expected non-empty position attribute}}
   %1 = vector.extractelement %arg0[] : vector<4x8x16xf32>
 }
 
 // -----
 
-// CHECK-LABEL: position_rank_overflow
-func @position_rank_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{expected position attribute of rank smaller than vector}}
   %1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32>
 }
 
 // -----
 
-// CHECK-LABEL: position_overflow
-func @position_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{expected position attribute of rank smaller than vector}}
+  %1 = "vector.extractelement" (%arg0) { position = [0 : i32, 0 : i32, 0 : i32, 0 : i32] } : (vector<4x8x16xf32>) -> (vector<16xf32>)
+}
+
+// -----
+
+func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}}
   %1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
 }
 
 // -----
 
-// CHECK-LABEL: position_underflow
-func @position_overflow(%arg0: vector<4x8x16xf32>) {
+func @extractelement_position_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}}
   %1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
 }
 
 // -----
 
-// CHECK-LABEL: outerproduct_non_vector_operand
+func @outerproduct_num_operands(%arg0: f32) {
+  // expected-error@+1 {{expected at least 2 operands}}
+  %1 = vector.outerproduct %arg0 : f32, f32
+}
+// -----
+
 func @outerproduct_non_vector_operand(%arg0: f32) {
   // expected-error@+1 {{expected 2 vector types}}
   %1 = vector.outerproduct %arg0, %arg0 : f32, f32
@@ -42,7 +57,6 @@ func @outerproduct_non_vector_operand(%arg0: f32) {
 
 // -----
 
-// CHECK-LABEL: outerproduct_operand_1
 func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
   // expected-error@+1 {{expected 1-d vector for operand #1}}
   %1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32>
@@ -50,8 +64,35 @@ func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
 
 // -----
 
-// CHECK-LABEL: outerproduct_operand_2
 func @outerproduct_operand_2(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
   // expected-error@+1 {{expected 1-d vector for operand #2}}
   %1 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<4x8xf32>
 }
+
+// -----
+
+func @outerproduct_result_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+  // expected-error@+1 {{expected 2-d vector result}}
+  %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<8xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_1_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+  // expected-error@+1 {{expected #1 operand dim to match result dim #1}}
+  %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<8x16xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_2_dim_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>) {
+  // expected-error@+1 {{expected #2 operand dim to match result dim #2}}
+  %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, vector<8xf32>) -> (vector<4x16xf32>)
+}
+
+// -----
+
+func @outerproduct_operand_3_result_type_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x16xf32>) {
+  // expected-error@+1 {{expected operand #3 of same type as result type}}
+  %1 = "vector.outerproduct" (%arg0, %arg1, %arg2) : (vector<4xf32>, vector<8xf32>, vector<4x16xf32>) -> (vector<4x8xf32>)
+}
index a072b5c..067345a 100644 (file)
@@ -12,8 +12,10 @@ func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16x
 }
 
 // CHECK-LABEL: outerproduct
-func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
+func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
   //     CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
   %0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
-  return %0 : vector<4x8xf32>
+  //     CHECK: vector.outerproduct {{.*}}, {{.*}}, {{.*}} : vector<4xf32>, vector<8xf32>
+  %1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32>
+  return %1 : vector<4x8xf32>
 }