[Linalg] Add a simple buffer type with alloc/dealloc ops
authorNicolas Vasilache <ntv@google.com>
Thu, 18 Apr 2019 20:56:18 +0000 (13:56 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Wed, 24 Apr 2019 05:00:36 +0000 (22:00 -0700)
    This CL adds a linalg.buffer<f32> type and buffer_alloc / buffer_dealloc ops with the proper roundtripping test.

--

PiperOrigin-RevId: 244252306

mlir/include/mlir/Linalg/LinalgOps.h
mlir/include/mlir/Linalg/LinalgTypes.h
mlir/lib/Linalg/LinalgOps.cpp
mlir/lib/Linalg/LinalgTypes.cpp
mlir/test/Linalg/roundtrip.mlir

index 7921822..8e83d7c 100644 (file)
 #define MLIR_LINALG_LINALGOPS_H_
 
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Linalg/LinalgTypes.h"
 #include "mlir/Support/LLVM.h"
 
 namespace mlir {
 
+/// A BufferAllocOp is used to create a 1-D !linalg.buffer upon which a base
+/// view can be laid out. The size argument is an `i64` (and not an index), so
+/// that we can
+class BufferAllocOp
+    : public Op<BufferAllocOp, OpTrait::OneOperand, OpTrait::OneResult> {
+public:
+  using Op::Op;
+
+  // Hooks to customize the behavior of this op.
+  static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; }
+  static void build(Builder *b, OperationState *result, Type type, Value *size);
+  LogicalResult verify();
+  static bool parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+
+  // Op-specific functionality.
+  Value *size() { return getOperand(); }
+  BufferType getBufferType() { return getType().cast<BufferType>(); }
+  Type getElementType() { return getBufferType().getElementType(); }
+};
+
+/// A BufferDeallocOp is used to free a !linalg.buffer.
+class BufferDeallocOp
+    : public Op<BufferDeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
+public:
+  using Op::Op;
+
+  // Hooks to customize the behavior of this op.
+  static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; }
+  static void build(Builder *b, OperationState *result, Value *buffer);
+  LogicalResult verify();
+  static bool parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+
+  // Op-specific functionality.
+  Value *getBuffer() { return getOperand(); }
+  BufferType getBufferType() {
+    return getOperand()->getType().cast<BufferType>();
+  }
+};
+
 /// A RangeOp is used to create a value of RangeType from 3 values of type index
 /// that represent the min, max and step values of the range.
 class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
@@ -30,9 +72,7 @@ class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
 public:
   using Op::Op;
 
-  //////////////////////////////////////////////////////////////////////////////
   // Hooks to customize the behavior of this op.
-  //////////////////////////////////////////////////////////////////////////////
   static llvm::StringRef getOperationName() { return "linalg.range"; }
   static void build(Builder *b, OperationState *result, Value *min, Value *max,
                     Value *step);
@@ -40,9 +80,7 @@ public:
   static bool parse(OpAsmParser *parser, OperationState *result);
   void print(OpAsmPrinter *p);
 
-  //////////////////////////////////////////////////////////////////////////////
   // Op-specific functionality.
-  //////////////////////////////////////////////////////////////////////////////
   Value *min() { return getOperand(0); }
   Value *max() { return getOperand(1); }
   Value *step() { return getOperand(2); }
index 2d2c74e..f08d071 100644 (file)
@@ -25,7 +25,8 @@ namespace mlir {
 class MLIRContext;
 
 enum LinalgTypes {
-  Range = Type::FIRST_LINALG_TYPE,
+  Buffer = Type::FIRST_LINALG_TYPE,
+  Range,
   LAST_USED_LINALG_TYPE = Range,
 };
 
@@ -40,6 +41,22 @@ public:
   void printType(Type type, llvm::raw_ostream &os) const override;
 };
 
+/// A BufferType represents a minimal range abstraction (min, max, step).
+class BufferTypeStorage;
+class BufferType : public Type::TypeBase<BufferType, Type, BufferTypeStorage> {
+public:
+  // Used for generic hooks in TypeBase.
+  using Base::Base;
+  /// Construction hook.
+  static BufferType get(MLIRContext *context, Type elementType);
+  /// Used to implement llvm-style cast.
+  static bool kindof(unsigned kind) { return kind == LinalgTypes::Buffer; }
+  //////////////////////////////////////////////////////////////////////////////
+  // Type-specific functionality.
+  //////////////////////////////////////////////////////////////////////////////
+  Type getElementType();
+};
+
 /// A RangeType represents a minimal range abstraction (min, max, step).
 class RangeType : public Type::TypeBase<RangeType, Type> {
 public:
index bba47fb..c6260d8 100644 (file)
@@ -28,6 +28,9 @@
 
 using namespace mlir;
 
+//////////////////////////////////////////////////////////////////////////////
+// RangeOp
+//////////////////////////////////////////////////////////////////////////////
 void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
                           Value *max, Value *step) {
   result->addOperands({min, max, step});
@@ -65,3 +68,75 @@ bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
          parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
          parser->addTypeToList(type, result->types);
 }
+
+//////////////////////////////////////////////////////////////////////////////
+// BufferAllocOp
+//////////////////////////////////////////////////////////////////////////////
+void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type,
+                                Value *size) {
+  result->addOperands({size});
+  result->addTypes(type);
+}
+
+mlir::LogicalResult mlir::BufferAllocOp::verify() {
+  if (!size() || !size()->getType().isa<IntegerType>() ||
+      !size()->getType().cast<IntegerType>().isInteger(64))
+    return emitOpError("first operand should be of type i64");
+  if (!VectorType::isValidElementType(getElementType()) &&
+      !getElementType().isa<VectorType>())
+    return emitOpError("unsupported buffer element type");
+  return mlir::success();
+}
+
+// A BufferAllocOp prints as:
+//
+// ```{.mlir}
+//   linalg.alloc %0 : !linalg.buffer<f32>
+// ```
+void mlir::BufferAllocOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *size() << " : " << getType();
+}
+
+bool mlir::BufferAllocOp::parse(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType sizeInfo;
+  BufferType bufferType;
+  auto int64Ty = parser->getBuilder().getIntegerType(64);
+  if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
+    return true;
+  if (bufferType.getElementType() != parser->getBuilder().getF32Type())
+    return parser->emitError(
+        parser->getNameLoc(),
+        "Only buffer<f32> supported until mlir::Parser pieces are exposed");
+  return parser->resolveOperands(sizeInfo, int64Ty, result->operands) ||
+         parser->addTypeToList(bufferType, result->types);
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// BufferDeallocOp
+//////////////////////////////////////////////////////////////////////////////
+void mlir::BufferDeallocOp::build(Builder *b, OperationState *result,
+                                  Value *buffer) {
+  result->addOperands({buffer});
+}
+
+mlir::LogicalResult mlir::BufferDeallocOp::verify() {
+  if (!getBuffer()->getType())
+    return emitOpError("first operand should be of type buffer");
+  return mlir::success();
+}
+
+// A BufferDeallocOp prints as:
+//
+// ```{.mlir}
+//   linalg.dealloc %0 : !linalg.buffer<f32>
+// ```
+void mlir::BufferDeallocOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
+}
+
+bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType sizeInfo;
+  BufferType bufferType;
+  return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
+         parser->resolveOperands(sizeInfo, bufferType, result->operands);
+}
index 7aabfd2..822cd70 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "mlir/Linalg/LinalgTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/StandardTypes.h"
 #include "mlir/Linalg/LinalgOps.h"
 #include "mlir/Support/LLVM.h"
 
@@ -28,24 +29,73 @@ using namespace mlir;
 
 mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
     : Dialect("linalg", context) {
-  addTypes<RangeType>();
-  addOperations<RangeOp>();
+  addTypes<BufferType, RangeType>();
+  addOperations<BufferAllocOp, BufferDeallocOp, RangeOp>();
 }
 
+struct mlir::BufferTypeStorage : public mlir::TypeStorage {
+  /// Underlying Key type to transport the payload needed to construct a custom
+  /// type in a generic way.
+  struct Key {
+    Key(Type elementType) : elementType(elementType) {}
+    Type elementType;
+  };
+  /// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
+  using KeyTy = Key;
+
+  /// Construction in the llvm::BumpPtrAllocator given a key.
+  static BufferTypeStorage *construct(TypeStorageAllocator &allocator,
+                                      const Key &key) {
+    return new (allocator.allocate<BufferTypeStorage>()) BufferTypeStorage(key);
+  }
+
+  /// Equality operator for hashing.
+  bool operator==(const Key &key) const {
+    return elementType == key.elementType;
+  }
+
+  /// Hashing for unique'ing.
+  static unsigned hashKey(const Key &key) {
+    return llvm::hash_combine(key.elementType);
+  }
+
+  Type getElementType() { return elementType; };
+
+private:
+  BufferTypeStorage(const Key &key) : elementType(key.elementType) {}
+
+  Type elementType;
+};
+
+BufferType mlir::BufferType::get(MLIRContext *context, Type elementType) {
+  return Base::get(context, LinalgTypes::Buffer, elementType);
+}
+
+Type mlir::BufferType::getElementType() { return getImpl()->getElementType(); }
+
 Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const {
   MLIRContext *context = getContext();
   if (spec == "range")
     return RangeType::get(getContext());
+  // TODO(ntv): reuse mlir Parser once exposed.
+  if (spec == "buffer<f32>")
+    return BufferType::get(getContext(), FloatType::getF32(getContext()));
   return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
 }
 
 /// RangeType prints as just "range".
+static void print(BufferType bt, raw_ostream &os) {
+  os << "buffer<" << bt.getElementType() << ">";
+}
 static void print(RangeType rt, raw_ostream &os) { os << "range"; }
 
 void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
   switch (type.getKind()) {
   default:
     llvm_unreachable("Unhandled Linalg type");
+  case LinalgTypes::Buffer:
+    print(type.cast<BufferType>(), os);
+    break;
   case LinalgTypes::Range:
     print(type.cast<RangeType>(), os);
     break;
index f98558a..8544ed2 100644 (file)
@@ -5,4 +5,15 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
   return
 }
 // CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) {
-//  CHECK-NEXT:  %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
\ No newline at end of file
+//  CHECK-NEXT:  %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
+
+func @buffer(%arg0: i64, %arg1: i64) {
+  %0 = muli %arg0, %arg0 : i64
+  %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+  linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+  return
+}
+// CHECK-LABEL: func @buffer(%arg0: i64, %arg1: i64) {
+//  CHECK-NEXT:  %0 = muli %arg0, %arg0 : i64
+//  CHECK-NEXT:  %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+//  CHECK-NEXT:  linalg.buffer_dealloc %1 : !linalg.buffer<f32>
\ No newline at end of file