} // namespace detail
namespace OpTrait {
+template <typename ConcreteType>
+class InferTensorType;
+} // namespace OpTrait
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
+
+namespace mlir {
+namespace OpTrait {
/// Tensor type inference trait that constructs a tensor from the inferred
/// shape and elemental types.
-/// Requires: Op implements functions of InferShapedTypeOpInterface.
+/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
+/// Less strict is possible (e.g., implements inferReturnTypeComponents and
+/// these always populates all element types and shapes or fails, but this\
+/// trait is currently only used where the interfaces are, so keep it
+/// restricted for now).
template <typename ConcreteType>
class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
public:
ValueRange operands, DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
+ static_assert(
+ ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
+ "requires InferShapedTypeOpInterface to ensure succesful invocation");
+ static_assert(
+ ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
+ "requires InferTypeOpInterface to ensure succesful invocation");
return ::mlir::detail::inferReturnTensorTypes(
ConcreteType::inferReturnTypeComponents, context, location, operands,
attributes, regions, inferredReturnTypes);
} // namespace OpTrait
} // namespace mlir
-/// Include the generated interface declarations.
-#include "mlir/Interfaces/InferTypeOpInterface.h.inc"
-
#endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_