Change script for apply upstream code
[platform/upstream/connectedhomeip.git] / third_party / pigweed / repo / pw_rpc / nanopb / public / pw_rpc / test_method_context.h
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15
16 #include <tuple>
17 #include <utility>
18
19 #include "pw_assert/assert.h"
20 #include "pw_containers/vector.h"
21 #include "pw_preprocessor/arguments.h"
22 #include "pw_rpc/channel.h"
23 #include "pw_rpc/internal/hash.h"
24 #include "pw_rpc/internal/nanopb_method.h"
25 #include "pw_rpc/internal/packet.h"
26 #include "pw_rpc/internal/server.h"
27 #include "pw_rpc/internal/service_method_traits.h"
28
29 namespace pw::rpc {
30
31 // Declares a context object that may be used to invoke an RPC. The context is
32 // declared with a pointer to the service member function (&Service::Method).
33 // The RPC can then be invoked with the call method.
34 //
35 // For a unary RPC, context.call(request) returns the status, and the response
36 // struct can be accessed via context.response().
37 //
38 //   pw::rpc::TestMethodContext<&my::CoolService::TheMethod> context;
39 //   EXPECT_EQ(Status::Ok(), context.call({.some_arg = 123}));
40 //   EXPECT_EQ(500, context.response().some_response_value);
41 //
42 // For a server streaming RPC, context.call(request) invokes the method. As in a
43 // normal RPC, the method completes when the ServerWriter's Finish method is
44 // called (or it goes out of scope).
45 //
46 //   pw::rpc::TestMethodContext<&my::CoolService::TheStreamingMethod> context;
47 //   context.call({.some_arg = 123});
48 //
49 //   EXPECT_TRUE(context.done());  // Check that the RPC completed
50 //   EXPECT_EQ(Status::Ok(), context.status());  // Check the status
51 //
52 //   EXPECT_EQ(3u, context.responses().size());
53 //   EXPECT_EQ(123, context.responses()[0].value); // check individual responses
54 //
55 //   for (const MyResponse& response : context.responses()) {
56 //     // iterate over the responses
57 //   }
58 //
59 // TestMethodContext forwards its constructor arguments to the underlying
60 // serivce. For example:
61 //
62 //   pw::rpc::TestMethodContext<&MyService::Go> context(serivce, args);
63 //
64 // pw::rpc::TestMethodContext takes two optional template arguments:
65 //
66 //   size_t max_responses: maximum responses to store; ignored unless streaming
67 //   size_t output_size_bytes: buffer size; must be large enough for a packet
68 //
69 // Example:
70 //
71 //   pw::rpc::TestMethodContext<&MyService::BestMethod, 3, 256> context;
72 //   ASSERT_EQ(3u, context.responses().max_size());
73 //
74 template <auto method, size_t max_responses = 4, size_t output_size_bytes = 128>
75 class TestMethodContext;
76
77 // Internal classes that implement TestMethodContext.
78 namespace internal::test {
79
80 // A ChannelOutput implementation that stores the outgoing payloads and status.
81 template <typename Response>
82 class MessageOutput final : public ChannelOutput {
83  public:
84   MessageOutput(const internal::NanopbMethod& method,
85                 Vector<Response>& responses,
86                 std::span<std::byte> buffer)
87       : ChannelOutput("internal::test::MessageOutput"),
88         method_(method),
89         responses_(responses),
90         buffer_(buffer) {
91     clear();
92   }
93
94   Status last_status() const { return last_status_; }
95   void set_last_status(Status status) { last_status_ = status; }
96
97   size_t total_responses() const { return total_responses_; }
98
99   bool stream_ended() const { return stream_ended_; }
100
101   void clear();
102
103  private:
104   std::span<std::byte> AcquireBuffer() override { return buffer_; }
105
106   Status SendAndReleaseBuffer(size_t size) override;
107
108   const internal::NanopbMethod& method_;
109   Vector<Response>& responses_;
110   std::span<std::byte> buffer_;
111   size_t total_responses_;
112   bool stream_ended_;
113   Status last_status_;
114 };
115
116 // Collects everything needed to invoke a particular RPC.
117 template <auto method, size_t max_responses, size_t output_size>
118 struct InvocationContext {
119   using Request = internal::Request<method>;
120   using Response = internal::Response<method>;
121
122   template <typename... Args>
123   InvocationContext(Args&&... args)
124       : output(ServiceMethodTraits<method>::method(), responses, buffer),
125         channel(Channel::Create<123>(&output)),
126         server(std::span(&channel, 1)),
127         service(std::forward<Args>(args)...),
128         call(static_cast<internal::Server&>(server),
129              static_cast<internal::Channel&>(channel),
130              service,
131              ServiceMethodTraits<method>::method()) {}
132
133   MessageOutput<Response> output;
134
135   rpc::Channel channel;
136   rpc::Server server;
137   typename ServiceMethodTraits<method>::Service service;
138   Vector<Response, max_responses> responses;
139   std::array<std::byte, output_size> buffer = {};
140
141   internal::ServerCall call;
142 };
143
144 // Method invocation context for a unary RPC. Returns the status in call() and
145 // provides the response through the response() method.
146 template <auto method, size_t output_size>
147 class UnaryContext {
148  private:
149   InvocationContext<method, 1, output_size> ctx_;
150
151  public:
152   using Request = typename decltype(ctx_)::Request;
153   using Response = typename decltype(ctx_)::Response;
154
155   template <typename... Args>
156   UnaryContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
157
158   // Invokes the RPC with the provided request. Returns the status.
159   Status call(const Request& request) {
160     ctx_.output.clear();
161     ctx_.responses.emplace_back();
162     ctx_.responses.back() = {};
163     return (ctx_.service.*method)(
164         ctx_.call.context(), request, ctx_.responses.back());
165   }
166
167   // Gives access to the RPC's response.
168   const Response& response() const {
169     PW_CHECK_UINT_GT(ctx_.responses.size(), 0);
170     return ctx_.responses.back();
171   }
172 };
173
174 // Method invocation context for a server streaming RPC.
175 template <auto method, size_t max_responses, size_t output_size>
176 class ServerStreamingContext {
177  private:
178   InvocationContext<method, max_responses, output_size> ctx_;
179
180  public:
181   using Request = typename decltype(ctx_)::Request;
182   using Response = typename decltype(ctx_)::Response;
183
184   template <typename... Args>
185   ServerStreamingContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
186
187   // Invokes the RPC with the provided request.
188   void call(const Request& request) {
189     ctx_.output.clear();
190     internal::BaseServerWriter server_writer(ctx_.call);
191     return (ctx_.service.*method)(
192         ctx_.call.context(),
193         request,
194         static_cast<ServerWriter<Response>&>(server_writer));
195   }
196
197   // Returns a server writer which writes responses into the context's buffer.
198   // This should not be called alongside call(); use one or the other.
199   ServerWriter<Response> writer() {
200     ctx_.output.clear();
201     internal::BaseServerWriter server_writer(ctx_.call);
202     return std::move(static_cast<ServerWriter<Response>&>(server_writer));
203   }
204
205   // Returns the responses that have been recorded. The maximum number of
206   // responses is responses().max_size(). responses().back() is always the most
207   // recent response, even if total_responses() > responses().max_size().
208   const Vector<Response>& responses() const { return ctx_.responses; }
209
210   // The total number of responses sent, which may be larger than
211   // responses.max_size().
212   size_t total_responses() const { return ctx_.output.total_responses(); }
213
214   // True if the stream has terminated.
215   bool done() const { return ctx_.output.stream_ended(); }
216
217   // The status of the stream. Only valid if done() is true.
218   Status status() const {
219     PW_CHECK(done());
220     return ctx_.output.last_status();
221   }
222 };
223
224 // Alias to select the type of the context object to use based on which type of
225 // RPC it is for.
226 template <auto method, size_t responses, size_t output_size>
227 using Context = std::tuple_element_t<
228     static_cast<size_t>(internal::RpcTraits<decltype(method)>::kType),
229     std::tuple<
230         internal::test::UnaryContext<method, output_size>,
231         internal::test::ServerStreamingContext<method, responses, output_size>
232         // TODO(hepler): Support client and bidi streaming
233         >>;
234
235 template <typename Response>
236 void MessageOutput<Response>::clear() {
237   responses_.clear();
238   total_responses_ = 0;
239   stream_ended_ = false;
240   last_status_ = Status::Unknown();
241 }
242
243 template <typename Response>
244 Status MessageOutput<Response>::SendAndReleaseBuffer(size_t size) {
245   PW_CHECK(!stream_ended_);
246
247   if (size == 0u) {
248     return Status::Ok();
249   }
250
251   Result<internal::Packet> result =
252       internal::Packet::FromBuffer(std::span(buffer_.data(), size));
253
254   last_status_ = result.status();
255
256   switch (result.value().type()) {
257     case internal::PacketType::RESPONSE:
258       // If we run out of space, the back message is always the most recent.
259       responses_.emplace_back();
260       responses_.back() = {};
261       PW_CHECK(
262           method_.DecodeResponse(result.value().payload(), &responses_.back()));
263       total_responses_ += 1;
264       break;
265     case internal::PacketType::SERVER_STREAM_END:
266       stream_ended_ = true;
267       break;
268     default:
269       PW_CRASH("Unhandled PacketType");
270   }
271   return Status::Ok();
272 }
273
274 }  // namespace internal::test
275
276 template <auto method, size_t max_responses, size_t output_size_bytes>
277 class TestMethodContext
278     : public internal::test::Context<method, max_responses, output_size_bytes> {
279  public:
280   // Forwards constructor arguments to the service class.
281   template <typename... ServiceArgs>
282   TestMethodContext(ServiceArgs&&... service_args)
283       : internal::test::Context<method, max_responses, output_size_bytes>(
284             std::forward<ServiceArgs>(service_args)...) {}
285 };
286
287 }  // namespace pw::rpc