Add a higher-order vector.outerproduct operation in MLIR
authorNicolas Vasilache <ntv@google.com>
Fri, 9 Aug 2019 13:55:10 +0000 (06:55 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 9 Aug 2019 13:55:36 +0000 (06:55 -0700)
This CL is step 2/n towards building a simple, programmable and portable vector abstraction in MLIR that can go all the way down to generating assembly vector code via LLVM's opt and llc tools.

This CL adds the vector.outerproduct operation to the MLIR vector dialect as well as the appropriate roundtrip test. Lowering to LLVM will occur in the following CL.

PiperOrigin-RevId: 262552027

mlir/include/mlir/VectorOps/VectorOps.td
mlir/lib/VectorOps/VectorOps.cpp
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/Dialect/VectorOps/ops.mlir

index ba7ee92..962e53b 100644 (file)
@@ -58,9 +58,11 @@ def ExtractElementOp :
     Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
     the proper position. Degenerates to an element type in the 0-D case.
 
-    Example:
+    Examples:
+    ```
       %1 = vector.extractelement %0[3]: vector<4x8x16xf32>
       %2 = vector.extractelement %0[3, 3, 3]: vector<4x8x16xf32>
+    ```
   }];
   let extraClassDeclaration = [{
     VectorType getVectorType() {
@@ -68,5 +70,30 @@ def ExtractElementOp :
     }
   }];
 }
+def OuterProductOp :
+  Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
+    Arguments<(ins AnyVector:$lhs, AnyVector:$rhs)>,
+    Results<(outs AnyVector)> {
+  let summary = "outerproduct operation";
+  let description = [{
+    Takes 2 1-D vectors and returns the 2-D vector containing the outer product.
 
+    Example:
+    ```
+      %2 = vector.extractelement %0, %1: vector<4xf32>, vector<8xf32>
+      return %2: vector<4x8xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    VectorType getOperandVectorTypeLHS() {
+      return lhs()->getType().cast<VectorType>();
+    }
+    VectorType getOperandVectorTypeRHS() {
+      return rhs()->getType().cast<VectorType>();
+    }
+    VectorType getVectorType() {
+      return getResult()->getType().cast<VectorType>();
+    }
+  }];
+}
 #endif // VECTOR_OPS
index 9de4d93..38267af 100644 (file)
@@ -110,7 +110,51 @@ static LogicalResult verify(ExtractElementOp op) {
   }
   return success();
 }
+//===----------------------------------------------------------------------===//
+// OuterProductOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, OuterProductOp op) {
+  *p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
+  *p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
+}
+
+static ParseResult parseOuterProductOp(OpAsmParser *parser,
+                                       OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
+  Type t0, t1;
+  if (parser->parseOperandList(operandsInfo) || parser->parseColonType(t0) ||
+      parser->parseComma() || parser->parseType(t1))
+    return failure();
+  VectorType v0 = t0.dyn_cast<VectorType>();
+  VectorType v1 = t1.dyn_cast<VectorType>();
+  if (!v0 || !v1)
+    return parser->emitError(parser->getNameLoc(), "expected 2 vector types");
+  VectorType resType = VectorType::get({v0.getDimSize(0), v1.getDimSize(0)},
+                                       v0.getElementType());
+  return failure(parser->resolveOperands(operandsInfo, {t0, t1},
+                                         parser->getCurrentLocation(),
+                                         result->operands) ||
+                 parser->addTypeToList(resType, result->types));
+}
 
+static LogicalResult verify(OuterProductOp op) {
+  VectorType v1 = op.getOperandVectorTypeLHS(),
+             v2 = op.getOperandVectorTypeRHS(), res = op.getVectorType();
+  if (v1.getRank() != 1)
+    return op.emitOpError("expected 1-d vector for operand #1");
+  if (v2.getRank() != 1)
+    return op.emitOpError("expected 1-d vector for operand #2");
+  if (res.getRank() != 2)
+    return op.emitOpError("expected 2-d vector result");
+  if (v1.getDimSize(0) != res.getDimSize(0))
+    return op.emitOpError(
+        "expected first operand dim to match first result dim");
+  if (v2.getDimSize(0) != res.getDimSize(1))
+    return op.emitOpError(
+        "expected second operand dim to match second result dim");
+  return success();
+}
 //===----------------------------------------------------------------------===//
 // VectorTransferReadOp
 //===----------------------------------------------------------------------===//
index 49fcefc..7917f14 100644 (file)
@@ -6,7 +6,6 @@
 func @position_empty(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{expected non-empty position attribute}}
   %1 = vector.extractelement %arg0[] : vector<4x8x16xf32>
-  return
 }
 
 // -----
@@ -15,7 +14,6 @@ func @position_empty(%arg0: vector<4x8x16xf32>) {
 func @position_rank_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{expected position attribute of rank smaller than vector}}
   %1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32>
-  return
 }
 
 // -----
@@ -24,7 +22,6 @@ func @position_rank_overflow(%arg0: vector<4x8x16xf32>) {
 func @position_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}}
   %1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
-  return
 }
 
 // -----
@@ -33,5 +30,28 @@ func @position_overflow(%arg0: vector<4x8x16xf32>) {
 func @position_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}}
   %1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
-  return
+}
+
+// -----
+
+// CHECK-LABEL: outerproduct_non_vector_operand
+func @outerproduct_non_vector_operand(%arg0: f32) {
+  // expected-error@+1 {{expected 2 vector types}}
+  %1 = vector.outerproduct %arg0, %arg0 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: outerproduct_operand_1
+func @outerproduct_operand_1(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
+  // expected-error@+1 {{expected 1-d vector for operand #1}}
+  %1 = vector.outerproduct %arg1, %arg1 : vector<4x8xf32>, vector<4x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: outerproduct_operand_2
+func @outerproduct_operand_2(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>) {
+  // expected-error@+1 {{expected 1-d vector for operand #2}}
+  %1 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<4x8xf32>
 }
index 11928ad..a072b5c 100644 (file)
@@ -9,4 +9,11 @@ func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16x
   // CHECK-NEXT: vector.extractelement {{.*}}[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
   %3 = vector.extractelement %arg0[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
   return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: outerproduct
+func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
+  //     CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
+  %0 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
+  return %0 : vector<4x8xf32>
+}