From 46abc76f47425df1f2c28e421d18bded7a402a06 Mon Sep 17 00:00:00 2001 From: Lang Hames Date: Wed, 21 Dec 2016 00:59:33 +0000 Subject: [PATCH] [Orc] Add some static-assert checks to improve the error messages for RPC calls and handler registrations. Also add a unit test for alternate-type serialization/deserialization. llvm-svn: 290223 --- llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h | 133 ++++++++++++++++++++- .../unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp | 126 +++++++++++++++++++ 2 files changed, 257 insertions(+), 2 deletions(-) diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 68ce2c0..1355611 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -82,6 +82,17 @@ std::mutex Function::NameMutex; template std::string Function::Name; +/// Provides a typedef for a tuple containing the decayed argument types. +template +class FunctionArgsTuple; + +template +class FunctionArgsTuple { +public: + using Type = std::tuple::type>::type...>; +}; + /// Allocates RPC function ids during autonegotiation. /// Specializations of this class must provide four members: /// @@ -349,8 +360,7 @@ public: using ReturnType = RetT; // A std::tuple wrapping the handler arguments. - using ArgStorage = std::tuple::type>::type...>; + using ArgStorage = typename FunctionArgsTuple::Type; // Call the given handler with the given arguments. template @@ -589,6 +599,84 @@ private: std::vector FreeSequenceNumbers; }; +// Checks that predicate P holds for each corresponding pair of type arguments +// from T1 and T2 tuple. +template class P, typename T1Tuple, + typename T2Tuple> +class RPCArgTypeCheckHelper; + +template class P> +class RPCArgTypeCheckHelper, std::tuple<>> { +public: + static const bool value = true; +}; + +template class P, typename T, typename... Ts, + typename U, typename... Us> +class RPCArgTypeCheckHelper, std::tuple> { +public: + static const bool value = + P::value && + RPCArgTypeCheckHelper, std::tuple>::value; +}; + +template class P, typename T1Sig, typename T2Sig> +class RPCArgTypeCheck { +public: + + using T1Tuple = typename FunctionArgsTuple::Type; + using T2Tuple = typename FunctionArgsTuple::Type; + + static_assert(std::tuple_size::value >= std::tuple_size::value, + "Too many arguments to RPC call"); + static_assert(std::tuple_size::value <= std::tuple_size::value, + "Too few arguments to RPC call"); + + static const bool value = RPCArgTypeCheckHelper::value; +}; + +template +class CanSerialize { +private: + using S = SerializationTraits; + + template + static std::true_type + check(typename std::enable_if< + std::is_same< + decltype(T::serialize(std::declval(), + std::declval())), + Error>::value, + void*>::type); + + template + static std::false_type check(...); + +public: + static const bool value = decltype(check(0))::value; +}; + +template +class CanDeserialize { +private: + using S = SerializationTraits; + + template + static std::true_type + check(typename std::enable_if< + std::is_same< + decltype(T::deserialize(std::declval(), + std::declval())), + Error>::value, + void*>::type); + + template + static std::false_type check(...); + +public: + static const bool value = decltype(check(0))::value; +}; + /// Contains primitive utilities for defining, calling and handling calls to /// remote procedures. ChannelT is a bidirectional stream conforming to the /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure @@ -603,6 +691,7 @@ template class RPCBase { protected: + class OrcRPCInvalid : public Function { public: static const char *getName() { return "__orc_rpc$invalid"; } @@ -619,6 +708,31 @@ protected: static const char *getName() { return "__orc_rpc$negotiate"; } }; + // Helper predicate for testing for the presence of SerializeTraits + // serializers. + template + class CanSerializeCheck : detail::CanSerialize { + public: + using detail::CanSerialize::value; + + static_assert(value, "Missing serializer for argument (Can't serialize the " + "first template type argument of CanSerializeCheck " + "from the second)"); + }; + + // Helper predicate for testing for the presence of SerializeTraits + // deserializers. + template + class CanDeserializeCheck + : detail::CanDeserialize { + public: + using detail::CanDeserialize::value; + + static_assert(value, "Missing deserializer for argument (Can't deserialize " + "the second template type argument of " + "CanDeserializeCheck from the first)"); + }; + public: /// Construct an RPC instance on a channel. RPCBase(ChannelT &C, bool LazyAutoNegotiation) @@ -643,6 +757,13 @@ public: /// with an error if the return value is abandoned due to a channel error. template Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) { + + static_assert( + detail::RPCArgTypeCheck + ::value, + ""); + // Look up the function ID. FunctionIdT FnId; if (auto FnIdOrErr = getRemoteFunctionId()) @@ -738,6 +859,14 @@ protected: /// autonegotiation and execution. template void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) { + + static_assert( + detail::RPCArgTypeCheck::Type> + ::value, + ""); + FunctionIdT NewFnId = FnIdAllocator.template allocate(); LocalFunctionIds[Func::getPrototype()] = NewFnId; Handlers[NewFnId] = diff --git a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 9ace46d..1914696 100644 --- a/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -58,6 +58,40 @@ private: Queue &OutQueue; }; +class RPCFoo {}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "RPCFoo"; } +}; + +template <> +class SerializationTraits { +public: + static Error serialize(QueueChannel&, const RPCFoo&) { + return Error::success(); + } + + static Error deserialize(QueueChannel&, RPCFoo&) { + return Error::success(); + } +}; + +class RPCBar {}; + +template <> +class SerializationTraits { +public: + static Error serialize(QueueChannel&, const RPCBar&) { + return Error::success(); + } + + static Error deserialize(QueueChannel&, RPCBar&) { + return Error::success(); + } +}; + class DummyRPCAPI { public: @@ -79,6 +113,12 @@ public: public: static const char* getName() { return "AllTheTypes"; } }; + + class CustomType : public Function { + public: + static const char* getName() { return "CustomType"; } + }; + }; class DummyRPCEndpoint : public DummyRPCAPI, @@ -244,3 +284,89 @@ TEST(DummyRPC, TestSerialization) { ServerThread.join(); } + +TEST(DummyRPC, TestCustomType) { + Queue Q1, Q2; + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); + + std::thread ServerThread([&]() { + Server.addHandler( + [](RPCFoo F) {}); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the CustomType call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to RPCFoo(RPCFoo)"; + } + }); + + { + // Make an async call. + auto Err = Client.callAsync( + [](Expected FOrErr) { + EXPECT_TRUE(!!FOrErr) + << "Async RPCFoo(RPCFoo) response handler failed"; + return Error::success(); + }, RPCFoo()); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for RPCFoo(RPCFoo)"; + } + + { + // Poke the client to process the result of the RPCFoo() call. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) + << "Client failed to handle response from RPCFoo(RPCFoo)"; + } + + ServerThread.join(); +} + +TEST(DummyRPC, TestWithAltCustomType) { + Queue Q1, Q2; + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); + + std::thread ServerThread([&]() { + Server.addHandler( + [](RPCBar F) {}); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the CustomType call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to RPCFoo(RPCFoo)"; + } + }); + + { + // Make an async call. + auto Err = Client.callAsync( + [](Expected FOrErr) { + EXPECT_TRUE(!!FOrErr) + << "Async RPCFoo(RPCFoo) response handler failed"; + return Error::success(); + }, RPCBar()); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for RPCFoo(RPCFoo)"; + } + + { + // Poke the client to process the result of the RPCFoo() call. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) + << "Client failed to handle response from RPCFoo(RPCFoo)"; + } + + ServerThread.join(); +} -- 2.7.4