[mlir] Add translation from tensor.reshape to memref.reshape
authorAshay Rane <ashay@users.noreply.github.com>
Mon, 9 May 2022 15:41:21 +0000 (17:41 +0200)
committerMatthias Springer <springerm@google.com>
Mon, 9 May 2022 15:45:07 +0000 (17:45 +0200)
This patch augments the `tensor-bufferize` pass by adding a conversion
rule to translate ReshapeOp from the `tensor` dialect to the `memref`
dialect, in addition to adding a unit test to validate the translation.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D125031

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/bufferize.mlir

index 5be8205..b00d87b 100644 (file)
@@ -743,6 +743,54 @@ struct RankOpInterface
   }
 };
 
+/// Bufferization of tensor.reshape. Replace with memref.reshape.
+struct ReshapeOpInterface
+    : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
+                                                    tensor::ReshapeOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    if (&opOperand == &op->getOpOperand(1) /* shape */)
+      return true;
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    return false;
+  }
+
+  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    return {op->getOpResult(0)};
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const AnalysisState &state) const {
+    return BufferRelation::Equivalent;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          BufferizationState &state) const {
+    auto reshapeOp = cast<tensor::ReshapeOp>(op);
+    auto &srcOperand = reshapeOp->getOpOperand(0);
+    auto srcBuffer = state.getBuffer(rewriter, srcOperand);
+    if (failed(srcBuffer))
+      return failure();
+
+    auto &shapeOperand = reshapeOp->getOpOperand(1);
+    auto shapeBuffer = state.getBuffer(rewriter, shapeOperand);
+    if (failed(shapeBuffer))
+      return failure();
+
+    auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
+    auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions());
+
+    replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
+        rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
+    return success();
+  }
+};
+
 } // namespace
 } // namespace tensor
 } // namespace mlir
@@ -761,5 +809,6 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
     InsertOp::attachInterface<InsertOpInterface>(*ctx);
     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
     RankOp::attachInterface<RankOpInterface>(*ctx);
+    ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
   });
 }
index 587508c..cd88f2f 100644 (file)
@@ -408,3 +408,30 @@ func.func @tensor.collapse_shape_of_slice4(%arg0: tensor<?x2x4xf32>, %offset: in
   %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32>
   return %ret: tensor<8xf32>
 }
+
+// CHECK-LABEL: func @tensor.reshape(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
+func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> {
+  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
+
+  // CHECK: %[[two:.*]] = arith.constant 2 : i64
+  %two = arith.constant 2 : i64
+  // CHECK: %[[five:.*]] = arith.constant 5 : i64
+  %five = arith.constant 5 : i64
+
+  // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 128 : i64} : memref<3xi64>
+  // CHECK: %[[zero_idx:.*]] = arith.constant 0 : index
+  // CHECK: %[[one_idx:.*]] = arith.constant 1 : index
+  // CHECK: %[[two_idx:.*]] = arith.constant 2 : index
+  // CHECK: memref.store %[[two]], %[[alloc]][%[[zero_idx]]] : memref<3xi64>
+  // CHECK: memref.store %[[two]], %[[alloc]][%[[one_idx]]] : memref<3xi64>
+  // CHECK: memref.store %[[five]], %[[alloc]][%[[two_idx]]] : memref<3xi64>
+  %shape = tensor.from_elements %two, %two, %five : tensor<3xi64>
+
+  // CHECK: %[[reshaped:.*]] = memref.reshape %[[m1]](%[[alloc]]) : (memref<?x10xf32>, memref<3xi64>) -> memref<2x2x5xf32>
+  %reshaped = tensor.reshape %t1(%shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<2x2x5xf32>
+
+  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[reshaped]]
+  // CHECK: return %[[r]]
+  return %reshaped : tensor<2x2x5xf32>
+}