[mlir] Add bufferization for std.select op.
authorSean Silva <silvasean@google.com>
Tue, 27 Oct 2020 00:29:18 +0000 (17:29 -0700)
committerSean Silva <silvasean@google.com>
Tue, 27 Oct 2020 18:46:33 +0000 (11:46 -0700)
Differential Revision: https://reviews.llvm.org/D90204

mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/test/Dialect/Standard/bufferize.mlir

index a1b1f0a..9056fbc 100644 (file)
@@ -89,6 +89,24 @@ public:
 } // namespace
 
 namespace {
+class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!op.condition().getType().isa<IntegerType>())
+      return rewriter.notifyMatchFailure(op, "requires scalar condition");
+
+    SelectOp::Adaptor adaptor(operands);
+    rewriter.replaceOpWithNewOp<SelectOp>(
+        op, adaptor.condition(), adaptor.true_value(), adaptor.false_value());
+    return success();
+  }
+};
+} // namespace
+
+namespace {
 class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -128,10 +146,15 @@ public:
 void mlir::populateStdBufferizePatterns(MLIRContext *context,
                                         BufferizeTypeConverter &typeConverter,
                                         OwningRewritePatternList &patterns) {
-  patterns
-      .insert<BufferizeDynamicTensorFromElementsOp, BufferizeExtractElementOp,
-              BufferizeTensorCastOp, BufferizeTensorFromElementsOp>(
-          typeConverter, context);
+  patterns.insert<
+      // clang-format off
+      BufferizeDynamicTensorFromElementsOp,
+      BufferizeExtractElementOp,
+      BufferizeSelectOp,
+      BufferizeTensorCastOp,
+      BufferizeTensorFromElementsOp
+      // clang-format on
+      >(typeConverter, context);
 }
 
 namespace {
@@ -148,6 +171,13 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
     populateStdBufferizePatterns(context, typeConverter, patterns);
     target.addIllegalOp<DynamicTensorFromElementsOp, ExtractElementOp,
                         TensorCastOp, TensorFromElementsOp>();
+    // We only bufferize the case of tensor selected type and scalar condition,
+    // as that boils down to a select over memref descriptors (don't need to
+    // touch the data).
+    target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
+      return typeConverter.isLegal(op.getType()) ||
+             !op.condition().getType().isa<IntegerType>();
+    });
     if (failed(
             applyPartialConversion(getFunction(), target, std::move(patterns))))
       signalPassFailure();
index 6125998..b2cefe3 100644 (file)
@@ -61,6 +61,20 @@ func @extract_element(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
   return %0 : f32
 }
 
+// CHECK-LABEL:   func @select(
+// CHECK-SAME:                 %[[PRED:.*]]: i1,
+// CHECK-SAME:                 %[[TRUE_VAL:.*]]: tensor<f32>,
+// CHECK-SAME:                 %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK:           %[[TRUE_VAL_MEMREF:.*]] = tensor_to_memref %[[TRUE_VAL]] : memref<f32>
+// CHECK:           %[[FALSE_VAL_MEMREF:.*]] = tensor_to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK:           %[[RET_MEMREF:.*]] = select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
+// CHECK:           %[[RET:.*]] = tensor_load %[[RET_MEMREF]] : memref<f32>
+// CHECK:           return %[[RET]] : tensor<f32>
+func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
+  %0 = select %arg0, %arg1, %arg2 : tensor<f32>
+  return %0 : tensor<f32>
+}
+
 // CHECK-LABEL:   func @tensor_cast(
 // CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
 // CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]