[MLIR] Add `to/from_extent_tensor` lowering to the standard dialect
authorFrederik Gossen <frgossen@google.com>
Mon, 8 Jun 2020 09:37:43 +0000 (09:37 +0000)
committerFrederik Gossen <frgossen@google.com>
Mon, 8 Jun 2020 09:38:18 +0000 (09:38 +0000)
The operations `to_extent_tensor` and `from_extent_tensor` become no-ops when
lowered to the standard dialect.
This is possible with a lowering from `shape.shape` to `tensor<?xindex>`.

Differential Revision: https://reviews.llvm.org/D81162

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

index 8deb8b8..5c74cdc 100644 (file)
@@ -19,16 +19,16 @@ using namespace mlir;
 namespace {
 
 /// Conversion patterns.
-class SizeToIndexOpConversion
-    : public OpConversionPattern<shape::SizeToIndexOp> {
+class FromExtentTensorOpConversion
+    : public OpConversionPattern<shape::FromExtentTensorOp> {
 public:
-  using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+  using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+  matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    shape::SizeToIndexOpOperandAdaptor transformed(operands);
-    rewriter.replaceOp(op.getOperation(), transformed.arg());
+    shape::FromExtentTensorOpOperandAdaptor transformed(operands);
+    rewriter.replaceOp(op.getOperation(), transformed.input());
     return success();
   }
 };
@@ -47,6 +47,34 @@ public:
   }
 };
 
+class SizeToIndexOpConversion
+    : public OpConversionPattern<shape::SizeToIndexOp> {
+public:
+  using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    shape::SizeToIndexOpOperandAdaptor transformed(operands);
+    rewriter.replaceOp(op.getOperation(), transformed.arg());
+    return success();
+  }
+};
+
+class ToExtentTensorOpConversion
+    : public OpConversionPattern<shape::ToExtentTensorOp> {
+public:
+  using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    shape::ToExtentTensorOpOperandAdaptor transformed(operands);
+    rewriter.replaceOp(op.getOperation(), transformed.input());
+    return success();
+  }
+};
+
 /// Type conversions.
 class ShapeTypeConverter : public TypeConverter {
 public:
@@ -55,6 +83,7 @@ public:
   ShapeTypeConverter(MLIRContext *ctx) {
     // Add default pass-through conversion.
     addConversion([&](Type type) { return type; });
+
     addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
     addConversion([ctx](shape::ShapeType type) {
       return RankedTensorType::get({ShapedType::kDynamicSize},
@@ -99,8 +128,10 @@ void mlir::populateShapeToStandardConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
   // clang-format off
   patterns.insert<
+      FromExtentTensorOpConversion,
       IndexToSizeOpConversion,
-      SizeToIndexOpConversion>(ctx);
+      SizeToIndexOpConversion,
+      ToExtentTensorOpConversion>(ctx);
   // clang-format on
 }
 
index 138a9b2..de420a9 100644 (file)
@@ -39,3 +39,26 @@ func @shape_id(%shape : !shape.shape) -> !shape.shape {
   // CHECK: return %[[SHAPE]] : tensor<?xindex>
   return %shape : !shape.shape
 }
+
+// -----
+
+// Lower `to_extent_tensor` operation to no-op.
+// CHECK-LABEL: @to_extent_tensor
+// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> tensor<?xindex>
+func @to_extent_tensor(%shape : !shape.shape) -> tensor<?xindex> {
+  // CHECK-NEXT: return %[[SHAPE]] : tensor<?xindex>
+  %tensor = "shape.to_extent_tensor"(%shape) : (!shape.shape) -> tensor<?xindex>
+  return %tensor : tensor<?xindex>
+}
+
+// -----
+
+// Lower `from_extent_tensor` operation to no-op.
+// CHECK-LABEL: @from_extent_tensor
+// CHECK-SAME: (%[[TENSOR:.*]]: tensor<?xindex>) -> tensor<?xindex>
+func @from_extent_tensor(%tensor : tensor<?xindex>) -> !shape.shape {
+  // CHECK-NEXT: return %[[TENSOR]] : tensor<?xindex>
+  %shape = "shape.from_extent_tensor"(%tensor)
+      : (tensor<?xindex>) -> !shape.shape
+  return %shape : !shape.shape
+}