From 2f38ffbcb300fa75bd97ec70a446540bcba4ee7c Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Mon, 17 Dec 2018 16:07:14 -0800 Subject: [PATCH] Add a correctness check for C++ types to custom operators (#15247) Summary: The JIT uses `int64_t` for its integer type and `double` for its floating point type, but users quite often want to write `int` or `float` and that currently fails in not-so-nice ways for custom ops. This PR adds a simple `static_assert` to catch these common failure cases. zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/15247 Differential Revision: D13493941 Pulled By: goldsborough fbshipit-source-id: c1cd0d10ab5838c75f167c0bdb57e45a0bc1344e --- torch/csrc/jit/custom_operator.h | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/custom_operator.h b/torch/csrc/jit/custom_operator.h index 051a7bf..33e3865 100644 --- a/torch/csrc/jit/custom_operator.h +++ b/torch/csrc/jit/custom_operator.h @@ -9,10 +9,31 @@ #include #include -namespace torch { namespace jit { +namespace torch { +namespace jit { namespace detail { +/// Checks the static C++ type `T` for correctness to catch common error cases. +template +void checkStaticTypes() { + // Give nice error messages for some of the common error cases. + // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT + static_assert( + !std::is_integral::value || std::is_same::value, + "INVALID TYPE: Only int64_t is supported as an integral argument type"); + static_assert( + !std::is_same::value, + "INVALID TYPE: float is not supported as an argument type, use double instead"); +} + +template +void checkStaticTypes() { + checkStaticTypes(); + checkStaticTypes(); +} + template std::vector createArgumentVectorFromTypes(Indices indices) { + checkStaticTypes...>(); // Arguments are named "_" return {Argument("_" + std::to_string(Is), getTypePtr>())...}; } @@ -33,6 +54,7 @@ std::vector createReturns(std::tuple* tuple) { /// Create a single-element `vector` for simple (non-tuple) return types. template std::vector createReturns(ReturnType*) { + checkStaticTypes>(); return {Argument("_1", getTypePtr>())}; } @@ -50,9 +72,11 @@ std::vector createArgumentVectorFromTraits(Indices indices) { template FunctionSchema createFunctionSchemaFromTraits(const std::string& name) { using ReturnType = typename FunctionTraits::return_type; + auto arguments = createArgumentVectorFromTraits( typename MakeIndices::indices{}); auto returns = createReturns(static_cast(nullptr)); + return {name, arguments, returns}; } -- 2.7.4