1 # Protocol Buffers - Google's data interchange format
2 # Copyright 2008 Google Inc. All rights reserved.
3 # https://developers.google.com/protocol-buffers/
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are
9 # * Redistributions of source code must retain the above copyright
10 # notice, this list of conditions and the following disclaimer.
11 # * Redistributions in binary form must reproduce the above
12 # copyright notice, this list of conditions and the following disclaimer
13 # in the documentation and/or other materials provided with the
15 # * Neither the name of Google Inc. nor the names of its
16 # contributors may be used to endorse or promote products derived from
17 # this software without specific prior written permission.
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 """Contains helper functions used to create protocol message classes from
32 Descriptor objects at runtime backed by the protocol buffer C++ API.
35 __author__ = 'petar@google.com (Petar Petrov)'
39 from google.protobuf.internal import _net_proto2___python
40 from google.protobuf.internal import enum_type_wrapper
41 from google.protobuf import message
44 _LABEL_REPEATED = _net_proto2___python.LABEL_REPEATED
45 _LABEL_OPTIONAL = _net_proto2___python.LABEL_OPTIONAL
46 _CPPTYPE_MESSAGE = _net_proto2___python.CPPTYPE_MESSAGE
47 _TYPE_MESSAGE = _net_proto2___python.TYPE_MESSAGE
50 def GetDescriptorPool():
51 """Creates a new DescriptorPool C++ object."""
52 return _net_proto2___python.NewCDescriptorPool()
55 _pool = GetDescriptorPool()
58 def GetFieldDescriptor(full_field_name):
59 """Searches for a field descriptor given a full field name."""
60 return _pool.FindFieldByName(full_field_name)
63 def BuildFile(content):
64 """Registers a new proto file in the underlying C++ descriptor pool."""
65 _net_proto2___python.BuildFile(content)
68 def GetExtensionDescriptor(full_extension_name):
69 """Searches for extension descriptor given a full field name."""
70 return _pool.FindExtensionByName(full_extension_name)
73 def NewCMessage(full_message_name):
74 """Creates a new C++ protocol message by its name."""
75 return _net_proto2___python.NewCMessage(full_message_name)
78 def ScalarProperty(cdescriptor):
79 """Returns a scalar property for the given descriptor."""
82 return self._cmsg.GetScalar(cdescriptor)
84 def Setter(self, value):
85 self._cmsg.SetScalar(cdescriptor, value)
87 return property(Getter, Setter)
90 def CompositeProperty(cdescriptor, message_type):
91 """Returns a Python property the given composite field."""
94 sub_message = self._composite_fields.get(cdescriptor.name, None)
95 if sub_message is None:
96 cmessage = self._cmsg.NewSubMessage(cdescriptor)
97 sub_message = message_type._concrete_class(__cmessage=cmessage)
98 self._composite_fields[cdescriptor.name] = sub_message
101 return property(Getter)
104 class RepeatedScalarContainer(object):
105 """Container for repeated scalar fields."""
107 __slots__ = ['_message', '_cfield_descriptor', '_cmsg']
109 def __init__(self, msg, cfield_descriptor):
111 self._cmsg = msg._cmsg
112 self._cfield_descriptor = cfield_descriptor
114 def append(self, value):
115 self._cmsg.AddRepeatedScalar(
116 self._cfield_descriptor, value)
118 def extend(self, sequence):
119 for element in sequence:
122 def insert(self, key, value):
123 values = self[slice(None, None, None)]
124 values.insert(key, value)
125 self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
127 def remove(self, value):
128 values = self[slice(None, None, None)]
130 self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
132 def __setitem__(self, key, value):
133 values = self[slice(None, None, None)]
135 self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
137 def __getitem__(self, key):
138 return self._cmsg.GetRepeatedScalar(self._cfield_descriptor, key)
140 def __delitem__(self, key):
141 self._cmsg.DeleteRepeatedField(self._cfield_descriptor, key)
144 return len(self[slice(None, None, None)])
146 def __eq__(self, other):
149 if not operator.isSequenceType(other):
151 'Can only compare repeated scalar fields against sequences.')
152 # We are presumably comparing against some other sequence type.
153 return other == self[slice(None, None, None)]
155 def __ne__(self, other):
156 return not self == other
159 raise TypeError('unhashable object')
161 def sort(self, *args, **kwargs):
162 # Maintain compatibility with the previous interface.
163 if 'sort_function' in kwargs:
164 kwargs['cmp'] = kwargs.pop('sort_function')
165 self._cmsg.AssignRepeatedScalar(self._cfield_descriptor,
166 sorted(self, *args, **kwargs))
169 def RepeatedScalarProperty(cdescriptor):
170 """Returns a Python property the given repeated scalar field."""
173 container = self._composite_fields.get(cdescriptor.name, None)
174 if container is None:
175 container = RepeatedScalarContainer(self, cdescriptor)
176 self._composite_fields[cdescriptor.name] = container
179 def Setter(self, new_value):
180 raise AttributeError('Assignment not allowed to repeated field '
181 '"%s" in protocol message object.' % cdescriptor.name)
183 doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name
184 return property(Getter, Setter, doc=doc)
187 class RepeatedCompositeContainer(object):
188 """Container for repeated composite fields."""
190 __slots__ = ['_message', '_subclass', '_cfield_descriptor', '_cmsg']
192 def __init__(self, msg, cfield_descriptor, subclass):
194 self._cmsg = msg._cmsg
195 self._subclass = subclass
196 self._cfield_descriptor = cfield_descriptor
198 def add(self, **kwargs):
199 cmessage = self._cmsg.AddMessage(self._cfield_descriptor)
200 return self._subclass(__cmessage=cmessage, __owner=self._message, **kwargs)
202 def extend(self, elem_seq):
203 """Extends by appending the given sequence of elements of the same type
204 as this one, copying each individual message.
206 for message in elem_seq:
207 self.add().MergeFrom(message)
209 def remove(self, value):
210 # TODO(protocol-devel): This is inefficient as it needs to generate a
211 # message pointer for each message only to do index(). Move this to a C++
212 # extension function.
213 self.__delitem__(self[slice(None, None, None)].index(value))
215 def MergeFrom(self, other):
216 for message in other[:]:
217 self.add().MergeFrom(message)
219 def __getitem__(self, key):
220 cmessages = self._cmsg.GetRepeatedMessage(
221 self._cfield_descriptor, key)
222 subclass = self._subclass
223 if not isinstance(cmessages, list):
224 return subclass(__cmessage=cmessages, __owner=self._message)
226 return [subclass(__cmessage=m, __owner=self._message) for m in cmessages]
228 def __delitem__(self, key):
229 self._cmsg.DeleteRepeatedField(
230 self._cfield_descriptor, key)
233 return self._cmsg.FieldLength(self._cfield_descriptor)
235 def __eq__(self, other):
236 """Compares the current instance with another one."""
239 if not isinstance(other, self.__class__):
240 raise TypeError('Can only compare repeated composite fields against '
241 'other repeated composite fields.')
242 messages = self[slice(None, None, None)]
243 other_messages = other[slice(None, None, None)]
244 return messages == other_messages
247 raise TypeError('unhashable object')
249 def sort(self, cmp=None, key=None, reverse=False, **kwargs):
250 # Maintain compatibility with the old interface.
251 if cmp is None and 'sort_function' in kwargs:
252 cmp = kwargs.pop('sort_function')
254 # The cmp function, if provided, is passed the results of the key function,
255 # so we only need to wrap one of them.
257 index_key = self.__getitem__
259 index_key = lambda i: key(self[i])
261 # Sort the list of current indexes by the underlying object.
262 indexes = range(len(self))
263 indexes.sort(cmp=cmp, key=index_key, reverse=reverse)
265 # Apply the transposition.
266 for dest, src in enumerate(indexes):
269 self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src)
270 # Don't swap the same value twice.
274 def RepeatedCompositeProperty(cdescriptor, message_type):
275 """Returns a Python property for the given repeated composite field."""
278 container = self._composite_fields.get(cdescriptor.name, None)
279 if container is None:
280 container = RepeatedCompositeContainer(
281 self, cdescriptor, message_type._concrete_class)
282 self._composite_fields[cdescriptor.name] = container
285 def Setter(self, new_value):
286 raise AttributeError('Assignment not allowed to repeated field '
287 '"%s" in protocol message object.' % cdescriptor.name)
289 doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name
290 return property(Getter, Setter, doc=doc)
293 class ExtensionDict(object):
294 """Extension dictionary added to each protocol message."""
296 def __init__(self, msg):
298 self._cmsg = msg._cmsg
301 def __setitem__(self, extension, value):
302 from google.protobuf import descriptor
303 if not isinstance(extension, descriptor.FieldDescriptor):
304 raise KeyError('Bad extension %r.' % (extension,))
305 cdescriptor = extension._cdescriptor
306 if (cdescriptor.label != _LABEL_OPTIONAL or
307 cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
308 raise TypeError('Extension %r is repeated and/or a composite type.' % (
309 extension.full_name,))
310 self._cmsg.SetScalar(cdescriptor, value)
311 self._values[extension] = value
313 def __getitem__(self, extension):
314 from google.protobuf import descriptor
315 if not isinstance(extension, descriptor.FieldDescriptor):
316 raise KeyError('Bad extension %r.' % (extension,))
318 cdescriptor = extension._cdescriptor
319 if (cdescriptor.label != _LABEL_REPEATED and
320 cdescriptor.cpp_type != _CPPTYPE_MESSAGE):
321 return self._cmsg.GetScalar(cdescriptor)
323 ext = self._values.get(extension, None)
327 ext = self._CreateNewHandle(extension)
328 self._values[extension] = ext
331 def ClearExtension(self, extension):
332 from google.protobuf import descriptor
333 if not isinstance(extension, descriptor.FieldDescriptor):
334 raise KeyError('Bad extension %r.' % (extension,))
335 self._cmsg.ClearFieldByDescriptor(extension._cdescriptor)
336 if extension in self._values:
337 del self._values[extension]
339 def HasExtension(self, extension):
340 from google.protobuf import descriptor
341 if not isinstance(extension, descriptor.FieldDescriptor):
342 raise KeyError('Bad extension %r.' % (extension,))
343 return self._cmsg.HasFieldByDescriptor(extension._cdescriptor)
345 def _FindExtensionByName(self, name):
346 """Tries to find a known extension with the specified name.
349 name: Extension full name.
352 Extension field descriptor.
354 return self._message._extensions_by_name.get(name, None)
356 def _CreateNewHandle(self, extension):
357 cdescriptor = extension._cdescriptor
358 if (cdescriptor.label != _LABEL_REPEATED and
359 cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
360 cmessage = self._cmsg.NewSubMessage(cdescriptor)
361 return extension.message_type._concrete_class(__cmessage=cmessage)
363 if cdescriptor.label == _LABEL_REPEATED:
364 if cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
365 return RepeatedCompositeContainer(
366 self._message, cdescriptor, extension.message_type._concrete_class)
368 return RepeatedScalarContainer(self._message, cdescriptor)
369 # This shouldn't happen!
374 def NewMessage(bases, message_descriptor, dictionary):
375 """Creates a new protocol message *class*."""
376 _AddClassAttributesForNestedExtensions(message_descriptor, dictionary)
377 _AddEnumValues(message_descriptor, dictionary)
378 _AddDescriptors(message_descriptor, dictionary)
382 def InitMessage(message_descriptor, cls):
383 """Constructs a new message instance (called before instance's __init__)."""
384 cls._extensions_by_name = {}
385 _AddInitMethod(message_descriptor, cls)
386 _AddMessageMethods(message_descriptor, cls)
387 _AddPropertiesForExtensions(message_descriptor, cls)
388 copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
391 def _AddDescriptors(message_descriptor, dictionary):
392 """Sets up a new protocol message class dictionary.
395 message_descriptor: A Descriptor instance describing this message type.
396 dictionary: Class dictionary to which we'll add a '__slots__' entry.
398 dictionary['__descriptors'] = {}
399 for field in message_descriptor.fields:
400 dictionary['__descriptors'][field.name] = GetFieldDescriptor(
403 dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [
404 '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS']
407 def _AddEnumValues(message_descriptor, dictionary):
408 """Sets class-level attributes for all enum fields defined in this message.
411 message_descriptor: Descriptor object for this message type.
412 dictionary: Class dictionary that should be populated.
414 for enum_type in message_descriptor.enum_types:
415 dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type)
416 for enum_value in enum_type.values:
417 dictionary[enum_value.name] = enum_value.number
420 def _AddClassAttributesForNestedExtensions(message_descriptor, dictionary):
421 """Adds class attributes for the nested extensions."""
422 extension_dict = message_descriptor.extensions_by_name
423 for extension_name, extension_field in extension_dict.iteritems():
424 assert extension_name not in dictionary
425 dictionary[extension_name] = extension_field
428 def _AddInitMethod(message_descriptor, cls):
429 """Adds an __init__ method to cls."""
431 # Create and attach message field properties to the message class.
432 # This can be done just once per message class, since property setters and
433 # getters are passed the message instance.
434 # This makes message instantiation extremely fast, and at the same time it
435 # doesn't require the creation of property objects for each message instance,
436 # which saves a lot of memory.
437 for field in message_descriptor.fields:
438 field_cdescriptor = cls.__descriptors[field.name]
439 if field.label == _LABEL_REPEATED:
440 if field.cpp_type == _CPPTYPE_MESSAGE:
441 value = RepeatedCompositeProperty(field_cdescriptor, field.message_type)
443 value = RepeatedScalarProperty(field_cdescriptor)
444 elif field.cpp_type == _CPPTYPE_MESSAGE:
445 value = CompositeProperty(field_cdescriptor, field.message_type)
447 value = ScalarProperty(field_cdescriptor)
448 setattr(cls, field.name, value)
450 # Attach a constant with the field number.
451 constant_name = field.name.upper() + '_FIELD_NUMBER'
452 setattr(cls, constant_name, field.number)
454 def Init(self, **kwargs):
455 """Message constructor."""
456 cmessage = kwargs.pop('__cmessage', None)
458 self._cmsg = cmessage
460 self._cmsg = NewCMessage(message_descriptor.full_name)
462 # Keep a reference to the owner, as the owner keeps a reference to the
463 # underlying protocol buffer message.
464 owner = kwargs.pop('__owner', None)
468 if message_descriptor.is_extendable:
469 self.Extensions = ExtensionDict(self)
471 # Reference counting in the C++ code is broken and depends on
472 # the Extensions reference to keep this object alive during unit
473 # tests (see b/4856052). Remove this once b/4945904 is fixed.
474 self._HACK_REFCOUNTS = self
475 self._composite_fields = {}
477 for field_name, field_value in kwargs.iteritems():
478 field_cdescriptor = self.__descriptors.get(field_name, None)
479 if not field_cdescriptor:
480 raise ValueError('Protocol message has no "%s" field.' % field_name)
481 if field_cdescriptor.label == _LABEL_REPEATED:
482 if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
483 field_name = getattr(self, field_name)
484 for val in field_value:
485 field_name.add().MergeFrom(val)
487 getattr(self, field_name).extend(field_value)
488 elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
489 getattr(self, field_name).MergeFrom(field_value)
491 setattr(self, field_name, field_value)
493 Init.__module__ = None
498 def _IsMessageSetExtension(field):
499 """Checks if a field is a message set extension."""
500 return (field.is_extension and
501 field.containing_type.has_options and
502 field.containing_type.GetOptions().message_set_wire_format and
503 field.type == _TYPE_MESSAGE and
504 field.message_type == field.extension_scope and
505 field.label == _LABEL_OPTIONAL)
508 def _AddMessageMethods(message_descriptor, cls):
509 """Adds the methods to a protocol message class."""
510 if message_descriptor.is_extendable:
512 def ClearExtension(self, extension):
513 self.Extensions.ClearExtension(extension)
515 def HasExtension(self, extension):
516 return self.Extensions.HasExtension(extension)
518 def HasField(self, field_name):
519 return self._cmsg.HasField(field_name)
521 def ClearField(self, field_name):
522 child_cmessage = None
523 if field_name in self._composite_fields:
524 child_field = self._composite_fields[field_name]
525 del self._composite_fields[field_name]
527 child_cdescriptor = self.__descriptors[field_name]
528 # TODO(anuraag): Support clearing repeated message fields as well.
529 if (child_cdescriptor.label != _LABEL_REPEATED and
530 child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
531 child_field._owner = None
532 child_cmessage = child_field._cmsg
534 if child_cmessage is not None:
535 self._cmsg.ClearField(field_name, child_cmessage)
537 self._cmsg.ClearField(field_name)
540 cmessages_to_release = []
541 for field_name, child_field in self._composite_fields.iteritems():
542 child_cdescriptor = self.__descriptors[field_name]
543 # TODO(anuraag): Support clearing repeated message fields as well.
544 if (child_cdescriptor.label != _LABEL_REPEATED and
545 child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
546 child_field._owner = None
547 cmessages_to_release.append((child_cdescriptor, child_field._cmsg))
548 self._composite_fields.clear()
549 self._cmsg.Clear(cmessages_to_release)
551 def IsInitialized(self, errors=None):
552 if self._cmsg.IsInitialized():
554 if errors is not None:
555 errors.extend(self.FindInitializationErrors());
558 def SerializeToString(self):
559 if not self.IsInitialized():
560 raise message.EncodeError(
561 'Message %s is missing required fields: %s' % (
562 self._cmsg.full_name, ','.join(self.FindInitializationErrors())))
563 return self._cmsg.SerializeToString()
565 def SerializePartialToString(self):
566 return self._cmsg.SerializePartialToString()
568 def ParseFromString(self, serialized):
570 self.MergeFromString(serialized)
572 def MergeFromString(self, serialized):
573 byte_size = self._cmsg.MergeFromString(serialized)
575 raise message.DecodeError('Unable to merge from string.')
578 def MergeFrom(self, msg):
579 if not isinstance(msg, cls):
581 "Parameter to MergeFrom() must be instance of same class: "
582 "expected %s got %s." % (cls.__name__, type(msg).__name__))
583 self._cmsg.MergeFrom(msg._cmsg)
585 def CopyFrom(self, msg):
586 self._cmsg.CopyFrom(msg._cmsg)
589 return self._cmsg.ByteSize()
591 def SetInParent(self):
592 return self._cmsg.SetInParent()
594 def ListFields(self):
596 field_list = self._cmsg.ListFields()
597 fields_by_name = cls.DESCRIPTOR.fields_by_name
598 for is_extension, field_name in field_list:
600 extension = cls._extensions_by_name[field_name]
601 all_fields.append((extension, self.Extensions[extension]))
603 field_descriptor = fields_by_name[field_name]
605 (field_descriptor, getattr(self, field_name)))
606 all_fields.sort(key=lambda item: item[0].number)
609 def FindInitializationErrors(self):
610 return self._cmsg.FindInitializationErrors()
613 return str(self._cmsg)
615 def __eq__(self, other):
618 if not isinstance(other, self.__class__):
620 return self.ListFields() == other.ListFields()
622 def __ne__(self, other):
623 return not self == other
626 raise TypeError('unhashable object')
628 def __unicode__(self):
629 # Lazy import to prevent circular import when text_format imports this file.
630 from google.protobuf import text_format
631 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
633 # Attach the local methods to the message class.
634 for key, value in locals().copy().iteritems():
635 if key not in ('key', 'value', '__builtins__', '__name__', '__doc__'):
636 setattr(cls, key, value)
640 def RegisterExtension(extension_handle):
641 extension_handle.containing_type = cls.DESCRIPTOR
642 cls._extensions_by_name[extension_handle.full_name] = extension_handle
644 if _IsMessageSetExtension(extension_handle):
645 # MessageSet extension. Also register under type name.
646 cls._extensions_by_name[
647 extension_handle.message_type.full_name] = extension_handle
648 cls.RegisterExtension = staticmethod(RegisterExtension)
650 def FromString(string):
652 msg.MergeFromString(string)
654 cls.FromString = staticmethod(FromString)
658 def _AddPropertiesForExtensions(message_descriptor, cls):
659 """Adds properties for all fields in this protocol message type."""
660 extension_dict = message_descriptor.extensions_by_name
661 for extension_name, extension_field in extension_dict.iteritems():
662 constant_name = extension_name.upper() + '_FIELD_NUMBER'
663 setattr(cls, constant_name, extension_field.number)