[mlir] Make ShapedTypeComponents contructible from ShapeAdaptor
authorChia-hung Duan <chiahungduan@google.com>
Tue, 8 Mar 2022 20:31:06 +0000 (20:31 +0000)
committerChia-hung Duan <chiahungduan@google.com>
Wed, 9 Mar 2022 03:35:24 +0000 (03:35 +0000)
ValueShapeRange::getShape() returns ShapeAdaptor rather than ShapedType
and ShapeAdaptor allows implicit conversion to bool. It ends up that
ShapedTypeComponents can be constructed with ShapeAdaptor incorrectly.
The reason is that the type trait
  std::is_constructible<ShapeStorageT, Arg>::value
is fulfilled because ShapeAdaptor can be converted to bool and it can be
used to construct ShapeStorageT. In the end, we won't give any warning
or error message when doing things like
  inferredReturnShapes.emplace_back(valueShapeRange.getShape(0));

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D120845

mlir/include/mlir/Interfaces/InferTypeOpInterface.h

index 3ddd2c1..eaf1f6c 100644 (file)
 
 namespace mlir {
 
+class ShapedTypeComponents;
 using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>;
 
-/// ShapedTypeComponents that represents the components of a ShapedType.
-/// The components consist of
-///  - A ranked or unranked shape with the dimension specification match those
-///    of ShapeType's getShape() (e.g., dynamic dimension represented using
-///    ShapedType::kDynamicSize)
-///  - A element type, may be unset (nullptr)
-///  - A attribute, may be unset (nullptr)
-/// Used by ShapedType type inferences.
-class ShapedTypeComponents {
-  /// Internal storage type for shape.
-  using ShapeStorageT = SmallVector<int64_t, 3>;
-
-public:
-  /// Default construction is an unranked shape.
-  ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
-  ShapedTypeComponents(Type elementType)
-      : elementType(elementType), attr(nullptr), ranked(false) {}
-  ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
-    ranked = shapedType.hasRank();
-    elementType = shapedType.getElementType();
-    if (ranked)
-      dims = llvm::to_vector<4>(shapedType.getShape());
-  }
-  template <typename Arg, typename = typename std::enable_if_t<
-                              std::is_constructible<ShapeStorageT, Arg>::value>>
-  ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
-                       Attribute attr = nullptr)
-      : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
-        ranked(true) {}
-  ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
-                       Attribute attr = nullptr)
-      : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
-        ranked(true) {}
-
-  /// Return the dimensions of the shape.
-  /// Requires: shape is ranked.
-  ArrayRef<int64_t> getDims() const {
-    assert(ranked && "requires ranked shape");
-    return dims;
-  }
-
-  /// Return whether the shape has a rank.
-  bool hasRank() const { return ranked; };
-
-  /// Return the element type component.
-  Type getElementType() const { return elementType; };
-
-  /// Return the raw attribute component.
-  Attribute getAttribute() const { return attr; };
-
-private:
-  friend class ShapeAdaptor;
-
-  ShapeStorageT dims;
-  Type elementType;
-  Attribute attr;
-  bool ranked{false};
-};
-
 /// Adaptor class to abstract the differences between whether value is from
 /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
 class ShapeAdaptor {
@@ -137,7 +79,7 @@ public:
   int64_t getNumElements() const;
 
   /// Returns whether valid (non-null) shape.
-  operator bool() const { return !val.isNull(); }
+  explicit operator bool() const { return !val.isNull(); }
 
   /// Dumps textual repesentation to stderr.
   void dump() const;
@@ -148,6 +90,71 @@ private:
   PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr;
 };
 
+/// ShapedTypeComponents that represents the components of a ShapedType.
+/// The components consist of
+///  - A ranked or unranked shape with the dimension specification match those
+///    of ShapeType's getShape() (e.g., dynamic dimension represented using
+///    ShapedType::kDynamicSize)
+///  - A element type, may be unset (nullptr)
+///  - A attribute, may be unset (nullptr)
+/// Used by ShapedType type inferences.
+class ShapedTypeComponents {
+  /// Internal storage type for shape.
+  using ShapeStorageT = SmallVector<int64_t, 3>;
+
+public:
+  /// Default construction is an unranked shape.
+  ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
+  ShapedTypeComponents(Type elementType)
+      : elementType(elementType), attr(nullptr), ranked(false) {}
+  ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
+    ranked = shapedType.hasRank();
+    elementType = shapedType.getElementType();
+    if (ranked)
+      dims = llvm::to_vector<4>(shapedType.getShape());
+  }
+  ShapedTypeComponents(ShapeAdaptor adaptor) : attr(nullptr) {
+    ranked = adaptor.hasRank();
+    elementType = adaptor.getElementType();
+    if (ranked)
+      adaptor.getDims(*this);
+  }
+  template <typename Arg, typename = typename std::enable_if_t<
+                              std::is_constructible<ShapeStorageT, Arg>::value>>
+  ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
+                       Attribute attr = nullptr)
+      : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
+        ranked(true) {}
+  ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
+                       Attribute attr = nullptr)
+      : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
+        ranked(true) {}
+
+  /// Return the dimensions of the shape.
+  /// Requires: shape is ranked.
+  ArrayRef<int64_t> getDims() const {
+    assert(ranked && "requires ranked shape");
+    return dims;
+  }
+
+  /// Return whether the shape has a rank.
+  bool hasRank() const { return ranked; };
+
+  /// Return the element type component.
+  Type getElementType() const { return elementType; };
+
+  /// Return the raw attribute component.
+  Attribute getAttribute() const { return attr; };
+
+private:
+  friend class ShapeAdaptor;
+
+  ShapeStorageT dims;
+  Type elementType;
+  Attribute attr;
+  bool ranked{false};
+};
+
 /// Range of values and shapes (corresponding effectively to Shapes dialect's
 /// ValueShape type concept).
 // Currently this exposes the Value (of operands) and Type of the Value. This is