1 // Copyright 2020 The Pigweed Authors
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
7 // https://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
15 #include "pw_rpc/internal/raw_method.h"
19 #include "gtest/gtest.h"
20 #include "pw_bytes/array.h"
21 #include "pw_protobuf/decoder.h"
22 #include "pw_protobuf/encoder.h"
23 #include "pw_rpc/internal/raw_method_union.h"
24 #include "pw_rpc/server_context.h"
25 #include "pw_rpc/service.h"
26 #include "pw_rpc_private/internal_test_utils.h"
27 #include "pw_rpc_private/method_impl_tester.h"
28 #include "pw_rpc_test_protos/test.pwpb.h"
30 namespace pw::rpc::internal {
33 // Create a fake service for use with the MethodImplTester.
34 class TestRawService final : public Service {
36 StatusWithSize Unary(ServerContext&, ConstByteSpan, ByteSpan) {
37 return StatusWithSize(0);
40 static StatusWithSize StaticUnary(ServerContext&, ConstByteSpan, ByteSpan) {
41 return StatusWithSize(0);
44 void ServerStreaming(ServerContext&, ConstByteSpan, RawServerWriter&) {}
46 static void StaticServerStreaming(ServerContext&,
50 StatusWithSize UnaryWrongArg(ServerContext&, ConstByteSpan, ConstByteSpan) {
51 return StatusWithSize(0);
54 static void StaticUnaryVoidReturn(ServerContext&, ConstByteSpan, ByteSpan) {}
56 Status ServerStreamingBadReturn(ServerContext&,
62 static void StaticServerStreamingMissingArg(ConstByteSpan, RawServerWriter&) {
66 TEST(MethodImplTester, RawMethod) {
67 constexpr MethodImplTester<RawMethod, TestRawService> method_tester;
68 EXPECT_TRUE(method_tester.MethodImplIsValid());
75 RawServerWriter last_writer;
77 void DecodeRawTestRequest(ConstByteSpan request) {
78 protobuf::Decoder decoder(request);
80 while (decoder.Next().ok()) {
81 test::TestRequest::Fields field =
82 static_cast<test::TestRequest::Fields>(decoder.FieldNumber());
85 case test::TestRequest::Fields::INTEGER:
86 decoder.ReadInt64(&last_request.integer);
88 case test::TestRequest::Fields::STATUS_CODE:
89 decoder.ReadUint32(&last_request.status_code);
95 StatusWithSize AddFive(ServerContext&,
96 ConstByteSpan request,
98 DecodeRawTestRequest(request);
100 protobuf::NestedEncoder encoder(response);
101 test::TestResponse::Encoder test_response(&encoder);
102 test_response.WriteValue(last_request.integer + 5);
103 ConstByteSpan payload;
104 encoder.Encode(&payload);
106 return StatusWithSize::Unauthenticated(payload.size());
109 void StartStream(ServerContext&,
110 ConstByteSpan request,
111 RawServerWriter& writer) {
112 DecodeRawTestRequest(request);
113 last_writer = std::move(writer);
116 class FakeService : public Service {
118 FakeService(uint32_t id) : Service(id, kMethods) {}
120 static constexpr std::array<RawMethodUnion, 2> kMethods = {
121 RawMethod::Unary<AddFive>(10u),
122 RawMethod::ServerStreaming<StartStream>(11u),
126 TEST(RawMethod, UnaryRpc_SendsResponse) {
127 std::byte buffer[16];
128 protobuf::NestedEncoder encoder(buffer);
129 test::TestRequest::Encoder test_request(&encoder);
130 test_request.WriteInteger(456);
131 test_request.WriteStatusCode(7);
133 const RawMethod& method = std::get<0>(FakeService::kMethods).raw_method();
134 ServerContextForTest<FakeService> context(method);
135 method.Invoke(context.get(), context.packet(encoder.Encode().value()));
137 EXPECT_EQ(last_request.integer, 456);
138 EXPECT_EQ(last_request.status_code, 7u);
140 const Packet& response = context.output().sent_packet();
141 EXPECT_EQ(response.status(), Status::Unauthenticated());
143 protobuf::Decoder decoder(response.payload());
144 ASSERT_TRUE(decoder.Next().ok());
146 EXPECT_EQ(decoder.ReadInt64(&value), OkStatus());
147 EXPECT_EQ(value, 461);
150 TEST(RawMethod, ServerStreamingRpc_SendsNothingWhenInitiallyCalled) {
151 std::byte buffer[16];
152 protobuf::NestedEncoder encoder(buffer);
153 test::TestRequest::Encoder test_request(&encoder);
154 test_request.WriteInteger(777);
155 test_request.WriteStatusCode(2);
157 const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
158 ServerContextForTest<FakeService> context(method);
160 method.Invoke(context.get(), context.packet(encoder.Encode().value()));
162 EXPECT_EQ(0u, context.output().packet_count());
163 EXPECT_EQ(777, last_request.integer);
164 EXPECT_EQ(2u, last_request.status_code);
165 EXPECT_TRUE(last_writer.open());
166 last_writer.Finish();
169 TEST(RawServerWriter, Write_SendsPreviouslyAcquiredBuffer) {
170 const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
171 ServerContextForTest<FakeService> context(method);
173 method.Invoke(context.get(), context.packet({}));
175 auto buffer = last_writer.PayloadBuffer();
177 constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>();
178 std::memcpy(buffer.data(), data.data(), data.size());
180 EXPECT_EQ(last_writer.Write(buffer.first(data.size())), OkStatus());
182 const internal::Packet& packet = context.output().sent_packet();
183 EXPECT_EQ(packet.type(), internal::PacketType::RESPONSE);
184 EXPECT_EQ(packet.channel_id(), context.kChannelId);
185 EXPECT_EQ(packet.service_id(), context.kServiceId);
186 EXPECT_EQ(packet.method_id(), context.get().method().id());
187 EXPECT_EQ(std::memcmp(packet.payload().data(), data.data(), data.size()), 0);
188 EXPECT_EQ(packet.status(), OkStatus());
191 TEST(RawServerWriter, Write_SendsExternalBuffer) {
192 const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
193 ServerContextForTest<FakeService> context(method);
195 method.Invoke(context.get(), context.packet({}));
197 constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>();
198 EXPECT_EQ(last_writer.Write(data), OkStatus());
200 const internal::Packet& packet = context.output().sent_packet();
201 EXPECT_EQ(packet.type(), internal::PacketType::RESPONSE);
202 EXPECT_EQ(packet.channel_id(), context.kChannelId);
203 EXPECT_EQ(packet.service_id(), context.kServiceId);
204 EXPECT_EQ(packet.method_id(), context.get().method().id());
205 EXPECT_EQ(std::memcmp(packet.payload().data(), data.data(), data.size()), 0);
206 EXPECT_EQ(packet.status(), OkStatus());
209 TEST(RawServerWriter, Write_Closed_ReturnsFailedPrecondition) {
210 const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
211 ServerContextForTest<FakeService, 16> context(method);
213 method.Invoke(context.get(), context.packet({}));
215 last_writer.Finish();
216 constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>();
217 EXPECT_EQ(last_writer.Write(data), Status::FailedPrecondition());
220 TEST(RawServerWriter, Write_BufferTooSmall_ReturnsOutOfRange) {
221 const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
222 ServerContextForTest<FakeService, 16> context(method);
224 method.Invoke(context.get(), context.packet({}));
226 constexpr auto data =
227 bytes::Array<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16>();
228 EXPECT_EQ(last_writer.Write(data), Status::OutOfRange());
231 TEST(RawServerWriter,
232 Destructor_ReleasesAcquiredBufferWithoutSendingAndCloses) {
233 const RawMethod& method = std::get<1>(FakeService::kMethods).raw_method();
234 ServerContextForTest<FakeService> context(method);
236 method.Invoke(context.get(), context.packet({}));
239 RawServerWriter writer = std::move(last_writer);
240 auto buffer = writer.PayloadBuffer();
241 buffer[0] = std::byte{'!'};
242 // Don't release the buffer.
245 auto output = context.output();
246 EXPECT_EQ(output.packet_count(), 1u);
247 EXPECT_EQ(output.sent_packet().type(), PacketType::SERVER_STREAM_END);
251 } // namespace pw::rpc::internal