Fix for x86_64 build fail
[platform/upstream/connectedhomeip.git] / third_party / pigweed / repo / pw_rpc / py / client_test.py
1 #!/usr/bin/env python3
2 # Copyright 2020 The Pigweed Authors
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
5 # use this file except in compliance with the License. You may obtain a copy of
6 # the License at
7 #
8 #     https://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 # License for the specific language governing permissions and limitations under
14 # the License.
15 """Tests creating pw_rpc client."""
16
17 import unittest
18
19 from pw_rpc_protos.internal.packet_pb2 import PacketType, RpcPacket
20 from pw_protobuf_compiler import python_protos
21 from pw_status import Status
22
23 from pw_rpc import callback_client, client, packets
24 import pw_rpc.ids
25
26 TEST_PROTO_1 = """\
27 syntax = "proto3";
28
29 package pw.test1;
30
31 message SomeMessage {
32   uint32 magic_number = 1;
33 }
34
35 message AnotherMessage {
36   enum Result {
37     FAILED = 0;
38     FAILED_MISERABLY = 1;
39     I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
40   }
41
42   Result result = 1;
43   string payload = 2;
44 }
45
46 service PublicService {
47   rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
48   rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
49   rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
50   rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
51 }
52 """
53
54 TEST_PROTO_2 = """\
55 syntax = "proto2";
56
57 package pw.test2;
58
59 message Request {
60   optional float magic_number = 1;
61 }
62
63 message Response {
64 }
65
66 service Alpha {
67   rpc Unary(Request) returns (Response) {}
68 }
69
70 service Bravo {
71   rpc BidiStreaming(stream Request) returns (stream Response) {}
72 }
73 """
74
75
76 def _test_setup(output=None):
77     protos = python_protos.Library.from_strings([TEST_PROTO_1, TEST_PROTO_2])
78     return protos, client.Client.from_modules(callback_client.Impl(),
79                                               [client.Channel(1, output)],
80                                               protos.modules())
81
82
83 class ChannelClientTest(unittest.TestCase):
84     """Tests the ChannelClient."""
85     def setUp(self):
86         self._channel_client = _test_setup()[1].channel(1)
87
88     def test_access_service_client_as_attribute_or_index(self):
89         self.assertIs(self._channel_client.rpcs.pw.test1.PublicService,
90                       self._channel_client.rpcs['pw.test1.PublicService'])
91         self.assertIs(
92             self._channel_client.rpcs.pw.test1.PublicService,
93             self._channel_client.rpcs[pw_rpc.ids.calculate(
94                 'pw.test1.PublicService')])
95
96     def test_access_method_client_as_attribute_or_index(self):
97         self.assertIs(self._channel_client.rpcs.pw.test2.Alpha.Unary,
98                       self._channel_client.rpcs['pw.test2.Alpha']['Unary'])
99         self.assertIs(
100             self._channel_client.rpcs.pw.test2.Alpha.Unary,
101             self._channel_client.rpcs['pw.test2.Alpha'][pw_rpc.ids.calculate(
102                 'Unary')])
103
104     def test_service_name(self):
105         self.assertEqual(
106             self._channel_client.rpcs.pw.test2.Alpha.Unary.service.name,
107             'Alpha')
108         self.assertEqual(
109             self._channel_client.rpcs.pw.test2.Alpha.Unary.service.full_name,
110             'pw.test2.Alpha')
111
112     def test_method_name(self):
113         self.assertEqual(
114             self._channel_client.rpcs.pw.test2.Alpha.Unary.method.name,
115             'Unary')
116         self.assertEqual(
117             self._channel_client.rpcs.pw.test2.Alpha.Unary.method.full_name,
118             'pw.test2.Alpha.Unary')
119
120     def test_iterate_over_all_methods(self):
121         channel_client = self._channel_client
122         all_methods = {
123             channel_client.rpcs.pw.test1.PublicService.SomeUnary,
124             channel_client.rpcs.pw.test1.PublicService.SomeServerStreaming,
125             channel_client.rpcs.pw.test1.PublicService.SomeClientStreaming,
126             channel_client.rpcs.pw.test1.PublicService.SomeBidiStreaming,
127             channel_client.rpcs.pw.test2.Alpha.Unary,
128             channel_client.rpcs.pw.test2.Bravo.BidiStreaming,
129         }
130         self.assertEqual(set(channel_client.methods()), all_methods)
131
132     def test_check_for_presence_of_services(self):
133         self.assertIn('pw.test1.PublicService', self._channel_client.rpcs)
134         self.assertIn(pw_rpc.ids.calculate('pw.test1.PublicService'),
135                       self._channel_client.rpcs)
136
137     def test_check_for_presence_of_missing_services(self):
138         self.assertNotIn('PublicService', self._channel_client.rpcs)
139         self.assertNotIn('NotAService', self._channel_client.rpcs)
140         self.assertNotIn(-1213, self._channel_client.rpcs)
141
142     def test_check_for_presence_of_methods(self):
143         service = self._channel_client.rpcs.pw.test1.PublicService
144         self.assertIn('SomeUnary', service)
145         self.assertIn(pw_rpc.ids.calculate('SomeUnary'), service)
146
147     def test_check_for_presence_of_missing_methods(self):
148         service = self._channel_client.rpcs.pw.test1.PublicService
149         self.assertNotIn('Some', service)
150         self.assertNotIn('Unary', service)
151         self.assertNotIn(12345, service)
152
153     def test_method_fully_qualified_name(self):
154         self.assertIs(self._channel_client.method('pw.test2.Alpha/Unary'),
155                       self._channel_client.rpcs.pw.test2.Alpha.Unary)
156         self.assertIs(self._channel_client.method('pw.test2.Alpha.Unary'),
157                       self._channel_client.rpcs.pw.test2.Alpha.Unary)
158
159
160 class ClientTest(unittest.TestCase):
161     """Tests the pw_rpc Client independently of the ClientImpl."""
162     def setUp(self):
163         self._last_packet_sent_bytes = None
164         self._protos, self._client = _test_setup(self._save_packet)
165
166     def _save_packet(self, packet):
167         self._last_packet_sent_bytes = packet
168
169     def _last_packet_sent(self):
170         packet = RpcPacket()
171         self.assertIsNotNone(self._last_packet_sent_bytes)
172         packet.MergeFromString(self._last_packet_sent_bytes)
173         return packet
174
175     def test_all_methods(self):
176         services = self._client.services
177
178         all_methods = {
179             services['pw.test1.PublicService'].methods['SomeUnary'],
180             services['pw.test1.PublicService'].methods['SomeServerStreaming'],
181             services['pw.test1.PublicService'].methods['SomeClientStreaming'],
182             services['pw.test1.PublicService'].methods['SomeBidiStreaming'],
183             services['pw.test2.Alpha'].methods['Unary'],
184             services['pw.test2.Bravo'].methods['BidiStreaming'],
185         }
186         self.assertEqual(set(self._client.methods()), all_methods)
187
188     def test_method_present(self):
189         self.assertIs(
190             self._client.method('pw.test1.PublicService.SomeUnary'), self.
191             _client.services['pw.test1.PublicService'].methods['SomeUnary'])
192         self.assertIs(
193             self._client.method('pw.test1.PublicService/SomeUnary'), self.
194             _client.services['pw.test1.PublicService'].methods['SomeUnary'])
195
196     def test_method_invalid_format(self):
197         with self.assertRaises(ValueError):
198             self._client.method('SomeUnary')
199
200     def test_method_not_present(self):
201         with self.assertRaises(KeyError):
202             self._client.method('pw.test1.PublicService/ThisIsNotGood')
203
204         with self.assertRaises(KeyError):
205             self._client.method('nothing.Good')
206
207     def test_process_packet_invalid_proto_data(self):
208         self.assertIs(self._client.process_packet(b'NOT a packet!'),
209                       Status.DATA_LOSS)
210
211     def test_process_packet_not_for_client(self):
212         self.assertIs(
213             self._client.process_packet(
214                 RpcPacket(type=PacketType.REQUEST).SerializeToString()),
215             Status.INVALID_ARGUMENT)
216
217     def test_process_packet_unrecognized_channel(self):
218         self.assertIs(
219             self._client.process_packet(
220                 packets.encode_response(
221                     (123, 456, 789),
222                     self._protos.packages.pw.test2.Request())),
223             Status.NOT_FOUND)
224
225     def test_process_packet_unrecognized_service(self):
226         self.assertIs(
227             self._client.process_packet(
228                 packets.encode_response(
229                     (1, 456, 789), self._protos.packages.pw.test2.Request())),
230             Status.OK)
231
232         self.assertEqual(
233             self._last_packet_sent(),
234             RpcPacket(type=PacketType.CLIENT_ERROR,
235                       channel_id=1,
236                       service_id=456,
237                       method_id=789,
238                       status=Status.NOT_FOUND.value))
239
240     def test_process_packet_unrecognized_method(self):
241         service = next(iter(self._client.services))
242
243         self.assertIs(
244             self._client.process_packet(
245                 packets.encode_response(
246                     (1, service.id, 789),
247                     self._protos.packages.pw.test2.Request())), Status.OK)
248
249         self.assertEqual(
250             self._last_packet_sent(),
251             RpcPacket(type=PacketType.CLIENT_ERROR,
252                       channel_id=1,
253                       service_id=service.id,
254                       method_id=789,
255                       status=Status.NOT_FOUND.value))
256
257     def test_process_packet_non_pending_method(self):
258         service = next(iter(self._client.services))
259         method = next(iter(service.methods))
260
261         self.assertIs(
262             self._client.process_packet(
263                 packets.encode_response(
264                     (1, service.id, method.id),
265                     self._protos.packages.pw.test2.Request())), Status.OK)
266
267         self.assertEqual(
268             self._last_packet_sent(),
269             RpcPacket(type=PacketType.CLIENT_ERROR,
270                       channel_id=1,
271                       service_id=service.id,
272                       method_id=method.id,
273                       status=Status.FAILED_PRECONDITION.value))
274
275
276 if __name__ == '__main__':
277     unittest.main()