add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(StandardOps)
+add_subdirectory(VectorOps)
using std_load = ValueBuilder<LoadOp>;
using std_store = OperationBuilder<StoreOp>;
using subi = ValueBuilder<SubIOp>;
-using vector_type_cast = ValueBuilder<VectorTypeCastOp>;
+using vector_type_cast = ValueBuilder<vector::VectorTypeCastOp>;
/// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`.
///
--- /dev/null
+set(LLVM_TARGET_DEFINITIONS VectorOps.td)
+mlir_tablegen(VectorOps.h.inc -gen-op-decls)
+mlir_tablegen(VectorOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRVectorOpsIncGen)
#include "mlir/IR/StandardTypes.h"
namespace mlir {
+namespace vector {
/// Dialect for super-vectorization Ops.
class VectorOpsDialect : public Dialect {
public:
VectorOpsDialect(MLIRContext *context);
+ static StringRef getDialectNamespace() { return "vector"; }
};
/// VectorTransferReadOp performs a blocking read from a scalar memref
LogicalResult verify();
};
+#define GET_OP_CLASSES
+#include "mlir/VectorOps/VectorOps.h.inc"
+
+} // end namespace vector
} // end namespace mlir
#endif // MLIR_VECTOROPS_VECTOROPS_H
--- /dev/null
+//===- VectorOps.td - Vector op definitions ---------------*- tablegen -*-====//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines MLIR vector operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef VECTOR_OPS
+#else
+#define VECTOR_OPS
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def Vector_Dialect : Dialect {
+ let name = "vector";
+ let cppNamespace = "vector";
+}
+
+// Base class for Vector dialect ops.
+class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<Vector_Dialect, mnemonic, traits> {
+ // For every vector op, there needs to be a:
+ // * void print(OpAsmPrinter *p, ${C++ class of Op} op)
+ // * LogicalResult verify(${C++ class of Op} op)
+ // * ParseResult parse${C++ class of Op}(OpAsmParser *parser,
+ // OperationState *result)
+ // functions.
+ let printer = [{ return ::print(p, *this); }];
+ let verifier = [{ return ::verify(*this); }];
+ let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
+def ExtractElementOp :
+ Vector_Op<"extractelement", [NoSideEffect,
+ PredOpTrait<"operand and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 0>>]>,
+ Arguments<(ins AnyVector:$vector, I32ArrayAttr:$position)>,
+ Results<(outs AnyType)> {
+ let summary = "extractelement operation";
+ let description = [{
+ Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
+ the proper position. Degenerates to an element type in the 0-D case.
+
+ Example:
+ %1 = vector.extractelement %0[3]: vector<4x8x16xf32>
+ %2 = vector.extractelement %0[3, 3, 3]: vector<4x8x16xf32>
+ }];
+ let extraClassDeclaration = [{
+ VectorType getVectorType() {
+ return vector()->getType().cast<VectorType>();
+ }
+ }];
+}
+
+#endif // VECTOR_OPS
}
static bool isVectorTransferReadOrWrite(Operation &op) {
- return isa<VectorTransferReadOp>(op) || isa<VectorTransferWriteOp>(op);
+ return isa<vector::VectorTransferReadOp>(op) ||
+ isa<vector::VectorTransferWriteOp>(op);
}
using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>;
bool mustDivide = false;
(void)mustDivide;
VectorType superVectorType;
- if (auto read = dyn_cast<VectorTransferReadOp>(op)) {
+ if (auto read = dyn_cast<vector::VectorTransferReadOp>(op)) {
superVectorType = read.getResultType();
mustDivide = true;
- } else if (auto write = dyn_cast<VectorTransferWriteOp>(op)) {
+ } else if (auto write = dyn_cast<vector::VectorTransferWriteOp>(op)) {
superVectorType = write.getVectorType();
mustDivide = true;
} else if (op.getNumResults() == 0) {
/// ```
using namespace mlir;
+using vector::VectorTransferReadOp;
+using vector::VectorTransferWriteOp;
#define DEBUG_TYPE "affine-lower-vector-transfers"
void runOnFunction() {
OwningRewritePatternList patterns;
auto *context = &getContext();
- patterns.insert<VectorTransferRewriter<VectorTransferReadOp>,
- VectorTransferRewriter<VectorTransferWriteOp>>(context);
+ patterns.insert<VectorTransferRewriter<vector::VectorTransferReadOp>,
+ VectorTransferRewriter<vector::VectorTransferWriteOp>>(
+ context);
applyPatternsGreedily(getFunction(), std::move(patterns));
}
};
using llvm::SetVector;
using namespace mlir;
+using vector::VectorTransferReadOp;
+using vector::VectorTransferWriteOp;
using functional::makePtrDynCaster;
using functional::map;
return LogicalResult::Failure;
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
- auto transfer = b.create<VectorTransferReadOp>(
+ auto transfer = b.create<vector::VectorTransferReadOp>(
opInst->getLoc(), vectorType, memoryOp.getMemRef(),
map(makePtrDynCaster<Value>(), indices), permutationMap);
state->registerReplacement(opInst, transfer.getOperation());
// Sanity checks.
assert(!isa<AffineLoadOp>(opInst) &&
"all loads must have already been fully vectorized independently");
- assert(!isa<VectorTransferReadOp>(opInst) &&
+ assert(!isa<vector::VectorTransferReadOp>(opInst) &&
"vector.transfer_read cannot be further vectorized");
- assert(!isa<VectorTransferWriteOp>(opInst) &&
+ assert(!isa<vector::VectorTransferWriteOp>(opInst) &&
"vector.transfer_write cannot be further vectorized");
if (auto store = dyn_cast<AffineStoreOp>(opInst)) {
return nullptr;
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
- auto transfer = b.create<VectorTransferWriteOp>(
+ auto transfer = b.create<vector::VectorTransferWriteOp>(
opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
auto *res = transfer.getOperation();
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/VectorOps
)
+
+add_dependencies(MLIRVectorOps MLIRVectorOpsIncGen)
+
+target_link_libraries(MLIRVectorOps MLIRIR)
using namespace mlir;
// Static initialization for VectorOps dialect registration.
-static DialectRegistration<VectorOpsDialect> VectorOps;
+static DialectRegistration<vector::VectorOpsDialect> VectorOps;
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
+
using namespace mlir;
+using namespace mlir::vector;
//===----------------------------------------------------------------------===//
// VectorOpsDialect
//===----------------------------------------------------------------------===//
-VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
- : Dialect("vector", context) {
+mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context) {
addOperations<VectorTransferReadOp, VectorTransferWriteOp,
VectorTypeCastOp>();
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/VectorOps/VectorOps.cpp.inc"
+ >();
+}
+
+//===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, ExtractElementOp op) {
+ *p << op.getOperationName() << " " << *op.vector() << op.position();
+ p->printOptionalAttrDict(op.getAttrs(), {"position"});
+ *p << " : " << op.vector()->getType();
+}
+
+static ParseResult parseExtractElementOp(OpAsmParser *parser,
+ OperationState *result) {
+ llvm::SMLoc attributeLoc, typeLoc;
+ SmallVector<NamedAttribute, 4> attrs;
+ OpAsmParser::OperandType vector;
+ Type type;
+ Attribute attr;
+ if (parser->parseOperand(vector) ||
+ parser->getCurrentLocation(&attributeLoc) ||
+ parser->parseAttribute(attr, "position", attrs) ||
+ parser->parseOptionalAttributeDict(attrs) ||
+ parser->getCurrentLocation(&typeLoc) || parser->parseColonType(type))
+ return failure();
+
+ auto vectorType = type.dyn_cast<VectorType>();
+ if (!vectorType)
+ return parser->emitError(typeLoc, "expected vector type");
+
+ auto positionAttr = attr.dyn_cast<ArrayAttr>();
+ if (!positionAttr ||
+ static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
+ return parser->emitError(
+ attributeLoc,
+ "expected position attribute of rank smaller than vector");
+
+ Type resType =
+ (static_cast<int64_t>(positionAttr.size()) == vectorType.getRank())
+ ? vectorType.getElementType()
+ : VectorType::get(
+ vectorType.getShape().drop_front(positionAttr.size()),
+ vectorType.getElementType());
+
+ result->attributes = attrs;
+ return failure(parser->resolveOperand(vector, type, result->operands) ||
+ parser->addTypeToList(resType, result->types));
+}
+
+static LogicalResult verify(ExtractElementOp op) {
+ auto positionAttr = op.position().getValue();
+ if (positionAttr.empty())
+ return op.emitOpError("expected non-empty position attribute");
+ if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
+ return op.emitOpError(
+ "expected position attribute of rank smaller than vector");
+ for (auto en : llvm::enumerate(positionAttr)) {
+ auto attr = en.value().dyn_cast<IntegerAttr>();
+ if (!attr || attr.getInt() < 0 ||
+ attr.getInt() > op.getVectorType().getDimSize(en.index()))
+ return op.emitOpError("expected position attribute #")
+ << (en.index() + 1)
+ << " to be a positive integer smaller than the corresponding "
+ "vector dimension";
+ }
+ return success();
}
//===----------------------------------------------------------------------===//
return parser->emitError(parser->getNameLoc(), "vector type expected");
// Extract optional paddingValue.
- // At this point, indexInfo may contain the optional paddingValue, pop it out.
+ // At this point, indexInfo may contain the optional paddingValue, pop it
+ // out.
if (static_cast<int64_t>(indexInfo.size()) != memrefType.getRank())
return parser->emitError(parser->getNameLoc(),
"expected " + Twine(memrefType.getRank()) +
return success();
}
+
+namespace mlir {
+
+#define GET_OP_CLASSES
+#include "mlir/VectorOps/VectorOps.cpp.inc"
+
+} // namespace mlir
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+// CHECK-LABEL: position_empty
+func @position_empty(%arg0: vector<4x8x16xf32>) {
+ // expected-error@+1 {{expected non-empty position attribute}}
+ %1 = vector.extractelement %arg0[] : vector<4x8x16xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: position_rank_overflow
+func @position_rank_overflow(%arg0: vector<4x8x16xf32>) {
+ // expected-error@+1 {{expected position attribute of rank smaller than vector}}
+ %1 = vector.extractelement %arg0[0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<4x8x16xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: position_overflow
+func @position_overflow(%arg0: vector<4x8x16xf32>) {
+ // expected-error@+1 {{expected position attribute #2 to be a positive integer smaller than the corresponding vector dimension}}
+ %1 = vector.extractelement %arg0[0 : i32, 43 : i32, 0 : i32] : vector<4x8x16xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: position_underflow
+func @position_overflow(%arg0: vector<4x8x16xf32>) {
+ // expected-error@+1 {{expected position attribute #3 to be a positive integer smaller than the corresponding vector dimension}}
+ %1 = vector.extractelement %arg0[0 : i32, 0 : i32, -1 : i32] : vector<4x8x16xf32>
+ return
+}
--- /dev/null
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: extractelement
+func @extractelement(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) {
+ // CHECK: vector.extractelement {{.*}}[3 : i32] : vector<4x8x16xf32>
+ %1 = vector.extractelement %arg0[3 : i32] : vector<4x8x16xf32>
+ // CHECK-NEXT: vector.extractelement {{.*}}[3 : i32, 3 : i32] : vector<4x8x16xf32>
+ %2 = vector.extractelement %arg0[3 : i32, 3 : i32] : vector<4x8x16xf32>
+ // CHECK-NEXT: vector.extractelement {{.*}}[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
+ %3 = vector.extractelement %arg0[3 : i32, 3 : i32, 3 : i32] : vector<4x8x16xf32>
+ return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32
+}
\ No newline at end of file