[MLIR][Shape] Lower `shape.rank`
authorFrederik Gossen <frgossen@google.com>
Thu, 25 Jun 2020 08:42:40 +0000 (08:42 +0000)
committerFrederik Gossen <frgossen@google.com>
Thu, 25 Jun 2020 08:44:06 +0000 (08:44 +0000)
Lower `shape.rank` to standard dialect.
A shape's size is the same as the extent of the first and only dimension of the
`tensor<?xindex>` it is represented by.

Differential Revision: https://reviews.llvm.org/D82080

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

index 6a02bdc..5fd9be0 100644 (file)
@@ -90,6 +90,20 @@ public:
   }
 };
 
+class RankOpConverter : public OpConversionPattern<shape::RankOp> {
+public:
+  using OpConversionPattern<shape::RankOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    shape::RankOp::Adaptor transformed(operands);
+    rewriter.replaceOpWithNewOp<DimOp>(op.getOperation(), transformed.shape(),
+                                       0);
+    return success();
+  }
+};
+
 /// Type conversions.
 class ShapeTypeConverter : public TypeConverter {
 public:
@@ -147,6 +161,7 @@ void mlir::populateShapeToStandardConversionPatterns(
       BinaryOpConversion<AddOp, AddIOp>,
       BinaryOpConversion<MulOp, MulIOp>,
       ConstSizeOpConverter,
+      RankOpConverter,
       ShapeOfOpConversion>(ctx);
   // clang-format on
 }
index bfe3c2b..a9b4bf7 100644 (file)
@@ -86,7 +86,6 @@ func @size_const() -> !shape.size {
 }
 // CHECK: %[[C1:.*]] = constant 1 : index
 // CHECK: return %[[C1]] : index
-
 // -----
 
 // Lower `shape_of` for statically shaped tensor.
@@ -115,3 +114,16 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
   %shape = shape.shape_of %arg : tensor<1x5x?xf32>
   return
 }
+
+// -----
+
+// Convert `rank` to `dim` of the first dimension.
+// CHECK-LABEL: @rank
+// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
+func @rank(%shape : !shape.shape) -> !shape.size {
+  // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+  // CHECK-DAG: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]]
+  // CHECK-DAG: return %[[RESULT]] : index
+  %rank = shape.rank %shape
+  return %rank : !shape.size
+}