[MLIR][Shape] Fix lowering of `shape.get_extent`
authorFrederik Gossen <frgossen@google.com>
Tue, 30 Jun 2020 08:33:49 +0000 (08:33 +0000)
committerFrederik Gossen <frgossen@google.com>
Tue, 30 Jun 2020 08:35:24 +0000 (08:35 +0000)
The declarative conversion patterns caused crashes in the asan configuration.
The non-declarative implementation circumvents this.

Differential Revision: https://reviews.llvm.org/D82797

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td

index 5fd9be0..7ebcb39 100644 (file)
@@ -90,6 +90,29 @@ public:
   }
 };
 
+class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
+  using OpConversionPattern<GetExtentOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    GetExtentOp::Adaptor transformed(operands);
+
+    // Derive shape extent directly from shape origin if possible.
+    // This circumvents the necessity to materialize the shape in memory.
+    if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
+      rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
+                                         transformed.dim());
+      return success();
+    }
+
+    rewriter.replaceOpWithNewOp<ExtractElementOp>(
+        op, rewriter.getIndexType(), transformed.shape(),
+        ValueRange{transformed.dim()});
+    return success();
+  }
+};
+
 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
 public:
   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
@@ -161,6 +184,7 @@ void mlir::populateShapeToStandardConversionPatterns(
       BinaryOpConversion<AddOp, AddIOp>,
       BinaryOpConversion<MulOp, MulIOp>,
       ConstSizeOpConverter,
+      GetExtentOpConverter,
       RankOpConverter,
       ShapeOfOpConversion>(ctx);
   // clang-format on
index 154cf6a..a133548 100644 (file)
@@ -19,20 +19,3 @@ def SizeToIndexOpConversion : Pat<
     (Shape_SizeToIndexOp $arg),
     (replaceWithValue $arg)>;
 
-// Derive shape extent directly from shape origin if possible.
-// This circumvents the necessity to materialize the shape in memory.
-def GetExtentShapeOfConversion : Pat<
-    (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx),
-    (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))),
-    [],
-    (addBenefit 10)>;
-def GetExtentFromExtentTensorConversion : Pattern<
-    (Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx),
-    [
-      (Shape_SizeToIndexOp:$std_idx $idx),
-      (ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)),
-      (Shape_IndexToSizeOp $std_result)
-    ],
-    [],
-    (addBenefit 10)>;
-