From 83154c541806468802d687a8b3c8f1a65e92199c Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Mon, 26 Oct 2020 17:29:18 -0700 Subject: [PATCH] [mlir] Add bufferization for std.select op. Differential Revision: https://reviews.llvm.org/D90204 --- .../Dialect/StandardOps/Transforms/Bufferize.cpp | 38 +++++++++++++++++++--- mlir/test/Dialect/Standard/bufferize.mlir | 14 ++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp index a1b1f0a..9056fbc 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -89,6 +89,24 @@ public: } // namespace namespace { +class BufferizeSelectOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(SelectOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!op.condition().getType().isa()) + return rewriter.notifyMatchFailure(op, "requires scalar condition"); + + SelectOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp( + op, adaptor.condition(), adaptor.true_value(), adaptor.false_value()); + return success(); + } +}; +} // namespace + +namespace { class BufferizeTensorCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -128,10 +146,15 @@ public: void mlir::populateStdBufferizePatterns(MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns - .insert( - typeConverter, context); + patterns.insert< + // clang-format off + BufferizeDynamicTensorFromElementsOp, + BufferizeExtractElementOp, + BufferizeSelectOp, + BufferizeTensorCastOp, + BufferizeTensorFromElementsOp + // clang-format on + >(typeConverter, context); } namespace { @@ -148,6 +171,13 @@ struct StdBufferizePass : public StdBufferizeBase { populateStdBufferizePatterns(context, typeConverter, patterns); target.addIllegalOp(); + // We only bufferize the case of tensor selected type and scalar condition, + // as that boils down to a select over memref descriptors (don't need to + // touch the data). + target.addDynamicallyLegalOp([&](SelectOp op) { + return typeConverter.isLegal(op.getType()) || + !op.condition().getType().isa(); + }); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir index 6125998..b2cefe3 100644 --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -61,6 +61,20 @@ func @extract_element(%arg0: tensor, %arg1: index) -> f32 { return %0 : f32 } +// CHECK-LABEL: func @select( +// CHECK-SAME: %[[PRED:.*]]: i1, +// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor, +// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor) -> tensor { +// CHECK: %[[TRUE_VAL_MEMREF:.*]] = tensor_to_memref %[[TRUE_VAL]] : memref +// CHECK: %[[FALSE_VAL_MEMREF:.*]] = tensor_to_memref %[[FALSE_VAL]] : memref +// CHECK: %[[RET_MEMREF:.*]] = select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref +// CHECK: %[[RET:.*]] = tensor_load %[[RET_MEMREF]] : memref +// CHECK: return %[[RET]] : tensor +func @select(%arg0: i1, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = select %arg0, %arg1, %arg2 : tensor + return %0 : tensor +} + // CHECK-LABEL: func @tensor_cast( // CHECK-SAME: %[[TENSOR:.*]]: tensor) -> tensor<2xindex> { // CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] -- 2.7.4