- add sources.
[platform/framework/web/crosswalk.git] / src / third_party / protobuf / python / google / protobuf / internal / python_message.py
1 # Protocol Buffers - Google's data interchange format
2 # Copyright 2008 Google Inc.  All rights reserved.
3 # http://code.google.com/p/protobuf/
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are
7 # met:
8 #
9 #     * Redistributions of source code must retain the above copyright
10 # notice, this list of conditions and the following disclaimer.
11 #     * Redistributions in binary form must reproduce the above
12 # copyright notice, this list of conditions and the following disclaimer
13 # in the documentation and/or other materials provided with the
14 # distribution.
15 #     * Neither the name of Google Inc. nor the names of its
16 # contributors may be used to endorse or promote products derived from
17 # this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 # This code is meant to work on Python 2.4 and above only.
32 #
33 # TODO(robinson): Helpers for verbose, common checks like seeing if a
34 # descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36 """Contains a metaclass and helper functions used to create
37 protocol message classes from Descriptor objects at runtime.
38
39 Recall that a metaclass is the "type" of a class.
40 (A class is to a metaclass what an instance is to a class.)
41
42 In this case, we use the GeneratedProtocolMessageType metaclass
43 to inject all the useful functionality into the classes
44 output by the protocol compiler at compile-time.
45
46 The upshot of all this is that the real implementation
47 details for ALL pure-Python protocol buffers are *here in
48 this file*.
49 """
50
51 __author__ = 'robinson@google.com (Will Robinson)'
52
53 try:
54   from cStringIO import StringIO
55 except ImportError:
56   from StringIO import StringIO
57 import copy_reg
58 import struct
59 import weakref
60
61 # We use "as" to avoid name collisions with variables.
62 from google.protobuf.internal import containers
63 from google.protobuf.internal import decoder
64 from google.protobuf.internal import encoder
65 from google.protobuf.internal import enum_type_wrapper
66 from google.protobuf.internal import message_listener as message_listener_mod
67 from google.protobuf.internal import type_checkers
68 from google.protobuf.internal import wire_format
69 from google.protobuf import descriptor as descriptor_mod
70 from google.protobuf import message as message_mod
71 from google.protobuf import text_format
72
73 _FieldDescriptor = descriptor_mod.FieldDescriptor
74
75
76 def NewMessage(bases, descriptor, dictionary):
77   _AddClassAttributesForNestedExtensions(descriptor, dictionary)
78   _AddSlots(descriptor, dictionary)
79   return bases
80
81
82 def InitMessage(descriptor, cls):
83   cls._decoders_by_tag = {}
84   cls._extensions_by_name = {}
85   cls._extensions_by_number = {}
86   if (descriptor.has_options and
87       descriptor.GetOptions().message_set_wire_format):
88     cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
89         decoder.MessageSetItemDecoder(cls._extensions_by_number))
90
91   # Attach stuff to each FieldDescriptor for quick lookup later on.
92   for field in descriptor.fields:
93     _AttachFieldHelpers(cls, field)
94
95   _AddEnumValues(descriptor, cls)
96   _AddInitMethod(descriptor, cls)
97   _AddPropertiesForFields(descriptor, cls)
98   _AddPropertiesForExtensions(descriptor, cls)
99   _AddStaticMethods(cls)
100   _AddMessageMethods(descriptor, cls)
101   _AddPrivateHelperMethods(cls)
102   copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
103
104
105 # Stateless helpers for GeneratedProtocolMessageType below.
106 # Outside clients should not access these directly.
107 #
108 # I opted not to make any of these methods on the metaclass, to make it more
109 # clear that I'm not really using any state there and to keep clients from
110 # thinking that they have direct access to these construction helpers.
111
112
113 def _PropertyName(proto_field_name):
114   """Returns the name of the public property attribute which
115   clients can use to get and (in some cases) set the value
116   of a protocol message field.
117
118   Args:
119     proto_field_name: The protocol message field name, exactly
120       as it appears (or would appear) in a .proto file.
121   """
122   # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
123   # nnorwitz makes my day by writing:
124   # """
125   # FYI.  See the keyword module in the stdlib. This could be as simple as:
126   #
127   # if keyword.iskeyword(proto_field_name):
128   #   return proto_field_name + "_"
129   # return proto_field_name
130   # """
131   # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
132   #   getattr() and setattr() to reflectively manipulate field values.  If we
133   #   rename the properties, then every such user has to also make sure to apply
134   #   the same transformation.  Note that currently if you name a field "yield",
135   #   you can still access it just fine using getattr/setattr -- it's not even
136   #   that cumbersome to do so.
137   # TODO(kenton):  Remove this method entirely if/when everyone agrees with my
138   #   position.
139   return proto_field_name
140
141
142 def _VerifyExtensionHandle(message, extension_handle):
143   """Verify that the given extension handle is valid."""
144
145   if not isinstance(extension_handle, _FieldDescriptor):
146     raise KeyError('HasExtension() expects an extension handle, got: %s' %
147                    extension_handle)
148
149   if not extension_handle.is_extension:
150     raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
151
152   if not extension_handle.containing_type:
153     raise KeyError('"%s" is missing a containing_type.'
154                    % extension_handle.full_name)
155
156   if extension_handle.containing_type is not message.DESCRIPTOR:
157     raise KeyError('Extension "%s" extends message type "%s", but this '
158                    'message is of type "%s".' %
159                    (extension_handle.full_name,
160                     extension_handle.containing_type.full_name,
161                     message.DESCRIPTOR.full_name))
162
163
164 def _AddSlots(message_descriptor, dictionary):
165   """Adds a __slots__ entry to dictionary, containing the names of all valid
166   attributes for this message type.
167
168   Args:
169     message_descriptor: A Descriptor instance describing this message type.
170     dictionary: Class dictionary to which we'll add a '__slots__' entry.
171   """
172   dictionary['__slots__'] = ['_cached_byte_size',
173                              '_cached_byte_size_dirty',
174                              '_fields',
175                              '_unknown_fields',
176                              '_is_present_in_parent',
177                              '_listener',
178                              '_listener_for_children',
179                              '__weakref__']
180
181
182 def _IsMessageSetExtension(field):
183   return (field.is_extension and
184           field.containing_type.has_options and
185           field.containing_type.GetOptions().message_set_wire_format and
186           field.type == _FieldDescriptor.TYPE_MESSAGE and
187           field.message_type == field.extension_scope and
188           field.label == _FieldDescriptor.LABEL_OPTIONAL)
189
190
191 def _AttachFieldHelpers(cls, field_descriptor):
192   is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
193   is_packed = (field_descriptor.has_options and
194                field_descriptor.GetOptions().packed)
195
196   if _IsMessageSetExtension(field_descriptor):
197     field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
198     sizer = encoder.MessageSetItemSizer(field_descriptor.number)
199   else:
200     field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
201         field_descriptor.number, is_repeated, is_packed)
202     sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
203         field_descriptor.number, is_repeated, is_packed)
204
205   field_descriptor._encoder = field_encoder
206   field_descriptor._sizer = sizer
207   field_descriptor._default_constructor = _DefaultValueConstructorForField(
208       field_descriptor)
209
210   def AddDecoder(wiretype, is_packed):
211     tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
212     cls._decoders_by_tag[tag_bytes] = (
213         type_checkers.TYPE_TO_DECODER[field_descriptor.type](
214             field_descriptor.number, is_repeated, is_packed,
215             field_descriptor, field_descriptor._default_constructor))
216
217   AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
218              False)
219
220   if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
221     # To support wire compatibility of adding packed = true, add a decoder for
222     # packed values regardless of the field's options.
223     AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
224
225
226 def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
227   extension_dict = descriptor.extensions_by_name
228   for extension_name, extension_field in extension_dict.iteritems():
229     assert extension_name not in dictionary
230     dictionary[extension_name] = extension_field
231
232
233 def _AddEnumValues(descriptor, cls):
234   """Sets class-level attributes for all enum fields defined in this message.
235
236   Also exporting a class-level object that can name enum values.
237
238   Args:
239     descriptor: Descriptor object for this message type.
240     cls: Class we're constructing for this message type.
241   """
242   for enum_type in descriptor.enum_types:
243     setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
244     for enum_value in enum_type.values:
245       setattr(cls, enum_value.name, enum_value.number)
246
247
248 def _DefaultValueConstructorForField(field):
249   """Returns a function which returns a default value for a field.
250
251   Args:
252     field: FieldDescriptor object for this field.
253
254   The returned function has one argument:
255     message: Message instance containing this field, or a weakref proxy
256       of same.
257
258   That function in turn returns a default value for this field.  The default
259     value may refer back to |message| via a weak reference.
260   """
261
262   if field.label == _FieldDescriptor.LABEL_REPEATED:
263     if field.has_default_value and field.default_value != []:
264       raise ValueError('Repeated field default value not empty list: %s' % (
265           field.default_value))
266     if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
267       # We can't look at _concrete_class yet since it might not have
268       # been set.  (Depends on order in which we initialize the classes).
269       message_type = field.message_type
270       def MakeRepeatedMessageDefault(message):
271         return containers.RepeatedCompositeFieldContainer(
272             message._listener_for_children, field.message_type)
273       return MakeRepeatedMessageDefault
274     else:
275       type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
276       def MakeRepeatedScalarDefault(message):
277         return containers.RepeatedScalarFieldContainer(
278             message._listener_for_children, type_checker)
279       return MakeRepeatedScalarDefault
280
281   if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
282     # _concrete_class may not yet be initialized.
283     message_type = field.message_type
284     def MakeSubMessageDefault(message):
285       result = message_type._concrete_class()
286       result._SetListener(message._listener_for_children)
287       return result
288     return MakeSubMessageDefault
289
290   def MakeScalarDefault(message):
291     # TODO(protobuf-team): This may be broken since there may not be
292     # default_value.  Combine with has_default_value somehow.
293     return field.default_value
294   return MakeScalarDefault
295
296
297 def _AddInitMethod(message_descriptor, cls):
298   """Adds an __init__ method to cls."""
299   fields = message_descriptor.fields
300   def init(self, **kwargs):
301     self._cached_byte_size = 0
302     self._cached_byte_size_dirty = len(kwargs) > 0
303     self._fields = {}
304     # _unknown_fields is () when empty for efficiency, and will be turned into
305     # a list if fields are added.
306     self._unknown_fields = ()
307     self._is_present_in_parent = False
308     self._listener = message_listener_mod.NullMessageListener()
309     self._listener_for_children = _Listener(self)
310     for field_name, field_value in kwargs.iteritems():
311       field = _GetFieldByName(message_descriptor, field_name)
312       if field is None:
313         raise TypeError("%s() got an unexpected keyword argument '%s'" %
314                         (message_descriptor.name, field_name))
315       if field.label == _FieldDescriptor.LABEL_REPEATED:
316         copy = field._default_constructor(self)
317         if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
318           for val in field_value:
319             copy.add().MergeFrom(val)
320         else:  # Scalar
321           copy.extend(field_value)
322         self._fields[field] = copy
323       elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
324         copy = field._default_constructor(self)
325         copy.MergeFrom(field_value)
326         self._fields[field] = copy
327       else:
328         setattr(self, field_name, field_value)
329
330   init.__module__ = None
331   init.__doc__ = None
332   cls.__init__ = init
333
334
335 def _GetFieldByName(message_descriptor, field_name):
336   """Returns a field descriptor by field name.
337
338   Args:
339     message_descriptor: A Descriptor describing all fields in message.
340     field_name: The name of the field to retrieve.
341   Returns:
342     The field descriptor associated with the field name.
343   """
344   try:
345     return message_descriptor.fields_by_name[field_name]
346   except KeyError:
347     raise ValueError('Protocol message has no "%s" field.' % field_name)
348
349
350 def _AddPropertiesForFields(descriptor, cls):
351   """Adds properties for all fields in this protocol message type."""
352   for field in descriptor.fields:
353     _AddPropertiesForField(field, cls)
354
355   if descriptor.is_extendable:
356     # _ExtensionDict is just an adaptor with no state so we allocate a new one
357     # every time it is accessed.
358     cls.Extensions = property(lambda self: _ExtensionDict(self))
359
360
361 def _AddPropertiesForField(field, cls):
362   """Adds a public property for a protocol message field.
363   Clients can use this property to get and (in the case
364   of non-repeated scalar fields) directly set the value
365   of a protocol message field.
366
367   Args:
368     field: A FieldDescriptor for this field.
369     cls: The class we're constructing.
370   """
371   # Catch it if we add other types that we should
372   # handle specially here.
373   assert _FieldDescriptor.MAX_CPPTYPE == 10
374
375   constant_name = field.name.upper() + "_FIELD_NUMBER"
376   setattr(cls, constant_name, field.number)
377
378   if field.label == _FieldDescriptor.LABEL_REPEATED:
379     _AddPropertiesForRepeatedField(field, cls)
380   elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
381     _AddPropertiesForNonRepeatedCompositeField(field, cls)
382   else:
383     _AddPropertiesForNonRepeatedScalarField(field, cls)
384
385
386 def _AddPropertiesForRepeatedField(field, cls):
387   """Adds a public property for a "repeated" protocol message field.  Clients
388   can use this property to get the value of the field, which will be either a
389   _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
390   below).
391
392   Note that when clients add values to these containers, we perform
393   type-checking in the case of repeated scalar fields, and we also set any
394   necessary "has" bits as a side-effect.
395
396   Args:
397     field: A FieldDescriptor for this field.
398     cls: The class we're constructing.
399   """
400   proto_field_name = field.name
401   property_name = _PropertyName(proto_field_name)
402
403   def getter(self):
404     field_value = self._fields.get(field)
405     if field_value is None:
406       # Construct a new object to represent this field.
407       field_value = field._default_constructor(self)
408
409       # Atomically check if another thread has preempted us and, if not, swap
410       # in the new object we just created.  If someone has preempted us, we
411       # take that object and discard ours.
412       # WARNING:  We are relying on setdefault() being atomic.  This is true
413       #   in CPython but we haven't investigated others.  This warning appears
414       #   in several other locations in this file.
415       field_value = self._fields.setdefault(field, field_value)
416     return field_value
417   getter.__module__ = None
418   getter.__doc__ = 'Getter for %s.' % proto_field_name
419
420   # We define a setter just so we can throw an exception with a more
421   # helpful error message.
422   def setter(self, new_value):
423     raise AttributeError('Assignment not allowed to repeated field '
424                          '"%s" in protocol message object.' % proto_field_name)
425
426   doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
427   setattr(cls, property_name, property(getter, setter, doc=doc))
428
429
430 def _AddPropertiesForNonRepeatedScalarField(field, cls):
431   """Adds a public property for a nonrepeated, scalar protocol message field.
432   Clients can use this property to get and directly set the value of the field.
433   Note that when the client sets the value of a field by using this property,
434   all necessary "has" bits are set as a side-effect, and we also perform
435   type-checking.
436
437   Args:
438     field: A FieldDescriptor for this field.
439     cls: The class we're constructing.
440   """
441   proto_field_name = field.name
442   property_name = _PropertyName(proto_field_name)
443   type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
444   default_value = field.default_value
445   valid_values = set()
446
447   def getter(self):
448     # TODO(protobuf-team): This may be broken since there may not be
449     # default_value.  Combine with has_default_value somehow.
450     return self._fields.get(field, default_value)
451   getter.__module__ = None
452   getter.__doc__ = 'Getter for %s.' % proto_field_name
453   def setter(self, new_value):
454     type_checker.CheckValue(new_value)
455     self._fields[field] = new_value
456     # Check _cached_byte_size_dirty inline to improve performance, since scalar
457     # setters are called frequently.
458     if not self._cached_byte_size_dirty:
459       self._Modified()
460
461   setter.__module__ = None
462   setter.__doc__ = 'Setter for %s.' % proto_field_name
463
464   # Add a property to encapsulate the getter/setter.
465   doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
466   setattr(cls, property_name, property(getter, setter, doc=doc))
467
468
469 def _AddPropertiesForNonRepeatedCompositeField(field, cls):
470   """Adds a public property for a nonrepeated, composite protocol message field.
471   A composite field is a "group" or "message" field.
472
473   Clients can use this property to get the value of the field, but cannot
474   assign to the property directly.
475
476   Args:
477     field: A FieldDescriptor for this field.
478     cls: The class we're constructing.
479   """
480   # TODO(robinson): Remove duplication with similar method
481   # for non-repeated scalars.
482   proto_field_name = field.name
483   property_name = _PropertyName(proto_field_name)
484
485   # TODO(komarek): Can anyone explain to me why we cache the message_type this
486   # way, instead of referring to field.message_type inside of getter(self)?
487   # What if someone sets message_type later on (which makes for simpler
488   # dyanmic proto descriptor and class creation code).
489   message_type = field.message_type
490
491   def getter(self):
492     field_value = self._fields.get(field)
493     if field_value is None:
494       # Construct a new object to represent this field.
495       field_value = message_type._concrete_class()  # use field.message_type?
496       field_value._SetListener(self._listener_for_children)
497
498       # Atomically check if another thread has preempted us and, if not, swap
499       # in the new object we just created.  If someone has preempted us, we
500       # take that object and discard ours.
501       # WARNING:  We are relying on setdefault() being atomic.  This is true
502       #   in CPython but we haven't investigated others.  This warning appears
503       #   in several other locations in this file.
504       field_value = self._fields.setdefault(field, field_value)
505     return field_value
506   getter.__module__ = None
507   getter.__doc__ = 'Getter for %s.' % proto_field_name
508
509   # We define a setter just so we can throw an exception with a more
510   # helpful error message.
511   def setter(self, new_value):
512     raise AttributeError('Assignment not allowed to composite field '
513                          '"%s" in protocol message object.' % proto_field_name)
514
515   # Add a property to encapsulate the getter.
516   doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
517   setattr(cls, property_name, property(getter, setter, doc=doc))
518
519
520 def _AddPropertiesForExtensions(descriptor, cls):
521   """Adds properties for all fields in this protocol message type."""
522   extension_dict = descriptor.extensions_by_name
523   for extension_name, extension_field in extension_dict.iteritems():
524     constant_name = extension_name.upper() + "_FIELD_NUMBER"
525     setattr(cls, constant_name, extension_field.number)
526
527
528 def _AddStaticMethods(cls):
529   # TODO(robinson): This probably needs to be thread-safe(?)
530   def RegisterExtension(extension_handle):
531     extension_handle.containing_type = cls.DESCRIPTOR
532     _AttachFieldHelpers(cls, extension_handle)
533
534     # Try to insert our extension, failing if an extension with the same number
535     # already exists.
536     actual_handle = cls._extensions_by_number.setdefault(
537         extension_handle.number, extension_handle)
538     if actual_handle is not extension_handle:
539       raise AssertionError(
540           'Extensions "%s" and "%s" both try to extend message type "%s" with '
541           'field number %d.' %
542           (extension_handle.full_name, actual_handle.full_name,
543            cls.DESCRIPTOR.full_name, extension_handle.number))
544
545     cls._extensions_by_name[extension_handle.full_name] = extension_handle
546
547     handle = extension_handle  # avoid line wrapping
548     if _IsMessageSetExtension(handle):
549       # MessageSet extension.  Also register under type name.
550       cls._extensions_by_name[
551           extension_handle.message_type.full_name] = extension_handle
552
553   cls.RegisterExtension = staticmethod(RegisterExtension)
554
555   def FromString(s):
556     message = cls()
557     message.MergeFromString(s)
558     return message
559   cls.FromString = staticmethod(FromString)
560
561
562 def _IsPresent(item):
563   """Given a (FieldDescriptor, value) tuple from _fields, return true if the
564   value should be included in the list returned by ListFields()."""
565
566   if item[0].label == _FieldDescriptor.LABEL_REPEATED:
567     return bool(item[1])
568   elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
569     return item[1]._is_present_in_parent
570   else:
571     return True
572
573
574 def _AddListFieldsMethod(message_descriptor, cls):
575   """Helper for _AddMessageMethods()."""
576
577   def ListFields(self):
578     all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
579     all_fields.sort(key = lambda item: item[0].number)
580     return all_fields
581
582   cls.ListFields = ListFields
583
584
585 def _AddHasFieldMethod(message_descriptor, cls):
586   """Helper for _AddMessageMethods()."""
587
588   singular_fields = {}
589   for field in message_descriptor.fields:
590     if field.label != _FieldDescriptor.LABEL_REPEATED:
591       singular_fields[field.name] = field
592
593   def HasField(self, field_name):
594     try:
595       field = singular_fields[field_name]
596     except KeyError:
597       raise ValueError(
598           'Protocol message has no singular "%s" field.' % field_name)
599
600     if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
601       value = self._fields.get(field)
602       return value is not None and value._is_present_in_parent
603     else:
604       return field in self._fields
605   cls.HasField = HasField
606
607
608 def _AddClearFieldMethod(message_descriptor, cls):
609   """Helper for _AddMessageMethods()."""
610   def ClearField(self, field_name):
611     try:
612       field = message_descriptor.fields_by_name[field_name]
613     except KeyError:
614       raise ValueError('Protocol message has no "%s" field.' % field_name)
615
616     if field in self._fields:
617       # Note:  If the field is a sub-message, its listener will still point
618       #   at us.  That's fine, because the worst than can happen is that it
619       #   will call _Modified() and invalidate our byte size.  Big deal.
620       del self._fields[field]
621
622     # Always call _Modified() -- even if nothing was changed, this is
623     # a mutating method, and thus calling it should cause the field to become
624     # present in the parent message.
625     self._Modified()
626
627   cls.ClearField = ClearField
628
629
630 def _AddClearExtensionMethod(cls):
631   """Helper for _AddMessageMethods()."""
632   def ClearExtension(self, extension_handle):
633     _VerifyExtensionHandle(self, extension_handle)
634
635     # Similar to ClearField(), above.
636     if extension_handle in self._fields:
637       del self._fields[extension_handle]
638     self._Modified()
639   cls.ClearExtension = ClearExtension
640
641
642 def _AddClearMethod(message_descriptor, cls):
643   """Helper for _AddMessageMethods()."""
644   def Clear(self):
645     # Clear fields.
646     self._fields = {}
647     self._unknown_fields = ()
648     self._Modified()
649   cls.Clear = Clear
650
651
652 def _AddHasExtensionMethod(cls):
653   """Helper for _AddMessageMethods()."""
654   def HasExtension(self, extension_handle):
655     _VerifyExtensionHandle(self, extension_handle)
656     if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
657       raise KeyError('"%s" is repeated.' % extension_handle.full_name)
658
659     if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
660       value = self._fields.get(extension_handle)
661       return value is not None and value._is_present_in_parent
662     else:
663       return extension_handle in self._fields
664   cls.HasExtension = HasExtension
665
666
667 def _AddEqualsMethod(message_descriptor, cls):
668   """Helper for _AddMessageMethods()."""
669   def __eq__(self, other):
670     if (not isinstance(other, message_mod.Message) or
671         other.DESCRIPTOR != self.DESCRIPTOR):
672       return False
673
674     if self is other:
675       return True
676
677     if not self.ListFields() == other.ListFields():
678       return False
679
680     # Sort unknown fields because their order shouldn't affect equality test.
681     unknown_fields = list(self._unknown_fields)
682     unknown_fields.sort()
683     other_unknown_fields = list(other._unknown_fields)
684     other_unknown_fields.sort()
685
686     return unknown_fields == other_unknown_fields
687
688   cls.__eq__ = __eq__
689
690
691 def _AddStrMethod(message_descriptor, cls):
692   """Helper for _AddMessageMethods()."""
693   def __str__(self):
694     return text_format.MessageToString(self)
695   cls.__str__ = __str__
696
697
698 def _AddUnicodeMethod(unused_message_descriptor, cls):
699   """Helper for _AddMessageMethods()."""
700
701   def __unicode__(self):
702     return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
703   cls.__unicode__ = __unicode__
704
705
706 def _AddSetListenerMethod(cls):
707   """Helper for _AddMessageMethods()."""
708   def SetListener(self, listener):
709     if listener is None:
710       self._listener = message_listener_mod.NullMessageListener()
711     else:
712       self._listener = listener
713   cls._SetListener = SetListener
714
715
716 def _BytesForNonRepeatedElement(value, field_number, field_type):
717   """Returns the number of bytes needed to serialize a non-repeated element.
718   The returned byte count includes space for tag information and any
719   other additional space associated with serializing value.
720
721   Args:
722     value: Value we're serializing.
723     field_number: Field number of this value.  (Since the field number
724       is stored as part of a varint-encoded tag, this has an impact
725       on the total bytes required to serialize the value).
726     field_type: The type of the field.  One of the TYPE_* constants
727       within FieldDescriptor.
728   """
729   try:
730     fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
731     return fn(field_number, value)
732   except KeyError:
733     raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
734
735
736 def _AddByteSizeMethod(message_descriptor, cls):
737   """Helper for _AddMessageMethods()."""
738
739   def ByteSize(self):
740     if not self._cached_byte_size_dirty:
741       return self._cached_byte_size
742
743     size = 0
744     for field_descriptor, field_value in self.ListFields():
745       size += field_descriptor._sizer(field_value)
746
747     for tag_bytes, value_bytes in self._unknown_fields:
748       size += len(tag_bytes) + len(value_bytes)
749
750     self._cached_byte_size = size
751     self._cached_byte_size_dirty = False
752     self._listener_for_children.dirty = False
753     return size
754
755   cls.ByteSize = ByteSize
756
757
758 def _AddSerializeToStringMethod(message_descriptor, cls):
759   """Helper for _AddMessageMethods()."""
760
761   def SerializeToString(self):
762     # Check if the message has all of its required fields set.
763     errors = []
764     if not self.IsInitialized():
765       raise message_mod.EncodeError(
766           'Message %s is missing required fields: %s' % (
767           self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
768     return self.SerializePartialToString()
769   cls.SerializeToString = SerializeToString
770
771
772 def _AddSerializePartialToStringMethod(message_descriptor, cls):
773   """Helper for _AddMessageMethods()."""
774
775   def SerializePartialToString(self):
776     out = StringIO()
777     self._InternalSerialize(out.write)
778     return out.getvalue()
779   cls.SerializePartialToString = SerializePartialToString
780
781   def InternalSerialize(self, write_bytes):
782     for field_descriptor, field_value in self.ListFields():
783       field_descriptor._encoder(write_bytes, field_value)
784     for tag_bytes, value_bytes in self._unknown_fields:
785       write_bytes(tag_bytes)
786       write_bytes(value_bytes)
787   cls._InternalSerialize = InternalSerialize
788
789
790 def _AddMergeFromStringMethod(message_descriptor, cls):
791   """Helper for _AddMessageMethods()."""
792   def MergeFromString(self, serialized):
793     length = len(serialized)
794     try:
795       if self._InternalParse(serialized, 0, length) != length:
796         # The only reason _InternalParse would return early is if it
797         # encountered an end-group tag.
798         raise message_mod.DecodeError('Unexpected end-group tag.')
799     except IndexError:
800       raise message_mod.DecodeError('Truncated message.')
801     except struct.error, e:
802       raise message_mod.DecodeError(e)
803     return length   # Return this for legacy reasons.
804   cls.MergeFromString = MergeFromString
805
806   local_ReadTag = decoder.ReadTag
807   local_SkipField = decoder.SkipField
808   decoders_by_tag = cls._decoders_by_tag
809
810   def InternalParse(self, buffer, pos, end):
811     self._Modified()
812     field_dict = self._fields
813     unknown_field_list = self._unknown_fields
814     while pos != end:
815       (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
816       field_decoder = decoders_by_tag.get(tag_bytes)
817       if field_decoder is None:
818         value_start_pos = new_pos
819         new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
820         if new_pos == -1:
821           return pos
822         if not unknown_field_list:
823           unknown_field_list = self._unknown_fields = []
824         unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
825         pos = new_pos
826       else:
827         pos = field_decoder(buffer, new_pos, end, self, field_dict)
828     return pos
829   cls._InternalParse = InternalParse
830
831
832 def _AddIsInitializedMethod(message_descriptor, cls):
833   """Adds the IsInitialized and FindInitializationError methods to the
834   protocol message class."""
835
836   required_fields = [field for field in message_descriptor.fields
837                            if field.label == _FieldDescriptor.LABEL_REQUIRED]
838
839   def IsInitialized(self, errors=None):
840     """Checks if all required fields of a message are set.
841
842     Args:
843       errors:  A list which, if provided, will be populated with the field
844                paths of all missing required fields.
845
846     Returns:
847       True iff the specified message has all required fields set.
848     """
849
850     # Performance is critical so we avoid HasField() and ListFields().
851
852     for field in required_fields:
853       if (field not in self._fields or
854           (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
855            not self._fields[field]._is_present_in_parent)):
856         if errors is not None:
857           errors.extend(self.FindInitializationErrors())
858         return False
859
860     for field, value in self._fields.iteritems():
861       if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
862         if field.label == _FieldDescriptor.LABEL_REPEATED:
863           for element in value:
864             if not element.IsInitialized():
865               if errors is not None:
866                 errors.extend(self.FindInitializationErrors())
867               return False
868         elif value._is_present_in_parent and not value.IsInitialized():
869           if errors is not None:
870             errors.extend(self.FindInitializationErrors())
871           return False
872
873     return True
874
875   cls.IsInitialized = IsInitialized
876
877   def FindInitializationErrors(self):
878     """Finds required fields which are not initialized.
879
880     Returns:
881       A list of strings.  Each string is a path to an uninitialized field from
882       the top-level message, e.g. "foo.bar[5].baz".
883     """
884
885     errors = []  # simplify things
886
887     for field in required_fields:
888       if not self.HasField(field.name):
889         errors.append(field.name)
890
891     for field, value in self.ListFields():
892       if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
893         if field.is_extension:
894           name = "(%s)" % field.full_name
895         else:
896           name = field.name
897
898         if field.label == _FieldDescriptor.LABEL_REPEATED:
899           for i in xrange(len(value)):
900             element = value[i]
901             prefix = "%s[%d]." % (name, i)
902             sub_errors = element.FindInitializationErrors()
903             errors += [ prefix + error for error in sub_errors ]
904         else:
905           prefix = name + "."
906           sub_errors = value.FindInitializationErrors()
907           errors += [ prefix + error for error in sub_errors ]
908
909     return errors
910
911   cls.FindInitializationErrors = FindInitializationErrors
912
913
914 def _AddMergeFromMethod(cls):
915   LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
916   CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
917
918   def MergeFrom(self, msg):
919     if not isinstance(msg, cls):
920       raise TypeError(
921           "Parameter to MergeFrom() must be instance of same class: "
922           "expected %s got %s." % (cls.__name__, type(msg).__name__))
923
924     assert msg is not self
925     self._Modified()
926
927     fields = self._fields
928
929     for field, value in msg._fields.iteritems():
930       if field.label == LABEL_REPEATED:
931         field_value = fields.get(field)
932         if field_value is None:
933           # Construct a new object to represent this field.
934           field_value = field._default_constructor(self)
935           fields[field] = field_value
936         field_value.MergeFrom(value)
937       elif field.cpp_type == CPPTYPE_MESSAGE:
938         if value._is_present_in_parent:
939           field_value = fields.get(field)
940           if field_value is None:
941             # Construct a new object to represent this field.
942             field_value = field._default_constructor(self)
943             fields[field] = field_value
944           field_value.MergeFrom(value)
945       else:
946         self._fields[field] = value
947
948     if msg._unknown_fields:
949       if not self._unknown_fields:
950         self._unknown_fields = []
951       self._unknown_fields.extend(msg._unknown_fields)
952
953   cls.MergeFrom = MergeFrom
954
955
956 def _AddMessageMethods(message_descriptor, cls):
957   """Adds implementations of all Message methods to cls."""
958   _AddListFieldsMethod(message_descriptor, cls)
959   _AddHasFieldMethod(message_descriptor, cls)
960   _AddClearFieldMethod(message_descriptor, cls)
961   if message_descriptor.is_extendable:
962     _AddClearExtensionMethod(cls)
963     _AddHasExtensionMethod(cls)
964   _AddClearMethod(message_descriptor, cls)
965   _AddEqualsMethod(message_descriptor, cls)
966   _AddStrMethod(message_descriptor, cls)
967   _AddUnicodeMethod(message_descriptor, cls)
968   _AddSetListenerMethod(cls)
969   _AddByteSizeMethod(message_descriptor, cls)
970   _AddSerializeToStringMethod(message_descriptor, cls)
971   _AddSerializePartialToStringMethod(message_descriptor, cls)
972   _AddMergeFromStringMethod(message_descriptor, cls)
973   _AddIsInitializedMethod(message_descriptor, cls)
974   _AddMergeFromMethod(cls)
975
976
977 def _AddPrivateHelperMethods(cls):
978   """Adds implementation of private helper methods to cls."""
979
980   def Modified(self):
981     """Sets the _cached_byte_size_dirty bit to true,
982     and propagates this to our listener iff this was a state change.
983     """
984
985     # Note:  Some callers check _cached_byte_size_dirty before calling
986     #   _Modified() as an extra optimization.  So, if this method is ever
987     #   changed such that it does stuff even when _cached_byte_size_dirty is
988     #   already true, the callers need to be updated.
989     if not self._cached_byte_size_dirty:
990       self._cached_byte_size_dirty = True
991       self._listener_for_children.dirty = True
992       self._is_present_in_parent = True
993       self._listener.Modified()
994
995   cls._Modified = Modified
996   cls.SetInParent = Modified
997
998
999 class _Listener(object):
1000
1001   """MessageListener implementation that a parent message registers with its
1002   child message.
1003
1004   In order to support semantics like:
1005
1006     foo.bar.baz.qux = 23
1007     assert foo.HasField('bar')
1008
1009   ...child objects must have back references to their parents.
1010   This helper class is at the heart of this support.
1011   """
1012
1013   def __init__(self, parent_message):
1014     """Args:
1015       parent_message: The message whose _Modified() method we should call when
1016         we receive Modified() messages.
1017     """
1018     # This listener establishes a back reference from a child (contained) object
1019     # to its parent (containing) object.  We make this a weak reference to avoid
1020     # creating cyclic garbage when the client finishes with the 'parent' object
1021     # in the tree.
1022     if isinstance(parent_message, weakref.ProxyType):
1023       self._parent_message_weakref = parent_message
1024     else:
1025       self._parent_message_weakref = weakref.proxy(parent_message)
1026
1027     # As an optimization, we also indicate directly on the listener whether
1028     # or not the parent message is dirty.  This way we can avoid traversing
1029     # up the tree in the common case.
1030     self.dirty = False
1031
1032   def Modified(self):
1033     if self.dirty:
1034       return
1035     try:
1036       # Propagate the signal to our parents iff this is the first field set.
1037       self._parent_message_weakref._Modified()
1038     except ReferenceError:
1039       # We can get here if a client has kept a reference to a child object,
1040       # and is now setting a field on it, but the child's parent has been
1041       # garbage-collected.  This is not an error.
1042       pass
1043
1044
1045 # TODO(robinson): Move elsewhere?  This file is getting pretty ridiculous...
1046 # TODO(robinson): Unify error handling of "unknown extension" crap.
1047 # TODO(robinson): Support iteritems()-style iteration over all
1048 # extensions with the "has" bits turned on?
1049 class _ExtensionDict(object):
1050
1051   """Dict-like container for supporting an indexable "Extensions"
1052   field on proto instances.
1053
1054   Note that in all cases we expect extension handles to be
1055   FieldDescriptors.
1056   """
1057
1058   def __init__(self, extended_message):
1059     """extended_message: Message instance for which we are the Extensions dict.
1060     """
1061
1062     self._extended_message = extended_message
1063
1064   def __getitem__(self, extension_handle):
1065     """Returns the current value of the given extension handle."""
1066
1067     _VerifyExtensionHandle(self._extended_message, extension_handle)
1068
1069     result = self._extended_message._fields.get(extension_handle)
1070     if result is not None:
1071       return result
1072
1073     if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1074       result = extension_handle._default_constructor(self._extended_message)
1075     elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1076       result = extension_handle.message_type._concrete_class()
1077       try:
1078         result._SetListener(self._extended_message._listener_for_children)
1079       except ReferenceError:
1080         pass
1081     else:
1082       # Singular scalar -- just return the default without inserting into the
1083       # dict.
1084       return extension_handle.default_value
1085
1086     # Atomically check if another thread has preempted us and, if not, swap
1087     # in the new object we just created.  If someone has preempted us, we
1088     # take that object and discard ours.
1089     # WARNING:  We are relying on setdefault() being atomic.  This is true
1090     #   in CPython but we haven't investigated others.  This warning appears
1091     #   in several other locations in this file.
1092     result = self._extended_message._fields.setdefault(
1093         extension_handle, result)
1094
1095     return result
1096
1097   def __eq__(self, other):
1098     if not isinstance(other, self.__class__):
1099       return False
1100
1101     my_fields = self._extended_message.ListFields()
1102     other_fields = other._extended_message.ListFields()
1103
1104     # Get rid of non-extension fields.
1105     my_fields    = [ field for field in my_fields    if field.is_extension ]
1106     other_fields = [ field for field in other_fields if field.is_extension ]
1107
1108     return my_fields == other_fields
1109
1110   def __ne__(self, other):
1111     return not self == other
1112
1113   def __hash__(self):
1114     raise TypeError('unhashable object')
1115
1116   # Note that this is only meaningful for non-repeated, scalar extension
1117   # fields.  Note also that we may have to call _Modified() when we do
1118   # successfully set a field this way, to set any necssary "has" bits in the
1119   # ancestors of the extended message.
1120   def __setitem__(self, extension_handle, value):
1121     """If extension_handle specifies a non-repeated, scalar extension
1122     field, sets the value of that field.
1123     """
1124
1125     _VerifyExtensionHandle(self._extended_message, extension_handle)
1126
1127     if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1128         extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1129       raise TypeError(
1130           'Cannot assign to extension "%s" because it is a repeated or '
1131           'composite type.' % extension_handle.full_name)
1132
1133     # It's slightly wasteful to lookup the type checker each time,
1134     # but we expect this to be a vanishingly uncommon case anyway.
1135     type_checker = type_checkers.GetTypeChecker(
1136         extension_handle.cpp_type, extension_handle.type)
1137     type_checker.CheckValue(value)
1138     self._extended_message._fields[extension_handle] = value
1139     self._extended_message._Modified()
1140
1141   def _FindExtensionByName(self, name):
1142     """Tries to find a known extension with the specified name.
1143
1144     Args:
1145       name: Extension full name.
1146
1147     Returns:
1148       Extension field descriptor.
1149     """
1150     return self._extended_message._extensions_by_name.get(name, None)