[Linalg] Add a slice op
authorNicolas Vasilache <ntv@google.com>
Fri, 19 Apr 2019 19:55:34 +0000 (12:55 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Wed, 24 Apr 2019 05:01:10 +0000 (22:01 -0700)
    This CL adds a linalg.slice op with the proper roundtripping test.
    A slice op allows taking subviews that may be rank-reducing (if some indexing is of index type) or not (if all indexings are of linalg.range type).

    A slice must be constructed directly from a base view (no chains of slices may exist in the IR). Helper functions that fold will be provided for construction if/when necessary.

    This also renames base_view to view.

--

PiperOrigin-RevId: 244406827

mlir/include/mlir/Linalg/LinalgOps.h
mlir/include/mlir/Linalg/LinalgTypes.h
mlir/lib/Linalg/LinalgOps.cpp
mlir/lib/Linalg/LinalgTypes.cpp
mlir/test/Linalg/roundtrip.mlir

index 925129e..2142459 100644 (file)
 
 namespace mlir {
 
-/// A `BaseViewOp` produces a `ViewType` which is a multi-dimensional range
-/// abstraction on top of an underlying linalg.buffer. A BaseViewOp gives a
-/// buffer an indexing structure.
-///
-/// A new value of ViewType is constructed from a buffer with a base_view op and
-/// ranges:
+/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
+/// upon which a base view can be laid out to give it indexing semantics.
+/// "buffer_alloc" takes a single argument, the size of the buffer to allocate
+/// (in number of elements).
 ///
 /// ```{.mlir}
-///    %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
-///    %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
-///    %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
+///     %0 = linalg.buffer_alloc %arg0 : !linalg.buffer<f32>
 /// ```
-class BaseViewOp : public mlir::Op<BaseViewOp, mlir::OpTrait::VariadicOperands,
-                                   mlir::OpTrait::OneResult,
-                                   mlir::OpTrait::HasNoSideEffect> {
-  enum { FirstIndexingOperand = 1 };
-
-public:
-  using Op::Op;
-
-  // Hooks to customize the behavior of this op.
-  static llvm::StringRef getOperationName() { return "linalg.base_view"; }
-  static void build(mlir::Builder *b, mlir::OperationState *result,
-                    mlir::Value *buffer,
-                    llvm::ArrayRef<mlir::Value *> indexings);
-  mlir::LogicalResult verify();
-  static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
-  void print(mlir::OpAsmPrinter *p);
-
-  // Op-specific functionality.
-  unsigned getRank() { return getViewType().getRank(); }
-  mlir::Type getElementType() { return getViewType().getElementType(); }
-  ViewType getViewType() { return getType().cast<ViewType>(); }
-  mlir::Value *getSupportingBuffer() { return getOperand(0); }
-  // Get the underlying indexing at a given rank.
-  mlir::Value *getIndexing(unsigned rank) {
-    return *(getIndexings().begin() + rank);
-  }
-  // Get all the indexings in this view.
-  mlir::Operation::operand_range getIndexings() {
-    return {operand_begin() + BaseViewOp::FirstIndexingOperand, operand_end()};
-  }
-};
-
-/// A BufferAllocOp is used to create a 1-D !linalg.buffer upon which a base
-/// view can be laid out. The size argument is an `i64` (and not an index), so
-/// that we can
 class BufferAllocOp
     : public Op<BufferAllocOp, OpTrait::OneOperand, OpTrait::OneResult> {
 public:
@@ -89,7 +50,11 @@ public:
   Type getElementType() { return getBufferType().getElementType(); }
 };
 
-/// A BufferDeallocOp is used to free a !linalg.buffer.
+/// The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
+///
+/// ```{.mlir}
+///     linalg.buffer_dealloc %0 : !linalg.buffer<f32>
+/// ```
 class BufferDeallocOp
     : public Op<BufferDeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
 public:
@@ -109,8 +74,12 @@ public:
   }
 };
 
-/// A RangeOp is used to create a value of RangeType from 3 values of type index
+/// The "linalg.range" op creates a linalg.range from 3 values of type `index`
 /// that represent the min, max and step values of the range.
+///
+/// ```{.mlir}
+///    %3 = linalg.range %0:%1:%2 : !linalg.range
+/// ```
 class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
                           OpTrait::OneResult, OpTrait::HasNoSideEffect> {
 public:
@@ -130,6 +99,126 @@ public:
   Value *step() { return getOperand(2); }
 };
 
+/// The "linalg.slice" op produces a linalg.view which is a subview of a given
+/// base view. This allows defining a subregion within the underlying buffer to
+/// operate on only a subset of the buffer.
+///
+/// A "linalg.slice" op takes a base view and a variadic number of indexings and
+/// produces a linalg.view of the same elemental type as the buffer. An indexing
+/// is either:
+///   1. a linalg.range, in which case it does not reduce the rank of the parent
+///      view.
+///   2. an index, in which case it reduces the rank of the parent view by one.
+///
+/// The parent view must be a base view (i.e. either a function argument or has
+/// been produced by a linalg.view op). In other words, chains of
+/// linalg.slice operations cannot be constructed in the IR. This defines away
+/// problems related to keeping track of which dimensions of the base view have
+/// been rank-reduced.
+///
+/// Examples:
+///   1. rank-preserving slice:
+///
+/// ```{.mlir}
+///    %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, !linalg.range,
+///    !linalg.range, !linalg.view<?x?xf32>
+/// ```
+///
+///   2. rank-reducing slice (from 2-D to 1-D):
+///
+/// ```{.mlir}
+///    %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index,
+///    !linalg.range, !linalg.view<?xf32>
+/// ```
+///
+///   3. rank-reducing slice (from 2-D to 0-D):
+///
+/// ```{.mlir}
+///    %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index, index,
+///    !linalg.view<f32>
+/// ```
+class ViewOp;
+class SliceOp : public mlir::Op<SliceOp, mlir::OpTrait::VariadicOperands,
+                                mlir::OpTrait::OneResult,
+                                mlir::OpTrait::HasNoSideEffect> {
+  enum { FirstIndexingOperand = 1 };
+
+public:
+  using Op::Op;
+
+  // Hooks to customize the behavior of this op.
+  static llvm::StringRef getOperationName() { return "linalg.slice"; }
+  static void build(mlir::Builder *b, mlir::OperationState *result,
+                    mlir::Value *base, llvm::ArrayRef<mlir::Value *> indexings);
+  mlir::LogicalResult verify();
+  static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+  void print(mlir::OpAsmPrinter *p);
+
+  // Op-specific functionality.
+  unsigned getRank() { return getViewType().getRank(); }
+  mlir::Type getElementType() { return getViewType().getElementType(); }
+  ViewType getViewType() { return getType().cast<ViewType>(); }
+  Value *getBaseView() { return getOperand(0); }
+  ViewOp getBaseViewOp();
+  ViewType getBaseViewType();
+  unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
+  // Get the underlying indexing at a given rank.
+  mlir::Value *getIndexing(unsigned rank) {
+    return *(getIndexings().begin() + rank);
+  }
+  // Get all the indexings in this view.
+  mlir::Operation::operand_range getIndexings() {
+    return {operand_begin() + SliceOp::FirstIndexingOperand, operand_end()};
+  }
+  // Get the subset of indexings that are of RangeType.
+  SmallVector<Value *, 8> getRanges();
+};
+
+/// The "linalg.view" op produces a linalg.view which is a multi-dimensional
+/// range abstraction on top of an underlying linalg.buffer. This gives an
+/// indexing structure to an otherwise non-indexable linalg.buffer.
+///
+/// A "linalg.view" takes a buffer and a variadic number of ranges and produces
+/// a `view` of the same elemental type as the buffer and of rank the number of
+/// ranges:
+///
+/// ```{.mlir}
+///    %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+///    %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
+///    %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
+/// ```
+class ViewOp : public mlir::Op<ViewOp, mlir::OpTrait::VariadicOperands,
+                               mlir::OpTrait::OneResult,
+                               mlir::OpTrait::HasNoSideEffect> {
+  enum { FirstIndexingOperand = 1 };
+
+public:
+  using Op::Op;
+
+  // Hooks to customize the behavior of this op.
+  static llvm::StringRef getOperationName() { return "linalg.view"; }
+  static void build(mlir::Builder *b, mlir::OperationState *result,
+                    mlir::Value *buffer,
+                    llvm::ArrayRef<mlir::Value *> indexings);
+  mlir::LogicalResult verify();
+  static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+  void print(mlir::OpAsmPrinter *p);
+
+  // Op-specific functionality.
+  unsigned getRank() { return getViewType().getRank(); }
+  mlir::Type getElementType() { return getViewType().getElementType(); }
+  ViewType getViewType() { return getType().cast<ViewType>(); }
+  mlir::Value *getSupportingBuffer() { return getOperand(0); }
+  // Get the underlying indexing at a given rank.
+  mlir::Value *getIndexing(unsigned rank) {
+    return *(getIndexings().begin() + rank);
+  }
+  // Get all the indexings in this view.
+  mlir::Operation::operand_range getIndexings() {
+    return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
+  }
+};
+
 } // namespace mlir
 
 #endif // MLIR_LINALG_LINALGOPS_H_
index b294bb3..fbb1cbf 100644 (file)
@@ -42,7 +42,9 @@ public:
   void printType(Type type, llvm::raw_ostream &os) const override;
 };
 
-/// A BufferType represents a minimal range abstraction (min, max, step).
+/// A BufferType represents a contiguous block of memory that can be allocated
+/// and deallocated. A buffer cannot be indexed directly, a view must be
+/// laid out on a buffer to give it indexing semantics.
 class BufferTypeStorage;
 class BufferType : public Type::TypeBase<BufferType, Type, BufferTypeStorage> {
 public:
@@ -58,6 +60,14 @@ public:
 };
 
 /// A RangeType represents a minimal range abstraction (min, max, step).
+/// It is constructed by calling the linalg.range op with three values index of
+/// index type:
+///
+/// ```{.mlir}
+///    func @foo(%arg0 : index, %arg1 : index, %arg2 : index) {
+///      %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
+///    }
+/// ```
 class RangeType : public Type::TypeBase<RangeType, Type> {
 public:
   // Used for generic hooks in TypeBase.
@@ -74,13 +84,13 @@ public:
 /// A ViewType represents a multi-dimensional range abstraction on top of an
 /// underlying storage type. It is parameterizable by the underlying element
 /// type and the rank of the view.
-/// A new value of ViewType is constructed from a buffer with a base_view op and
+/// A new value of ViewType is constructed from a buffer with a view op and
 /// passing it ranges:
 ///
 /// ```{.mlir}
 ///    %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
 ///    %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
-///    %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
+///    %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
 /// ```
 class ViewTypeStorage;
 class ViewType
index 47aae67..2db7696 100644 (file)
 using namespace mlir;
 
 //////////////////////////////////////////////////////////////////////////////
-// BaseViewOp
-//////////////////////////////////////////////////////////////////////////////
-void mlir::BaseViewOp::build(Builder *b, OperationState *result, Value *buffer,
-                             ArrayRef<Value *> indexings) {
-  BufferType bufferType = buffer->getType().cast<BufferType>();
-  result->addOperands({buffer});
-  result->addOperands(indexings);
-  assert(
-      std::none_of(indexings.begin(), indexings.end(),
-                   [](Value *v) { return !v->getType().isa<RangeType>(); }) &&
-      "linalg.base_view takes only arguments of type linalg.range");
-
-  Type elementType = bufferType.getElementType();
-  result->addTypes(
-      {ViewType::get(b->getContext(), elementType, indexings.size())});
-}
-
-LogicalResult mlir::BaseViewOp::verify() {
-  if (llvm::empty(getOperands()))
-    return emitOpError(
-        "requires at least a buffer operand followed by indexings");
-  auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
-  if (!bufferType)
-    return emitOpError("first operand must be of BufferType");
-  unsigned index = 0;
-  for (auto indexing : getIndexings()) {
-    if (!indexing->getType().isa<RangeType>()) {
-      return emitOpError(Twine(index) + "^th index must be of range type");
-    }
-    ++index;
-  }
-  if (getViewType().getRank() != index)
-    return emitOpError(
-        "the rank of the base view must be the number of its indexings");
-  return success();
-}
-
-bool mlir::BaseViewOp::parse(OpAsmParser *parser, OperationState *result) {
-  OpAsmParser::OperandType bufferInfo;
-  SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
-  Type type;
-  if (parser->parseOperand(bufferInfo) ||
-      parser->parseOperandList(indexingsInfo, -1,
-                               OpAsmParser::Delimiter::Square) ||
-      parser->parseOptionalAttributeDict(result->attributes) ||
-      parser->parseColonType(type))
-    return true;
-
-  ViewType viewType = type.dyn_cast<ViewType>();
-  if (!viewType)
-    return parser->emitError(parser->getNameLoc(), "view type expected");
-  if (viewType.getRank() != indexingsInfo.size())
-    return parser->emitError(parser->getNameLoc(),
-                             "expected" + Twine(viewType.getRank()) +
-                                 " range indexings");
-  return parser->resolveOperand(
-             bufferInfo,
-             BufferType::get(type.getContext(), viewType.getElementType()),
-             result->operands) ||
-         (!indexingsInfo.empty() &&
-          parser->resolveOperands(indexingsInfo,
-                                  RangeType::get(type.getContext()),
-                                  result->operands)) ||
-         parser->addTypeToList(viewType, result->types);
-}
-
-// A BaseViewOp prints as:
-//
-// ```{.mlir}
-//   linalg.base_view %0[%1, %2] : !linalg.view<?x?xf32>
-// ```
-//
-// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
-// holding a range.
-void mlir::BaseViewOp::print(OpAsmPrinter *p) {
-  *p << getOperationName() << " " << *getSupportingBuffer() << "[";
-  interleave(
-      getIndexings().begin(), getIndexings().end(),
-      [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
-  *p << "] : " << getType();
-}
-
-//////////////////////////////////////////////////////////////////////////////
 // BufferAllocOp
 //////////////////////////////////////////////////////////////////////////////
 void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type,
@@ -122,9 +39,8 @@ void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type,
 }
 
 mlir::LogicalResult mlir::BufferAllocOp::verify() {
-  if (!size() || !size()->getType().isa<IntegerType>() ||
-      !size()->getType().cast<IntegerType>().isInteger(64))
-    return emitOpError("first operand should be of type i64");
+  if (!size() || !size()->getType().isa<IndexType>())
+    return emitOpError("first operand should be of type index");
   if (!VectorType::isValidElementType(getElementType()) &&
       !getElementType().isa<VectorType>())
     return emitOpError("unsupported buffer element type");
@@ -143,14 +59,14 @@ void mlir::BufferAllocOp::print(OpAsmPrinter *p) {
 bool mlir::BufferAllocOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType sizeInfo;
   BufferType bufferType;
-  auto int64Ty = parser->getBuilder().getIntegerType(64);
+  auto indexTy = parser->getBuilder().getIndexType();
   if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
     return true;
   if (bufferType.getElementType() != parser->getBuilder().getF32Type())
     return parser->emitError(
         parser->getNameLoc(),
         "Only buffer<f32> supported until mlir::Parser pieces are exposed");
-  return parser->resolveOperands(sizeInfo, int64Ty, result->operands) ||
+  return parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
          parser->addTypeToList(bufferType, result->types);
 }
 
@@ -183,7 +99,6 @@ bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) {
   return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
          parser->resolveOperands(sizeInfo, bufferType, result->operands);
 }
-
 //////////////////////////////////////////////////////////////////////////////
 // RangeOp
 //////////////////////////////////////////////////////////////////////////////
@@ -224,3 +139,218 @@ bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
          parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
          parser->addTypeToList(type, result->types);
 }
+
+//////////////////////////////////////////////////////////////////////////////
+// SliceOp
+//////////////////////////////////////////////////////////////////////////////
+void mlir::SliceOp::build(Builder *b, OperationState *result, Value *base,
+                          ArrayRef<Value *> indexings) {
+  result->addOperands({base});
+  result->addOperands(indexings);
+
+  ViewType viewType = base->getType().cast<ViewType>();
+  unsigned rank = viewType.getRank();
+  for (auto *i : indexings)
+    if (!i->getType().isa<RangeType>())
+      rank--;
+  Type elementType = viewType.getElementType();
+  result->addTypes(
+      {ViewType::get(b->getContext(), elementType, indexings.size())});
+}
+
+LogicalResult mlir::SliceOp::verify() {
+  if (llvm::empty(getOperands()))
+    return emitOpError(
+        "requires at least a view operand followed by 'rank' indices");
+  if (!getOperand(0)->getDefiningOp()->isa<ViewOp>())
+    return emitOpError(
+        "requires at least a view operand followed by 'rank' indices");
+
+  auto viewOp = getOperand(0)->getDefiningOp()->dyn_cast<ViewOp>();
+  if (!viewOp)
+    return emitOpError("first operand must come from a ViewOp");
+  unsigned rank = getBaseViewRank();
+  if (llvm::size(getIndexings()) != rank) {
+    return emitOpError("requires at least a view operand followed by " +
+                       Twine(rank) + " indexings");
+  }
+  unsigned index = 0;
+  for (auto indexing : getIndexings()) {
+    if (!indexing->getType().isa<RangeType>() &&
+        !indexing->getType().isa<IndexType>()) {
+      return emitOpError(Twine(index) +
+                         "^th index must be of range or index type");
+    }
+    if (indexing->getType().isa<IndexType>())
+      --rank;
+    ++index;
+  }
+  if (getRank() != rank) {
+    return emitOpError("the rank of the view must be the number of its range "
+                       "indices: " +
+                       Twine(rank));
+  }
+  return success();
+}
+
+bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType baseInfo;
+  SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
+  SmallVector<Type, 8> types;
+  if (parser->parseOperand(baseInfo) ||
+      parser->parseOperandList(indexingsInfo, -1,
+                               OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonTypeList(types))
+    return true;
+
+  if (types.size() != 2 + indexingsInfo.size())
+    return parser->emitError(parser->getNameLoc(),
+                             "unexpected number of types ");
+  ViewType baseViewType = types[0].dyn_cast<ViewType>();
+  if (!baseViewType)
+    return parser->emitError(parser->getNameLoc(),
+                             "view type expected for first type");
+  if (indexingsInfo.size() != baseViewType.getRank())
+    return parser->emitError(parser->getNameLoc(),
+                             "expected " + Twine(baseViewType.getRank()) +
+                                 " indexings");
+  ViewType viewType = types.back().dyn_cast<ViewType>();
+  if (!viewType)
+    return parser->emitError(parser->getNameLoc(), "view type expected");
+
+  ArrayRef<Type> indexingTypes =
+      ArrayRef<Type>(types).drop_front(1).drop_back(1);
+  if (indexingTypes.size() != baseViewType.getRank())
+    return parser->emitError(parser->getNameLoc(),
+                             "expected " + Twine(baseViewType.getRank()) +
+                                 " indexing types");
+  return parser->resolveOperand(baseInfo, baseViewType, result->operands) ||
+         (!indexingsInfo.empty() &&
+          parser->resolveOperands(indexingsInfo, indexingTypes,
+                                  indexingsInfo.front().location,
+                                  result->operands)) ||
+         parser->addTypeToList(viewType, result->types);
+}
+
+// A SliceOp prints as:
+//
+// ```{.mlir}
+//   linalg.slice %0[%1, %2] :
+//     !linalg.view<?x?xf32>, [indexing-types], !linalg.view<?x?xf32>
+// ```
+//
+// Where %0 is an ssa-value holding a view created from a buffer, %1 and %2 are
+// ssa-value each holding a range.
+void mlir::SliceOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *getBaseView() << "[";
+  interleave(
+      getIndexings().begin(), getIndexings().end(),
+      [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
+  *p << "] : " << getBaseViewType();
+  for (auto indexing : getIndexings()) {
+    *p << ", " << indexing->getType();
+  }
+  *p << ", " << getType();
+}
+
+ViewOp mlir::SliceOp::getBaseViewOp() {
+  return getOperand(0)->getDefiningOp()->cast<ViewOp>();
+}
+
+ViewType mlir::SliceOp::getBaseViewType() {
+  return getBaseViewOp().getType().cast<ViewType>();
+}
+
+SmallVector<Value *, 8> mlir::SliceOp::getRanges() {
+  llvm::SmallVector<Value *, 8> res;
+  for (auto *operand : getIndexings()) {
+    if (!operand->getType().isa<IndexType>()) {
+      res.push_back(operand);
+    }
+  }
+  return res;
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// ViewOp
+//////////////////////////////////////////////////////////////////////////////
+void mlir::ViewOp::build(Builder *b, OperationState *result, Value *buffer,
+                         ArrayRef<Value *> indexings) {
+  BufferType bufferType = buffer->getType().cast<BufferType>();
+  result->addOperands({buffer});
+  result->addOperands(indexings);
+  assert(
+      std::none_of(indexings.begin(), indexings.end(),
+                   [](Value *v) { return !v->getType().isa<RangeType>(); }) &&
+      "linalg.view takes only arguments of type linalg.range");
+
+  Type elementType = bufferType.getElementType();
+  result->addTypes(
+      {ViewType::get(b->getContext(), elementType, indexings.size())});
+}
+
+LogicalResult mlir::ViewOp::verify() {
+  if (llvm::empty(getOperands()))
+    return emitOpError(
+        "requires at least a buffer operand followed by indexings");
+  auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
+  if (!bufferType)
+    return emitOpError("first operand must be of BufferType");
+  unsigned index = 0;
+  for (auto indexing : getIndexings()) {
+    if (!indexing->getType().isa<RangeType>()) {
+      return emitOpError(Twine(index) + "^th index must be of range type");
+    }
+    ++index;
+  }
+  if (getViewType().getRank() != index)
+    return emitOpError(
+        "the rank of the view must be the number of its indexings");
+  return success();
+}
+
+bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType bufferInfo;
+  SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
+  Type type;
+  if (parser->parseOperand(bufferInfo) ||
+      parser->parseOperandList(indexingsInfo, -1,
+                               OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type))
+    return true;
+
+  ViewType viewType = type.dyn_cast<ViewType>();
+  if (!viewType)
+    return parser->emitError(parser->getNameLoc(), "view type expected");
+  if (viewType.getRank() != indexingsInfo.size())
+    return parser->emitError(parser->getNameLoc(),
+                             "expected" + Twine(viewType.getRank()) +
+                                 " range indexings");
+  return parser->resolveOperand(
+             bufferInfo,
+             BufferType::get(type.getContext(), viewType.getElementType()),
+             result->operands) ||
+         (!indexingsInfo.empty() &&
+          parser->resolveOperands(indexingsInfo,
+                                  RangeType::get(type.getContext()),
+                                  result->operands)) ||
+         parser->addTypeToList(viewType, result->types);
+}
+
+// A ViewOp prints as:
+//
+// ```{.mlir}
+//   linalg.view %0[%1, %2] : !linalg.view<?x?xf32>
+// ```
+//
+// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
+// holding a range.
+void mlir::ViewOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *getSupportingBuffer() << "[";
+  interleave(
+      getIndexings().begin(), getIndexings().end(),
+      [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
+  *p << "] : " << getType();
+}
index e164b0b..fa08f75 100644 (file)
@@ -30,7 +30,7 @@ using namespace mlir;
 mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
     : Dialect("linalg", context) {
   addTypes<BufferType, RangeType, ViewType>();
-  addOperations<BaseViewOp, BufferAllocOp, BufferDeallocOp, RangeOp>();
+  addOperations<BufferAllocOp, BufferDeallocOp, RangeOp, SliceOp, ViewOp>();
 }
 
 struct mlir::BufferTypeStorage : public mlir::TypeStorage {
index eab2aa3..4327e5d 100644 (file)
@@ -7,28 +7,36 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
 // CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) {
 //  CHECK-NEXT:  %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
 
-func @buffer(%arg0: i64, %arg1: i64) {
-  %0 = muli %arg0, %arg0 : i64
+func @buffer(%arg0: index, %arg1: index) {
+  %0 = muli %arg0, %arg0 : index
   %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
   linalg.buffer_dealloc %1 : !linalg.buffer<f32>
   return
 }
-// CHECK-LABEL: func @buffer(%arg0: i64, %arg1: i64) {
-//  CHECK-NEXT:  %0 = muli %arg0, %arg0 : i64
+// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
+//  CHECK-NEXT:  %0 = muli %arg0, %arg0 : index
 //  CHECK-NEXT:  %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
 //  CHECK-NEXT:  linalg.buffer_dealloc %1 : !linalg.buffer<f32>
 
-func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
-  %0 = muli %arg0, %arg0 : i64
+func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
+  %0 = muli %arg0, %arg0 : index
   %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
   %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
-  %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
+  %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
+  %4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+  %5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
+  %6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
+  %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
   linalg.buffer_dealloc %1 : !linalg.buffer<f32>
   return
 }
-// CHECK-LABEL: func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
-//  CHECK-NEXT:  %0 = muli %arg0, %arg0 : i64
+// CHECK-LABEL: func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
+//  CHECK-NEXT:  %0 = muli %arg0, %arg0 : index
 //  CHECK-NEXT:  %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
 //  CHECK-NEXT:  %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
-//  CHECK-NEXT:  %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
+//  CHECK-NEXT:  %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
+//  CHECK-NEXT:  %4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  CHECK-NEXT:  %5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
+//  CHECK-NEXT:  %6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
+//  CHECK-NEXT:  %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
 //  CHECK-NEXT:  linalg.buffer_dealloc %1 : !linalg.buffer<f32>
\ No newline at end of file