[MLIR][Shape] Lower `shape.const_shape` to `tensor_from_elements`
authorFrederik Gossen <frgossen@google.com>
Tue, 28 Jul 2020 15:39:49 +0000 (15:39 +0000)
committerFrederik Gossen <frgossen@google.com>
Tue, 28 Jul 2020 15:40:55 +0000 (15:40 +0000)
Differential Revision: https://reviews.llvm.org/D82848

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

index f239d1c..b84b6ba 100644 (file)
@@ -104,6 +104,39 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
 }
 
 namespace {
+class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
+public:
+  using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult ConstShapeOpConverter::matchAndRewrite(
+    ConstShapeOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+
+  // For now, this lowering supports only extent tensors, not `shape.shape`
+  // types.
+  if (op.getType().isa<ShapeType>())
+    return failure();
+
+  auto loc = op.getLoc();
+  SmallVector<Value, 4> extentOperands;
+  for (auto extent : op.shape()) {
+    extentOperands.push_back(
+        rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
+  }
+  Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
+  Type indexTy = rewriter.getIndexType();
+  Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+  rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
+  return success();
+}
+
+namespace {
 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
 
@@ -209,6 +242,7 @@ void mlir::populateShapeToStandardConversionPatterns(
   patterns.insert<
       AnyOpConversion,
       BinaryOpConversion<AddOp, AddIOp>,
+      ConstShapeOpConverter,
       BinaryOpConversion<MulOp, MulIOp>,
       GetExtentOpConverter,
       RankOpConverter,
index 9336402..7f875f3 100644 (file)
@@ -111,6 +111,22 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
 
 // -----
 
+// Lower `const_shape` to `tensor_from_elements`.
+// CHECK-LABEL: @const_shape
+// CHECK-SAME: () -> tensor<?xindex>
+func @const_shape() -> tensor<?xindex> {
+  // CHECK: %[[C1:.*]] = constant 1 : index
+  // CHECK: %[[C2:.*]] = constant 2 : index
+  // CHECK: %[[C3:.*]] = constant 3 : index
+  // CHECK: %[[TENSOR3:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]])
+  // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
+  // CHECK: return %[[RESULT]] : tensor<?xindex>
+  %shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
+  return %shape : tensor<?xindex>
+}
+
+// -----
+
 // Lower `any` to its first operand.
 // CHECK-LABEL: @any_of_three
 // CHECK-SAME:  (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>