let assemblyFormat = "$min `:` $max `:` $step attr-dict `:` type(results)";
}
-def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
- Arguments<(ins AnyStridedMemRef:$view, AffineMapArrayAttr:$reassociation)>,
- Results<(outs AnyStridedMemRef)> {
+class Linalg_ReshapeLikeOp<string mnemonic> :
+ Linalg_Op<mnemonic, [NoSideEffect]> {
+ let builders = [
+ // Builder for a contracting reshape whose result type is computed from
+ // `src` and `reassociation`.
+ OpBuilder<"Builder *b, OperationState &result, Value src, "
+ "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+ "ArrayRef<NamedAttribute> attrs = {}">,
+ // Builder for a reshape whose result type is passed explicitly. This may be
+ // either a contracting or expanding reshape.
+ OpBuilder<"Builder *b, OperationState &result, Type resultType, Value src,"
+ "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+ "ArrayRef<NamedAttribute> attrs = {}">];
+
+ code commonExtraClassDeclaration = [{
+ static StringRef getReassociationAttrName() { return "reassociation"; }
+ }];
+ let assemblyFormat = [{
+ $src $reassociation attr-dict `:` type($src) `into` type(results)
+ }];
+}
+
+def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape">,
+ Arguments<(ins AnyStridedMemRef:$src, AffineMapArrayAttr:$reassociation)>,
+ Results<(outs AnyStridedMemRef:$result)> {
let summary = "linalg.reshape produces a new view into the operand view";
let description = [{
The `linalg.reshape` op produces a new view whose sizes are a reassociation
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
```
}];
+ let extraClassDeclaration = commonExtraClassDeclaration # [{
+ MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
+ MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
+ }];
+ let hasFolder = 1;
+}
- let builders = [
- // Builder for a contracting reshape whose result type is computed from
- // `view` and `reassociation`.
- OpBuilder<"Builder *b, OperationState &result, Value view, "
- "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
- "ArrayRef<NamedAttribute> attrs = {}">,
- // Builder for a reshape whose result type is passed explicitly. This may be
- // either a contracting or expanding reshape.
- OpBuilder<"Builder *b, OperationState &result, Type resultType, Value view,"
- "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
- "ArrayRef<NamedAttribute> attrs = {}">];
+def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
+ Arguments<(ins AnyTensor:$src,
+ AffineMapArrayAttr:$reassociation)>,
+ Results<(outs AnyTensor:$result)> {
+ let summary = "linalg.tensor_reshape produces a new reshaped tensor.";
+ let description = [{
+ The `linalg.reshape` op produces a new tensor whose sizes are a
+ reassociation of the original `src`.
- let extraClassDeclaration = [{
- static StringRef getReassociationAttrName() { return "reassociation"; }
- MemRefType getViewType() { return view().getType().cast<MemRefType>(); }
+ A reassociation is defined as a continuous grouping of dimensions and is
+ represented with an affine map array attribute. In the future,
+ non-continuous groupings may be allowed (i.e. permutations, reindexings
+ etc).
+
+ A reshape may either collapse or expand dimensions, depending on the
+ relationship between source and target tensor ranks. The verification rule
+ is that the reassociation maps are applied to the tensor with the larger
+ rank to obtain the tensor with the smaller rank. In the case of a dimension
+ expansion, the reassociation maps can be interpreted as inverse maps.
+
+ Examples:
+
+ ```mlir
+ // Dimension collapse (i, j) -> i' and k -> k'
+ %b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
+ tensor<?x?x?xf32> into tensor<?x?xf32>
+ ```
+
+ ```mlir
+ // Dimension expansion i -> (i', j') and (k) -> (k')
+ %b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
+ tensor<?x?xf32> into tensor<?x?x?xf32>
+ ```
}];
- let assemblyFormat = [{
- $view $reassociation attr-dict `:` type($view) `into` type(results)
+ let extraClassDeclaration = commonExtraClassDeclaration # [{
+ RankedTensorType getSrcType() {
+ return src().getType().cast<RankedTensorType>();
+ }
+ RankedTensorType getResultType() {
+ return result().getType().cast<RankedTensorType>();
+ }
}];
- let hasFolder = 1;
}
def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto reshapeOp = cast<ReshapeOp>(op);
- MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
+ MemRefType dstType = reshapeOp.getResultType();
if (!dstType.hasStaticShape())
return failure();
edsc::ScopedContext context(rewriter, op->getLoc());
ReshapeOpOperandAdaptor adaptor(operands);
- BaseViewConversionHelper baseDesc(adaptor.view());
+ BaseViewConversionHelper baseDesc(adaptor.src());
BaseViewConversionHelper desc(typeConverter.convertType(dstType));
desc.setAllocatedPtr(baseDesc.allocatedPtr());
desc.setAlignedPtr(baseDesc.alignedPtr());
}
void mlir::linalg::ReshapeOp::build(
- Builder *b, OperationState &result, Value view,
+ Builder *b, OperationState &result, Value src,
ArrayRef<ArrayRef<AffineExpr>> reassociation,
ArrayRef<NamedAttribute> attrs) {
auto maps = getSymbolLessAffineMaps(reassociation);
- auto memRefType = view.getType().cast<MemRefType>();
+ auto memRefType = src.getType().cast<MemRefType>();
auto resultType = computeReshapeCollapsedType(memRefType, maps);
- build(b, result, resultType, view, attrs);
+ build(b, result, resultType, src, attrs);
result.addAttribute(ReshapeOp::getReassociationAttrName(),
b->getAffineMapArrayAttr(maps));
}
void mlir::linalg::ReshapeOp::build(
- Builder *b, OperationState &result, Type resultType, Value view,
+ Builder *b, OperationState &result, Type resultType, Value src,
ArrayRef<ArrayRef<AffineExpr>> reassociation,
ArrayRef<NamedAttribute> attrs) {
auto maps = getSymbolLessAffineMaps(reassociation);
- build(b, result, resultType, view, attrs);
+ build(b, result, resultType, src, attrs);
result.addAttribute(ReshapeOp::getReassociationAttrName(),
b->getAffineMapArrayAttr(maps));
}
-static LogicalResult verify(ReshapeOp op) {
- MemRefType expandedType = op.getViewType();
- MemRefType collapsedType = op.getResult().getType().cast<MemRefType>();
+// Common verifier for reshape-like types. Fills `expandedType` and
+// `collapsedType` with the proper `src` or `result` type.
+template <typename Op, typename T>
+LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType, T &collapsedType) {
+ expandedType = op.getSrcType();
+ collapsedType = op.getResultType();
unsigned expandedRank = expandedType.getRank();
unsigned collapsedRank = collapsedType.getRank();
bool isCollapse = expandedRank > collapsedRank;
return op.emitOpError("expected to collapse or expand dims");
if (collapsedRank != op.reassociation().size())
- return op.emitOpError("expected rank of the collapsed view(")
+ return op.emitOpError("expected rank of the collapsed type(")
<< collapsedRank << ") to be the number of reassociation maps("
<< op.reassociation().size() << ")";
auto maps = getAffineMaps(op.reassociation());
if (!isReassociationValid(maps, &invalidIdx))
return op.emitOpError("expected reassociation map #")
<< invalidIdx << " to be valid and contiguous";
+ return success();
+}
+
+static LogicalResult verify(ReshapeOp op) {
+ MemRefType expandedType, collapsedType;
+ if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
+ return failure();
+ auto maps = getAffineMaps(op.reassociation());
MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
if (collapsedType != expectedType)
return op.emitOpError("expected collapsed type to be ")
return success();
}
+//===----------------------------------------------------------------------===//
+// TensorReshapeOp
+//===----------------------------------------------------------------------===//
+
+/// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
+static RankedTensorType
+computeTensorReshapeCollapsedType(RankedTensorType type,
+ ArrayRef<AffineMap> reassociation) {
+ auto shape = type.getShape();
+ SmallVector<int64_t, 4> newShape;
+ newShape.reserve(reassociation.size());
+
+ // Use the fact that reassociation is valid to simplify the logic: only use
+ // each map's rank.
+ assert(isReassociationValid(reassociation) && "invalid reassociation");
+ unsigned currentDim = 0;
+ for (AffineMap m : reassociation) {
+ unsigned dim = m.getNumResults();
+ auto band = shape.drop_front(currentDim).take_front(dim);
+ int64_t size = 1;
+ if (llvm::is_contained(band, ShapedType::kDynamicSize))
+ size = ShapedType::kDynamicSize;
+ else
+ for (unsigned d = 0; d < dim; ++d)
+ size *= shape[currentDim + d];
+ newShape.push_back(size);
+ currentDim += dim;
+ }
+
+ return RankedTensorType::get(newShape, type.getElementType());
+}
+
+void mlir::linalg::TensorReshapeOp::build(
+ Builder *b, OperationState &result, Value src,
+ ArrayRef<ArrayRef<AffineExpr>> reassociation,
+ ArrayRef<NamedAttribute> attrs) {
+ auto maps = getSymbolLessAffineMaps(reassociation);
+ auto resultType = computeTensorReshapeCollapsedType(
+ src.getType().cast<RankedTensorType>(), maps);
+ build(b, result, resultType, src, attrs);
+ result.addAttribute(TensorReshapeOp::getReassociationAttrName(),
+ b->getAffineMapArrayAttr(maps));
+}
+
+void mlir::linalg::TensorReshapeOp::build(
+ Builder *b, OperationState &result, Type resultType, Value src,
+ ArrayRef<ArrayRef<AffineExpr>> reassociation,
+ ArrayRef<NamedAttribute> attrs) {
+ auto maps = getSymbolLessAffineMaps(reassociation);
+ build(b, result, resultType, src, attrs);
+ result.addAttribute(TensorReshapeOp::getReassociationAttrName(),
+ b->getAffineMapArrayAttr(maps));
+}
+
+static LogicalResult verify(TensorReshapeOp op) {
+ RankedTensorType expandedType, collapsedType;
+ if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
+ return failure();
+ auto maps = getAffineMaps(op.reassociation());
+ // TODO(ntv): expanding a ? with a non-constant is under-specified. Error
+ // out.
+ RankedTensorType expectedType =
+ computeTensorReshapeCollapsedType(expandedType, maps);
+ if (collapsedType != expectedType)
+ return op.emitOpError("expected collapsed type to be ")
+ << expectedType << ", but got " << collapsedType;
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
// -----
func @reshape(%arg0: memref<?x?x?xf32>) {
- // expected-error @+1 {{expected rank of the collapsed view(2) to be the number of reassociation maps(1)}}
+ // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}}
%0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>] :
memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
}
// CHECK-DAG: #[[reshape5D2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)>
// CHECK-DAG: #[[reshape5D34:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
-func @reshape_static(%arg0: memref<3x4x5xf32>) {
- // Reshapes that collapse and expand back a contiguous tensor.
+func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>, %arg2: tensor<3x?x5xf32>) {
+ // Reshapes that collapse and expand back a contiguous buffer.
%0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>,
affine_map<(i, j, k) -> (k)>] :
memref<3x4x5xf32> into memref<12x5xf32>
memref<3x4x5xf32> into memref<60xf32>
%r2 = linalg.reshape %2 [affine_map<(i, j, k) -> (i, j, k)>] :
memref<60xf32> into memref<3x4x5xf32>
- // Reshapes that expand and collapse back a contiguous tensor with some 1's.
+ // Reshapes that expand and collapse back a contiguous buffer with some 1's.
%3 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
affine_map<(i, j, k, l, m) -> (k)>,
affine_map<(i, j, k, l, m) -> (l, m)>] :
affine_map<(i, j, k, l, m) -> (k)>,
affine_map<(i, j, k, l, m) -> (l, m)>] :
memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+ // Reshapes on tensors.
+ %t0 = linalg.tensor_reshape %arg1 [affine_map<(i, j, k, l, m) -> (i, j)>,
+ affine_map<(i, j, k, l, m) -> (k)>,
+ affine_map<(i, j, k, l, m) -> (l, m)>] :
+ tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+ %rt0 = linalg.tensor_reshape %t0 [affine_map<(i, j, k, l, m) -> (i, j)>,
+ affine_map<(i, j, k, l, m) -> (k)>,
+ affine_map<(i, j, k, l, m) -> (l, m)>] :
+ tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+ %t1 = linalg.tensor_reshape %arg2 [affine_map<(i, j, k, l, m) -> (i, j)>,
+ affine_map<(i, j, k, l, m) -> (k)>,
+ affine_map<(i, j, k, l, m) -> (l, m)>] :
+ tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+ %rt1 = linalg.tensor_reshape %t1 [affine_map<(i, j, k, l, m) -> (i)>,
+ affine_map<(i, j, k, l, m) -> (j, k)>,
+ affine_map<(i, j, k, l, m) -> (l, m)>] :
+ tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
return
}
// CHECK-LABEL: func @reshape_static
// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
// CHECK: linalg.reshape {{.*}} [#[[reshape5D01]], #[[reshape5D2]], #[[reshape5D34]]]
// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+//
+// CHECK: linalg.tensor_reshape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+// CHECK: linalg.tensor_reshape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+// CHECK: linalg.tensor_reshape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+// CHECK: linalg.tensor_reshape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
// -----