[mlir] Add a missing pattern to bufferize tensor.rank.
authorAlexander Belyaev <pifon@google.com>
Tue, 14 Dec 2021 18:58:40 +0000 (19:58 +0100)
committerAlexander Belyaev <pifon@google.com>
Tue, 14 Dec 2021 19:04:57 +0000 (20:04 +0100)
Differential Revision: https://reviews.llvm.org/D115745

mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/test/Dialect/Tensor/bufferize.mlir

index d02328e..0fd5b2d 100644 (file)
@@ -24,8 +24,7 @@
 using namespace mlir;
 
 namespace {
-class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
-public:
+struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
@@ -36,11 +35,8 @@ public:
     return success();
   }
 };
-} // namespace
 
-namespace {
-class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
-public:
+struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
@@ -50,11 +46,8 @@ public:
     return success();
   }
 };
-} // namespace
 
-namespace {
-class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
-public:
+struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
@@ -64,10 +57,8 @@ public:
     return success();
   }
 };
-} // namespace
 
-namespace {
-class BufferizeFromElementsOp
+struct BufferizeFromElementsOp
     : public OpConversionPattern<tensor::FromElementsOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -88,11 +79,8 @@ public:
     return success();
   }
 };
-} // namespace
 
-namespace {
-class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
-public:
+struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -150,44 +138,51 @@ public:
     return success();
   }
 };
-} // namespace
 
-void mlir::populateTensorBufferizePatterns(
-    bufferization::BufferizeTypeConverter &typeConverter,
-    RewritePatternSet &patterns) {
-  patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
-               BufferizeFromElementsOp, BufferizeGenerateOp>(
-      typeConverter, patterns.getContext());
-}
+struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
+                                                adaptor.tensor());
+    return success();
+  }
+};
 
-namespace {
 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
   void runOnFunction() override {
     auto *context = &getContext();
     bufferization::BufferizeTypeConverter typeConverter;
-    RewritePatternSet patterns(context);
-    ConversionTarget target(*context);
-
-    bufferization::populateBufferizeMaterializationLegality(target);
 
-    populateTensorBufferizePatterns(typeConverter, patterns);
-    target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
-                        tensor::FromElementsOp, tensor::GenerateOp>();
-    target.addLegalDialect<memref::MemRefDialect>();
+    ConversionTarget target(*context);
+    target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
     target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
                                       StandardOpsDialect>(
         [&](Operation *op) { return typeConverter.isLegal(op); });
-    target.addLegalOp<CallOp>();
-    target.addLegalOp<ReturnOp>();
-    target.addLegalDialect<scf::SCFDialect>();
+    target.addLegalOp<CallOp, ReturnOp>();
+    target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
+                        tensor::FromElementsOp, tensor::GenerateOp>();
+    bufferization::populateBufferizeMaterializationLegality(target);
 
+    RewritePatternSet patterns(context);
+    populateTensorBufferizePatterns(typeConverter, patterns);
     if (failed(
             applyPartialConversion(getFunction(), target, std::move(patterns))))
       signalPassFailure();
   }
 };
+
 } // namespace
 
+void mlir::populateTensorBufferizePatterns(
+    bufferization::BufferizeTypeConverter &typeConverter,
+    RewritePatternSet &patterns) {
+  patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
+               BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
+      typeConverter, patterns.getContext());
+}
+
 std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
   return std::make_unique<TensorBufferizePass>();
 }
index 91642f0..5b3bb14 100644 (file)
@@ -11,6 +11,15 @@ func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
   return %0 : index
 }
 
+// CHECK-LABEL: func @rank(
+// CHECK-SAME:    %[[TENSOR:.*]]: tensor<*xf32>) -> index {
+// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
+// CHECK:           %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32>
+func @rank(%arg0: tensor<*xf32>) -> index {
+  %0 = tensor.rank %arg0 : tensor<*xf32>
+  return %0 : index
+}
+
 // CHECK-LABEL:   func @tensor.cast(
 // CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
 // CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]