[MLIR] Add type conversion for `shape.shape`
authorFrederik Gossen <frgossen@google.com>
Mon, 8 Jun 2020 09:33:24 +0000 (09:33 +0000)
committerFrederik Gossen <frgossen@google.com>
Mon, 8 Jun 2020 09:34:03 +0000 (09:34 +0000)
Convert `shape.shape` to `tensor<?xindex>` when lowering the `shape` to the
`std` dialect.

Differential Revision: https://reviews.llvm.org/D81161

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

index 1ddd681..8deb8b8 100644 (file)
@@ -56,6 +56,10 @@ public:
     // Add default pass-through conversion.
     addConversion([&](Type type) { return type; });
     addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
+    addConversion([ctx](shape::ShapeType type) {
+      return RankedTensorType::get({ShapedType::kDynamicSize},
+                                   IndexType::get(ctx));
+    });
   }
 };
 
index c27b408..138a9b2 100644 (file)
@@ -29,3 +29,13 @@ func @index_to_size(%index : index) -> !shape.size {
   %size = shape.index_to_size %index
   return %size : !shape.size
 }
+
+// -----
+
+// Convert `shape` to `tensor<?xindex>` type.
+// CHECK-LABEL: @shape_id
+// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>)
+func @shape_id(%shape : !shape.shape) -> !shape.shape {
+  // CHECK: return %[[SHAPE]] : tensor<?xindex>
+  return %shape : !shape.shape
+}