Fix for x86_64 build fail
[platform/upstream/connectedhomeip.git] / third_party / pigweed / repo / pw_protobuf / py / pw_protobuf / codegen_pwpb.py
1 # Copyright 2020 The Pigweed Authors
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
5 # the License at
6 #
7 #     https://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
13 # the License.
14 """This module defines the generated code for pw_protobuf C++ classes."""
15
16 import abc
17 from datetime import datetime
18 import os
19 import sys
20 from typing import Dict, Iterable, List, Tuple
21 from typing import cast
22
23 import google.protobuf.descriptor_pb2 as descriptor_pb2
24
25 from pw_protobuf.output_file import OutputFile
26 from pw_protobuf.proto_tree import ProtoEnum, ProtoMessage, ProtoMessageField
27 from pw_protobuf.proto_tree import ProtoNode
28 from pw_protobuf.proto_tree import build_node_tree
29
30 PLUGIN_NAME = 'pw_protobuf'
31 PLUGIN_VERSION = '0.1.0'
32
33 PROTO_H_EXTENSION = '.pwpb.h'
34 PROTO_CC_EXTENSION = '.pwpb.cc'
35
36 PROTOBUF_NAMESPACE = 'pw::protobuf'
37 BASE_PROTO_CLASS = 'ProtoMessageEncoder'
38
39
40 # protoc captures stdout, so we need to printf debug to stderr.
41 def debug_print(*args, **kwargs):
42     print(*args, file=sys.stderr, **kwargs)
43
44
45 class ProtoMethod(abc.ABC):
46     """Base class for a C++ method for a field in a protobuf message."""
47     def __init__(
48         self,
49         field: ProtoMessageField,
50         scope: ProtoNode,
51         root: ProtoNode,
52     ):
53         """Creates an instance of a method.
54
55         Args:
56           field: the ProtoMessageField to which the method belongs.
57           scope: the ProtoNode namespace in which the method is being defined.
58         """
59         self._field: ProtoMessageField = field
60         self._scope: ProtoNode = scope
61         self._root: ProtoNode = root
62
63     @abc.abstractmethod
64     def name(self) -> str:
65         """Returns the name of the method, e.g. DoSomething."""
66
67     @abc.abstractmethod
68     def params(self) -> List[Tuple[str, str]]:
69         """Returns the parameters of the method as a list of (type, name) pairs.
70
71         e.g.
72         [('int', 'foo'), ('const char*', 'bar')]
73         """
74
75     @abc.abstractmethod
76     def body(self) -> List[str]:
77         """Returns the method body as a list of source code lines.
78
79         e.g.
80         [
81           'int baz = bar[foo];',
82           'return (baz ^ foo) >> 3;'
83         ]
84         """
85
86     @abc.abstractmethod
87     def return_type(self, from_root: bool = False) -> str:
88         """Returns the return type of the method, e.g. int.
89
90         For non-primitive return types, the from_root argument determines
91         whether the namespace should be relative to the message's scope
92         (default) or the root scope.
93         """
94
95     @abc.abstractmethod
96     def in_class_definition(self) -> bool:
97         """Determines where the method should be defined.
98
99         Returns True if the method definition should be inlined in its class
100         definition, or False if it should be declared in the class and defined
101         later.
102         """
103
104     def should_appear(self) -> bool:  # pylint: disable=no-self-use
105         """Whether the method should be generated."""
106         return True
107
108     def param_string(self) -> str:
109         return ', '.join([f'{type} {name}' for type, name in self.params()])
110
111     def field_cast(self) -> str:
112         return 'static_cast<uint32_t>(Fields::{})'.format(
113             self._field.enum_name())
114
115     def _relative_type_namespace(self, from_root: bool = False) -> str:
116         """Returns relative namespace between method's scope and field type."""
117         scope = self._root if from_root else self._scope
118         type_node = self._field.type_node()
119         assert type_node is not None
120         ancestor = scope.common_ancestor(type_node)
121         namespace = type_node.cpp_namespace(ancestor)
122         assert namespace is not None
123         return namespace
124
125
126 class SubMessageMethod(ProtoMethod):
127     """Method which returns a sub-message encoder."""
128     def name(self) -> str:
129         return 'Get{}Encoder'.format(self._field.name())
130
131     def return_type(self, from_root: bool = False) -> str:
132         return '{}::Encoder'.format(self._relative_type_namespace(from_root))
133
134     def params(self) -> List[Tuple[str, str]]:
135         return []
136
137     def body(self) -> List[str]:
138         line = 'return {}::Encoder(encoder_, {});'.format(
139             self._relative_type_namespace(), self.field_cast())
140         return [line]
141
142     # Submessage methods are not defined within the class itself because the
143     # submessage class may not yet have been defined.
144     def in_class_definition(self) -> bool:
145         return False
146
147
148 class WriteMethod(ProtoMethod):
149     """Base class representing an encoder write method.
150
151     Write methods have following format (for the proto field foo):
152
153         Status WriteFoo({params...}) {
154           return encoder_->Write{type}(kFoo, {params...});
155         }
156
157     """
158     def name(self) -> str:
159         return 'Write{}'.format(self._field.name())
160
161     def return_type(self, from_root: bool = False) -> str:
162         return '::pw::Status'
163
164     def body(self) -> List[str]:
165         params = ', '.join([pair[1] for pair in self.params()])
166         line = 'return encoder_->{}({}, {});'.format(self._encoder_fn(),
167                                                      self.field_cast(), params)
168         return [line]
169
170     def params(self) -> List[Tuple[str, str]]:
171         """Method parameters, defined in subclasses."""
172         raise NotImplementedError()
173
174     def in_class_definition(self) -> bool:
175         return True
176
177     def _encoder_fn(self) -> str:
178         """The encoder function to call.
179
180         Defined in subclasses.
181
182         e.g. 'WriteUint32', 'WriteBytes', etc.
183         """
184         raise NotImplementedError()
185
186
187 class PackedMethod(WriteMethod):
188     """A method for a packed repeated field.
189
190     Same as a WriteMethod, but is only generated for repeated fields.
191     """
192     def should_appear(self) -> bool:
193         return self._field.is_repeated()
194
195     def _encoder_fn(self) -> str:
196         raise NotImplementedError()
197
198
199 #
200 # The following code defines write methods for each of the
201 # primitive protobuf types.
202 #
203
204
205 class DoubleMethod(WriteMethod):
206     """Method which writes a proto double value."""
207     def params(self) -> List[Tuple[str, str]]:
208         return [('double', 'value')]
209
210     def _encoder_fn(self) -> str:
211         return 'WriteDouble'
212
213
214 class PackedDoubleMethod(PackedMethod):
215     """Method which writes a packed list of doubles."""
216     def params(self) -> List[Tuple[str, str]]:
217         return [('std::span<const double>', 'values')]
218
219     def _encoder_fn(self) -> str:
220         return 'WritePackedDouble'
221
222
223 class FloatMethod(WriteMethod):
224     """Method which writes a proto float value."""
225     def params(self) -> List[Tuple[str, str]]:
226         return [('float', 'value')]
227
228     def _encoder_fn(self) -> str:
229         return 'WriteFloat'
230
231
232 class PackedFloatMethod(PackedMethod):
233     """Method which writes a packed list of floats."""
234     def params(self) -> List[Tuple[str, str]]:
235         return [('std::span<const float>', 'values')]
236
237     def _encoder_fn(self) -> str:
238         return 'WritePackedFloat'
239
240
241 class Int32Method(WriteMethod):
242     """Method which writes a proto int32 value."""
243     def params(self) -> List[Tuple[str, str]]:
244         return [('int32_t', 'value')]
245
246     def _encoder_fn(self) -> str:
247         return 'WriteInt32'
248
249
250 class PackedInt32Method(PackedMethod):
251     """Method which writes a packed list of int32."""
252     def params(self) -> List[Tuple[str, str]]:
253         return [('std::span<const int32_t>', 'values')]
254
255     def _encoder_fn(self) -> str:
256         return 'WritePackedInt32'
257
258
259 class Sint32Method(WriteMethod):
260     """Method which writes a proto sint32 value."""
261     def params(self) -> List[Tuple[str, str]]:
262         return [('int32_t', 'value')]
263
264     def _encoder_fn(self) -> str:
265         return 'WriteSint32'
266
267
268 class PackedSint32Method(PackedMethod):
269     """Method which writes a packed list of sint32."""
270     def params(self) -> List[Tuple[str, str]]:
271         return [('std::span<const int32_t>', 'values')]
272
273     def _encoder_fn(self) -> str:
274         return 'WritePackedSint32'
275
276
277 class Sfixed32Method(WriteMethod):
278     """Method which writes a proto sfixed32 value."""
279     def params(self) -> List[Tuple[str, str]]:
280         return [('int32_t', 'value')]
281
282     def _encoder_fn(self) -> str:
283         return 'WriteSfixed32'
284
285
286 class PackedSfixed32Method(PackedMethod):
287     """Method which writes a packed list of sfixed32."""
288     def params(self) -> List[Tuple[str, str]]:
289         return [('std::span<const int32_t>', 'values')]
290
291     def _encoder_fn(self) -> str:
292         return 'WritePackedSfixed32'
293
294
295 class Int64Method(WriteMethod):
296     """Method which writes a proto int64 value."""
297     def params(self) -> List[Tuple[str, str]]:
298         return [('int64_t', 'value')]
299
300     def _encoder_fn(self) -> str:
301         return 'WriteInt64'
302
303
304 class PackedInt64Method(PackedMethod):
305     """Method which writes a proto int64 value."""
306     def params(self) -> List[Tuple[str, str]]:
307         return [('std::span<const int64_t>', 'values')]
308
309     def _encoder_fn(self) -> str:
310         return 'WritePackedInt64'
311
312
313 class Sint64Method(WriteMethod):
314     """Method which writes a proto sint64 value."""
315     def params(self) -> List[Tuple[str, str]]:
316         return [('int64_t', 'value')]
317
318     def _encoder_fn(self) -> str:
319         return 'WriteSint64'
320
321
322 class PackedSint64Method(PackedMethod):
323     """Method which writes a proto sint64 value."""
324     def params(self) -> List[Tuple[str, str]]:
325         return [('std::span<const int64_t>', 'values')]
326
327     def _encoder_fn(self) -> str:
328         return 'WritePackedSint64'
329
330
331 class Sfixed64Method(WriteMethod):
332     """Method which writes a proto sfixed64 value."""
333     def params(self) -> List[Tuple[str, str]]:
334         return [('int64_t', 'value')]
335
336     def _encoder_fn(self) -> str:
337         return 'WriteSfixed64'
338
339
340 class PackedSfixed64Method(PackedMethod):
341     """Method which writes a proto sfixed64 value."""
342     def params(self) -> List[Tuple[str, str]]:
343         return [('std::span<const int64_t>', 'values')]
344
345     def _encoder_fn(self) -> str:
346         return 'WritePackedSfixed4'
347
348
349 class Uint32Method(WriteMethod):
350     """Method which writes a proto uint32 value."""
351     def params(self) -> List[Tuple[str, str]]:
352         return [('uint32_t', 'value')]
353
354     def _encoder_fn(self) -> str:
355         return 'WriteUint32'
356
357
358 class PackedUint32Method(PackedMethod):
359     """Method which writes a proto uint32 value."""
360     def params(self) -> List[Tuple[str, str]]:
361         return [('std::span<const uint32_t>', 'values')]
362
363     def _encoder_fn(self) -> str:
364         return 'WritePackedUint32'
365
366
367 class Fixed32Method(WriteMethod):
368     """Method which writes a proto fixed32 value."""
369     def params(self) -> List[Tuple[str, str]]:
370         return [('uint32_t', 'value')]
371
372     def _encoder_fn(self) -> str:
373         return 'WriteFixed32'
374
375
376 class PackedFixed32Method(PackedMethod):
377     """Method which writes a proto fixed32 value."""
378     def params(self) -> List[Tuple[str, str]]:
379         return [('std::span<const uint32_t>', 'values')]
380
381     def _encoder_fn(self) -> str:
382         return 'WritePackedFixed32'
383
384
385 class Uint64Method(WriteMethod):
386     """Method which writes a proto uint64 value."""
387     def params(self) -> List[Tuple[str, str]]:
388         return [('uint64_t', 'value')]
389
390     def _encoder_fn(self) -> str:
391         return 'WriteUint64'
392
393
394 class PackedUint64Method(PackedMethod):
395     """Method which writes a proto uint64 value."""
396     def params(self) -> List[Tuple[str, str]]:
397         return [('std::span<const uint64_t>', 'values')]
398
399     def _encoder_fn(self) -> str:
400         return 'WritePackedUint64'
401
402
403 class Fixed64Method(WriteMethod):
404     """Method which writes a proto fixed64 value."""
405     def params(self) -> List[Tuple[str, str]]:
406         return [('uint64_t', 'value')]
407
408     def _encoder_fn(self) -> str:
409         return 'WriteFixed64'
410
411
412 class PackedFixed64Method(PackedMethod):
413     """Method which writes a proto fixed64 value."""
414     def params(self) -> List[Tuple[str, str]]:
415         return [('std::span<const uint64_t>', 'values')]
416
417     def _encoder_fn(self) -> str:
418         return 'WritePackedFixed64'
419
420
421 class BoolMethod(WriteMethod):
422     """Method which writes a proto bool value."""
423     def params(self) -> List[Tuple[str, str]]:
424         return [('bool', 'value')]
425
426     def _encoder_fn(self) -> str:
427         return 'WriteBool'
428
429
430 class BytesMethod(WriteMethod):
431     """Method which writes a proto bytes value."""
432     def params(self) -> List[Tuple[str, str]]:
433         return [('std::span<const std::byte>', 'value')]
434
435     def _encoder_fn(self) -> str:
436         return 'WriteBytes'
437
438
439 class StringLenMethod(WriteMethod):
440     """Method which writes a proto string value with length."""
441     def params(self) -> List[Tuple[str, str]]:
442         return [('const char*', 'value'), ('size_t', 'len')]
443
444     def _encoder_fn(self) -> str:
445         return 'WriteString'
446
447
448 class StringMethod(WriteMethod):
449     """Method which writes a proto string value."""
450     def params(self) -> List[Tuple[str, str]]:
451         return [('const char*', 'value')]
452
453     def _encoder_fn(self) -> str:
454         return 'WriteString'
455
456
457 class EnumMethod(WriteMethod):
458     """Method which writes a proto enum value."""
459     def params(self) -> List[Tuple[str, str]]:
460         return [(self._relative_type_namespace(), 'value')]
461
462     def body(self) -> List[str]:
463         line = 'return encoder_->WriteUint32(' \
464             '{}, static_cast<uint32_t>(value));'.format(self.field_cast())
465         return [line]
466
467     def in_class_definition(self) -> bool:
468         return True
469
470     def _encoder_fn(self) -> str:
471         raise NotImplementedError()
472
473
474 # Mapping of protobuf field types to their method definitions.
475 PROTO_FIELD_METHODS: Dict[int, List] = {
476     descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE:
477     [DoubleMethod, PackedDoubleMethod],
478     descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT:
479     [FloatMethod, PackedFloatMethod],
480     descriptor_pb2.FieldDescriptorProto.TYPE_INT32:
481     [Int32Method, PackedInt32Method],
482     descriptor_pb2.FieldDescriptorProto.TYPE_SINT32:
483     [Sint32Method, PackedSint32Method],
484     descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32:
485     [Sfixed32Method, PackedSfixed32Method],
486     descriptor_pb2.FieldDescriptorProto.TYPE_INT64:
487     [Int64Method, PackedInt64Method],
488     descriptor_pb2.FieldDescriptorProto.TYPE_SINT64:
489     [Sint64Method, PackedSint64Method],
490     descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64:
491     [Sfixed64Method, PackedSfixed64Method],
492     descriptor_pb2.FieldDescriptorProto.TYPE_UINT32:
493     [Uint32Method, PackedUint32Method],
494     descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32:
495     [Fixed32Method, PackedFixed32Method],
496     descriptor_pb2.FieldDescriptorProto.TYPE_UINT64:
497     [Uint64Method, PackedUint64Method],
498     descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64:
499     [Fixed64Method, PackedFixed64Method],
500     descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: [BoolMethod],
501     descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: [BytesMethod],
502     descriptor_pb2.FieldDescriptorProto.TYPE_STRING: [
503         StringLenMethod, StringMethod
504     ],
505     descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: [SubMessageMethod],
506     descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: [EnumMethod],
507 }
508
509
510 def generate_code_for_message(message: ProtoMessage, root: ProtoNode,
511                               output: OutputFile) -> None:
512     """Creates a C++ class for a protobuf message."""
513     assert message.type() == ProtoNode.Type.MESSAGE
514
515     # Message classes inherit from the base proto message class in codegen.h
516     # and use its constructor.
517     base_class = f'{PROTOBUF_NAMESPACE}::{BASE_PROTO_CLASS}'
518     output.write_line(
519         f'class {message.cpp_namespace(root)}::Encoder : public {base_class} {{'
520     )
521     output.write_line(' public:')
522
523     with output.indent():
524         output.write_line(f'using {BASE_PROTO_CLASS}::{BASE_PROTO_CLASS};')
525
526         # Generate methods for each of the message's fields.
527         for field in message.fields():
528             for method_class in PROTO_FIELD_METHODS[field.type()]:
529                 method = method_class(field, message, root)
530                 if not method.should_appear():
531                     continue
532
533                 output.write_line()
534                 method_signature = (
535                     f'{method.return_type()} '
536                     f'{method.name()}({method.param_string()})')
537
538                 if not method.in_class_definition():
539                     # Method will be defined outside of the class at the end of
540                     # the file.
541                     output.write_line(f'{method_signature};')
542                     continue
543
544                 output.write_line(f'{method_signature} {{')
545                 with output.indent():
546                     for line in method.body():
547                         output.write_line(line)
548                 output.write_line('}')
549
550     output.write_line('};')
551
552
553 def define_not_in_class_methods(message: ProtoMessage, root: ProtoNode,
554                                 output: OutputFile) -> None:
555     """Defines methods for a message class that were previously declared."""
556     assert message.type() == ProtoNode.Type.MESSAGE
557
558     for field in message.fields():
559         for method_class in PROTO_FIELD_METHODS[field.type()]:
560             method = method_class(field, message, root)
561             if not method.should_appear() or method.in_class_definition():
562                 continue
563
564             output.write_line()
565             class_name = f'{message.cpp_namespace(root)}::Encoder'
566             method_signature = (
567                 f'inline {method.return_type(from_root=True)} '
568                 f'{class_name}::{method.name()}({method.param_string()})')
569             output.write_line(f'{method_signature} {{')
570             with output.indent():
571                 for line in method.body():
572                     output.write_line(line)
573             output.write_line('}')
574
575
576 def generate_code_for_enum(enum: ProtoEnum, root: ProtoNode,
577                            output: OutputFile) -> None:
578     """Creates a C++ enum for a proto enum."""
579     assert enum.type() == ProtoNode.Type.ENUM
580
581     output.write_line(f'enum class {enum.cpp_namespace(root)} {{')
582     with output.indent():
583         for name, number in enum.values():
584             output.write_line(f'{name} = {number},')
585     output.write_line('};')
586
587
588 def forward_declare(node: ProtoMessage, root: ProtoNode,
589                     output: OutputFile) -> None:
590     """Generates code forward-declaring entities in a message's namespace."""
591     namespace = node.cpp_namespace(root)
592     output.write_line()
593     output.write_line(f'namespace {namespace} {{')
594
595     # Define an enum defining each of the message's fields and their numbers.
596     output.write_line('enum class Fields {')
597     with output.indent():
598         for field in node.fields():
599             output.write_line(f'{field.enum_name()} = {field.number()},')
600     output.write_line('};')
601
602     # Declare the message's encoder class and all of its enums.
603     output.write_line()
604     output.write_line('class Encoder;')
605     for child in node.children():
606         if child.type() == ProtoNode.Type.ENUM:
607             output.write_line()
608             generate_code_for_enum(cast(ProtoEnum, child), node, output)
609
610     output.write_line(f'}}  // namespace {namespace}')
611
612
613 def _proto_filename_to_generated_header(proto_file: str) -> str:
614     """Returns the generated C++ header name for a .proto file."""
615     return os.path.splitext(proto_file)[0] + PROTO_H_EXTENSION
616
617
618 def generate_code_for_package(file_descriptor_proto, package: ProtoNode,
619                               output: OutputFile) -> None:
620     """Generates code for a single .pb.h file corresponding to a .proto file."""
621
622     assert package.type() == ProtoNode.Type.PACKAGE
623
624     output.write_line(f'// {os.path.basename(output.name())} automatically '
625                       f'generated by {PLUGIN_NAME} {PLUGIN_VERSION}')
626     output.write_line(f'// on {datetime.now()}')
627     output.write_line('#pragma once\n')
628     output.write_line('#include <cstddef>')
629     output.write_line('#include <cstdint>')
630     output.write_line('#include <span>\n')
631     output.write_line('#include "pw_protobuf/codegen.h"')
632
633     for imported_file in file_descriptor_proto.dependency:
634         generated_header = _proto_filename_to_generated_header(imported_file)
635         output.write_line(f'#include "{generated_header}"')
636
637     if package.cpp_namespace():
638         file_namespace = package.cpp_namespace()
639         if file_namespace.startswith('::'):
640             file_namespace = file_namespace[2:]
641
642         output.write_line(f'\nnamespace {file_namespace} {{')
643
644     for node in package:
645         if node.type() == ProtoNode.Type.MESSAGE:
646             forward_declare(cast(ProtoMessage, node), package, output)
647
648     # Define all top-level enums.
649     for node in package.children():
650         if node.type() == ProtoNode.Type.ENUM:
651             output.write_line()
652             generate_code_for_enum(cast(ProtoEnum, node), package, output)
653
654     # Run through all messages in the file, generating a class for each.
655     for node in package:
656         if node.type() == ProtoNode.Type.MESSAGE:
657             output.write_line()
658             generate_code_for_message(cast(ProtoMessage, node), package,
659                                       output)
660
661     # Run a second pass through the classes, this time defining all of the
662     # methods which were previously only declared.
663     for node in package:
664         if node.type() == ProtoNode.Type.MESSAGE:
665             define_not_in_class_methods(cast(ProtoMessage, node), package,
666                                         output)
667
668     if package.cpp_namespace():
669         output.write_line(f'\n}}  // namespace {package.cpp_namespace()}')
670
671
672 def process_proto_file(proto_file) -> Iterable[OutputFile]:
673     """Generates code for a single .proto file."""
674
675     # Two passes are made through the file. The first builds the tree of all
676     # message/enum nodes, then the second creates the fields in each. This is
677     # done as non-primitive fields need pointers to their types, which requires
678     # the entire tree to have been parsed into memory.
679     _, package_root = build_node_tree(proto_file)
680
681     output_filename = _proto_filename_to_generated_header(proto_file.name)
682     output_file = OutputFile(output_filename)
683     generate_code_for_package(proto_file, package_root, output_file)
684
685     return [output_file]