1 # Copyright 2020 The Pigweed Authors
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
7 # https://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
14 """This module defines the generated code for pw_protobuf C++ classes."""
17 from datetime import datetime
20 from typing import Dict, Iterable, List, Tuple
21 from typing import cast
23 import google.protobuf.descriptor_pb2 as descriptor_pb2
25 from pw_protobuf.output_file import OutputFile
26 from pw_protobuf.proto_tree import ProtoEnum, ProtoMessage, ProtoMessageField
27 from pw_protobuf.proto_tree import ProtoNode
28 from pw_protobuf.proto_tree import build_node_tree
30 PLUGIN_NAME = 'pw_protobuf'
31 PLUGIN_VERSION = '0.1.0'
33 PROTO_H_EXTENSION = '.pwpb.h'
34 PROTO_CC_EXTENSION = '.pwpb.cc'
36 PROTOBUF_NAMESPACE = 'pw::protobuf'
37 BASE_PROTO_CLASS = 'ProtoMessageEncoder'
40 # protoc captures stdout, so we need to printf debug to stderr.
41 def debug_print(*args, **kwargs):
42 print(*args, file=sys.stderr, **kwargs)
45 class ProtoMethod(abc.ABC):
46 """Base class for a C++ method for a field in a protobuf message."""
49 field: ProtoMessageField,
53 """Creates an instance of a method.
56 field: the ProtoMessageField to which the method belongs.
57 scope: the ProtoNode namespace in which the method is being defined.
59 self._field: ProtoMessageField = field
60 self._scope: ProtoNode = scope
61 self._root: ProtoNode = root
64 def name(self) -> str:
65 """Returns the name of the method, e.g. DoSomething."""
68 def params(self) -> List[Tuple[str, str]]:
69 """Returns the parameters of the method as a list of (type, name) pairs.
72 [('int', 'foo'), ('const char*', 'bar')]
76 def body(self) -> List[str]:
77 """Returns the method body as a list of source code lines.
81 'int baz = bar[foo];',
82 'return (baz ^ foo) >> 3;'
87 def return_type(self, from_root: bool = False) -> str:
88 """Returns the return type of the method, e.g. int.
90 For non-primitive return types, the from_root argument determines
91 whether the namespace should be relative to the message's scope
92 (default) or the root scope.
96 def in_class_definition(self) -> bool:
97 """Determines where the method should be defined.
99 Returns True if the method definition should be inlined in its class
100 definition, or False if it should be declared in the class and defined
104 def should_appear(self) -> bool: # pylint: disable=no-self-use
105 """Whether the method should be generated."""
108 def param_string(self) -> str:
109 return ', '.join([f'{type} {name}' for type, name in self.params()])
111 def field_cast(self) -> str:
112 return 'static_cast<uint32_t>(Fields::{})'.format(
113 self._field.enum_name())
115 def _relative_type_namespace(self, from_root: bool = False) -> str:
116 """Returns relative namespace between method's scope and field type."""
117 scope = self._root if from_root else self._scope
118 type_node = self._field.type_node()
119 assert type_node is not None
120 ancestor = scope.common_ancestor(type_node)
121 namespace = type_node.cpp_namespace(ancestor)
122 assert namespace is not None
126 class SubMessageMethod(ProtoMethod):
127 """Method which returns a sub-message encoder."""
128 def name(self) -> str:
129 return 'Get{}Encoder'.format(self._field.name())
131 def return_type(self, from_root: bool = False) -> str:
132 return '{}::Encoder'.format(self._relative_type_namespace(from_root))
134 def params(self) -> List[Tuple[str, str]]:
137 def body(self) -> List[str]:
138 line = 'return {}::Encoder(encoder_, {});'.format(
139 self._relative_type_namespace(), self.field_cast())
142 # Submessage methods are not defined within the class itself because the
143 # submessage class may not yet have been defined.
144 def in_class_definition(self) -> bool:
148 class WriteMethod(ProtoMethod):
149 """Base class representing an encoder write method.
151 Write methods have following format (for the proto field foo):
153 Status WriteFoo({params...}) {
154 return encoder_->Write{type}(kFoo, {params...});
158 def name(self) -> str:
159 return 'Write{}'.format(self._field.name())
161 def return_type(self, from_root: bool = False) -> str:
162 return '::pw::Status'
164 def body(self) -> List[str]:
165 params = ', '.join([pair[1] for pair in self.params()])
166 line = 'return encoder_->{}({}, {});'.format(self._encoder_fn(),
167 self.field_cast(), params)
170 def params(self) -> List[Tuple[str, str]]:
171 """Method parameters, defined in subclasses."""
172 raise NotImplementedError()
174 def in_class_definition(self) -> bool:
177 def _encoder_fn(self) -> str:
178 """The encoder function to call.
180 Defined in subclasses.
182 e.g. 'WriteUint32', 'WriteBytes', etc.
184 raise NotImplementedError()
187 class PackedMethod(WriteMethod):
188 """A method for a packed repeated field.
190 Same as a WriteMethod, but is only generated for repeated fields.
192 def should_appear(self) -> bool:
193 return self._field.is_repeated()
195 def _encoder_fn(self) -> str:
196 raise NotImplementedError()
200 # The following code defines write methods for each of the
201 # primitive protobuf types.
205 class DoubleMethod(WriteMethod):
206 """Method which writes a proto double value."""
207 def params(self) -> List[Tuple[str, str]]:
208 return [('double', 'value')]
210 def _encoder_fn(self) -> str:
214 class PackedDoubleMethod(PackedMethod):
215 """Method which writes a packed list of doubles."""
216 def params(self) -> List[Tuple[str, str]]:
217 return [('std::span<const double>', 'values')]
219 def _encoder_fn(self) -> str:
220 return 'WritePackedDouble'
223 class FloatMethod(WriteMethod):
224 """Method which writes a proto float value."""
225 def params(self) -> List[Tuple[str, str]]:
226 return [('float', 'value')]
228 def _encoder_fn(self) -> str:
232 class PackedFloatMethod(PackedMethod):
233 """Method which writes a packed list of floats."""
234 def params(self) -> List[Tuple[str, str]]:
235 return [('std::span<const float>', 'values')]
237 def _encoder_fn(self) -> str:
238 return 'WritePackedFloat'
241 class Int32Method(WriteMethod):
242 """Method which writes a proto int32 value."""
243 def params(self) -> List[Tuple[str, str]]:
244 return [('int32_t', 'value')]
246 def _encoder_fn(self) -> str:
250 class PackedInt32Method(PackedMethod):
251 """Method which writes a packed list of int32."""
252 def params(self) -> List[Tuple[str, str]]:
253 return [('std::span<const int32_t>', 'values')]
255 def _encoder_fn(self) -> str:
256 return 'WritePackedInt32'
259 class Sint32Method(WriteMethod):
260 """Method which writes a proto sint32 value."""
261 def params(self) -> List[Tuple[str, str]]:
262 return [('int32_t', 'value')]
264 def _encoder_fn(self) -> str:
268 class PackedSint32Method(PackedMethod):
269 """Method which writes a packed list of sint32."""
270 def params(self) -> List[Tuple[str, str]]:
271 return [('std::span<const int32_t>', 'values')]
273 def _encoder_fn(self) -> str:
274 return 'WritePackedSint32'
277 class Sfixed32Method(WriteMethod):
278 """Method which writes a proto sfixed32 value."""
279 def params(self) -> List[Tuple[str, str]]:
280 return [('int32_t', 'value')]
282 def _encoder_fn(self) -> str:
283 return 'WriteSfixed32'
286 class PackedSfixed32Method(PackedMethod):
287 """Method which writes a packed list of sfixed32."""
288 def params(self) -> List[Tuple[str, str]]:
289 return [('std::span<const int32_t>', 'values')]
291 def _encoder_fn(self) -> str:
292 return 'WritePackedSfixed32'
295 class Int64Method(WriteMethod):
296 """Method which writes a proto int64 value."""
297 def params(self) -> List[Tuple[str, str]]:
298 return [('int64_t', 'value')]
300 def _encoder_fn(self) -> str:
304 class PackedInt64Method(PackedMethod):
305 """Method which writes a proto int64 value."""
306 def params(self) -> List[Tuple[str, str]]:
307 return [('std::span<const int64_t>', 'values')]
309 def _encoder_fn(self) -> str:
310 return 'WritePackedInt64'
313 class Sint64Method(WriteMethod):
314 """Method which writes a proto sint64 value."""
315 def params(self) -> List[Tuple[str, str]]:
316 return [('int64_t', 'value')]
318 def _encoder_fn(self) -> str:
322 class PackedSint64Method(PackedMethod):
323 """Method which writes a proto sint64 value."""
324 def params(self) -> List[Tuple[str, str]]:
325 return [('std::span<const int64_t>', 'values')]
327 def _encoder_fn(self) -> str:
328 return 'WritePackedSint64'
331 class Sfixed64Method(WriteMethod):
332 """Method which writes a proto sfixed64 value."""
333 def params(self) -> List[Tuple[str, str]]:
334 return [('int64_t', 'value')]
336 def _encoder_fn(self) -> str:
337 return 'WriteSfixed64'
340 class PackedSfixed64Method(PackedMethod):
341 """Method which writes a proto sfixed64 value."""
342 def params(self) -> List[Tuple[str, str]]:
343 return [('std::span<const int64_t>', 'values')]
345 def _encoder_fn(self) -> str:
346 return 'WritePackedSfixed4'
349 class Uint32Method(WriteMethod):
350 """Method which writes a proto uint32 value."""
351 def params(self) -> List[Tuple[str, str]]:
352 return [('uint32_t', 'value')]
354 def _encoder_fn(self) -> str:
358 class PackedUint32Method(PackedMethod):
359 """Method which writes a proto uint32 value."""
360 def params(self) -> List[Tuple[str, str]]:
361 return [('std::span<const uint32_t>', 'values')]
363 def _encoder_fn(self) -> str:
364 return 'WritePackedUint32'
367 class Fixed32Method(WriteMethod):
368 """Method which writes a proto fixed32 value."""
369 def params(self) -> List[Tuple[str, str]]:
370 return [('uint32_t', 'value')]
372 def _encoder_fn(self) -> str:
373 return 'WriteFixed32'
376 class PackedFixed32Method(PackedMethod):
377 """Method which writes a proto fixed32 value."""
378 def params(self) -> List[Tuple[str, str]]:
379 return [('std::span<const uint32_t>', 'values')]
381 def _encoder_fn(self) -> str:
382 return 'WritePackedFixed32'
385 class Uint64Method(WriteMethod):
386 """Method which writes a proto uint64 value."""
387 def params(self) -> List[Tuple[str, str]]:
388 return [('uint64_t', 'value')]
390 def _encoder_fn(self) -> str:
394 class PackedUint64Method(PackedMethod):
395 """Method which writes a proto uint64 value."""
396 def params(self) -> List[Tuple[str, str]]:
397 return [('std::span<const uint64_t>', 'values')]
399 def _encoder_fn(self) -> str:
400 return 'WritePackedUint64'
403 class Fixed64Method(WriteMethod):
404 """Method which writes a proto fixed64 value."""
405 def params(self) -> List[Tuple[str, str]]:
406 return [('uint64_t', 'value')]
408 def _encoder_fn(self) -> str:
409 return 'WriteFixed64'
412 class PackedFixed64Method(PackedMethod):
413 """Method which writes a proto fixed64 value."""
414 def params(self) -> List[Tuple[str, str]]:
415 return [('std::span<const uint64_t>', 'values')]
417 def _encoder_fn(self) -> str:
418 return 'WritePackedFixed64'
421 class BoolMethod(WriteMethod):
422 """Method which writes a proto bool value."""
423 def params(self) -> List[Tuple[str, str]]:
424 return [('bool', 'value')]
426 def _encoder_fn(self) -> str:
430 class BytesMethod(WriteMethod):
431 """Method which writes a proto bytes value."""
432 def params(self) -> List[Tuple[str, str]]:
433 return [('std::span<const std::byte>', 'value')]
435 def _encoder_fn(self) -> str:
439 class StringLenMethod(WriteMethod):
440 """Method which writes a proto string value with length."""
441 def params(self) -> List[Tuple[str, str]]:
442 return [('const char*', 'value'), ('size_t', 'len')]
444 def _encoder_fn(self) -> str:
448 class StringMethod(WriteMethod):
449 """Method which writes a proto string value."""
450 def params(self) -> List[Tuple[str, str]]:
451 return [('const char*', 'value')]
453 def _encoder_fn(self) -> str:
457 class EnumMethod(WriteMethod):
458 """Method which writes a proto enum value."""
459 def params(self) -> List[Tuple[str, str]]:
460 return [(self._relative_type_namespace(), 'value')]
462 def body(self) -> List[str]:
463 line = 'return encoder_->WriteUint32(' \
464 '{}, static_cast<uint32_t>(value));'.format(self.field_cast())
467 def in_class_definition(self) -> bool:
470 def _encoder_fn(self) -> str:
471 raise NotImplementedError()
474 # Mapping of protobuf field types to their method definitions.
475 PROTO_FIELD_METHODS: Dict[int, List] = {
476 descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE:
477 [DoubleMethod, PackedDoubleMethod],
478 descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT:
479 [FloatMethod, PackedFloatMethod],
480 descriptor_pb2.FieldDescriptorProto.TYPE_INT32:
481 [Int32Method, PackedInt32Method],
482 descriptor_pb2.FieldDescriptorProto.TYPE_SINT32:
483 [Sint32Method, PackedSint32Method],
484 descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32:
485 [Sfixed32Method, PackedSfixed32Method],
486 descriptor_pb2.FieldDescriptorProto.TYPE_INT64:
487 [Int64Method, PackedInt64Method],
488 descriptor_pb2.FieldDescriptorProto.TYPE_SINT64:
489 [Sint64Method, PackedSint64Method],
490 descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64:
491 [Sfixed64Method, PackedSfixed64Method],
492 descriptor_pb2.FieldDescriptorProto.TYPE_UINT32:
493 [Uint32Method, PackedUint32Method],
494 descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32:
495 [Fixed32Method, PackedFixed32Method],
496 descriptor_pb2.FieldDescriptorProto.TYPE_UINT64:
497 [Uint64Method, PackedUint64Method],
498 descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64:
499 [Fixed64Method, PackedFixed64Method],
500 descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: [BoolMethod],
501 descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: [BytesMethod],
502 descriptor_pb2.FieldDescriptorProto.TYPE_STRING: [
503 StringLenMethod, StringMethod
505 descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: [SubMessageMethod],
506 descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: [EnumMethod],
510 def generate_code_for_message(message: ProtoMessage, root: ProtoNode,
511 output: OutputFile) -> None:
512 """Creates a C++ class for a protobuf message."""
513 assert message.type() == ProtoNode.Type.MESSAGE
515 # Message classes inherit from the base proto message class in codegen.h
516 # and use its constructor.
517 base_class = f'{PROTOBUF_NAMESPACE}::{BASE_PROTO_CLASS}'
519 f'class {message.cpp_namespace(root)}::Encoder : public {base_class} {{'
521 output.write_line(' public:')
523 with output.indent():
524 output.write_line(f'using {BASE_PROTO_CLASS}::{BASE_PROTO_CLASS};')
526 # Generate methods for each of the message's fields.
527 for field in message.fields():
528 for method_class in PROTO_FIELD_METHODS[field.type()]:
529 method = method_class(field, message, root)
530 if not method.should_appear():
535 f'{method.return_type()} '
536 f'{method.name()}({method.param_string()})')
538 if not method.in_class_definition():
539 # Method will be defined outside of the class at the end of
541 output.write_line(f'{method_signature};')
544 output.write_line(f'{method_signature} {{')
545 with output.indent():
546 for line in method.body():
547 output.write_line(line)
548 output.write_line('}')
550 output.write_line('};')
553 def define_not_in_class_methods(message: ProtoMessage, root: ProtoNode,
554 output: OutputFile) -> None:
555 """Defines methods for a message class that were previously declared."""
556 assert message.type() == ProtoNode.Type.MESSAGE
558 for field in message.fields():
559 for method_class in PROTO_FIELD_METHODS[field.type()]:
560 method = method_class(field, message, root)
561 if not method.should_appear() or method.in_class_definition():
565 class_name = f'{message.cpp_namespace(root)}::Encoder'
567 f'inline {method.return_type(from_root=True)} '
568 f'{class_name}::{method.name()}({method.param_string()})')
569 output.write_line(f'{method_signature} {{')
570 with output.indent():
571 for line in method.body():
572 output.write_line(line)
573 output.write_line('}')
576 def generate_code_for_enum(enum: ProtoEnum, root: ProtoNode,
577 output: OutputFile) -> None:
578 """Creates a C++ enum for a proto enum."""
579 assert enum.type() == ProtoNode.Type.ENUM
581 output.write_line(f'enum class {enum.cpp_namespace(root)} {{')
582 with output.indent():
583 for name, number in enum.values():
584 output.write_line(f'{name} = {number},')
585 output.write_line('};')
588 def forward_declare(node: ProtoMessage, root: ProtoNode,
589 output: OutputFile) -> None:
590 """Generates code forward-declaring entities in a message's namespace."""
591 namespace = node.cpp_namespace(root)
593 output.write_line(f'namespace {namespace} {{')
595 # Define an enum defining each of the message's fields and their numbers.
596 output.write_line('enum class Fields {')
597 with output.indent():
598 for field in node.fields():
599 output.write_line(f'{field.enum_name()} = {field.number()},')
600 output.write_line('};')
602 # Declare the message's encoder class and all of its enums.
604 output.write_line('class Encoder;')
605 for child in node.children():
606 if child.type() == ProtoNode.Type.ENUM:
608 generate_code_for_enum(cast(ProtoEnum, child), node, output)
610 output.write_line(f'}} // namespace {namespace}')
613 def _proto_filename_to_generated_header(proto_file: str) -> str:
614 """Returns the generated C++ header name for a .proto file."""
615 return os.path.splitext(proto_file)[0] + PROTO_H_EXTENSION
618 def generate_code_for_package(file_descriptor_proto, package: ProtoNode,
619 output: OutputFile) -> None:
620 """Generates code for a single .pb.h file corresponding to a .proto file."""
622 assert package.type() == ProtoNode.Type.PACKAGE
624 output.write_line(f'// {os.path.basename(output.name())} automatically '
625 f'generated by {PLUGIN_NAME} {PLUGIN_VERSION}')
626 output.write_line(f'// on {datetime.now()}')
627 output.write_line('#pragma once\n')
628 output.write_line('#include <cstddef>')
629 output.write_line('#include <cstdint>')
630 output.write_line('#include <span>\n')
631 output.write_line('#include "pw_protobuf/codegen.h"')
633 for imported_file in file_descriptor_proto.dependency:
634 generated_header = _proto_filename_to_generated_header(imported_file)
635 output.write_line(f'#include "{generated_header}"')
637 if package.cpp_namespace():
638 file_namespace = package.cpp_namespace()
639 if file_namespace.startswith('::'):
640 file_namespace = file_namespace[2:]
642 output.write_line(f'\nnamespace {file_namespace} {{')
645 if node.type() == ProtoNode.Type.MESSAGE:
646 forward_declare(cast(ProtoMessage, node), package, output)
648 # Define all top-level enums.
649 for node in package.children():
650 if node.type() == ProtoNode.Type.ENUM:
652 generate_code_for_enum(cast(ProtoEnum, node), package, output)
654 # Run through all messages in the file, generating a class for each.
656 if node.type() == ProtoNode.Type.MESSAGE:
658 generate_code_for_message(cast(ProtoMessage, node), package,
661 # Run a second pass through the classes, this time defining all of the
662 # methods which were previously only declared.
664 if node.type() == ProtoNode.Type.MESSAGE:
665 define_not_in_class_methods(cast(ProtoMessage, node), package,
668 if package.cpp_namespace():
669 output.write_line(f'\n}} // namespace {package.cpp_namespace()}')
672 def process_proto_file(proto_file) -> Iterable[OutputFile]:
673 """Generates code for a single .proto file."""
675 # Two passes are made through the file. The first builds the tree of all
676 # message/enum nodes, then the second creates the fields in each. This is
677 # done as non-primitive fields need pointers to their types, which requires
678 # the entire tree to have been parsed into memory.
679 _, package_root = build_node_tree(proto_file)
681 output_filename = _proto_filename_to_generated_header(proto_file.name)
682 output_file = OutputFile(output_filename)
683 generate_code_for_package(proto_file, package_root, output_file)