}
};
+/// Bufferization of tensor.reshape. Replace with memref.reshape.
+struct ReshapeOpInterface
+ : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
+ tensor::ReshapeOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ if (&opOperand == &op->getOpOperand(1) /* shape */)
+ return true;
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return false;
+ }
+
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return {op->getOpResult(0)};
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const AnalysisState &state) const {
+ return BufferRelation::Equivalent;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ BufferizationState &state) const {
+ auto reshapeOp = cast<tensor::ReshapeOp>(op);
+ auto &srcOperand = reshapeOp->getOpOperand(0);
+ auto srcBuffer = state.getBuffer(rewriter, srcOperand);
+ if (failed(srcBuffer))
+ return failure();
+
+ auto &shapeOperand = reshapeOp->getOpOperand(1);
+ auto shapeBuffer = state.getBuffer(rewriter, shapeOperand);
+ if (failed(shapeBuffer))
+ return failure();
+
+ auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
+ auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions());
+
+ replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
+ rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
+ return success();
+ }
+};
+
} // namespace
} // namespace tensor
} // namespace mlir
InsertOp::attachInterface<InsertOpInterface>(*ctx);
InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
RankOp::attachInterface<RankOpInterface>(*ctx);
+ ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
});
}
%ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32>
return %ret: tensor<8xf32>
}
+
+// CHECK-LABEL: func @tensor.reshape(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
+func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> {
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
+
+ // CHECK: %[[two:.*]] = arith.constant 2 : i64
+ %two = arith.constant 2 : i64
+ // CHECK: %[[five:.*]] = arith.constant 5 : i64
+ %five = arith.constant 5 : i64
+
+ // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 128 : i64} : memref<3xi64>
+ // CHECK: %[[zero_idx:.*]] = arith.constant 0 : index
+ // CHECK: %[[one_idx:.*]] = arith.constant 1 : index
+ // CHECK: %[[two_idx:.*]] = arith.constant 2 : index
+ // CHECK: memref.store %[[two]], %[[alloc]][%[[zero_idx]]] : memref<3xi64>
+ // CHECK: memref.store %[[two]], %[[alloc]][%[[one_idx]]] : memref<3xi64>
+ // CHECK: memref.store %[[five]], %[[alloc]][%[[two_idx]]] : memref<3xi64>
+ %shape = tensor.from_elements %two, %two, %five : tensor<3xi64>
+
+ // CHECK: %[[reshaped:.*]] = memref.reshape %[[m1]](%[[alloc]]) : (memref<?x10xf32>, memref<3xi64>) -> memref<2x2x5xf32>
+ %reshaped = tensor.reshape %t1(%shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<2x2x5xf32>
+
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[reshaped]]
+ // CHECK: return %[[r]]
+ return %reshaped : tensor<2x2x5xf32>
+}