[mlir] Restrict to requiring traits when using InferTensorType trait.
authorJacques Pienaar <jpienaar@google.com>
Mon, 11 Oct 2021 21:56:28 +0000 (14:56 -0700)
committerJacques Pienaar <jpienaar@google.com>
Mon, 11 Oct 2021 21:56:28 +0000 (14:56 -0700)
Avoids running into segfaults accidentally.

Differential Revision: https://reviews.llvm.org/D110297

mlir/include/mlir/Interfaces/InferTypeOpInterface.h

index e153e66..c4f8f2d 100644 (file)
@@ -246,10 +246,24 @@ LogicalResult verifyInferredResultTypes(Operation *op);
 } // namespace detail
 
 namespace OpTrait {
+template <typename ConcreteType>
+class InferTensorType;
+} // namespace OpTrait
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
+
+namespace mlir {
+namespace OpTrait {
 
 /// Tensor type inference trait that constructs a tensor from the inferred
 /// shape and elemental types.
-/// Requires: Op implements functions of InferShapedTypeOpInterface.
+/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
+///   Less strict is possible (e.g., implements inferReturnTypeComponents and
+///   these always populates all element types and shapes or fails, but this\
+///   trait is currently only used where the interfaces are, so keep it
+///   restricted for now).
 template <typename ConcreteType>
 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
 public:
@@ -258,6 +272,12 @@ public:
                    ValueRange operands, DictionaryAttr attributes,
                    RegionRange regions,
                    SmallVectorImpl<Type> &inferredReturnTypes) {
+    static_assert(
+        ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
+        "requires InferShapedTypeOpInterface to ensure succesful invocation");
+    static_assert(
+        ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
+        "requires InferTypeOpInterface to ensure succesful invocation");
     return ::mlir::detail::inferReturnTensorTypes(
         ConcreteType::inferReturnTypeComponents, context, location, operands,
         attributes, regions, inferredReturnTypes);
@@ -267,7 +287,4 @@ public:
 } // namespace OpTrait
 } // namespace mlir
 
-/// Include the generated interface declarations.
-#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
-
 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_