[MLIR][Shape] Lower `shape.shape_of` to standard dialect
authorFrederik Gossen <frgossen@google.com>
Fri, 19 Jun 2020 15:09:36 +0000 (15:09 +0000)
committerFrederik Gossen <frgossen@google.com>
Fri, 19 Jun 2020 15:21:13 +0000 (15:21 +0000)
Lower `shape.shape_of` to standard dialect.
This lowering supports statically and dynamically shaped tensors.
Support for unranked tensors will be added as part of the lowering to `scf`.

Differential Revision: https://reviews.llvm.org/D82098

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

index d02f5e3..6a02bdc 100644 (file)
@@ -38,6 +38,45 @@ public:
   }
 };
 
+class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
+public:
+  using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    ShapeOfOp::Adaptor transformed(operands);
+    auto loc = op.getLoc();
+    auto tensorVal = transformed.arg();
+    auto tensorTy = tensorVal.getType();
+
+    // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
+    // found in the corresponding pass.
+    if (tensorTy.isa<UnrankedTensorType>())
+      return failure();
+
+    // Build values for individual dimensions.
+    SmallVector<Value, 8> dimValues;
+    auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
+    int64_t rank = rankedTensorTy.getRank();
+    for (int64_t i = 0; i < rank; i++) {
+      if (rankedTensorTy.isDynamicDim(i)) {
+        auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
+        dimValues.push_back(dimVal);
+      } else {
+        int64_t dim = rankedTensorTy.getDimSize(i);
+        auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
+        dimValues.push_back(dimVal);
+      }
+    }
+
+    // Materialize shape as ranked tensor.
+    rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op.getOperation(),
+                                                      dimValues);
+    return success();
+  }
+};
+
 class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
 public:
   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
@@ -107,7 +146,8 @@ void mlir::populateShapeToStandardConversionPatterns(
   patterns.insert<
       BinaryOpConversion<AddOp, AddIOp>,
       BinaryOpConversion<MulOp, MulIOp>,
-      ConstSizeOpConverter>(ctx);
+      ConstSizeOpConverter,
+      ShapeOfOpConversion>(ctx);
   // clang-format on
 }
 
index 1caf005..bfe3c2b 100644 (file)
@@ -86,3 +86,32 @@ func @size_const() -> !shape.size {
 }
 // CHECK: %[[C1:.*]] = constant 1 : index
 // CHECK: return %[[C1]] : index
+
+// -----
+
+// Lower `shape_of` for statically shaped tensor.
+// CHECK-LABEL: @shape_of_stat
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
+func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
+  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+  // CHECK-DAG: %[[C2:.*]] = constant 2 : index
+  // CHECK-DAG: %[[C3:.*]] = constant 3 : index
+  // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
+  %shape = shape.shape_of %arg : tensor<1x2x3xf32>
+  return
+}
+
+// -----
+
+// Lower `shape_of` for dynamically shaped tensor.
+// CHECK-LABEL: @shape_of_dyn
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
+func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
+  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+  // CHECK-DAG: %[[C5:.*]] = constant 5 : index
+  // CHECK-DAG: %[[C2:.*]] = constant 2 : index
+  // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
+  // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
+  %shape = shape.shape_of %arg : tensor<1x5x?xf32>
+  return
+}