Allow linalg.view to change the underlying elemental type.
authorNicolas Vasilache <ntv@google.com>
Fri, 9 Aug 2019 14:28:51 +0000 (07:28 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 9 Aug 2019 14:29:21 +0000 (07:29 -0700)
This CL adds the ability for linalg.view to act as a bitcast operation.
This will be used when promoting views into faster memory and casting to vector types.

In the process, linalg.view is moved to ODS.

PiperOrigin-RevId: 262556246

mlir/include/mlir/Linalg/IR/LinalgOps.h
mlir/include/mlir/Linalg/IR/LinalgOps.td
mlir/lib/Linalg/Analysis/DependenceAnalysis.cpp
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Linalg/IR/LinalgTypes.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/test/Linalg/roundtrip.mlir

index a1bff98..4085d06 100644 (file)
@@ -186,47 +186,6 @@ public:
   }
 };
 
-/// The "linalg.view" op produces a linalg.view which is a multi-dimensional
-/// range abstraction on top of an underlying linalg.buffer. This gives an
-/// indexing structure to an otherwise non-indexable linalg.buffer.
-///
-/// A "linalg.view" takes a buffer and a variadic number of ranges and produces
-/// a `view` of the same elemental type as the buffer and of rank the number of
-/// ranges:
-///
-/// ```{.mlir}
-///    %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
-///    %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
-///    %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
-/// ```
-class ViewOp : public Op<ViewOp, OpTrait::VariadicOperands, OpTrait::OneResult,
-                         OpTrait::HasNoSideEffect> {
-  enum { FirstIndexingOperand = 1 };
-
-public:
-  using Op::Op;
-
-  // Hooks to customize the behavior of this op.
-  static llvm::StringRef getOperationName() { return "linalg.view"; }
-  static void build(Builder *b, OperationState *result, Value *buffer,
-                    llvm::ArrayRef<Value *> indexings);
-  LogicalResult verify();
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-
-  // Op-specific functionality.
-  unsigned getRank() { return getViewType().getRank(); }
-  Type getElementType() { return getViewType().getElementType(); }
-  ViewType getViewType() { return getType().cast<ViewType>(); }
-  Value *getSupportingBuffer() { return getOperand(0); }
-  // Get the underlying indexing at a given rank.
-  Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); }
-  // Get all the indexings in this view.
-  Operation::operand_range getIndexings() {
-    return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
-  }
-};
-
 #define GET_OP_CLASSES
 #include "mlir/Linalg/IR/LinalgOps.h.inc"
 
index bbbbfad..f7a07fc 100644 (file)
@@ -215,6 +215,51 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
   }];
 }
 
+def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
+    Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
+    Results<(outs View)> {
+  let summary = "view operation";
+  let description = [{
+    The "linalg.view" op produces a linalg.view which is a multi-dimensional
+    range abstraction on top of an underlying linalg.buffer. This gives an
+    indexing structure to an otherwise non-indexable linalg.buffer.
+
+    A "linalg.view" takes a buffer and a variadic number of ranges and produces
+    a `view` of rank the number of ranges. The elemental type may not match the
+    buffer element type:
+
+    Examples:
+    ```
+       %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+       %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
+       %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xvector<4xf32>>
+    ```
+  }];
+
+  let builders = [OpBuilder<
+    "Builder *b, OperationState *result, Value *buffer, "
+    "ArrayRef<Value *> ranges, Type resultType = Type(), "
+    "ArrayRef<NamedAttribute> attrs = {}">];
+
+  let verifier = [{
+    if (getViewType().getRank() != llvm::size(ranges()))
+      return emitOpError("the view rank must be the number of its ranges");
+    return success();
+  }];
+
+  let extraClassDeclaration = [{
+    enum { FirstIndexingOperand = 1 };
+    unsigned getRank() { return getViewType().getRank(); }
+    Type getElementType() { return getViewType().getElementType(); }
+    ViewType getViewType() { return getType().cast<ViewType>(); }
+    /// Get the underlying indexing at a given rank.
+    Value *getRange(unsigned rank) {
+      assert(rank < getRank() && "rank overflow");
+      return *(ranges().begin() + rank);
+    }
+  }];
+}
+
 def YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
     Arguments<(ins Variadic<AnyType>:$values)> {
   let summary = "Linalg yield operation";
index f44bea3..5a272a4 100644 (file)
@@ -53,7 +53,7 @@ Value *Aliases::find(Value *v) {
       return it.first->second;
     }
     if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
-      auto it = aliases.insert(std::make_pair(v, view.getSupportingBuffer()));
+      auto it = aliases.insert(std::make_pair(v, view.buffer()));
       return it.first->second;
     }
     if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {
index 60820ae..6549508 100644 (file)
@@ -67,10 +67,10 @@ SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp,
   Value *min, *max, *step;
   if (view) {
     // Cannot traverse block arguments, fail.
-    if (isa<BlockArgument>(view.getIndexing(dim)))
+    if (isa<BlockArgument>(view.getRange(dim)))
       return matchFailure();
     // Record min, max, step for further processing.
-    auto range = cast<RangeOp>(view.getIndexing(dim)->getDefiningOp());
+    auto range = cast<RangeOp>(view.getRange(dim)->getDefiningOp());
     std::tie(min, max, step) =
         std::make_tuple(range.min(), range.max(), range.step());
   } else if (subView) {
@@ -414,97 +414,15 @@ LogicalResult mlir::linalg::StoreOp::verify() {
   return success();
 }
 
-//////////////////////////////////////////////////////////////////////////////
-// ViewOp
-//////////////////////////////////////////////////////////////////////////////
-void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
-                                 Value *buffer, ArrayRef<Value *> indexings) {
-  BufferType bufferType = buffer->getType().cast<BufferType>();
-  result->addOperands({buffer});
-  result->addOperands(indexings);
-  assert(
-      std::none_of(indexings.begin(), indexings.end(),
-                   [](Value *v) { return !v->getType().isa<RangeType>(); }) &&
-      "linalg.view takes only arguments of type linalg.range");
-
-  Type elementType = bufferType.getElementType();
-  result->addTypes(
-      {ViewType::get(b->getContext(), elementType, indexings.size())});
-}
-
-LogicalResult mlir::linalg::ViewOp::verify() {
-  if (llvm::empty(getOperands()))
-    return emitOpError(
-        "requires at least a buffer operand followed by indexings");
-  auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
-  if (!bufferType)
-    return emitOpError("first operand must be of BufferType");
-  unsigned index = 0;
-  for (auto indexing : getIndexings()) {
-    if (!indexing->getType().isa<RangeType>()) {
-      return emitOpError() << index << "^th index must be of range type";
-    }
-    ++index;
-  }
-  if (getViewType().getRank() != index)
-    return emitOpError()
-           << "the rank of the view must be the number of its indexings";
-  return success();
-}
-
-ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
-                                        OperationState *result) {
-  OpAsmParser::OperandType bufferInfo;
-  SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
-  Type bType, type;
-  if (parser->parseOperand(bufferInfo) ||
-      parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
-      parser->parseOptionalAttributeDict(result->attributes) ||
-      parser->parseColon() || parser->parseType(bType) ||
-      parser->parseArrow() || parser->parseType(type)) {
-    return failure();
-  }
-
-  BufferType bufferType = bType.dyn_cast<BufferType>();
-  if (!bufferType) {
-    return parser->emitError(parser->getNameLoc(), "buffer type expected");
-  }
-
-  ViewType viewType = type.dyn_cast<ViewType>();
-  if (!viewType)
-    return parser->emitError(parser->getNameLoc(), "view type expected");
-  if (viewType.getRank() != indexingsInfo.size())
-    return parser->emitError(parser->getNameLoc(), "expected")
-           << viewType.getRank() << " range indexings";
-  return failure(
-      parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
-      (!indexingsInfo.empty() &&
-       parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()),
-                               result->operands)) ||
-      parser->addTypeToList(viewType, result->types));
-}
-
-// A ViewOp prints as:
-//
-// ```{.mlir}
-//   linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
-// ```
-//
-// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
-// holding a range.
-void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
-  *p << getOperationName() << " " << *getSupportingBuffer() << "[";
-  interleave(
-      getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; },
-      [&]() { *p << ", "; });
-  *p << "] : " << getSupportingBuffer()->getType() << " -> " << getType();
-}
-
 ///////////////////// Operations defined with Tablegen /////////////////////////
 // For such operations that do not correspond to library calls (i.e. defined in
 // LinalgOps.td), we define an overloaded `print` function and a
 // parse`className` function.
 
+//===----------------------------------------------------------------------===//
+// BufferAllocOp
+//===----------------------------------------------------------------------===//
+
 static void print(OpAsmPrinter *p, BufferAllocOp op) {
   *p << op.getOperationName() << " ";
   if (!llvm::empty(op.size()))
@@ -544,6 +462,10 @@ static LogicalResult verify(BufferAllocOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// BufferDeallocOp
+//===----------------------------------------------------------------------===//
+
 static void print(OpAsmPrinter *p, BufferDeallocOp op) {
   *p << op.getOperationName() << " " << *op.buffer();
   p->printOptionalAttrDict(op.getAttrs());
@@ -565,6 +487,10 @@ static void print(OpAsmPrinter *p, BufferSizeOp op) {
   *p << " : " << op.getOperand()->getType();
 }
 
+//===----------------------------------------------------------------------===//
+// BufferSizeOp
+//===----------------------------------------------------------------------===//
+
 static ParseResult parseBufferSizeOp(OpAsmParser *parser,
                                      OperationState *result) {
   OpAsmParser::OperandType op;
@@ -748,6 +674,66 @@ static LogicalResult verify(GenericOp op) {
 }
 
 //===----------------------------------------------------------------------===//
+// ViewOp
+//===----------------------------------------------------------------------===//
+void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
+                                 Value *buffer, ArrayRef<Value *> ranges,
+                                 Type resultType,
+                                 ArrayRef<NamedAttribute> attrs) {
+  if (!resultType) {
+    Type elementType = buffer->getType().cast<BufferType>().getElementType();
+    resultType = ViewType::get(b->getContext(), elementType, ranges.size());
+  }
+  build(b, result, resultType, buffer, ranges);
+  result->addAttributes(attrs);
+}
+
+static ParseResult parseViewOp(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType bufferInfo;
+  SmallVector<OpAsmParser::OperandType, 8> rangesInfo;
+  Type bType, vType;
+  if (parser->parseOperand(bufferInfo) ||
+      parser->parseOperandList(rangesInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColon() || parser->parseType(bType) ||
+      parser->parseArrow() || parser->parseType(vType)) {
+    return failure();
+  }
+
+  BufferType bufferType = bType.dyn_cast<BufferType>();
+  if (!bufferType) {
+    return parser->emitError(parser->getNameLoc(), "buffer type expected");
+  }
+
+  ViewType viewType = vType.dyn_cast<ViewType>();
+  if (!viewType)
+    return parser->emitError(parser->getNameLoc(), "view type expected");
+  if (viewType.getRank() != rangesInfo.size())
+    return parser->emitError(parser->getNameLoc(), "expected")
+           << viewType.getRank() << " range ranges";
+  return failure(
+      parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
+      (!rangesInfo.empty() &&
+       parser->resolveOperands(rangesInfo, RangeType::get(vType.getContext()),
+                               result->operands)) ||
+      parser->addTypeToList(viewType, result->types));
+}
+
+// A ViewOp prints as:
+//
+// ```{.mlir}
+//   linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
+// ```
+//
+// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
+// holding a range.
+static void print(OpAsmPrinter *p, ViewOp op) {
+  *p << op.getOperationName() << " " << *op.buffer() << "[";
+  interleaveComma(op.ranges(), *p, [&](Value *v) { *p << *v; });
+  *p << "] : " << op.buffer()->getType() << " -> " << op.getType();
+}
+
+//===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//
 
@@ -808,6 +794,10 @@ static void print(OpAsmPrinter *p, SubViewOp op) {
   *p << " : " << op.getViewType();
 }
 
+//===----------------------------------------------------------------------===//
+// SubViewOp
+//===----------------------------------------------------------------------===//
+
 static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType inputView, resultView;
   Type viewType;
index 61acbce..ca54c33 100644 (file)
@@ -35,7 +35,7 @@ using namespace mlir::linalg;
 mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
   addTypes<BufferType, RangeType, ViewType>();
-  addOperations<LoadOp, RangeOp, StoreOp, SliceOp, ViewOp>();
+  addOperations<LoadOp, RangeOp, StoreOp, SliceOp>();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Linalg/IR/LinalgOps.cpp.inc"
index f06a09d..6967a9d 100644 (file)
@@ -512,9 +512,9 @@ public:
     desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
 
     // Compute and insert view sizes (max - min along the range).
-    int numIndexings = llvm::size(viewOp.getIndexings());
+    int numRanges = llvm::size(viewOp.ranges());
     Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
-    for (int i = numIndexings - 1; i >= 0; --i) {
+    for (int i = numRanges - 1; i >= 0; --i) {
       // Update stride.
       Value *rangeDescriptor = operands[1 + i];
       Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
index 39213f2..ee91f4c 100644 (file)
@@ -39,6 +39,7 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
   %5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
   %6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
   %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
+  %8 = linalg.view %1[%2, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xvector<4x4xf32>>
   linalg.buffer_dealloc %1 : !linalg.buffer<?xf32>
   return
 }
@@ -51,6 +52,7 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
 //  CHECK-NEXT:  %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
 //  CHECK-NEXT:  %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
 //  CHECK-NEXT:  %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
+//  CHECK-NEXT:  %{{.*}} = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xvector<4x4xf32>>
 //  CHECK-NEXT:  linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xf32>
 
 func @ops(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<?xf32>, %arg3: !linalg.view<f32>) {