1 // Copyright 2020 The Pigweed Authors
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
7 // https://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
16 #include <type_traits>
18 #include "pw_assert/light.h"
19 #include "pw_bytes/span.h"
20 #include "pw_containers/vector.h"
21 #include "pw_rpc/channel.h"
22 #include "pw_rpc/internal/hash.h"
23 #include "pw_rpc/internal/method_lookup.h"
24 #include "pw_rpc/internal/packet.h"
25 #include "pw_rpc/internal/raw_method.h"
26 #include "pw_rpc/internal/server.h"
30 // Declares a context object that may be used to invoke an RPC. The context is
31 // declared with the name of the implemented service and the method to invoke.
32 // The RPC can then be invoked with the call method.
34 // For a unary RPC, context.call(request) returns the status, and the response
35 // struct can be accessed via context.response().
37 // PW_RAW_TEST_METHOD_CONTEXT(my::CoolService, TheMethod) context;
38 // EXPECT_EQ(OkStatus(), context.call(encoded_request).status());
40 // std::memcmp(encoded_response,
41 // context.response().data(),
42 // sizeof(encoded_response)));
44 // For a server streaming RPC, context.call(request) invokes the method. As in a
45 // normal RPC, the method completes when the ServerWriter's Finish method is
46 // called (or it goes out of scope).
48 // PW_RAW_TEST_METHOD_CONTEXT(my::CoolService, TheStreamingMethod) context;
49 // context.call(encoded_response);
51 // EXPECT_TRUE(context.done()); // Check that the RPC completed
52 // EXPECT_EQ(OkStatus(), context.status()); // Check the status
54 // EXPECT_EQ(3u, context.responses().size());
55 // ByteSpan& response = context.responses()[0]; // check individual responses
57 // for (ByteSpan& response : context.responses()) {
58 // // iterate over the responses
61 // PW_RAW_TEST_METHOD_CONTEXT forwards its constructor arguments to the
62 // underlying service. For example:
64 // PW_RAW_TEST_METHOD_CONTEXT(MyService, Go) context(service, args);
66 // PW_RAW_TEST_METHOD_CONTEXT takes two optional arguments:
68 // size_t max_responses: maximum responses to store; ignored unless streaming
69 // size_t output_size_bytes: buffer size; must be large enough for a packet
73 // PW_RAW_TEST_METHOD_CONTEXT(MyService, BestMethod, 3, 256) context;
74 // ASSERT_EQ(3u, context.responses().max_size());
76 #define PW_RAW_TEST_METHOD_CONTEXT(service, method, ...) \
77 ::pw::rpc::RawTestMethodContext<service, \
79 ::pw::rpc::internal::Hash(#method), \
81 template <typename Service,
84 size_t max_responses = 4,
85 size_t output_size_bytes = 128>
86 class RawTestMethodContext;
88 // Internal classes that implement RawTestMethodContext.
89 namespace internal::test::raw {
91 // A ChannelOutput implementation that stores the outgoing payloads and status.
92 template <size_t output_size>
93 class MessageOutput final : public ChannelOutput {
95 using ResponseBuffer = std::array<std::byte, output_size>;
97 MessageOutput(Vector<ByteSpan>& responses,
98 Vector<ResponseBuffer>& buffers,
99 ByteSpan packet_buffer)
100 : ChannelOutput("internal::test::raw::MessageOutput"),
101 responses_(responses),
103 packet_buffer_(packet_buffer) {
107 Status last_status() const { return last_status_; }
108 void set_last_status(Status status) { last_status_ = status; }
110 size_t total_responses() const { return total_responses_; }
112 bool stream_ended() const { return stream_ended_; }
117 total_responses_ = 0;
118 stream_ended_ = false;
119 last_status_ = Status::Unknown();
123 ByteSpan AcquireBuffer() override { return packet_buffer_; }
125 Status SendAndReleaseBuffer(std::span<const std::byte> buffer) override;
127 Vector<ByteSpan>& responses_;
128 Vector<ResponseBuffer>& buffers_;
129 ByteSpan packet_buffer_;
130 size_t total_responses_;
135 // Collects everything needed to invoke a particular RPC.
136 template <typename Service,
138 size_t max_responses,
140 struct InvocationContext {
141 template <typename... Args>
142 InvocationContext(Args&&... args)
143 : output(responses, buffers, packet_buffer),
144 channel(Channel::Create<123>(&output)),
145 server(std::span(&channel, 1)),
146 service(std::forward<Args>(args)...),
147 call(static_cast<internal::Server&>(server),
148 static_cast<internal::Channel&>(channel),
150 MethodLookup::GetRawMethod<Service, method_id>()) {}
152 using ResponseBuffer = std::array<std::byte, output_size>;
154 MessageOutput<output_size> output;
155 rpc::Channel channel;
158 Vector<ByteSpan, max_responses> responses;
159 Vector<ResponseBuffer, max_responses> buffers;
160 std::array<std::byte, output_size> packet_buffer = {};
161 internal::ServerCall call;
164 // Method invocation context for a unary RPC. Returns the status in call() and
165 // provides the response through the response() method.
166 template <typename Service, auto method, uint32_t method_id, size_t output_size>
169 using Context = InvocationContext<Service, method_id, 1, output_size>;
173 template <typename... Args>
174 UnaryContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
176 Service& service() { return ctx_.service; }
178 // Invokes the RPC with the provided request. Returns RPC's StatusWithSize.
179 StatusWithSize call(ConstByteSpan request) {
181 ctx_.buffers.emplace_back();
182 ctx_.buffers.back() = {};
183 ctx_.responses.emplace_back();
184 auto& response = ctx_.responses.back();
185 response = {ctx_.buffers.back().data(), ctx_.buffers.back().size()};
186 auto sws = CallMethodImplFunction<method>(ctx_.call, request, response);
187 response = response.first(sws.size());
191 // Gives access to the RPC's response.
192 ConstByteSpan response() const {
193 PW_ASSERT(ctx_.responses.size() > 0u);
194 return ctx_.responses.back();
198 // Method invocation context for a server streaming RPC.
199 template <typename Service,
202 size_t max_responses,
204 class ServerStreamingContext {
207 InvocationContext<Service, method_id, max_responses, output_size>;
211 template <typename... Args>
212 ServerStreamingContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
214 Service& service() { return ctx_.service; }
216 // Invokes the RPC with the provided request.
217 void call(ConstByteSpan request) {
219 BaseServerWriter server_writer(ctx_.call);
220 return CallMethodImplFunction<method>(
221 ctx_.call, request, static_cast<RawServerWriter&>(server_writer));
224 // Returns a server writer which writes responses into the context's buffer.
225 // This should not be called alongside call(); use one or the other.
226 RawServerWriter writer() {
228 BaseServerWriter server_writer(ctx_.call);
229 return std::move(static_cast<RawServerWriter&>(server_writer));
232 // Returns the responses that have been recorded. The maximum number of
233 // responses is responses().max_size(). responses().back() is always the most
234 // recent response, even if total_responses() > responses().max_size().
235 const Vector<ByteSpan>& responses() const { return ctx_.responses; }
237 // The total number of responses sent, which may be larger than
238 // responses.max_size().
239 size_t total_responses() const { return ctx_.output.total_responses(); }
241 // True if the stream has terminated.
242 bool done() const { return ctx_.output.stream_ended(); }
244 // The status of the stream. Only valid if done() is true.
245 Status status() const {
247 return ctx_.output.last_status();
251 // Alias to select the type of the context object to use based on which type of
253 template <typename Service,
258 using Context = std::tuple_element_t<
259 static_cast<size_t>(MethodTraits<decltype(method)>::kType),
260 std::tuple<UnaryContext<Service, method, method_id, output_size>,
261 ServerStreamingContext<Service,
266 // TODO(hepler): Support client and bidi streaming
269 template <size_t output_size>
270 Status MessageOutput<output_size>::SendAndReleaseBuffer(
271 std::span<const std::byte> buffer) {
272 PW_ASSERT(!stream_ended_);
273 PW_ASSERT(buffer.data() == packet_buffer_.data());
275 if (buffer.empty()) {
279 Result<internal::Packet> result = internal::Packet::FromBuffer(buffer);
280 PW_ASSERT(result.ok());
282 last_status_ = result.value().status();
284 switch (result.value().type()) {
285 case internal::PacketType::RESPONSE: {
286 // If we run out of space, the back message is always the most recent.
287 buffers_.emplace_back();
288 buffers_.back() = {};
289 auto response = result.value().payload();
290 std::memcpy(&buffers_.back(), response.data(), response.size());
291 responses_.emplace_back();
292 responses_.back() = {buffers_.back().data(), response.size()};
293 total_responses_ += 1;
296 case internal::PacketType::SERVER_STREAM_END:
297 stream_ended_ = true;
300 PW_CRASH("Unhandled PacketType");
305 } // namespace internal::test::raw
307 template <typename Service,
310 size_t max_responses,
311 size_t output_size_bytes>
312 class RawTestMethodContext
313 : public internal::test::raw::Context<Service,
319 // Forwards constructor arguments to the service class.
320 template <typename... ServiceArgs>
321 RawTestMethodContext(ServiceArgs&&... service_args)
322 : internal::test::raw::Context<Service,
327 std::forward<ServiceArgs>(service_args)...) {}
330 } // namespace pw::rpc