Add a higher-order vector.extractelement operation in MLIR
authorNicolas Vasilache <ntv@google.com>
Fri, 9 Aug 2019 12:58:19 +0000 (05:58 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 9 Aug 2019 12:58:47 +0000 (05:58 -0700)
This CL is step 2/n towards building a simple, programmable and portable vector abstraction in MLIR that can go all the way down to generating assembly vector code via LLVM's opt and llc tools.

This CL adds the vector.extractelement operation to the MLIR vector dialect as well as the appropriate roundtrip test. Lowering to LLVM will occur in the following CL.

PiperOrigin-RevId: 262545089

15 files changed:
mlir/include/mlir/CMakeLists.txt
mlir/include/mlir/EDSC/Intrinsics.h
mlir/include/mlir/VectorOps/CMakeLists.txt [new file with mode: 0644]
mlir/include/mlir/VectorOps/VectorOps.h
mlir/include/mlir/VectorOps/VectorOps.td [new file with mode: 0644]
mlir/lib/Analysis/LoopAnalysis.cpp
mlir/lib/Analysis/VectorAnalysis.cpp
mlir/lib/Transforms/LowerVectorTransfers.cpp
mlir/lib/Transforms/MaterializeVectors.cpp
mlir/lib/Transforms/Vectorize.cpp
mlir/lib/VectorOps/CMakeLists.txt
mlir/lib/VectorOps/DialectRegistration.cpp
mlir/lib/VectorOps/VectorOps.cpp
mlir/test/Dialect/VectorOps/invalid.mlir [new file with mode: 0644]
mlir/test/Dialect/VectorOps/ops.mlir [new file with mode: 0644]

index 55843c0..202b40b 100644 (file)
@@ -4,3 +4,4 @@ add_subdirectory(EDSC)
 add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)
 add_subdirectory(StandardOps)
+add_subdirectory(VectorOps)
index 6870e02..98e9cea 100644 (file)
@@ -214,7 +214,7 @@ using select = ValueBuilder<SelectOp>;
 using std_load = ValueBuilder<LoadOp>;
 using std_store = OperationBuilder<StoreOp>;
 using subi = ValueBuilder<SubIOp>;
-using vector_type_cast = ValueBuilder<VectorTypeCastOp>;
+using vector_type_cast = ValueBuilder<vector::VectorTypeCastOp>;
 
 /// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`.
 ///
diff --git a/mlir/include/mlir/VectorOps/CMakeLists.txt b/mlir/include/mlir/VectorOps/CMakeLists.txt
new file mode 100644 (file)
index 0000000..6cc7e44
--- /dev/null
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS VectorOps.td)
+mlir_tablegen(VectorOps.h.inc -gen-op-decls)
+mlir_tablegen(VectorOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRVectorOpsIncGen)
index 434cda1..47cd8a1 100644 (file)
 #include "mlir/IR/StandardTypes.h"
 
 namespace mlir {
+namespace vector {
 
 /// Dialect for super-vectorization Ops.
 class VectorOpsDialect : public Dialect {
 public:
   VectorOpsDialect(MLIRContext *context);
+  static StringRef getDialectNamespace() { return "vector"; }
 };
 
 /// VectorTransferReadOp performs a blocking read from a scalar memref
@@ -201,6 +203,10 @@ public:
   LogicalResult verify();
 };
 
+#define GET_OP_CLASSES
+#include "mlir/VectorOps/VectorOps.h.inc"
+
+} // end namespace vector
 } // end namespace mlir
 
 #endif // MLIR_VECTOROPS_VECTOROPS_H
diff --git a/mlir/include/mlir/VectorOps/VectorOps.td b/mlir/include/mlir/VectorOps/VectorOps.td
new file mode 100644 (file)
index 0000000..ba7ee92
--- /dev/null
@@ -0,0 +1,72 @@
+//===- VectorOps.td - Vector op definitions ---------------*- tablegen -*-====//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines MLIR vector operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef VECTOR_OPS
+#else
+#define VECTOR_OPS
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def Vector_Dialect : Dialect {
+  let name = "vector";
+  let cppNamespace = "vector";
+}
+
+// Base class for Vector dialect ops.
+class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<Vector_Dialect, mnemonic, traits> {
+  // For every vector op, there needs to be a:
+  //   * void print(OpAsmPrinter *p, ${C++ class of Op} op)
+  //   * LogicalResult verify(${C++ class of Op} op)
+  //   * ParseResult parse${C++ class of Op}(OpAsmParser *parser,
+  //                                         OperationState *result)
+  // functions.
+  let printer = [{ return ::print(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
+def ExtractElementOp :
+  Vector_Op<"extractelement", [NoSideEffect,
+     PredOpTrait<"operand and result have same element type",
+                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+    Arguments<(ins AnyVector:$vector, I32ArrayAttr:$position)>,
+    Results<(outs AnyType)> {
+  let summary = "extractelement operation";
+  let description = [{
+    Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
+    the proper position. Degenerates to an element type in the 0-D case.
+
+    Example:
+      %1 = vector.extractelement %0[3]: vector<4x8x16xf32>
+      %2 = vector.extractelement %0[3, 3, 3]: vector<4x8x16xf32>
+  }];
+  let extraClassDeclaration = [{
+    VectorType getVectorType() {
+      return vector()->getType().cast<VectorType>();
+    }
+  }];
+}
+
+#endif // VECTOR_OPS
index 0b487ba..743907b 100644 (file)
@@ -289,7 +289,8 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
 }
 
 static bool isVectorTransferReadOrWrite(Operation &op) {
-  return isa<VectorTransferReadOp>(op) || isa<VectorTransferWriteOp>(op);
+  return isa<vector::VectorTransferReadOp>(op) ||
+         isa<vector::VectorTransferWriteOp>(op);
 }
 
 using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>;
index 7bb28e9..2306156 100644 (file)
@@ -194,10 +194,10 @@ bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op,
   bool mustDivide = false;
   (void)mustDivide;
   VectorType superVectorType;
-  if (auto read = dyn_cast<VectorTransferReadOp>(op)) {
+  if (auto read = dyn_cast<vector::VectorTransferReadOp>(op)) {
     superVectorType = read.getResultType();
     mustDivide = true;
-  } else if (auto write = dyn_cast<VectorTransferWriteOp>(op)) {
+  } else if (auto write = dyn_cast<vector::VectorTransferWriteOp>(op)) {
     superVectorType = write.getVectorType();
     mustDivide = true;
   } else if (op.getNumResults() == 0) {
index cda62d9..ded0732 100644 (file)
@@ -84,6 +84,8 @@
 /// ```
 
 using namespace mlir;
+using vector::VectorTransferReadOp;
+using vector::VectorTransferWriteOp;
 
 #define DEBUG_TYPE "affine-lower-vector-transfers"
 
@@ -362,8 +364,9 @@ struct LowerVectorTransfersPass
   void runOnFunction() {
     OwningRewritePatternList patterns;
     auto *context = &getContext();
-    patterns.insert<VectorTransferRewriter<VectorTransferReadOp>,
-                    VectorTransferRewriter<VectorTransferWriteOp>>(context);
+    patterns.insert<VectorTransferRewriter<vector::VectorTransferReadOp>,
+                    VectorTransferRewriter<vector::VectorTransferWriteOp>>(
+        context);
     applyPatternsGreedily(getFunction(), std::move(patterns));
   }
 };
index d345801..17acc92 100644 (file)
@@ -146,6 +146,8 @@ using llvm::dbgs;
 using llvm::SetVector;
 
 using namespace mlir;
+using vector::VectorTransferReadOp;
+using vector::VectorTransferWriteOp;
 
 using functional::makePtrDynCaster;
 using functional::map;
index 9470ca5..ce25406 100644 (file)
@@ -829,7 +829,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv,
       return LogicalResult::Failure;
     LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
     LLVM_DEBUG(permutationMap.print(dbgs()));
-    auto transfer = b.create<VectorTransferReadOp>(
+    auto transfer = b.create<vector::VectorTransferReadOp>(
         opInst->getLoc(), vectorType, memoryOp.getMemRef(),
         map(makePtrDynCaster<Value>(), indices), permutationMap);
     state->registerReplacement(opInst, transfer.getOperation());
@@ -1027,9 +1027,9 @@ static Operation *vectorizeOneOperation(Operation *opInst,
   // Sanity checks.
   assert(!isa<AffineLoadOp>(opInst) &&
          "all loads must have already been fully vectorized independently");
-  assert(!isa<VectorTransferReadOp>(opInst) &&
+  assert(!isa<vector::VectorTransferReadOp>(opInst) &&
          "vector.transfer_read cannot be further vectorized");
-  assert(!isa<VectorTransferWriteOp>(opInst) &&
+  assert(!isa<vector::VectorTransferWriteOp>(opInst) &&
          "vector.transfer_write cannot be further vectorized");
 
   if (auto store = dyn_cast<AffineStoreOp>(opInst)) {
@@ -1055,7 +1055,7 @@ static Operation *vectorizeOneOperation(Operation *opInst,
       return nullptr;
     LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
     LLVM_DEBUG(permutationMap.print(dbgs()));
-    auto transfer = b.create<VectorTransferWriteOp>(
+    auto transfer = b.create<vector::VectorTransferWriteOp>(
         opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
     auto *res = transfer.getOperation();
     LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
index 6c0ff68..0e76501 100644 (file)
@@ -5,3 +5,7 @@ add_llvm_library(MLIRVectorOps
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/VectorOps
   )
+
+add_dependencies(MLIRVectorOps MLIRVectorOpsIncGen)
+
+target_link_libraries(MLIRVectorOps MLIRIR)
index 94132ff..aedba31 100644 (file)
@@ -19,4 +19,4 @@
 using namespace mlir;
 
 // Static initialization for VectorOps dialect registration.
-static DialectRegistration<VectorOpsDialect> VectorOps;
+static DialectRegistration<vector::VectorOpsDialect> VectorOps;
index 580dd66..9de4d93 100644 (file)
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LLVM.h"
+
 using namespace mlir;
+using namespace mlir::vector;
 
 //===----------------------------------------------------------------------===//
 // VectorOpsDialect
 //===----------------------------------------------------------------------===//
 
-VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
-    : Dialect("vector", context) {
+mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
+    : Dialect(getDialectNamespace(), context) {
   addOperations<VectorTransferReadOp, VectorTransferWriteOp,
                 VectorTypeCastOp>();
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/VectorOps/VectorOps.cpp.inc"
+      >();
+}
+
+//===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, ExtractElementOp op) {
+  *p << op.getOperationName() << " " << *op.vector() << op.position();
+  p->printOptionalAttrDict(op.getAttrs(), {"position"});
+  *p << " : " << op.vector()->getType();
+}
+
+static ParseResult parseExtractElementOp(OpAsmParser *parser,
+                                         OperationState *result) {
+  llvm::SMLoc attributeLoc, typeLoc;
+  SmallVector<NamedAttribute, 4> attrs;
+  OpAsmParser::OperandType vector;
+  Type type;
+  Attribute attr;
+  if (parser->parseOperand(vector) ||
+      parser->getCurrentLocation(&attributeLoc) ||
+      parser->parseAttribute(attr, "position", attrs) ||
+      parser->parseOptionalAttributeDict(attrs) ||
+      parser->getCurrentLocation(&typeLoc) || parser->parseColonType(type))
+    return failure();
+
+  auto vectorType = type.dyn_cast<VectorType>();
+  if (!vectorType)
+    return parser->emitError(typeLoc, "expected vector type");
+
+  auto positionAttr = attr.dyn_cast<ArrayAttr>();
+  if (!positionAttr ||
+      static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
+    return parser->emitError(
+        attributeLoc,
+        "expected position attribute of rank smaller than vector");
+
+  Type resType =
+      (static_cast<int64_t>(positionAttr.size()) == vectorType.getRank())
+          ? vectorType.getElementType()
+          : VectorType::get(
+                vectorType.getShape().drop_front(positionAttr.size()),
+                vectorType.getElementType());
+
+  result->attributes = attrs;
+  return failure(parser->resolveOperand(vector, type, result->operands) ||
+                 parser->addTypeToList(resType, result->types));
+}
+
+static LogicalResult verify(ExtractElementOp op) {
+  auto positionAttr = op.position().getValue();
+  if (positionAttr.empty())
+    return op.emitOpError("expected non-empty position attribute");
+  if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
+    return op.emitOpError(
+        "expected position attribute of rank smaller than vector");
+  for (auto en : llvm::enumerate(positionAttr)) {
+    auto attr = en.value().dyn_cast<IntegerAttr>();
+    if (!attr || attr.getInt() < 0 ||
+        attr.getInt() > op.getVectorType().getDimSize(en.index()))
+      return op.emitOpError("expected position attribute #")
+             << (en.index() + 1)
+             << " to be a positive integer smaller than the corresponding "
+                "vector dimension";
+  }
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -147,7 +220,8 @@ ParseResult VectorTransferReadOp::parse(OpAsmParser *parser,
     return parser->emitError(parser->getNameLoc(), "vector type expected");
 
   // Extract optional paddingValue.
-  // At this point, indexInfo may contain the optional paddingValue, pop it out.
+  // At this point, indexInfo may contain the optional paddingValue, pop it
+  // out.
   if (static_cast<int64_t>(indexInfo.size()) != memrefType.getRank())
     return parser->emitError(parser->getNameLoc(),
                              "expected " + Twine(memrefType.getRank()) +
@@ -419,3 +493,10 @@ LogicalResult VectorTypeCastOp::verify() {
 
   return success();
 }
+
+namespace mlir {
+
+#define GET_OP_CLASSES
+#include "mlir/VectorOps/VectorOps.cpp.inc"
+
+} // namespace mlir
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
new file mode 100644 (file)
index 0000000..49fcefc
--- /dev/null
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+// CHECK-LABEL: position_empty
+func @position_empty(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{expected non-empty position attribute}}
+  %1 = vector.extractelement %arg0[] : vector<4x8x16xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: position_rank_overflow
+func @position_rank_overflow(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{expected position attribute of rank smaller than vector}}
+  %1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: position_overflow
+func @position_overflow(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}}
+  %1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: position_underflow
+func @position_overflow(%arg0: vector<4x8x16xf32>) {
+  // expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}}
+  %1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
+  return
+}
diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
new file mode 100644 (file)
index 0000000..11928ad
--- /dev/null
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: extractelement
+func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
+  //      CHECK: vector.extractelement {{.*}}[3 : i32] : vector<4x8x16xf32>
+  %1 = vector.extractelement %arg0[3 : i32] : vector<4x8x16xf32>
+  // CHECK-NEXT: vector.extractelement {{.*}}[3 : i32, 3 : i32] : vector<4x8x16xf32>
+  %2 = vector.extractelement %arg0[3 : i32, 3 : i32] : vector<4x8x16xf32>
+  // CHECK-NEXT: vector.extractelement {{.*}}[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
+  %3 = vector.extractelement %arg0[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
+  return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
+}
\ No newline at end of file