[Orc] Add some static-assert checks to improve the error messages for RPC calls
authorLang Hames <lhames@gmail.com>
Wed, 21 Dec 2016 00:59:33 +0000 (00:59 +0000)
committerLang Hames <lhames@gmail.com>
Wed, 21 Dec 2016 00:59:33 +0000 (00:59 +0000)
and handler registrations.

Also add a unit test for alternate-type serialization/deserialization.

llvm-svn: 290223

llvm/include/llvm/ExecutionEngine/Orc/RPCUtils.h
llvm/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp

index 68ce2c0..1355611 100644 (file)
@@ -82,6 +82,17 @@ std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex;
 template <typename DerivedFunc, typename RetT, typename... ArgTs>
 std::string Function<DerivedFunc, RetT(ArgTs...)>::Name;
 
+/// Provides a typedef for a tuple containing the decayed argument types.
+template <typename T>
+class FunctionArgsTuple;
+
+template <typename RetT, typename... ArgTs>
+class FunctionArgsTuple<RetT(ArgTs...)> {
+public:
+  using Type = std::tuple<typename std::decay<
+                 typename std::remove_reference<ArgTs>::type>::type...>;
+};
+
 /// Allocates RPC function ids during autonegotiation.
 /// Specializations of this class must provide four members:
 ///
@@ -349,8 +360,7 @@ public:
   using ReturnType = RetT;
 
   // A std::tuple wrapping the handler arguments.
-  using ArgStorage = std::tuple<typename std::decay<
-      typename std::remove_reference<ArgTs>::type>::type...>;
+  using ArgStorage = typename FunctionArgsTuple<RetT(ArgTs...)>::Type;
 
   // Call the given handler with the given arguments.
   template <typename HandlerT>
@@ -589,6 +599,84 @@ private:
   std::vector<SequenceNumberT> FreeSequenceNumbers;
 };
 
+// Checks that predicate P holds for each corresponding pair of type arguments
+// from T1 and T2 tuple.
+template <template<class, class> class P, typename T1Tuple,
+          typename T2Tuple>
+class RPCArgTypeCheckHelper;
+
+template <template<class, class> class P>
+class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
+public:
+  static const bool value = true;
+};
+
+template <template<class, class> class P, typename T, typename... Ts,
+          typename U, typename... Us>
+class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
+public:
+  static const bool value =
+    P<T, U>::value &&
+    RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
+};
+
+template <template<class, class> class P, typename T1Sig, typename T2Sig>
+class RPCArgTypeCheck {
+public:
+
+  using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
+  using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
+
+  static_assert(std::tuple_size<T1Tuple>::value >= std::tuple_size<T2Tuple>::value,
+                "Too many arguments to RPC call");
+  static_assert(std::tuple_size<T1Tuple>::value <= std::tuple_size<T2Tuple>::value,
+                "Too few arguments to RPC call");
+
+  static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
+};
+
+template <typename ChannelT, typename WireT, typename ConcreteT>
+class CanSerialize {
+private:
+  using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
+
+  template <typename T>
+  static std::true_type
+  check(typename std::enable_if<
+                   std::is_same<
+                     decltype(T::serialize(std::declval<ChannelT&>(),
+                                           std::declval<const ConcreteT&>())),
+                     Error>::value,
+                   void*>::type);
+
+  template <typename>
+  static std::false_type check(...);
+
+public:
+  static const bool value = decltype(check<S>(0))::value;
+};
+
+template <typename ChannelT, typename WireT, typename ConcreteT>
+class CanDeserialize {
+private:
+  using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
+
+  template <typename T>
+  static std::true_type
+  check(typename std::enable_if<
+                   std::is_same<
+                     decltype(T::deserialize(std::declval<ChannelT&>(),
+                                             std::declval<ConcreteT&>())),
+                     Error>::value,
+                   void*>::type);
+
+  template <typename>
+  static std::false_type check(...);
+
+public:
+  static const bool value = decltype(check<S>(0))::value;
+};
+
 /// Contains primitive utilities for defining, calling and handling calls to
 /// remote procedures. ChannelT is a bidirectional stream conforming to the
 /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
@@ -603,6 +691,7 @@ template <typename ImplT, typename ChannelT, typename FunctionIdT,
           typename SequenceNumberT>
 class RPCBase {
 protected:
+
   class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
   public:
     static const char *getName() { return "__orc_rpc$invalid"; }
@@ -619,6 +708,31 @@ protected:
     static const char *getName() { return "__orc_rpc$negotiate"; }
   };
 
+  // Helper predicate for testing for the presence of SerializeTraits
+  // serializers.
+  template <typename WireT, typename ConcreteT>
+  class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
+  public:
+    using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
+
+    static_assert(value, "Missing serializer for argument (Can't serialize the "
+                         "first template type argument of CanSerializeCheck "
+                         "from the second)");
+  };
+
+  // Helper predicate for testing for the presence of SerializeTraits
+  // deserializers.
+  template <typename WireT, typename ConcreteT>
+  class CanDeserializeCheck
+    : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
+  public:
+    using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
+
+    static_assert(value, "Missing deserializer for argument (Can't deserialize "
+                         "the second template type argument of "
+                         "CanDeserializeCheck from the first)");
+  };
+
 public:
   /// Construct an RPC instance on a channel.
   RPCBase(ChannelT &C, bool LazyAutoNegotiation)
@@ -643,6 +757,13 @@ public:
   /// with an error if the return value is abandoned due to a channel error.
   template <typename Func, typename HandlerT, typename... ArgTs>
   Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
+
+    static_assert(
+      detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
+                              void(ArgTs...)>
+        ::value,
+      "");
+
     // Look up the function ID.
     FunctionIdT FnId;
     if (auto FnIdOrErr = getRemoteFunctionId<Func>())
@@ -738,6 +859,14 @@ protected:
   /// autonegotiation and execution.
   template <typename Func, typename HandlerT>
   void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) {
+
+    static_assert(
+      detail::RPCArgTypeCheck<CanDeserializeCheck,
+                              typename Func::Type,
+                              typename detail::HandlerTraits<HandlerT>::Type>
+        ::value,
+      "");
+
     FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
     LocalFunctionIds[Func::getPrototype()] = NewFnId;
     Handlers[NewFnId] =
index 9ace46d..1914696 100644 (file)
@@ -58,6 +58,40 @@ private:
   Queue &OutQueue;
 };
 
+class RPCFoo {};
+
+template <>
+class RPCTypeName<RPCFoo> {
+public:
+  static const char* getName() { return "RPCFoo"; }
+};
+
+template <>
+class SerializationTraits<QueueChannel, RPCFoo, RPCFoo> {
+public:
+  static Error serialize(QueueChannel&, const RPCFoo&) {
+    return Error::success();
+  }
+
+  static Error deserialize(QueueChannel&, RPCFoo&) {
+    return Error::success();
+  }
+};
+
+class RPCBar {};
+
+template <>
+class SerializationTraits<QueueChannel, RPCFoo, RPCBar> {
+public:
+  static Error serialize(QueueChannel&, const RPCBar&) {
+    return Error::success();
+  }
+
+  static Error deserialize(QueueChannel&, RPCBar&) {
+    return Error::success();
+  }
+};
+
 class DummyRPCAPI {
 public:
 
@@ -79,6 +113,12 @@ public:
   public:
     static const char* getName() { return "AllTheTypes"; }
   };
+
+  class CustomType : public Function<CustomType, RPCFoo(RPCFoo)> {
+  public:
+    static const char* getName() { return "CustomType"; }
+  };
+
 };
 
 class DummyRPCEndpoint : public DummyRPCAPI,
@@ -244,3 +284,89 @@ TEST(DummyRPC, TestSerialization) {
 
   ServerThread.join();
 }
+
+TEST(DummyRPC, TestCustomType) {
+  Queue Q1, Q2;
+  DummyRPCEndpoint Client(Q1, Q2);
+  DummyRPCEndpoint Server(Q2, Q1);
+
+  std::thread ServerThread([&]() {
+      Server.addHandler<DummyRPCAPI::CustomType>(
+          [](RPCFoo F) {});
+
+      {
+        // Poke the server to handle the negotiate call.
+        auto Err = Server.handleOne();
+        EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate";
+      }
+
+      {
+        // Poke the server to handle the CustomType call.
+        auto Err = Server.handleOne();
+        EXPECT_FALSE(!!Err) << "Server failed to handle call to RPCFoo(RPCFoo)";
+      }
+  });
+
+  {
+    // Make an async call.
+    auto Err = Client.callAsync<DummyRPCAPI::CustomType>(
+        [](Expected<RPCFoo> FOrErr) {
+          EXPECT_TRUE(!!FOrErr)
+            << "Async RPCFoo(RPCFoo) response handler failed";
+          return Error::success();
+        }, RPCFoo());
+    EXPECT_FALSE(!!Err) << "Client.callAsync failed for RPCFoo(RPCFoo)";
+  }
+
+  {
+    // Poke the client to process the result of the RPCFoo() call.
+    auto Err = Client.handleOne();
+    EXPECT_FALSE(!!Err)
+      << "Client failed to handle response from RPCFoo(RPCFoo)";
+  }
+
+  ServerThread.join();
+}
+
+TEST(DummyRPC, TestWithAltCustomType) {
+  Queue Q1, Q2;
+  DummyRPCEndpoint Client(Q1, Q2);
+  DummyRPCEndpoint Server(Q2, Q1);
+
+  std::thread ServerThread([&]() {
+      Server.addHandler<DummyRPCAPI::CustomType>(
+          [](RPCBar F) {});
+
+      {
+        // Poke the server to handle the negotiate call.
+        auto Err = Server.handleOne();
+        EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate";
+      }
+
+      {
+        // Poke the server to handle the CustomType call.
+        auto Err = Server.handleOne();
+        EXPECT_FALSE(!!Err) << "Server failed to handle call to RPCFoo(RPCFoo)";
+      }
+  });
+
+  {
+    // Make an async call.
+    auto Err = Client.callAsync<DummyRPCAPI::CustomType>(
+        [](Expected<RPCBar> FOrErr) {
+          EXPECT_TRUE(!!FOrErr)
+            << "Async RPCFoo(RPCFoo) response handler failed";
+          return Error::success();
+        }, RPCBar());
+    EXPECT_FALSE(!!Err) << "Client.callAsync failed for RPCFoo(RPCFoo)";
+  }
+
+  {
+    // Poke the client to process the result of the RPCFoo() call.
+    auto Err = Client.handleOne();
+    EXPECT_FALSE(!!Err)
+      << "Client failed to handle response from RPCFoo(RPCFoo)";
+  }
+
+  ServerThread.join();
+}