1 // Copyright 2020 The Pigweed Authors
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
7 // https://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
18 #include "pw_bytes/span.h"
19 #include "pw_rpc/internal/base_client_call.h"
20 #include "pw_rpc/internal/method_type.h"
21 #include "pw_rpc/internal/nanopb_common.h"
22 #include "pw_status/status.h"
26 // Response handler callback for unary RPC methods.
27 template <typename Response>
28 class UnaryResponseHandler {
30 virtual ~UnaryResponseHandler() = default;
32 // Called when the response is received from the server with the method's
33 // status and the deserialized response struct.
34 virtual void ReceivedResponse(Status status, const Response& response) = 0;
36 // Called when an error occurs internally in the RPC client or server.
37 virtual void RpcError(Status) {}
40 // Response handler callbacks for server streaming RPC methods.
41 template <typename Response>
42 class ServerStreamingResponseHandler {
44 virtual ~ServerStreamingResponseHandler() = default;
46 // Called on every response received from the server with the deserialized
48 virtual void ReceivedResponse(const Response& response) = 0;
50 // Called when the server ends the stream with the overall RPC status.
51 virtual void Complete(Status status) = 0;
53 // Called when an error occurs internally in the RPC client or server.
54 virtual void RpcError(Status) {}
59 // Non-templated nanopb base class providing protobuf encoding and decoding.
60 class BaseNanopbClientCall : public BaseClientCall {
62 Status SendRequest(const void* request_struct);
65 constexpr BaseNanopbClientCall(
66 rpc::Channel* channel,
69 ResponseHandler handler,
70 internal::NanopbMessageDescriptor request_fields,
71 internal::NanopbMessageDescriptor response_fields)
72 : BaseClientCall(channel, service_id, method_id, handler),
73 serde_(request_fields, response_fields) {}
75 constexpr const internal::NanopbMethodSerde& serde() const { return serde_; }
78 internal::NanopbMethodSerde serde_;
81 template <typename Callback>
82 struct CallbackTraits {};
84 template <typename ResponseType>
85 struct CallbackTraits<UnaryResponseHandler<ResponseType>> {
86 using Response = ResponseType;
88 static constexpr MethodType kType = MethodType::kUnary;
91 template <typename ResponseType>
92 struct CallbackTraits<ServerStreamingResponseHandler<ResponseType>> {
93 using Response = ResponseType;
95 static constexpr MethodType kType = MethodType::kServerStreaming;
98 } // namespace internal
100 template <typename Callback>
101 class NanopbClientCall : public internal::BaseNanopbClientCall {
103 constexpr NanopbClientCall(Channel* channel,
107 internal::NanopbMessageDescriptor request_fields,
108 internal::NanopbMessageDescriptor response_fields)
109 : BaseNanopbClientCall(channel,
115 callback_(callback) {}
118 using Traits = internal::CallbackTraits<Callback>;
119 using Response = typename Traits::Response;
121 // Buffer into which the nanopb struct is decoded. Its contents are unknown,
122 // so it is aligned to maximum alignment to accommodate any type.
123 using ResponseBuffer =
124 std::aligned_storage_t<sizeof(Response), alignof(std::max_align_t)>;
128 static void ResponseHandler(internal::BaseClientCall& call,
129 const internal::Packet& packet) {
130 static_cast<NanopbClientCall<Callback>&>(call).HandleResponse(packet);
133 void HandleResponse(const internal::Packet& packet) {
134 if constexpr (Traits::kType == internal::MethodType::kUnary) {
135 InvokeUnaryCallback(packet);
137 if constexpr (Traits::kType == internal::MethodType::kServerStreaming) {
138 InvokeServerStreamingCallback(packet);
142 void InvokeUnaryCallback(const internal::Packet& packet) {
143 if (packet.type() == internal::PacketType::SERVER_ERROR) {
144 callback_.RpcError(packet.status());
148 ResponseBuffer response_struct{};
150 if (serde().DecodeResponse(&response_struct, packet.payload())) {
151 callback_.ReceivedResponse(
153 *std::launder(reinterpret_cast<Response*>(&response_struct)));
155 callback_.RpcError(Status::DataLoss());
161 void InvokeServerStreamingCallback(const internal::Packet& packet) {
162 if (packet.type() == internal::PacketType::SERVER_ERROR) {
163 callback_.RpcError(packet.status());
167 if (packet.type() == internal::PacketType::SERVER_STREAM_END) {
168 callback_.Complete(packet.status());
172 ResponseBuffer response_struct{};
174 if (serde().DecodeResponse(&response_struct, packet.payload())) {
175 callback_.ReceivedResponse(
176 *std::launder(reinterpret_cast<Response*>(&response_struct)));
178 callback_.RpcError(Status::DataLoss());
185 } // namespace pw::rpc