Add a linalg.dim
authorNicolas Vasilache <ntv@google.com>
Mon, 13 May 2019 21:59:55 +0000 (14:59 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:41:02 +0000 (13:41 -0700)
    A linalg.dim operation is used to extract size information from !linalg.view objects passed
    through function call boundaries.

--

PiperOrigin-RevId: 248017488

mlir/include/mlir/Linalg/IR/LinalgOps.td
mlir/include/mlir/Linalg/Passes.h
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/Tiling.cpp
mlir/test/Linalg/llvm.mlir
mlir/test/Linalg/roundtrip.mlir

index 2aa1e43..58eb3f0 100644 (file)
@@ -82,10 +82,8 @@ LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
 class LinalgOp<string mnemonic, list<OpTrait> props> :
 Op<Linalg_Dialect, mnemonic, props> {
   let arguments = (ins Variadic<View>); // default variadic builder
-
-  let parser = [{ return impl::parseLinalgLibraryOp(parser, result); }];
-
-  let printer = [{ impl::printLinalgLibraryOp(p, *this); }];
+  let parser = [{ return parseLinalgLibraryOp(parser, result); }];
+  let printer = [{ printLinalgLibraryOp(p, *this); }];
 }
 
 def BufferSizeOp :
@@ -93,12 +91,39 @@ def BufferSizeOp :
     Arguments<(ins Buffer)>,
     Results<(outs Index)>
 {
-  let parser = [{
-    return impl::parseBufferSizeOp(parser, result);
+  let parser = [{ return parseBufferSizeOp(parser, result); }];
+  let printer = [{ return printBufferSizeOp(p, *this); }];
+}
+
+def DimOp : Op<Linalg_Dialect, "dim", [NoSideEffect]>,
+    Arguments<(ins View:$view, APIntAttr:$index)>,
+    Results<(outs Index)> {
+  let summary = "dimension index operation";
+  let description = [{
+    The "linalg.dim" operation takes a linalg.view and returns an
+    "index". It requires a single integer attribute named "index". It
+     returns the size of the specified dimension. For example:
+
+      %1 = linalg.dim %0, 2 : view<?x?x?xf32>
   }];
 
-  let printer = [{
-    return impl::printBufferSizeOp(p, this->getOperation());
+  let parser = [{ return parseDimOp(parser, result); }];
+  let printer = [{ return printDimOp(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *view," "unsigned index",
+    [{
+      result->addOperands(view);
+      result->addAttribute(
+        "index", builder->getIntegerAttr(builder->getIndexType(), index));
+      result->types.push_back(builder->getIndexType());
+    }]>];
+
+  let extraClassDeclaration = [{
+    unsigned getIndex() {
+      return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
+    }
   }];
 }
 
index 931de90..2825139 100644 (file)
 namespace mlir {
 class ModulePassBase;
 
-mlir::ModulePassBase *
-createLinalgTilingPass(llvm::ArrayRef<int64_t> tileSizes = {});
+namespace linalg {
+ModulePassBase *createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
 
-mlir::ModulePassBase *createLowerLinalgToLLVMPass();
+ModulePassBase *createLowerLinalgToLLVMPass();
+
+} // namespace linalg
 } // namespace mlir
 
 #endif // MLIR_LINALG_PASSES_H_
index da102b3..e6e18bb 100644 (file)
@@ -488,30 +488,19 @@ void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
   *p << "] : " << getType();
 }
 
-namespace mlir {
-namespace linalg {
-namespace impl {
-void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op);
-ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
-void printBufferSizeOp(OpAsmPrinter *p, Operation *op);
-ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
-} // namespace impl
-} // namespace linalg
-
 /// Buffer size prints as:
 ///
 /// ``` {.mlir}
 ///    %0 = linalg.buffer_size %arg0 : !linalg.buffer<f32>
 /// ```
-void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) {
-  assert(op->getAbstractOperation() && "unregistered operation");
-  *p << cast<BufferSizeOp>(op).getOperationName() << " " << *op->getOperand(0);
-  p->printOptionalAttrDict(op->getAttrs());
-  *p << " : " << op->getOperand(0)->getType();
+static void printBufferSizeOp(OpAsmPrinter *p, BufferSizeOp op) {
+  *p << op.getOperationName() << " " << *op.getOperand();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getOperand()->getType();
 }
 
-ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
-                                                  OperationState *result) {
+static ParseResult parseBufferSizeOp(OpAsmParser *parser,
+                                     OperationState *result) {
   OpAsmParser::OperandType op;
   Type type;
   return failure(parser->parseOperand(op) ||
@@ -522,10 +511,44 @@ ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
                                        result->types));
 }
 
-#define GET_OP_CLASSES
-#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
+static void printDimOp(OpAsmPrinter *p, DimOp op) {
+  *p << op.getOperationName() << " " << *op.getOperand() << ", "
+     << op.getIndex();
+  p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
+  *p << " : " << op.getOperand()->getType();
+}
 
-} // namespace mlir
+static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType operandInfo;
+  IntegerAttr indexAttr;
+  Type type;
+  Type indexType = parser->getBuilder().getIndexType();
+  return failure(parser->parseOperand(operandInfo) || parser->parseComma() ||
+                 parser->parseAttribute(indexAttr, indexType, "index",
+                                        result->attributes) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonType(type) ||
+                 parser->resolveOperand(operandInfo, type, result->operands) ||
+                 parser->addTypeToList(indexType, result->types));
+}
+
+static LogicalResult verify(linalg::DimOp op) {
+  // Check that we have an integer index operand.
+  auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
+  if (!indexAttr)
+    return op.emitOpError("requires an integer attribute named 'index'");
+
+  uint64_t index = indexAttr.getValue().getZExtValue();
+  auto type = op.getOperand()->getType();
+  if (auto viewType = type.dyn_cast<ViewType>()) {
+    if (index >= viewType.getRank())
+      return op.emitOpError("index is out of range");
+  } else {
+    return op.emitOpError("requires an operand with view type");
+  }
+
+  return success();
+}
 
 // A LinalgLibraryOp prints as:
 //
@@ -541,7 +564,7 @@ ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
 // ```
 //
 // Where %0, %1 and %2 are ssa-values of type ViewType.
-void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
+static void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
   assert(op->getAbstractOperation() && "unregistered operation");
   *p << op->getName().getStringRef() << "(";
   interleave(
@@ -553,8 +576,8 @@ void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
       [&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
 }
 
-ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser,
-                                                     OperationState *result) {
+static ParseResult parseLinalgLibraryOp(OpAsmParser *parser,
+                                        OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 3> ops;
   SmallVector<Type, 3> types;
   return failure(
@@ -565,6 +588,13 @@ ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser,
                               result->operands));
 }
 
+namespace mlir {
+
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
+
+} // namespace mlir
+
 // Ideally this should all be Tablegen'd but there is no good story for
 // AffineMap for now.
 SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
index 2d1f5f2..6c4d5c2 100644 (file)
@@ -154,7 +154,7 @@ static ArrayAttr makePositionAttr(FuncBuilder &builder,
   return builder.getArrayAttr(attrs);
 }
 
-// BufferAllocOp creates a new `index` value.
+// BufferAllocOp creates a new `!linalg.buffer` value.
 class BufferAllocOpConversion : public LLVMOpLowering {
 public:
   explicit BufferAllocOpConversion(MLIRContext *context,
@@ -213,7 +213,7 @@ public:
   }
 };
 
-// BufferDeallocOp creates a new `index` value.
+// BufferDeallocOp creates no value.
 class BufferDeallocOpConversion : public LLVMOpLowering {
 public:
   explicit BufferDeallocOpConversion(MLIRContext *context,
@@ -268,6 +268,23 @@ public:
   }
 };
 
+// DimOp creates a new `index` value.
+class DimOpConversion : public LLVMOpLowering {
+public:
+  explicit DimOpConversion(MLIRContext *context, LLVMLowering &lowering_)
+      : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto dimOp = cast<linalg::DimOp>(op);
+    auto indexTy = lowering.convertType(rewriter.getIndexType());
+    edsc::ScopedContext context(rewriter, op->getLoc());
+    return {extractvalue(
+        indexTy, operands[0],
+        makePositionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))};
+  }
+};
+
 namespace {
 // Common functionality for Linalg LoadOp and StoreOp conversion to the
 // LLVM IR Dialect.
@@ -533,10 +550,11 @@ protected:
   llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() override {
     return ConversionListBuilder<
         BufferAllocOpConversion, BufferDeallocOpConversion,
-        BufferSizeOpConversion, DotOpConversion, LoadOpConversion,
-        RangeOpConversion, SliceOpConversion, StoreOpConversion,
-        ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(),
-                                 *this);
+        BufferSizeOpConversion, DimOpConversion, DotOpConversion,
+        LoadOpConversion, RangeOpConversion, SliceOpConversion,
+        StoreOpConversion, ViewOpConversion>::build(&converterStorage,
+                                                    llvmDialect->getContext(),
+                                                    *this);
   }
 
   Type convertAdditionalType(Type t) override {
@@ -564,7 +582,7 @@ void LowerLinalgToLLVMPass::runOnModule() {
     signalPassFailure();
 }
 
-ModulePassBase *mlir::createLowerLinalgToLLVMPass() {
+ModulePassBase *mlir::linalg::createLowerLinalgToLLVMPass() {
   return new LowerLinalgToLLVMPass();
 }
 
index e1fa74d..f50076a 100644 (file)
@@ -360,7 +360,8 @@ LinalgTilingPass::LinalgTilingPass(ArrayRef<int64_t> sizes)
     this->tileSizes.assign(sizes.begin(), sizes.end());
 }
 
-ModulePassBase *mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
+ModulePassBase *
+mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
   return new LinalgTilingPass(tileSizes);
 }
 
index 3213342..143b851 100644 (file)
@@ -73,3 +73,10 @@ func @dot(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg
 }
 // CHECK-LABEL: func @dot(%arg0: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg1: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg2: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
 //       CHECK:   llvm.call @linalg_dot(%arg0, %arg1, %arg2) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> ()
+
+func @dim(%arg0: !linalg.view<?x?xf32>) {
+  %0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
+  return
+}
+// CHECK-LABEL: func @dim(%arg0: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) {
+//       CHECK:   %0 = llvm.extractvalue %arg0[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
index 13e3604..c2eed72 100644 (file)
@@ -52,3 +52,14 @@ func @ops(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !lina
 //  CHECK-NEXT:  linalg.matvec(%arg0, %arg1, %arg2) : !linalg.view<?x?xf32>, !linalg.view<?xf32>, !linalg.view<?xf32>
 //  CHECK-NEXT:  linalg.dot(%arg1, %arg2, %arg3) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
 
+func @dim(%arg0: !linalg.view<?x?xf32>) {
+  %0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
+  %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+  linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+  return
+}
+// CHECK-LABEL: func @dim(%arg0: !linalg.view<?x?xf32>) {
+//  CHECK-NEXT:   %0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
+//  CHECK-NEXT:   %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+//  CHECK-NEXT:   linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+