1 # Protocol Buffers - Google's data interchange format
2 # Copyright 2008 Google Inc. All rights reserved.
3 # http://code.google.com/p/protobuf/
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are
9 # * Redistributions of source code must retain the above copyright
10 # notice, this list of conditions and the following disclaimer.
11 # * Redistributions in binary form must reproduce the above
12 # copyright notice, this list of conditions and the following disclaimer
13 # in the documentation and/or other materials provided with the
15 # * Neither the name of Google Inc. nor the names of its
16 # contributors may be used to endorse or promote products derived from
17 # this software without specific prior written permission.
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 # This code is meant to work on Python 2.4 and above only.
33 # TODO(robinson): Helpers for verbose, common checks like seeing if a
34 # descriptor's cpp_type is CPPTYPE_MESSAGE.
36 """Contains a metaclass and helper functions used to create
37 protocol message classes from Descriptor objects at runtime.
39 Recall that a metaclass is the "type" of a class.
40 (A class is to a metaclass what an instance is to a class.)
42 In this case, we use the GeneratedProtocolMessageType metaclass
43 to inject all the useful functionality into the classes
44 output by the protocol compiler at compile-time.
46 The upshot of all this is that the real implementation
47 details for ALL pure-Python protocol buffers are *here in
51 __author__ = 'robinson@google.com (Will Robinson)'
54 from cStringIO import StringIO
56 from StringIO import StringIO
61 # We use "as" to avoid name collisions with variables.
62 from google.protobuf.internal import containers
63 from google.protobuf.internal import decoder
64 from google.protobuf.internal import encoder
65 from google.protobuf.internal import enum_type_wrapper
66 from google.protobuf.internal import message_listener as message_listener_mod
67 from google.protobuf.internal import type_checkers
68 from google.protobuf.internal import wire_format
69 from google.protobuf import descriptor as descriptor_mod
70 from google.protobuf import message as message_mod
71 from google.protobuf import text_format
73 _FieldDescriptor = descriptor_mod.FieldDescriptor
76 def NewMessage(bases, descriptor, dictionary):
77 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
78 _AddSlots(descriptor, dictionary)
82 def InitMessage(descriptor, cls):
83 cls._decoders_by_tag = {}
84 cls._extensions_by_name = {}
85 cls._extensions_by_number = {}
86 if (descriptor.has_options and
87 descriptor.GetOptions().message_set_wire_format):
88 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
89 decoder.MessageSetItemDecoder(cls._extensions_by_number))
91 # Attach stuff to each FieldDescriptor for quick lookup later on.
92 for field in descriptor.fields:
93 _AttachFieldHelpers(cls, field)
95 _AddEnumValues(descriptor, cls)
96 _AddInitMethod(descriptor, cls)
97 _AddPropertiesForFields(descriptor, cls)
98 _AddPropertiesForExtensions(descriptor, cls)
99 _AddStaticMethods(cls)
100 _AddMessageMethods(descriptor, cls)
101 _AddPrivateHelperMethods(cls)
102 copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
105 # Stateless helpers for GeneratedProtocolMessageType below.
106 # Outside clients should not access these directly.
108 # I opted not to make any of these methods on the metaclass, to make it more
109 # clear that I'm not really using any state there and to keep clients from
110 # thinking that they have direct access to these construction helpers.
113 def _PropertyName(proto_field_name):
114 """Returns the name of the public property attribute which
115 clients can use to get and (in some cases) set the value
116 of a protocol message field.
119 proto_field_name: The protocol message field name, exactly
120 as it appears (or would appear) in a .proto file.
122 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
123 # nnorwitz makes my day by writing:
125 # FYI. See the keyword module in the stdlib. This could be as simple as:
127 # if keyword.iskeyword(proto_field_name):
128 # return proto_field_name + "_"
129 # return proto_field_name
131 # Kenton says: The above is a BAD IDEA. People rely on being able to use
132 # getattr() and setattr() to reflectively manipulate field values. If we
133 # rename the properties, then every such user has to also make sure to apply
134 # the same transformation. Note that currently if you name a field "yield",
135 # you can still access it just fine using getattr/setattr -- it's not even
136 # that cumbersome to do so.
137 # TODO(kenton): Remove this method entirely if/when everyone agrees with my
139 return proto_field_name
142 def _VerifyExtensionHandle(message, extension_handle):
143 """Verify that the given extension handle is valid."""
145 if not isinstance(extension_handle, _FieldDescriptor):
146 raise KeyError('HasExtension() expects an extension handle, got: %s' %
149 if not extension_handle.is_extension:
150 raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
152 if not extension_handle.containing_type:
153 raise KeyError('"%s" is missing a containing_type.'
154 % extension_handle.full_name)
156 if extension_handle.containing_type is not message.DESCRIPTOR:
157 raise KeyError('Extension "%s" extends message type "%s", but this '
158 'message is of type "%s".' %
159 (extension_handle.full_name,
160 extension_handle.containing_type.full_name,
161 message.DESCRIPTOR.full_name))
164 def _AddSlots(message_descriptor, dictionary):
165 """Adds a __slots__ entry to dictionary, containing the names of all valid
166 attributes for this message type.
169 message_descriptor: A Descriptor instance describing this message type.
170 dictionary: Class dictionary to which we'll add a '__slots__' entry.
172 dictionary['__slots__'] = ['_cached_byte_size',
173 '_cached_byte_size_dirty',
176 '_is_present_in_parent',
178 '_listener_for_children',
182 def _IsMessageSetExtension(field):
183 return (field.is_extension and
184 field.containing_type.has_options and
185 field.containing_type.GetOptions().message_set_wire_format and
186 field.type == _FieldDescriptor.TYPE_MESSAGE and
187 field.message_type == field.extension_scope and
188 field.label == _FieldDescriptor.LABEL_OPTIONAL)
191 def _AttachFieldHelpers(cls, field_descriptor):
192 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
193 is_packed = (field_descriptor.has_options and
194 field_descriptor.GetOptions().packed)
196 if _IsMessageSetExtension(field_descriptor):
197 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
198 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
200 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
201 field_descriptor.number, is_repeated, is_packed)
202 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
203 field_descriptor.number, is_repeated, is_packed)
205 field_descriptor._encoder = field_encoder
206 field_descriptor._sizer = sizer
207 field_descriptor._default_constructor = _DefaultValueConstructorForField(
210 def AddDecoder(wiretype, is_packed):
211 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
212 cls._decoders_by_tag[tag_bytes] = (
213 type_checkers.TYPE_TO_DECODER[field_descriptor.type](
214 field_descriptor.number, is_repeated, is_packed,
215 field_descriptor, field_descriptor._default_constructor))
217 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
220 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
221 # To support wire compatibility of adding packed = true, add a decoder for
222 # packed values regardless of the field's options.
223 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
226 def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
227 extension_dict = descriptor.extensions_by_name
228 for extension_name, extension_field in extension_dict.iteritems():
229 assert extension_name not in dictionary
230 dictionary[extension_name] = extension_field
233 def _AddEnumValues(descriptor, cls):
234 """Sets class-level attributes for all enum fields defined in this message.
236 Also exporting a class-level object that can name enum values.
239 descriptor: Descriptor object for this message type.
240 cls: Class we're constructing for this message type.
242 for enum_type in descriptor.enum_types:
243 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
244 for enum_value in enum_type.values:
245 setattr(cls, enum_value.name, enum_value.number)
248 def _DefaultValueConstructorForField(field):
249 """Returns a function which returns a default value for a field.
252 field: FieldDescriptor object for this field.
254 The returned function has one argument:
255 message: Message instance containing this field, or a weakref proxy
258 That function in turn returns a default value for this field. The default
259 value may refer back to |message| via a weak reference.
262 if field.label == _FieldDescriptor.LABEL_REPEATED:
263 if field.has_default_value and field.default_value != []:
264 raise ValueError('Repeated field default value not empty list: %s' % (
265 field.default_value))
266 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
267 # We can't look at _concrete_class yet since it might not have
268 # been set. (Depends on order in which we initialize the classes).
269 message_type = field.message_type
270 def MakeRepeatedMessageDefault(message):
271 return containers.RepeatedCompositeFieldContainer(
272 message._listener_for_children, field.message_type)
273 return MakeRepeatedMessageDefault
275 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
276 def MakeRepeatedScalarDefault(message):
277 return containers.RepeatedScalarFieldContainer(
278 message._listener_for_children, type_checker)
279 return MakeRepeatedScalarDefault
281 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
282 # _concrete_class may not yet be initialized.
283 message_type = field.message_type
284 def MakeSubMessageDefault(message):
285 result = message_type._concrete_class()
286 result._SetListener(message._listener_for_children)
288 return MakeSubMessageDefault
290 def MakeScalarDefault(message):
291 # TODO(protobuf-team): This may be broken since there may not be
292 # default_value. Combine with has_default_value somehow.
293 return field.default_value
294 return MakeScalarDefault
297 def _AddInitMethod(message_descriptor, cls):
298 """Adds an __init__ method to cls."""
299 fields = message_descriptor.fields
300 def init(self, **kwargs):
301 self._cached_byte_size = 0
302 self._cached_byte_size_dirty = len(kwargs) > 0
304 # _unknown_fields is () when empty for efficiency, and will be turned into
305 # a list if fields are added.
306 self._unknown_fields = ()
307 self._is_present_in_parent = False
308 self._listener = message_listener_mod.NullMessageListener()
309 self._listener_for_children = _Listener(self)
310 for field_name, field_value in kwargs.iteritems():
311 field = _GetFieldByName(message_descriptor, field_name)
313 raise TypeError("%s() got an unexpected keyword argument '%s'" %
314 (message_descriptor.name, field_name))
315 if field.label == _FieldDescriptor.LABEL_REPEATED:
316 copy = field._default_constructor(self)
317 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
318 for val in field_value:
319 copy.add().MergeFrom(val)
321 copy.extend(field_value)
322 self._fields[field] = copy
323 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
324 copy = field._default_constructor(self)
325 copy.MergeFrom(field_value)
326 self._fields[field] = copy
328 setattr(self, field_name, field_value)
330 init.__module__ = None
335 def _GetFieldByName(message_descriptor, field_name):
336 """Returns a field descriptor by field name.
339 message_descriptor: A Descriptor describing all fields in message.
340 field_name: The name of the field to retrieve.
342 The field descriptor associated with the field name.
345 return message_descriptor.fields_by_name[field_name]
347 raise ValueError('Protocol message has no "%s" field.' % field_name)
350 def _AddPropertiesForFields(descriptor, cls):
351 """Adds properties for all fields in this protocol message type."""
352 for field in descriptor.fields:
353 _AddPropertiesForField(field, cls)
355 if descriptor.is_extendable:
356 # _ExtensionDict is just an adaptor with no state so we allocate a new one
357 # every time it is accessed.
358 cls.Extensions = property(lambda self: _ExtensionDict(self))
361 def _AddPropertiesForField(field, cls):
362 """Adds a public property for a protocol message field.
363 Clients can use this property to get and (in the case
364 of non-repeated scalar fields) directly set the value
365 of a protocol message field.
368 field: A FieldDescriptor for this field.
369 cls: The class we're constructing.
371 # Catch it if we add other types that we should
372 # handle specially here.
373 assert _FieldDescriptor.MAX_CPPTYPE == 10
375 constant_name = field.name.upper() + "_FIELD_NUMBER"
376 setattr(cls, constant_name, field.number)
378 if field.label == _FieldDescriptor.LABEL_REPEATED:
379 _AddPropertiesForRepeatedField(field, cls)
380 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
381 _AddPropertiesForNonRepeatedCompositeField(field, cls)
383 _AddPropertiesForNonRepeatedScalarField(field, cls)
386 def _AddPropertiesForRepeatedField(field, cls):
387 """Adds a public property for a "repeated" protocol message field. Clients
388 can use this property to get the value of the field, which will be either a
389 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
392 Note that when clients add values to these containers, we perform
393 type-checking in the case of repeated scalar fields, and we also set any
394 necessary "has" bits as a side-effect.
397 field: A FieldDescriptor for this field.
398 cls: The class we're constructing.
400 proto_field_name = field.name
401 property_name = _PropertyName(proto_field_name)
404 field_value = self._fields.get(field)
405 if field_value is None:
406 # Construct a new object to represent this field.
407 field_value = field._default_constructor(self)
409 # Atomically check if another thread has preempted us and, if not, swap
410 # in the new object we just created. If someone has preempted us, we
411 # take that object and discard ours.
412 # WARNING: We are relying on setdefault() being atomic. This is true
413 # in CPython but we haven't investigated others. This warning appears
414 # in several other locations in this file.
415 field_value = self._fields.setdefault(field, field_value)
417 getter.__module__ = None
418 getter.__doc__ = 'Getter for %s.' % proto_field_name
420 # We define a setter just so we can throw an exception with a more
421 # helpful error message.
422 def setter(self, new_value):
423 raise AttributeError('Assignment not allowed to repeated field '
424 '"%s" in protocol message object.' % proto_field_name)
426 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
427 setattr(cls, property_name, property(getter, setter, doc=doc))
430 def _AddPropertiesForNonRepeatedScalarField(field, cls):
431 """Adds a public property for a nonrepeated, scalar protocol message field.
432 Clients can use this property to get and directly set the value of the field.
433 Note that when the client sets the value of a field by using this property,
434 all necessary "has" bits are set as a side-effect, and we also perform
438 field: A FieldDescriptor for this field.
439 cls: The class we're constructing.
441 proto_field_name = field.name
442 property_name = _PropertyName(proto_field_name)
443 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
444 default_value = field.default_value
448 # TODO(protobuf-team): This may be broken since there may not be
449 # default_value. Combine with has_default_value somehow.
450 return self._fields.get(field, default_value)
451 getter.__module__ = None
452 getter.__doc__ = 'Getter for %s.' % proto_field_name
453 def setter(self, new_value):
454 type_checker.CheckValue(new_value)
455 self._fields[field] = new_value
456 # Check _cached_byte_size_dirty inline to improve performance, since scalar
457 # setters are called frequently.
458 if not self._cached_byte_size_dirty:
461 setter.__module__ = None
462 setter.__doc__ = 'Setter for %s.' % proto_field_name
464 # Add a property to encapsulate the getter/setter.
465 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
466 setattr(cls, property_name, property(getter, setter, doc=doc))
469 def _AddPropertiesForNonRepeatedCompositeField(field, cls):
470 """Adds a public property for a nonrepeated, composite protocol message field.
471 A composite field is a "group" or "message" field.
473 Clients can use this property to get the value of the field, but cannot
474 assign to the property directly.
477 field: A FieldDescriptor for this field.
478 cls: The class we're constructing.
480 # TODO(robinson): Remove duplication with similar method
481 # for non-repeated scalars.
482 proto_field_name = field.name
483 property_name = _PropertyName(proto_field_name)
485 # TODO(komarek): Can anyone explain to me why we cache the message_type this
486 # way, instead of referring to field.message_type inside of getter(self)?
487 # What if someone sets message_type later on (which makes for simpler
488 # dyanmic proto descriptor and class creation code).
489 message_type = field.message_type
492 field_value = self._fields.get(field)
493 if field_value is None:
494 # Construct a new object to represent this field.
495 field_value = message_type._concrete_class() # use field.message_type?
496 field_value._SetListener(self._listener_for_children)
498 # Atomically check if another thread has preempted us and, if not, swap
499 # in the new object we just created. If someone has preempted us, we
500 # take that object and discard ours.
501 # WARNING: We are relying on setdefault() being atomic. This is true
502 # in CPython but we haven't investigated others. This warning appears
503 # in several other locations in this file.
504 field_value = self._fields.setdefault(field, field_value)
506 getter.__module__ = None
507 getter.__doc__ = 'Getter for %s.' % proto_field_name
509 # We define a setter just so we can throw an exception with a more
510 # helpful error message.
511 def setter(self, new_value):
512 raise AttributeError('Assignment not allowed to composite field '
513 '"%s" in protocol message object.' % proto_field_name)
515 # Add a property to encapsulate the getter.
516 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
517 setattr(cls, property_name, property(getter, setter, doc=doc))
520 def _AddPropertiesForExtensions(descriptor, cls):
521 """Adds properties for all fields in this protocol message type."""
522 extension_dict = descriptor.extensions_by_name
523 for extension_name, extension_field in extension_dict.iteritems():
524 constant_name = extension_name.upper() + "_FIELD_NUMBER"
525 setattr(cls, constant_name, extension_field.number)
528 def _AddStaticMethods(cls):
529 # TODO(robinson): This probably needs to be thread-safe(?)
530 def RegisterExtension(extension_handle):
531 extension_handle.containing_type = cls.DESCRIPTOR
532 _AttachFieldHelpers(cls, extension_handle)
534 # Try to insert our extension, failing if an extension with the same number
536 actual_handle = cls._extensions_by_number.setdefault(
537 extension_handle.number, extension_handle)
538 if actual_handle is not extension_handle:
539 raise AssertionError(
540 'Extensions "%s" and "%s" both try to extend message type "%s" with '
542 (extension_handle.full_name, actual_handle.full_name,
543 cls.DESCRIPTOR.full_name, extension_handle.number))
545 cls._extensions_by_name[extension_handle.full_name] = extension_handle
547 handle = extension_handle # avoid line wrapping
548 if _IsMessageSetExtension(handle):
549 # MessageSet extension. Also register under type name.
550 cls._extensions_by_name[
551 extension_handle.message_type.full_name] = extension_handle
553 cls.RegisterExtension = staticmethod(RegisterExtension)
557 message.MergeFromString(s)
559 cls.FromString = staticmethod(FromString)
562 def _IsPresent(item):
563 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
564 value should be included in the list returned by ListFields()."""
566 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
568 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
569 return item[1]._is_present_in_parent
574 def _AddListFieldsMethod(message_descriptor, cls):
575 """Helper for _AddMessageMethods()."""
577 def ListFields(self):
578 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
579 all_fields.sort(key = lambda item: item[0].number)
582 cls.ListFields = ListFields
585 def _AddHasFieldMethod(message_descriptor, cls):
586 """Helper for _AddMessageMethods()."""
589 for field in message_descriptor.fields:
590 if field.label != _FieldDescriptor.LABEL_REPEATED:
591 singular_fields[field.name] = field
593 def HasField(self, field_name):
595 field = singular_fields[field_name]
598 'Protocol message has no singular "%s" field.' % field_name)
600 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
601 value = self._fields.get(field)
602 return value is not None and value._is_present_in_parent
604 return field in self._fields
605 cls.HasField = HasField
608 def _AddClearFieldMethod(message_descriptor, cls):
609 """Helper for _AddMessageMethods()."""
610 def ClearField(self, field_name):
612 field = message_descriptor.fields_by_name[field_name]
614 raise ValueError('Protocol message has no "%s" field.' % field_name)
616 if field in self._fields:
617 # Note: If the field is a sub-message, its listener will still point
618 # at us. That's fine, because the worst than can happen is that it
619 # will call _Modified() and invalidate our byte size. Big deal.
620 del self._fields[field]
622 # Always call _Modified() -- even if nothing was changed, this is
623 # a mutating method, and thus calling it should cause the field to become
624 # present in the parent message.
627 cls.ClearField = ClearField
630 def _AddClearExtensionMethod(cls):
631 """Helper for _AddMessageMethods()."""
632 def ClearExtension(self, extension_handle):
633 _VerifyExtensionHandle(self, extension_handle)
635 # Similar to ClearField(), above.
636 if extension_handle in self._fields:
637 del self._fields[extension_handle]
639 cls.ClearExtension = ClearExtension
642 def _AddClearMethod(message_descriptor, cls):
643 """Helper for _AddMessageMethods()."""
647 self._unknown_fields = ()
652 def _AddHasExtensionMethod(cls):
653 """Helper for _AddMessageMethods()."""
654 def HasExtension(self, extension_handle):
655 _VerifyExtensionHandle(self, extension_handle)
656 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
657 raise KeyError('"%s" is repeated.' % extension_handle.full_name)
659 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
660 value = self._fields.get(extension_handle)
661 return value is not None and value._is_present_in_parent
663 return extension_handle in self._fields
664 cls.HasExtension = HasExtension
667 def _AddEqualsMethod(message_descriptor, cls):
668 """Helper for _AddMessageMethods()."""
669 def __eq__(self, other):
670 if (not isinstance(other, message_mod.Message) or
671 other.DESCRIPTOR != self.DESCRIPTOR):
677 if not self.ListFields() == other.ListFields():
680 # Sort unknown fields because their order shouldn't affect equality test.
681 unknown_fields = list(self._unknown_fields)
682 unknown_fields.sort()
683 other_unknown_fields = list(other._unknown_fields)
684 other_unknown_fields.sort()
686 return unknown_fields == other_unknown_fields
691 def _AddStrMethod(message_descriptor, cls):
692 """Helper for _AddMessageMethods()."""
694 return text_format.MessageToString(self)
695 cls.__str__ = __str__
698 def _AddUnicodeMethod(unused_message_descriptor, cls):
699 """Helper for _AddMessageMethods()."""
701 def __unicode__(self):
702 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
703 cls.__unicode__ = __unicode__
706 def _AddSetListenerMethod(cls):
707 """Helper for _AddMessageMethods()."""
708 def SetListener(self, listener):
710 self._listener = message_listener_mod.NullMessageListener()
712 self._listener = listener
713 cls._SetListener = SetListener
716 def _BytesForNonRepeatedElement(value, field_number, field_type):
717 """Returns the number of bytes needed to serialize a non-repeated element.
718 The returned byte count includes space for tag information and any
719 other additional space associated with serializing value.
722 value: Value we're serializing.
723 field_number: Field number of this value. (Since the field number
724 is stored as part of a varint-encoded tag, this has an impact
725 on the total bytes required to serialize the value).
726 field_type: The type of the field. One of the TYPE_* constants
727 within FieldDescriptor.
730 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
731 return fn(field_number, value)
733 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
736 def _AddByteSizeMethod(message_descriptor, cls):
737 """Helper for _AddMessageMethods()."""
740 if not self._cached_byte_size_dirty:
741 return self._cached_byte_size
744 for field_descriptor, field_value in self.ListFields():
745 size += field_descriptor._sizer(field_value)
747 for tag_bytes, value_bytes in self._unknown_fields:
748 size += len(tag_bytes) + len(value_bytes)
750 self._cached_byte_size = size
751 self._cached_byte_size_dirty = False
752 self._listener_for_children.dirty = False
755 cls.ByteSize = ByteSize
758 def _AddSerializeToStringMethod(message_descriptor, cls):
759 """Helper for _AddMessageMethods()."""
761 def SerializeToString(self):
762 # Check if the message has all of its required fields set.
764 if not self.IsInitialized():
765 raise message_mod.EncodeError(
766 'Message %s is missing required fields: %s' % (
767 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
768 return self.SerializePartialToString()
769 cls.SerializeToString = SerializeToString
772 def _AddSerializePartialToStringMethod(message_descriptor, cls):
773 """Helper for _AddMessageMethods()."""
775 def SerializePartialToString(self):
777 self._InternalSerialize(out.write)
778 return out.getvalue()
779 cls.SerializePartialToString = SerializePartialToString
781 def InternalSerialize(self, write_bytes):
782 for field_descriptor, field_value in self.ListFields():
783 field_descriptor._encoder(write_bytes, field_value)
784 for tag_bytes, value_bytes in self._unknown_fields:
785 write_bytes(tag_bytes)
786 write_bytes(value_bytes)
787 cls._InternalSerialize = InternalSerialize
790 def _AddMergeFromStringMethod(message_descriptor, cls):
791 """Helper for _AddMessageMethods()."""
792 def MergeFromString(self, serialized):
793 length = len(serialized)
795 if self._InternalParse(serialized, 0, length) != length:
796 # The only reason _InternalParse would return early is if it
797 # encountered an end-group tag.
798 raise message_mod.DecodeError('Unexpected end-group tag.')
800 raise message_mod.DecodeError('Truncated message.')
801 except struct.error, e:
802 raise message_mod.DecodeError(e)
803 return length # Return this for legacy reasons.
804 cls.MergeFromString = MergeFromString
806 local_ReadTag = decoder.ReadTag
807 local_SkipField = decoder.SkipField
808 decoders_by_tag = cls._decoders_by_tag
810 def InternalParse(self, buffer, pos, end):
812 field_dict = self._fields
813 unknown_field_list = self._unknown_fields
815 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
816 field_decoder = decoders_by_tag.get(tag_bytes)
817 if field_decoder is None:
818 value_start_pos = new_pos
819 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
822 if not unknown_field_list:
823 unknown_field_list = self._unknown_fields = []
824 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
827 pos = field_decoder(buffer, new_pos, end, self, field_dict)
829 cls._InternalParse = InternalParse
832 def _AddIsInitializedMethod(message_descriptor, cls):
833 """Adds the IsInitialized and FindInitializationError methods to the
834 protocol message class."""
836 required_fields = [field for field in message_descriptor.fields
837 if field.label == _FieldDescriptor.LABEL_REQUIRED]
839 def IsInitialized(self, errors=None):
840 """Checks if all required fields of a message are set.
843 errors: A list which, if provided, will be populated with the field
844 paths of all missing required fields.
847 True iff the specified message has all required fields set.
850 # Performance is critical so we avoid HasField() and ListFields().
852 for field in required_fields:
853 if (field not in self._fields or
854 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
855 not self._fields[field]._is_present_in_parent)):
856 if errors is not None:
857 errors.extend(self.FindInitializationErrors())
860 for field, value in self._fields.iteritems():
861 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
862 if field.label == _FieldDescriptor.LABEL_REPEATED:
863 for element in value:
864 if not element.IsInitialized():
865 if errors is not None:
866 errors.extend(self.FindInitializationErrors())
868 elif value._is_present_in_parent and not value.IsInitialized():
869 if errors is not None:
870 errors.extend(self.FindInitializationErrors())
875 cls.IsInitialized = IsInitialized
877 def FindInitializationErrors(self):
878 """Finds required fields which are not initialized.
881 A list of strings. Each string is a path to an uninitialized field from
882 the top-level message, e.g. "foo.bar[5].baz".
885 errors = [] # simplify things
887 for field in required_fields:
888 if not self.HasField(field.name):
889 errors.append(field.name)
891 for field, value in self.ListFields():
892 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
893 if field.is_extension:
894 name = "(%s)" % field.full_name
898 if field.label == _FieldDescriptor.LABEL_REPEATED:
899 for i in xrange(len(value)):
901 prefix = "%s[%d]." % (name, i)
902 sub_errors = element.FindInitializationErrors()
903 errors += [ prefix + error for error in sub_errors ]
906 sub_errors = value.FindInitializationErrors()
907 errors += [ prefix + error for error in sub_errors ]
911 cls.FindInitializationErrors = FindInitializationErrors
914 def _AddMergeFromMethod(cls):
915 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
916 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
918 def MergeFrom(self, msg):
919 if not isinstance(msg, cls):
921 "Parameter to MergeFrom() must be instance of same class: "
922 "expected %s got %s." % (cls.__name__, type(msg).__name__))
924 assert msg is not self
927 fields = self._fields
929 for field, value in msg._fields.iteritems():
930 if field.label == LABEL_REPEATED:
931 field_value = fields.get(field)
932 if field_value is None:
933 # Construct a new object to represent this field.
934 field_value = field._default_constructor(self)
935 fields[field] = field_value
936 field_value.MergeFrom(value)
937 elif field.cpp_type == CPPTYPE_MESSAGE:
938 if value._is_present_in_parent:
939 field_value = fields.get(field)
940 if field_value is None:
941 # Construct a new object to represent this field.
942 field_value = field._default_constructor(self)
943 fields[field] = field_value
944 field_value.MergeFrom(value)
946 self._fields[field] = value
948 if msg._unknown_fields:
949 if not self._unknown_fields:
950 self._unknown_fields = []
951 self._unknown_fields.extend(msg._unknown_fields)
953 cls.MergeFrom = MergeFrom
956 def _AddMessageMethods(message_descriptor, cls):
957 """Adds implementations of all Message methods to cls."""
958 _AddListFieldsMethod(message_descriptor, cls)
959 _AddHasFieldMethod(message_descriptor, cls)
960 _AddClearFieldMethod(message_descriptor, cls)
961 if message_descriptor.is_extendable:
962 _AddClearExtensionMethod(cls)
963 _AddHasExtensionMethod(cls)
964 _AddClearMethod(message_descriptor, cls)
965 _AddEqualsMethod(message_descriptor, cls)
966 _AddStrMethod(message_descriptor, cls)
967 _AddUnicodeMethod(message_descriptor, cls)
968 _AddSetListenerMethod(cls)
969 _AddByteSizeMethod(message_descriptor, cls)
970 _AddSerializeToStringMethod(message_descriptor, cls)
971 _AddSerializePartialToStringMethod(message_descriptor, cls)
972 _AddMergeFromStringMethod(message_descriptor, cls)
973 _AddIsInitializedMethod(message_descriptor, cls)
974 _AddMergeFromMethod(cls)
977 def _AddPrivateHelperMethods(cls):
978 """Adds implementation of private helper methods to cls."""
981 """Sets the _cached_byte_size_dirty bit to true,
982 and propagates this to our listener iff this was a state change.
985 # Note: Some callers check _cached_byte_size_dirty before calling
986 # _Modified() as an extra optimization. So, if this method is ever
987 # changed such that it does stuff even when _cached_byte_size_dirty is
988 # already true, the callers need to be updated.
989 if not self._cached_byte_size_dirty:
990 self._cached_byte_size_dirty = True
991 self._listener_for_children.dirty = True
992 self._is_present_in_parent = True
993 self._listener.Modified()
995 cls._Modified = Modified
996 cls.SetInParent = Modified
999 class _Listener(object):
1001 """MessageListener implementation that a parent message registers with its
1004 In order to support semantics like:
1006 foo.bar.baz.qux = 23
1007 assert foo.HasField('bar')
1009 ...child objects must have back references to their parents.
1010 This helper class is at the heart of this support.
1013 def __init__(self, parent_message):
1015 parent_message: The message whose _Modified() method we should call when
1016 we receive Modified() messages.
1018 # This listener establishes a back reference from a child (contained) object
1019 # to its parent (containing) object. We make this a weak reference to avoid
1020 # creating cyclic garbage when the client finishes with the 'parent' object
1022 if isinstance(parent_message, weakref.ProxyType):
1023 self._parent_message_weakref = parent_message
1025 self._parent_message_weakref = weakref.proxy(parent_message)
1027 # As an optimization, we also indicate directly on the listener whether
1028 # or not the parent message is dirty. This way we can avoid traversing
1029 # up the tree in the common case.
1036 # Propagate the signal to our parents iff this is the first field set.
1037 self._parent_message_weakref._Modified()
1038 except ReferenceError:
1039 # We can get here if a client has kept a reference to a child object,
1040 # and is now setting a field on it, but the child's parent has been
1041 # garbage-collected. This is not an error.
1045 # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
1046 # TODO(robinson): Unify error handling of "unknown extension" crap.
1047 # TODO(robinson): Support iteritems()-style iteration over all
1048 # extensions with the "has" bits turned on?
1049 class _ExtensionDict(object):
1051 """Dict-like container for supporting an indexable "Extensions"
1052 field on proto instances.
1054 Note that in all cases we expect extension handles to be
1058 def __init__(self, extended_message):
1059 """extended_message: Message instance for which we are the Extensions dict.
1062 self._extended_message = extended_message
1064 def __getitem__(self, extension_handle):
1065 """Returns the current value of the given extension handle."""
1067 _VerifyExtensionHandle(self._extended_message, extension_handle)
1069 result = self._extended_message._fields.get(extension_handle)
1070 if result is not None:
1073 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1074 result = extension_handle._default_constructor(self._extended_message)
1075 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1076 result = extension_handle.message_type._concrete_class()
1078 result._SetListener(self._extended_message._listener_for_children)
1079 except ReferenceError:
1082 # Singular scalar -- just return the default without inserting into the
1084 return extension_handle.default_value
1086 # Atomically check if another thread has preempted us and, if not, swap
1087 # in the new object we just created. If someone has preempted us, we
1088 # take that object and discard ours.
1089 # WARNING: We are relying on setdefault() being atomic. This is true
1090 # in CPython but we haven't investigated others. This warning appears
1091 # in several other locations in this file.
1092 result = self._extended_message._fields.setdefault(
1093 extension_handle, result)
1097 def __eq__(self, other):
1098 if not isinstance(other, self.__class__):
1101 my_fields = self._extended_message.ListFields()
1102 other_fields = other._extended_message.ListFields()
1104 # Get rid of non-extension fields.
1105 my_fields = [ field for field in my_fields if field.is_extension ]
1106 other_fields = [ field for field in other_fields if field.is_extension ]
1108 return my_fields == other_fields
1110 def __ne__(self, other):
1111 return not self == other
1114 raise TypeError('unhashable object')
1116 # Note that this is only meaningful for non-repeated, scalar extension
1117 # fields. Note also that we may have to call _Modified() when we do
1118 # successfully set a field this way, to set any necssary "has" bits in the
1119 # ancestors of the extended message.
1120 def __setitem__(self, extension_handle, value):
1121 """If extension_handle specifies a non-repeated, scalar extension
1122 field, sets the value of that field.
1125 _VerifyExtensionHandle(self._extended_message, extension_handle)
1127 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1128 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1130 'Cannot assign to extension "%s" because it is a repeated or '
1131 'composite type.' % extension_handle.full_name)
1133 # It's slightly wasteful to lookup the type checker each time,
1134 # but we expect this to be a vanishingly uncommon case anyway.
1135 type_checker = type_checkers.GetTypeChecker(
1136 extension_handle.cpp_type, extension_handle.type)
1137 type_checker.CheckValue(value)
1138 self._extended_message._fields[extension_handle] = value
1139 self._extended_message._Modified()
1141 def _FindExtensionByName(self, name):
1142 """Tries to find a known extension with the specified name.
1145 name: Extension full name.
1148 Extension field descriptor.
1150 return self._extended_message._extensions_by_name.get(name, None)