From 45aaa67fceadaeb3bf76bab21d36c1f337d97491 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Ingo=20M=C3=BCller?= Date: Fri, 26 May 2023 09:39:22 +0000 Subject: [PATCH] [mlir][tensor] Fix one-shot bufferization of tensor.reshape. I believe that the previous implementation did not work on any input. It called getMemRefType with `layout = {}`, presumably with the intention to create a MemrefType with identity layout. However, the implementation of that function returns a MemrefType with *unknown* layout if it is provided with a default-constructed layout attribute. This patch uses getMemRefTypeWithStaticIdentityLayout instead, with has identical behavior except for the case of a default-constructed layout, which it passes on as-is to the MemrefType. This problem did not surface in the test because tensor.reshape was not tested with -one-shot-bufferize. This patch introduces a test copied from the tests for -tesnor-bufferize adapted in as follows: since the test is run with "bufferize-function-boundaries", a tensor that is passed into the function is bufferized into a memref with unknown layout, which wouldn't be a valid intput for memref.reshape, so the tests now uses a tensor constructed with arith.constant inside of the function. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D151544 --- .../Tensor/Transforms/BufferizableOpInterfaceImpl.cpp | 4 ++-- mlir/test/Dialect/Tensor/one-shot-bufferize.mlir | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 545a9d0..1a4fc3b 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -992,8 +992,8 @@ struct ReshapeOpInterface getBuffer(rewriter, reshapeOp.getShape(), options); if (failed(srcBuffer) || failed(shapeBuffer)) return failure(); - auto resultMemRefType = getMemRefType( - reshapeOp.getResult(), options, /*layout=*/{}, + auto resultMemRefType = getMemRefTypeWithStaticIdentityLayout( + reshapeOp.getResult().getType(), cast(srcBuffer->getType()).getMemorySpace()); replaceOpWithNewBufferizedOp( rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index a4c868c..399fd05 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -398,3 +398,21 @@ func.func @parallel_insert_slice_source_out_of_place(%in: tensor<1xf32>, %out: t // CHECK: } return } + +// ----- + +// CHECK-LABEL: func @tensor.reshape( +func.func @tensor.reshape() -> tensor<2x2x5xf32> { + // CHECK-DAG: %[[M1:.*]] = memref.cast %{{.*}} : memref<2x10xf32> to memref + %t1_static = arith.constant dense<0.> : tensor<2x10xf32> + %t1 = tensor.cast %t1_static : tensor<2x10xf32> to tensor + + // CHECK: %[[SHAPE:.*]] = memref.get_global @{{.*}} : memref<3xi64> + %shape = arith.constant dense<[2, 2, 5]> : tensor<3xi64> + + // CHECK: %[[RESHAPED:.*]] = memref.reshape %[[M1]](%[[SHAPE]]) : (memref, memref<3xi64>) -> memref<2x2x5xf32> + %reshaped = tensor.reshape %t1(%shape) : (tensor, tensor<3xi64>) -> tensor<2x2x5xf32> + + // CHECK: return %[[RESHAPED]] + return %reshaped : tensor<2x2x5xf32> +} -- 2.7.4