[mlir][Linalg] Add a linalg.tensor_reshape to operate on tensors
authorNicolas Vasilache <ntv@google.com>
Mon, 6 Apr 2020 15:18:28 +0000 (11:18 -0400)
committerNicolas Vasilache <ntv@google.com>
Mon, 6 Apr 2020 15:19:17 +0000 (11:19 -0400)
Summary:
This revision adds a tensor_reshape operation that operates on tensors.
In the tensor world the constraints are less stringent and we can allow more
arbitrary dynamic reshapes, as long as they are contractions.

The expansion of a dynamic dimension into multiple dynamic dimensions is under-specified and is punted on for now.

Differential Revision: https://reviews.llvm.org/D77360

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir

index bf0e1dd4877071c73c223b5be07a36562ea969b6..3e667d98f82297785b97c3be9c6045d827a6c8d6 100644 (file)
@@ -60,9 +60,31 @@ def Linalg_RangeOp :
   let assemblyFormat = "$min `:` $max `:` $step attr-dict `:` type(results)";
 }
 
-def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
-    Arguments<(ins AnyStridedMemRef:$view, AffineMapArrayAttr:$reassociation)>,
-    Results<(outs AnyStridedMemRef)> {
+class Linalg_ReshapeLikeOp<string mnemonic> :
+    Linalg_Op<mnemonic, [NoSideEffect]> {
+  let builders = [
+    // Builder for a contracting reshape whose result type is computed from
+    // `src` and `reassociation`.
+    OpBuilder<"Builder *b, OperationState &result, Value src, "
+    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+    "ArrayRef<NamedAttribute> attrs = {}">,
+    // Builder for a reshape whose result type is passed explicitly. This may be
+    // either a contracting or expanding reshape.
+    OpBuilder<"Builder *b, OperationState &result, Type resultType, Value src,"
+    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+    "ArrayRef<NamedAttribute> attrs = {}">];
+
+  code commonExtraClassDeclaration = [{
+    static StringRef getReassociationAttrName() { return "reassociation"; }
+  }];
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type(results)
+  }];
+}
+
+def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape">,
+    Arguments<(ins AnyStridedMemRef:$src, AffineMapArrayAttr:$reassociation)>,
+    Results<(outs AnyStridedMemRef:$result)> {
   let summary = "linalg.reshape produces a new view into the operand view";
   let description = [{
     The `linalg.reshape` op produces a new view whose sizes are a reassociation
@@ -102,27 +124,55 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
       memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
     ```
   }];
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
+    MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
+  }];
+  let hasFolder = 1;
+}
 
-  let builders = [
-    // Builder for a contracting reshape whose result type is computed from
-    // `view` and `reassociation`.
-    OpBuilder<"Builder *b, OperationState &result, Value view, "
-    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
-    "ArrayRef<NamedAttribute> attrs = {}">,
-    // Builder for a reshape whose result type is passed explicitly. This may be
-    // either a contracting or expanding reshape.
-    OpBuilder<"Builder *b, OperationState &result, Type resultType, Value view,"
-    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
-    "ArrayRef<NamedAttribute> attrs = {}">];
+def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
+    Arguments<(ins AnyTensor:$src,
+                   AffineMapArrayAttr:$reassociation)>,
+    Results<(outs AnyTensor:$result)> {
+  let summary = "linalg.tensor_reshape produces a new reshaped tensor.";
+  let description = [{
+    The `linalg.reshape` op produces a new tensor whose sizes are a
+    reassociation of the original `src`.
 
-  let extraClassDeclaration = [{
-    static StringRef getReassociationAttrName() { return "reassociation"; }
-    MemRefType getViewType() { return view().getType().cast<MemRefType>(); }
+    A reassociation is defined as a continuous grouping of dimensions and is
+    represented with an affine map array attribute. In the future,
+    non-continuous groupings may be allowed (i.e. permutations, reindexings
+    etc).
+
+    A reshape may either collapse or expand dimensions, depending on the
+    relationship between source and target tensor ranks. The verification rule
+    is that the reassociation maps are applied to the tensor with the larger
+    rank to obtain the tensor with the smaller rank. In the case of a dimension
+    expansion, the reassociation maps can be interpreted as inverse maps.
+
+    Examples:
+
+    ```mlir
+    // Dimension collapse (i, j) -> i' and k -> k'
+    %b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
+      tensor<?x?x?xf32> into tensor<?x?xf32>
+    ```
+
+    ```mlir
+    // Dimension expansion i -> (i', j') and (k) -> (k')
+    %b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
+      tensor<?x?xf32> into tensor<?x?x?xf32>
+    ```
   }];
-  let assemblyFormat = [{
-    $view $reassociation attr-dict `:` type($view) `into` type(results)
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    RankedTensorType getSrcType() {
+      return src().getType().cast<RankedTensorType>();
+    }
+    RankedTensorType getResultType() {
+      return result().getType().cast<RankedTensorType>();
+    }
   }];
-  let hasFolder = 1;
 }
 
 def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
index cb66ae9f501325c1b01db7c6fd24e78c6236d846..07c8111941e4b9f7293ad430412e0769bafaa83b 100644 (file)
@@ -164,7 +164,7 @@ public:
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto reshapeOp = cast<ReshapeOp>(op);
-    MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
+    MemRefType dstType = reshapeOp.getResultType();
 
     if (!dstType.hasStaticShape())
       return failure();
@@ -179,7 +179,7 @@ public:
 
     edsc::ScopedContext context(rewriter, op->getLoc());
     ReshapeOpOperandAdaptor adaptor(operands);
-    BaseViewConversionHelper baseDesc(adaptor.view());
+    BaseViewConversionHelper baseDesc(adaptor.src());
     BaseViewConversionHelper desc(typeConverter.convertType(dstType));
     desc.setAllocatedPtr(baseDesc.allocatedPtr());
     desc.setAlignedPtr(baseDesc.alignedPtr());
index 24dcf7370943149370156690d669506cf1f93d7b..3d81cce0e883617009c8f52f5e0d84bb11898e0e 100644 (file)
@@ -531,30 +531,33 @@ getSymbolLessAffineMaps(ArrayRef<ArrayRef<AffineExpr>> reassociation) {
 }
 
 void mlir::linalg::ReshapeOp::build(
-    Builder *b, OperationState &result, Value view,
+    Builder *b, OperationState &result, Value src,
     ArrayRef<ArrayRef<AffineExpr>> reassociation,
     ArrayRef<NamedAttribute> attrs) {
   auto maps = getSymbolLessAffineMaps(reassociation);
-  auto memRefType = view.getType().cast<MemRefType>();
+  auto memRefType = src.getType().cast<MemRefType>();
   auto resultType = computeReshapeCollapsedType(memRefType, maps);
-  build(b, result, resultType, view, attrs);
+  build(b, result, resultType, src, attrs);
   result.addAttribute(ReshapeOp::getReassociationAttrName(),
                       b->getAffineMapArrayAttr(maps));
 }
 
 void mlir::linalg::ReshapeOp::build(
-    Builder *b, OperationState &result, Type resultType, Value view,
+    Builder *b, OperationState &result, Type resultType, Value src,
     ArrayRef<ArrayRef<AffineExpr>> reassociation,
     ArrayRef<NamedAttribute> attrs) {
   auto maps = getSymbolLessAffineMaps(reassociation);
-  build(b, result, resultType, view, attrs);
+  build(b, result, resultType, src, attrs);
   result.addAttribute(ReshapeOp::getReassociationAttrName(),
                       b->getAffineMapArrayAttr(maps));
 }
 
-static LogicalResult verify(ReshapeOp op) {
-  MemRefType expandedType = op.getViewType();
-  MemRefType collapsedType = op.getResult().getType().cast<MemRefType>();
+// Common verifier for reshape-like types. Fills `expandedType` and
+// `collapsedType` with the proper `src` or `result` type.
+template <typename Op, typename T>
+LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType, T &collapsedType) {
+  expandedType = op.getSrcType();
+  collapsedType = op.getResultType();
   unsigned expandedRank = expandedType.getRank();
   unsigned collapsedRank = collapsedType.getRank();
   bool isCollapse = expandedRank > collapsedRank;
@@ -568,7 +571,7 @@ static LogicalResult verify(ReshapeOp op) {
     return op.emitOpError("expected to collapse or expand dims");
 
   if (collapsedRank != op.reassociation().size())
-    return op.emitOpError("expected rank of the collapsed view(")
+    return op.emitOpError("expected rank of the collapsed type(")
            << collapsedRank << ") to be the number of reassociation maps("
            << op.reassociation().size() << ")";
   auto maps = getAffineMaps(op.reassociation());
@@ -581,6 +584,14 @@ static LogicalResult verify(ReshapeOp op) {
   if (!isReassociationValid(maps, &invalidIdx))
     return op.emitOpError("expected reassociation map #")
            << invalidIdx << " to be valid and contiguous";
+  return success();
+}
+
+static LogicalResult verify(ReshapeOp op) {
+  MemRefType expandedType, collapsedType;
+  if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
+    return failure();
+  auto maps = getAffineMaps(op.reassociation());
   MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
   if (collapsedType != expectedType)
     return op.emitOpError("expected collapsed type to be ")
@@ -588,6 +599,75 @@ static LogicalResult verify(ReshapeOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TensorReshapeOp
+//===----------------------------------------------------------------------===//
+
+/// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
+static RankedTensorType
+computeTensorReshapeCollapsedType(RankedTensorType type,
+                                  ArrayRef<AffineMap> reassociation) {
+  auto shape = type.getShape();
+  SmallVector<int64_t, 4> newShape;
+  newShape.reserve(reassociation.size());
+
+  // Use the fact that reassociation is valid to simplify the logic: only use
+  // each map's rank.
+  assert(isReassociationValid(reassociation) && "invalid reassociation");
+  unsigned currentDim = 0;
+  for (AffineMap m : reassociation) {
+    unsigned dim = m.getNumResults();
+    auto band = shape.drop_front(currentDim).take_front(dim);
+    int64_t size = 1;
+    if (llvm::is_contained(band, ShapedType::kDynamicSize))
+      size = ShapedType::kDynamicSize;
+    else
+      for (unsigned d = 0; d < dim; ++d)
+        size *= shape[currentDim + d];
+    newShape.push_back(size);
+    currentDim += dim;
+  }
+
+  return RankedTensorType::get(newShape, type.getElementType());
+}
+
+void mlir::linalg::TensorReshapeOp::build(
+    Builder *b, OperationState &result, Value src,
+    ArrayRef<ArrayRef<AffineExpr>> reassociation,
+    ArrayRef<NamedAttribute> attrs) {
+  auto maps = getSymbolLessAffineMaps(reassociation);
+  auto resultType = computeTensorReshapeCollapsedType(
+      src.getType().cast<RankedTensorType>(), maps);
+  build(b, result, resultType, src, attrs);
+  result.addAttribute(TensorReshapeOp::getReassociationAttrName(),
+                      b->getAffineMapArrayAttr(maps));
+}
+
+void mlir::linalg::TensorReshapeOp::build(
+    Builder *b, OperationState &result, Type resultType, Value src,
+    ArrayRef<ArrayRef<AffineExpr>> reassociation,
+    ArrayRef<NamedAttribute> attrs) {
+  auto maps = getSymbolLessAffineMaps(reassociation);
+  build(b, result, resultType, src, attrs);
+  result.addAttribute(TensorReshapeOp::getReassociationAttrName(),
+                      b->getAffineMapArrayAttr(maps));
+}
+
+static LogicalResult verify(TensorReshapeOp op) {
+  RankedTensorType expandedType, collapsedType;
+  if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
+    return failure();
+  auto maps = getAffineMaps(op.reassociation());
+  // TODO(ntv): expanding a ? with a non-constant is under-specified. Error
+  // out.
+  RankedTensorType expectedType =
+      computeTensorReshapeCollapsedType(expandedType, maps);
+  if (collapsedType != expectedType)
+    return op.emitOpError("expected collapsed type to be ")
+           << expectedType << ", but got " << collapsedType;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SliceOp
 //===----------------------------------------------------------------------===//
index 7a8291504ae67e19328eac84627f6eb9faacfed5..0041f97d7eea8d2aeff6e6126c46af2493579fc2 100644 (file)
@@ -485,7 +485,7 @@ func @reshape(%arg0: memref<?xf32>) {
 // -----
 
 func @reshape(%arg0: memref<?x?x?xf32>) {
-  // expected-error @+1 {{expected rank of the collapsed view(2) to be the number of reassociation maps(1)}}
+  // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}}
   %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>] :
     memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
 }
index 05d35f8f43e4d357cab00e3fbbeb5b1c52fcd940..c28c671d28851593c31e3af7521dee826ab6c83b 100644 (file)
@@ -505,8 +505,8 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?,
 // CHECK-DAG: #[[reshape5D2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)>
 // CHECK-DAG: #[[reshape5D34:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
 
-func @reshape_static(%arg0: memref<3x4x5xf32>) {
-  // Reshapes that collapse and expand back a contiguous tensor.
+func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>, %arg2: tensor<3x?x5xf32>) {
+  // Reshapes that collapse and expand back a contiguous buffer.
   %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>,
                              affine_map<(i, j, k) -> (k)>] :
     memref<3x4x5xf32> into memref<12x5xf32>
@@ -523,7 +523,7 @@ func @reshape_static(%arg0: memref<3x4x5xf32>) {
     memref<3x4x5xf32> into memref<60xf32>
   %r2 = linalg.reshape %2 [affine_map<(i, j, k) -> (i, j, k)>] :
     memref<60xf32> into memref<3x4x5xf32>
-  // Reshapes that expand and collapse back a contiguous tensor with some 1's.
+  // Reshapes that expand and collapse back a contiguous buffer with some 1's.
   %3 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
                              affine_map<(i, j, k, l, m) -> (k)>,
                              affine_map<(i, j, k, l, m) -> (l, m)>] :
@@ -532,6 +532,23 @@ func @reshape_static(%arg0: memref<3x4x5xf32>) {
                            affine_map<(i, j, k, l, m) -> (k)>,
                            affine_map<(i, j, k, l, m) -> (l, m)>] :
     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+  // Reshapes on tensors.
+  %t0 = linalg.tensor_reshape %arg1 [affine_map<(i, j, k, l, m) -> (i, j)>,
+                                     affine_map<(i, j, k, l, m) -> (k)>,
+                                     affine_map<(i, j, k, l, m) -> (l, m)>] :
+    tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+  %rt0 = linalg.tensor_reshape %t0 [affine_map<(i, j, k, l, m) -> (i, j)>,
+                                   affine_map<(i, j, k, l, m) -> (k)>,
+                                   affine_map<(i, j, k, l, m) -> (l, m)>] :
+    tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+  %t1 = linalg.tensor_reshape %arg2 [affine_map<(i, j, k, l, m) -> (i, j)>,
+                                     affine_map<(i, j, k, l, m) -> (k)>,
+                                     affine_map<(i, j, k, l, m) -> (l, m)>] :
+    tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+  %rt1 = linalg.tensor_reshape %t1 [affine_map<(i, j, k, l, m) -> (i)>,
+                                    affine_map<(i, j, k, l, m) -> (j, k)>,
+                                    affine_map<(i, j, k, l, m) -> (l, m)>] :
+    tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
   return
 }
 // CHECK-LABEL: func @reshape_static
@@ -551,6 +568,11 @@ func @reshape_static(%arg0: memref<3x4x5xf32>) {
 //  CHECK-SAME:     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
 //       CHECK:   linalg.reshape {{.*}} [#[[reshape5D01]], #[[reshape5D2]], #[[reshape5D34]]]
 //  CHECK-SAME:     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+//
+//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
 
 // -----