class LinalgOp<string mnemonic, list<OpTrait> props> :
Op<Linalg_Dialect, mnemonic, props> {
let arguments = (ins Variadic<View>); // default variadic builder
-
- let parser = [{ return impl::parseLinalgLibraryOp(parser, result); }];
-
- let printer = [{ impl::printLinalgLibraryOp(p, *this); }];
+ let parser = [{ return parseLinalgLibraryOp(parser, result); }];
+ let printer = [{ printLinalgLibraryOp(p, *this); }];
}
def BufferSizeOp :
Arguments<(ins Buffer)>,
Results<(outs Index)>
{
- let parser = [{
- return impl::parseBufferSizeOp(parser, result);
+ let parser = [{ return parseBufferSizeOp(parser, result); }];
+ let printer = [{ return printBufferSizeOp(p, *this); }];
+}
+
+def DimOp : Op<Linalg_Dialect, "dim", [NoSideEffect]>,
+ Arguments<(ins View:$view, APIntAttr:$index)>,
+ Results<(outs Index)> {
+ let summary = "dimension index operation";
+ let description = [{
+ The "linalg.dim" operation takes a linalg.view and returns an
+ "index". It requires a single integer attribute named "index". It
+ returns the size of the specified dimension. For example:
+
+ %1 = linalg.dim %0, 2 : view<?x?x?xf32>
}];
- let printer = [{
- return impl::printBufferSizeOp(p, this->getOperation());
+ let parser = [{ return parseDimOp(parser, result); }];
+ let printer = [{ return printDimOp(p, *this); }];
+ let verifier = [{ return ::verify(*this); }];
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState *result, Value *view," "unsigned index",
+ [{
+ result->addOperands(view);
+ result->addAttribute(
+ "index", builder->getIntegerAttr(builder->getIndexType(), index));
+ result->types.push_back(builder->getIndexType());
+ }]>];
+
+ let extraClassDeclaration = [{
+ unsigned getIndex() {
+ return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
+ }
}];
}
namespace mlir {
class ModulePassBase;
-mlir::ModulePassBase *
-createLinalgTilingPass(llvm::ArrayRef<int64_t> tileSizes = {});
+namespace linalg {
+ModulePassBase *createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
-mlir::ModulePassBase *createLowerLinalgToLLVMPass();
+ModulePassBase *createLowerLinalgToLLVMPass();
+
+} // namespace linalg
} // namespace mlir
#endif // MLIR_LINALG_PASSES_H_
*p << "] : " << getType();
}
-namespace mlir {
-namespace linalg {
-namespace impl {
-void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op);
-ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
-void printBufferSizeOp(OpAsmPrinter *p, Operation *op);
-ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
-} // namespace impl
-} // namespace linalg
-
/// Buffer size prints as:
///
/// ``` {.mlir}
/// %0 = linalg.buffer_size %arg0 : !linalg.buffer<f32>
/// ```
-void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) {
- assert(op->getAbstractOperation() && "unregistered operation");
- *p << cast<BufferSizeOp>(op).getOperationName() << " " << *op->getOperand(0);
- p->printOptionalAttrDict(op->getAttrs());
- *p << " : " << op->getOperand(0)->getType();
+static void printBufferSizeOp(OpAsmPrinter *p, BufferSizeOp op) {
+ *p << op.getOperationName() << " " << *op.getOperand();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.getOperand()->getType();
}
-ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
- OperationState *result) {
+static ParseResult parseBufferSizeOp(OpAsmParser *parser,
+ OperationState *result) {
OpAsmParser::OperandType op;
Type type;
return failure(parser->parseOperand(op) ||
result->types));
}
-#define GET_OP_CLASSES
-#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
+static void printDimOp(OpAsmPrinter *p, DimOp op) {
+ *p << op.getOperationName() << " " << *op.getOperand() << ", "
+ << op.getIndex();
+ p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
+ *p << " : " << op.getOperand()->getType();
+}
-} // namespace mlir
+static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) {
+ OpAsmParser::OperandType operandInfo;
+ IntegerAttr indexAttr;
+ Type type;
+ Type indexType = parser->getBuilder().getIndexType();
+ return failure(parser->parseOperand(operandInfo) || parser->parseComma() ||
+ parser->parseAttribute(indexAttr, indexType, "index",
+ result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(operandInfo, type, result->operands) ||
+ parser->addTypeToList(indexType, result->types));
+}
+
+static LogicalResult verify(linalg::DimOp op) {
+ // Check that we have an integer index operand.
+ auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
+ if (!indexAttr)
+ return op.emitOpError("requires an integer attribute named 'index'");
+
+ uint64_t index = indexAttr.getValue().getZExtValue();
+ auto type = op.getOperand()->getType();
+ if (auto viewType = type.dyn_cast<ViewType>()) {
+ if (index >= viewType.getRank())
+ return op.emitOpError("index is out of range");
+ } else {
+ return op.emitOpError("requires an operand with view type");
+ }
+
+ return success();
+}
// A LinalgLibraryOp prints as:
//
// ```
//
// Where %0, %1 and %2 are ssa-values of type ViewType.
-void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
+static void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
*p << op->getName().getStringRef() << "(";
interleave(
[&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
}
-ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser,
- OperationState *result) {
+static ParseResult parseLinalgLibraryOp(OpAsmParser *parser,
+ OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types;
return failure(
result->operands));
}
+namespace mlir {
+
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
+
+} // namespace mlir
+
// Ideally this should all be Tablegen'd but there is no good story for
// AffineMap for now.
SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
return builder.getArrayAttr(attrs);
}
-// BufferAllocOp creates a new `index` value.
+// BufferAllocOp creates a new `!linalg.buffer` value.
class BufferAllocOpConversion : public LLVMOpLowering {
public:
explicit BufferAllocOpConversion(MLIRContext *context,
}
};
-// BufferDeallocOp creates a new `index` value.
+// BufferDeallocOp creates no value.
class BufferDeallocOpConversion : public LLVMOpLowering {
public:
explicit BufferDeallocOpConversion(MLIRContext *context,
}
};
+// DimOp creates a new `index` value.
+class DimOpConversion : public LLVMOpLowering {
+public:
+ explicit DimOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+ : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
+
+ SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+ FuncBuilder &rewriter) const override {
+ auto dimOp = cast<linalg::DimOp>(op);
+ auto indexTy = lowering.convertType(rewriter.getIndexType());
+ edsc::ScopedContext context(rewriter, op->getLoc());
+ return {extractvalue(
+ indexTy, operands[0],
+ makePositionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))};
+ }
+};
+
namespace {
// Common functionality for Linalg LoadOp and StoreOp conversion to the
// LLVM IR Dialect.
llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() override {
return ConversionListBuilder<
BufferAllocOpConversion, BufferDeallocOpConversion,
- BufferSizeOpConversion, DotOpConversion, LoadOpConversion,
- RangeOpConversion, SliceOpConversion, StoreOpConversion,
- ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(),
- *this);
+ BufferSizeOpConversion, DimOpConversion, DotOpConversion,
+ LoadOpConversion, RangeOpConversion, SliceOpConversion,
+ StoreOpConversion, ViewOpConversion>::build(&converterStorage,
+ llvmDialect->getContext(),
+ *this);
}
Type convertAdditionalType(Type t) override {
signalPassFailure();
}
-ModulePassBase *mlir::createLowerLinalgToLLVMPass() {
+ModulePassBase *mlir::linalg::createLowerLinalgToLLVMPass() {
return new LowerLinalgToLLVMPass();
}
this->tileSizes.assign(sizes.begin(), sizes.end());
}
-ModulePassBase *mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
+ModulePassBase *
+mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
return new LinalgTilingPass(tileSizes);
}
}
// CHECK-LABEL: func @dot(%arg0: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg1: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg2: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
// CHECK: llvm.call @linalg_dot(%arg0, %arg1, %arg2) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> ()
+
+func @dim(%arg0: !linalg.view<?x?xf32>) {
+ %0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
+ return
+}
+// CHECK-LABEL: func @dim(%arg0: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) {
+// CHECK: %0 = llvm.extractvalue %arg0[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: linalg.matvec(%arg0, %arg1, %arg2) : !linalg.view<?x?xf32>, !linalg.view<?xf32>, !linalg.view<?xf32>
// CHECK-NEXT: linalg.dot(%arg1, %arg2, %arg3) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
+func @dim(%arg0: !linalg.view<?x?xf32>) {
+ %0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
+ %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+ linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+ return
+}
+// CHECK-LABEL: func @dim(%arg0: !linalg.view<?x?xf32>) {
+// CHECK-NEXT: %0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
+// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+