App Engine Python SDK version 1.9.12
[gae.git] / python / google / net / proto2 / python / internal / python_message.py
blob60cbcfaf7953d2fea8f1710f1808207c89523212
1 #!/usr/bin/env python
3 # Copyright 2007 Google Inc.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
26 """Contains a metaclass and helper functions used to create
27 protocol message classes from Descriptor objects at runtime.
29 Recall that a metaclass is the "type" of a class.
30 (A class is to a metaclass what an instance is to a class.)
32 In this case, we use the GeneratedProtocolMessageType metaclass
33 to inject all the useful functionality into the classes
34 output by the protocol compiler at compile-time.
36 The upshot of all this is that the real implementation
37 details for ALL pure-Python protocol buffers are *here in
38 this file*.
39 """
42 import sys
43 if sys.version_info[0] < 3:
44 try:
45 from cStringIO import StringIO as BytesIO
46 except ImportError:
47 from StringIO import StringIO as BytesIO
48 import copy_reg as copyreg
49 else:
50 from io import BytesIO
51 import copyreg
52 import struct
53 import weakref
56 from google.net.proto2.python.internal import containers
57 from google.net.proto2.python.internal import decoder
58 from google.net.proto2.python.internal import encoder
59 from google.net.proto2.python.internal import enum_type_wrapper
60 from google.net.proto2.python.internal import message_listener as message_listener_mod
61 from google.net.proto2.python.internal import type_checkers
62 from google.net.proto2.python.internal import wire_format
63 from google.net.proto2.python.public import descriptor as descriptor_mod
64 from google.net.proto2.python.public import message as message_mod
65 from google.net.proto2.python.public import text_format
67 _FieldDescriptor = descriptor_mod.FieldDescriptor
70 def NewMessage(bases, descriptor, dictionary):
71 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
72 _AddSlots(descriptor, dictionary)
73 return bases
76 def InitMessage(descriptor, cls):
77 cls._decoders_by_tag = {}
78 cls._extensions_by_name = {}
79 cls._extensions_by_number = {}
80 if (descriptor.has_options and
81 descriptor.GetOptions().message_set_wire_format):
82 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
83 decoder.MessageSetItemDecoder(cls._extensions_by_number))
86 for field in descriptor.fields:
87 _AttachFieldHelpers(cls, field)
89 _AddEnumValues(descriptor, cls)
90 _AddInitMethod(descriptor, cls)
91 _AddPropertiesForFields(descriptor, cls)
92 _AddPropertiesForExtensions(descriptor, cls)
93 _AddStaticMethods(cls)
94 _AddMessageMethods(descriptor, cls)
95 _AddPrivateHelperMethods(descriptor, cls)
96 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
107 def _PropertyName(proto_field_name):
108 """Returns the name of the public property attribute which
109 clients can use to get and (in some cases) set the value
110 of a protocol message field.
112 Args:
113 proto_field_name: The protocol message field name, exactly
114 as it appears (or would appear) in a .proto file.
133 return proto_field_name
136 def _VerifyExtensionHandle(message, extension_handle):
137 """Verify that the given extension handle is valid."""
139 if not isinstance(extension_handle, _FieldDescriptor):
140 raise KeyError('HasExtension() expects an extension handle, got: %s' %
141 extension_handle)
143 if not extension_handle.is_extension:
144 raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
146 if not extension_handle.containing_type:
147 raise KeyError('"%s" is missing a containing_type.'
148 % extension_handle.full_name)
150 if extension_handle.containing_type is not message.DESCRIPTOR:
151 raise KeyError('Extension "%s" extends message type "%s", but this '
152 'message is of type "%s".' %
153 (extension_handle.full_name,
154 extension_handle.containing_type.full_name,
155 message.DESCRIPTOR.full_name))
158 def _AddSlots(message_descriptor, dictionary):
159 """Adds a __slots__ entry to dictionary, containing the names of all valid
160 attributes for this message type.
162 Args:
163 message_descriptor: A Descriptor instance describing this message type.
164 dictionary: Class dictionary to which we'll add a '__slots__' entry.
166 dictionary['__slots__'] = ['_cached_byte_size',
167 '_cached_byte_size_dirty',
168 '_fields',
169 '_unknown_fields',
170 '_is_present_in_parent',
171 '_listener',
172 '_listener_for_children',
173 '__weakref__',
174 '_oneofs']
177 def _IsMessageSetExtension(field):
178 return (field.is_extension and
179 field.containing_type.has_options and
180 field.containing_type.GetOptions().message_set_wire_format and
181 field.type == _FieldDescriptor.TYPE_MESSAGE and
182 field.message_type == field.extension_scope and
183 field.label == _FieldDescriptor.LABEL_OPTIONAL)
186 def _AttachFieldHelpers(cls, field_descriptor):
187 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
188 is_packed = (field_descriptor.has_options and
189 field_descriptor.GetOptions().packed)
191 if _IsMessageSetExtension(field_descriptor):
192 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
193 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
194 else:
195 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
196 field_descriptor.number, is_repeated, is_packed)
197 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
198 field_descriptor.number, is_repeated, is_packed)
200 field_descriptor._encoder = field_encoder
201 field_descriptor._sizer = sizer
202 field_descriptor._default_constructor = _DefaultValueConstructorForField(
203 field_descriptor)
205 def AddDecoder(wiretype, is_packed):
206 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
207 cls._decoders_by_tag[tag_bytes] = (
208 type_checkers.TYPE_TO_DECODER[field_descriptor.type](
209 field_descriptor.number, is_repeated, is_packed,
210 field_descriptor, field_descriptor._default_constructor))
212 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
213 False)
215 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
218 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
221 def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
222 extension_dict = descriptor.extensions_by_name
223 for extension_name, extension_field in extension_dict.iteritems():
224 assert extension_name not in dictionary
225 dictionary[extension_name] = extension_field
228 def _AddEnumValues(descriptor, cls):
229 """Sets class-level attributes for all enum fields defined in this message.
231 Also exporting a class-level object that can name enum values.
233 Args:
234 descriptor: Descriptor object for this message type.
235 cls: Class we're constructing for this message type.
237 for enum_type in descriptor.enum_types:
238 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
239 for enum_value in enum_type.values:
240 setattr(cls, enum_value.name, enum_value.number)
243 def _DefaultValueConstructorForField(field):
244 """Returns a function which returns a default value for a field.
246 Args:
247 field: FieldDescriptor object for this field.
249 The returned function has one argument:
250 message: Message instance containing this field, or a weakref proxy
251 of same.
253 That function in turn returns a default value for this field. The default
254 value may refer back to |message| via a weak reference.
257 if field.label == _FieldDescriptor.LABEL_REPEATED:
258 if field.has_default_value and field.default_value != []:
259 raise ValueError('Repeated field default value not empty list: %s' % (
260 field.default_value))
261 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
264 message_type = field.message_type
265 def MakeRepeatedMessageDefault(message):
266 return containers.RepeatedCompositeFieldContainer(
267 message._listener_for_children, field.message_type)
268 return MakeRepeatedMessageDefault
269 else:
270 type_checker = type_checkers.GetTypeChecker(field)
271 def MakeRepeatedScalarDefault(message):
272 return containers.RepeatedScalarFieldContainer(
273 message._listener_for_children, type_checker)
274 return MakeRepeatedScalarDefault
276 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
278 message_type = field.message_type
279 def MakeSubMessageDefault(message):
280 result = message_type._concrete_class()
281 result._SetListener(message._listener_for_children)
282 return result
283 return MakeSubMessageDefault
285 def MakeScalarDefault(message):
288 return field.default_value
289 return MakeScalarDefault
292 def _ReraiseTypeErrorWithFieldName(message_name, field_name):
293 """Re-raise the currently-handled TypeError with the field name added."""
294 exc = sys.exc_info()[1]
295 if len(exc.args) == 1 and type(exc) is TypeError:
297 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
300 raise type(exc), exc, sys.exc_info()[2]
303 def _AddInitMethod(message_descriptor, cls):
304 """Adds an __init__ method to cls."""
305 fields = message_descriptor.fields
306 def init(self, **kwargs):
307 self._cached_byte_size = 0
308 self._cached_byte_size_dirty = len(kwargs) > 0
309 self._fields = {}
312 self._oneofs = {}
316 self._unknown_fields = ()
317 self._is_present_in_parent = False
318 self._listener = message_listener_mod.NullMessageListener()
319 self._listener_for_children = _Listener(self)
320 for field_name, field_value in kwargs.iteritems():
321 field = _GetFieldByName(message_descriptor, field_name)
322 if field is None:
323 raise TypeError("%s() got an unexpected keyword argument '%s'" %
324 (message_descriptor.name, field_name))
325 if field.label == _FieldDescriptor.LABEL_REPEATED:
326 copy = field._default_constructor(self)
327 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
328 for val in field_value:
329 copy.add().MergeFrom(val)
330 else:
331 copy.extend(field_value)
332 self._fields[field] = copy
333 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
334 copy = field._default_constructor(self)
335 try:
336 copy.MergeFrom(field_value)
337 except TypeError:
338 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
339 self._fields[field] = copy
340 else:
341 try:
342 setattr(self, field_name, field_value)
343 except TypeError:
344 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
346 init.__module__ = None
347 init.__doc__ = None
348 cls.__init__ = init
351 def _GetFieldByName(message_descriptor, field_name):
352 """Returns a field descriptor by field name.
354 Args:
355 message_descriptor: A Descriptor describing all fields in message.
356 field_name: The name of the field to retrieve.
357 Returns:
358 The field descriptor associated with the field name.
360 try:
361 return message_descriptor.fields_by_name[field_name]
362 except KeyError:
363 raise ValueError('Protocol message has no "%s" field.' % field_name)
366 def _AddPropertiesForFields(descriptor, cls):
367 """Adds properties for all fields in this protocol message type."""
368 for field in descriptor.fields:
369 _AddPropertiesForField(field, cls)
371 if descriptor.is_extendable:
374 cls.Extensions = property(lambda self: _ExtensionDict(self))
377 def _AddPropertiesForField(field, cls):
378 """Adds a public property for a protocol message field.
379 Clients can use this property to get and (in the case
380 of non-repeated scalar fields) directly set the value
381 of a protocol message field.
383 Args:
384 field: A FieldDescriptor for this field.
385 cls: The class we're constructing.
389 assert _FieldDescriptor.MAX_CPPTYPE == 10
391 constant_name = field.name.upper() + "_FIELD_NUMBER"
392 setattr(cls, constant_name, field.number)
394 if field.label == _FieldDescriptor.LABEL_REPEATED:
395 _AddPropertiesForRepeatedField(field, cls)
396 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
397 _AddPropertiesForNonRepeatedCompositeField(field, cls)
398 else:
399 _AddPropertiesForNonRepeatedScalarField(field, cls)
402 def _AddPropertiesForRepeatedField(field, cls):
403 """Adds a public property for a "repeated" protocol message field. Clients
404 can use this property to get the value of the field, which will be either a
405 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
406 below).
408 Note that when clients add values to these containers, we perform
409 type-checking in the case of repeated scalar fields, and we also set any
410 necessary "has" bits as a side-effect.
412 Args:
413 field: A FieldDescriptor for this field.
414 cls: The class we're constructing.
416 proto_field_name = field.name
417 property_name = _PropertyName(proto_field_name)
419 def getter(self):
420 field_value = self._fields.get(field)
421 if field_value is None:
423 field_value = field._default_constructor(self)
431 field_value = self._fields.setdefault(field, field_value)
432 return field_value
433 getter.__module__ = None
434 getter.__doc__ = 'Getter for %s.' % proto_field_name
438 def setter(self, new_value):
439 raise AttributeError('Assignment not allowed to repeated field '
440 '"%s" in protocol message object.' % proto_field_name)
442 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
443 setattr(cls, property_name, property(getter, setter, doc=doc))
446 def _AddPropertiesForNonRepeatedScalarField(field, cls):
447 """Adds a public property for a nonrepeated, scalar protocol message field.
448 Clients can use this property to get and directly set the value of the field.
449 Note that when the client sets the value of a field by using this property,
450 all necessary "has" bits are set as a side-effect, and we also perform
451 type-checking.
453 Args:
454 field: A FieldDescriptor for this field.
455 cls: The class we're constructing.
457 proto_field_name = field.name
458 property_name = _PropertyName(proto_field_name)
459 type_checker = type_checkers.GetTypeChecker(field)
460 default_value = field.default_value
461 valid_values = set()
463 def getter(self):
466 return self._fields.get(field, default_value)
467 getter.__module__ = None
468 getter.__doc__ = 'Getter for %s.' % proto_field_name
469 def field_setter(self, new_value):
471 self._fields[field] = type_checker.CheckValue(new_value)
474 if not self._cached_byte_size_dirty:
475 self._Modified()
477 if field.containing_oneof is not None:
478 def setter(self, new_value):
479 field_setter(self, new_value)
480 self._UpdateOneofState(field)
481 else:
482 setter = field_setter
484 setter.__module__ = None
485 setter.__doc__ = 'Setter for %s.' % proto_field_name
488 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
489 setattr(cls, property_name, property(getter, setter, doc=doc))
492 def _AddPropertiesForNonRepeatedCompositeField(field, cls):
493 """Adds a public property for a nonrepeated, composite protocol message field.
494 A composite field is a "group" or "message" field.
496 Clients can use this property to get the value of the field, but cannot
497 assign to the property directly.
499 Args:
500 field: A FieldDescriptor for this field.
501 cls: The class we're constructing.
505 proto_field_name = field.name
506 property_name = _PropertyName(proto_field_name)
512 message_type = field.message_type
514 def getter(self):
515 field_value = self._fields.get(field)
516 if field_value is None:
518 field_value = message_type._concrete_class()
519 field_value._SetListener(
520 _OneofListener(self, field)
521 if field.containing_oneof is not None
522 else self._listener_for_children)
530 field_value = self._fields.setdefault(field, field_value)
531 return field_value
532 getter.__module__ = None
533 getter.__doc__ = 'Getter for %s.' % proto_field_name
537 def setter(self, new_value):
538 raise AttributeError('Assignment not allowed to composite field '
539 '"%s" in protocol message object.' % proto_field_name)
542 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
543 setattr(cls, property_name, property(getter, setter, doc=doc))
546 def _AddPropertiesForExtensions(descriptor, cls):
547 """Adds properties for all fields in this protocol message type."""
548 extension_dict = descriptor.extensions_by_name
549 for extension_name, extension_field in extension_dict.iteritems():
550 constant_name = extension_name.upper() + "_FIELD_NUMBER"
551 setattr(cls, constant_name, extension_field.number)
554 def _AddStaticMethods(cls):
556 def RegisterExtension(extension_handle):
557 extension_handle.containing_type = cls.DESCRIPTOR
558 _AttachFieldHelpers(cls, extension_handle)
562 actual_handle = cls._extensions_by_number.setdefault(
563 extension_handle.number, extension_handle)
564 if actual_handle is not extension_handle:
565 raise AssertionError(
566 'Extensions "%s" and "%s" both try to extend message type "%s" with '
567 'field number %d.' %
568 (extension_handle.full_name, actual_handle.full_name,
569 cls.DESCRIPTOR.full_name, extension_handle.number))
571 cls._extensions_by_name[extension_handle.full_name] = extension_handle
573 handle = extension_handle
574 if _IsMessageSetExtension(handle):
576 cls._extensions_by_name[
577 extension_handle.message_type.full_name] = extension_handle
579 cls.RegisterExtension = staticmethod(RegisterExtension)
581 def FromString(s):
582 message = cls()
583 message.MergeFromString(s)
584 return message
585 cls.FromString = staticmethod(FromString)
588 def _IsPresent(item):
589 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
590 value should be included in the list returned by ListFields()."""
592 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
593 return bool(item[1])
594 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
595 return item[1]._is_present_in_parent
596 else:
597 return True
600 def _AddListFieldsMethod(message_descriptor, cls):
601 """Helper for _AddMessageMethods()."""
603 def ListFields(self):
604 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
605 all_fields.sort(key = lambda item: item[0].number)
606 return all_fields
608 cls.ListFields = ListFields
611 def _AddHasFieldMethod(message_descriptor, cls):
612 """Helper for _AddMessageMethods()."""
614 singular_fields = {}
615 for field in message_descriptor.fields:
616 if field.label != _FieldDescriptor.LABEL_REPEATED:
617 singular_fields[field.name] = field
619 for field in message_descriptor.oneofs:
620 singular_fields[field.name] = field
622 def HasField(self, field_name):
623 try:
624 field = singular_fields[field_name]
625 except KeyError:
626 raise ValueError(
627 'Protocol message has no singular "%s" field.' % field_name)
629 if isinstance(field, descriptor_mod.OneofDescriptor):
630 try:
631 return HasField(self, self._oneofs[field].name)
632 except KeyError:
633 return False
634 else:
635 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
636 value = self._fields.get(field)
637 return value is not None and value._is_present_in_parent
638 else:
639 return field in self._fields
641 cls.HasField = HasField
644 def _AddClearFieldMethod(message_descriptor, cls):
645 """Helper for _AddMessageMethods()."""
646 def ClearField(self, field_name):
647 try:
648 field = message_descriptor.fields_by_name[field_name]
649 except KeyError:
650 try:
651 field = message_descriptor.oneofs_by_name[field_name]
652 if field in self._oneofs:
653 field = self._oneofs[field]
654 else:
655 return
656 except KeyError:
657 raise ValueError('Protocol message has no "%s" field.' % field_name)
659 if field in self._fields:
663 del self._fields[field]
665 if self._oneofs.get(field.containing_oneof, None) is field:
666 del self._oneofs[field.containing_oneof]
671 self._Modified()
673 cls.ClearField = ClearField
676 def _AddClearExtensionMethod(cls):
677 """Helper for _AddMessageMethods()."""
678 def ClearExtension(self, extension_handle):
679 _VerifyExtensionHandle(self, extension_handle)
682 if extension_handle in self._fields:
683 del self._fields[extension_handle]
684 self._Modified()
685 cls.ClearExtension = ClearExtension
688 def _AddClearMethod(message_descriptor, cls):
689 """Helper for _AddMessageMethods()."""
690 def Clear(self):
692 self._fields = {}
693 self._unknown_fields = ()
694 self._Modified()
695 cls.Clear = Clear
698 def _AddHasExtensionMethod(cls):
699 """Helper for _AddMessageMethods()."""
700 def HasExtension(self, extension_handle):
701 _VerifyExtensionHandle(self, extension_handle)
702 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
703 raise KeyError('"%s" is repeated.' % extension_handle.full_name)
705 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
706 value = self._fields.get(extension_handle)
707 return value is not None and value._is_present_in_parent
708 else:
709 return extension_handle in self._fields
710 cls.HasExtension = HasExtension
713 def _AddEqualsMethod(message_descriptor, cls):
714 """Helper for _AddMessageMethods()."""
715 def __eq__(self, other):
716 if (not isinstance(other, message_mod.Message) or
717 other.DESCRIPTOR != self.DESCRIPTOR):
718 return False
720 if self is other:
721 return True
723 if not self.ListFields() == other.ListFields():
724 return False
727 unknown_fields = list(self._unknown_fields)
728 unknown_fields.sort()
729 other_unknown_fields = list(other._unknown_fields)
730 other_unknown_fields.sort()
732 return unknown_fields == other_unknown_fields
734 cls.__eq__ = __eq__
737 def _AddStrMethod(message_descriptor, cls):
738 """Helper for _AddMessageMethods()."""
739 def __str__(self):
740 return text_format.MessageToString(self)
741 cls.__str__ = __str__
744 def _AddUnicodeMethod(unused_message_descriptor, cls):
745 """Helper for _AddMessageMethods()."""
747 def __unicode__(self):
748 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
749 cls.__unicode__ = __unicode__
752 def _AddSetListenerMethod(cls):
753 """Helper for _AddMessageMethods()."""
754 def SetListener(self, listener):
755 if listener is None:
756 self._listener = message_listener_mod.NullMessageListener()
757 else:
758 self._listener = listener
759 cls._SetListener = SetListener
762 def _BytesForNonRepeatedElement(value, field_number, field_type):
763 """Returns the number of bytes needed to serialize a non-repeated element.
764 The returned byte count includes space for tag information and any
765 other additional space associated with serializing value.
767 Args:
768 value: Value we're serializing.
769 field_number: Field number of this value. (Since the field number
770 is stored as part of a varint-encoded tag, this has an impact
771 on the total bytes required to serialize the value).
772 field_type: The type of the field. One of the TYPE_* constants
773 within FieldDescriptor.
775 try:
776 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
777 return fn(field_number, value)
778 except KeyError:
779 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
782 def _AddByteSizeMethod(message_descriptor, cls):
783 """Helper for _AddMessageMethods()."""
785 def ByteSize(self):
786 if not self._cached_byte_size_dirty:
787 return self._cached_byte_size
789 size = 0
790 for field_descriptor, field_value in self.ListFields():
791 size += field_descriptor._sizer(field_value)
793 for tag_bytes, value_bytes in self._unknown_fields:
794 size += len(tag_bytes) + len(value_bytes)
796 self._cached_byte_size = size
797 self._cached_byte_size_dirty = False
798 self._listener_for_children.dirty = False
799 return size
801 cls.ByteSize = ByteSize
804 def _AddSerializeToStringMethod(message_descriptor, cls):
805 """Helper for _AddMessageMethods()."""
807 def SerializeToString(self):
809 errors = []
810 if not self.IsInitialized():
811 raise message_mod.EncodeError(
812 'Message %s is missing required fields: %s' % (
813 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
814 return self.SerializePartialToString()
815 cls.SerializeToString = SerializeToString
818 def _AddSerializePartialToStringMethod(message_descriptor, cls):
819 """Helper for _AddMessageMethods()."""
821 def SerializePartialToString(self):
822 out = BytesIO()
823 self._InternalSerialize(out.write)
824 return out.getvalue()
825 cls.SerializePartialToString = SerializePartialToString
827 def InternalSerialize(self, write_bytes):
828 for field_descriptor, field_value in self.ListFields():
829 field_descriptor._encoder(write_bytes, field_value)
830 for tag_bytes, value_bytes in self._unknown_fields:
831 write_bytes(tag_bytes)
832 write_bytes(value_bytes)
833 cls._InternalSerialize = InternalSerialize
836 def _AddMergeFromStringMethod(message_descriptor, cls):
837 """Helper for _AddMessageMethods()."""
838 def MergeFromString(self, serialized):
839 length = len(serialized)
840 try:
841 if self._InternalParse(serialized, 0, length) != length:
844 raise message_mod.DecodeError('Unexpected end-group tag.')
845 except (IndexError, TypeError):
847 raise message_mod.DecodeError('Truncated message.')
848 except struct.error, e:
849 raise message_mod.DecodeError(e)
850 return length
851 cls.MergeFromString = MergeFromString
853 local_ReadTag = decoder.ReadTag
854 local_SkipField = decoder.SkipField
855 decoders_by_tag = cls._decoders_by_tag
857 def InternalParse(self, buffer, pos, end):
858 self._Modified()
859 field_dict = self._fields
860 unknown_field_list = self._unknown_fields
861 while pos != end:
862 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
863 field_decoder = decoders_by_tag.get(tag_bytes)
864 if field_decoder is None:
865 value_start_pos = new_pos
866 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
867 if new_pos == -1:
868 return pos
869 if not unknown_field_list:
870 unknown_field_list = self._unknown_fields = []
871 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
872 pos = new_pos
873 else:
874 pos = field_decoder(buffer, new_pos, end, self, field_dict)
875 return pos
876 cls._InternalParse = InternalParse
879 def _AddIsInitializedMethod(message_descriptor, cls):
880 """Adds the IsInitialized and FindInitializationError methods to the
881 protocol message class."""
883 required_fields = [field for field in message_descriptor.fields
884 if field.label == _FieldDescriptor.LABEL_REQUIRED]
886 def IsInitialized(self, errors=None):
887 """Checks if all required fields of a message are set.
889 Args:
890 errors: A list which, if provided, will be populated with the field
891 paths of all missing required fields.
893 Returns:
894 True iff the specified message has all required fields set.
899 for field in required_fields:
900 if (field not in self._fields or
901 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
902 not self._fields[field]._is_present_in_parent)):
903 if errors is not None:
904 errors.extend(self.FindInitializationErrors())
905 return False
907 for field, value in list(self._fields.items()):
908 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
909 if field.label == _FieldDescriptor.LABEL_REPEATED:
910 for element in value:
911 if not element.IsInitialized():
912 if errors is not None:
913 errors.extend(self.FindInitializationErrors())
914 return False
915 elif value._is_present_in_parent and not value.IsInitialized():
916 if errors is not None:
917 errors.extend(self.FindInitializationErrors())
918 return False
920 return True
922 cls.IsInitialized = IsInitialized
924 def FindInitializationErrors(self):
925 """Finds required fields which are not initialized.
927 Returns:
928 A list of strings. Each string is a path to an uninitialized field from
929 the top-level message, e.g. "foo.bar[5].baz".
932 errors = []
934 for field in required_fields:
935 if not self.HasField(field.name):
936 errors.append(field.name)
938 for field, value in self.ListFields():
939 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
940 if field.is_extension:
941 name = "(%s)" % field.full_name
942 else:
943 name = field.name
945 if field.label == _FieldDescriptor.LABEL_REPEATED:
946 for i in xrange(len(value)):
947 element = value[i]
948 prefix = "%s[%d]." % (name, i)
949 sub_errors = element.FindInitializationErrors()
950 errors += [ prefix + error for error in sub_errors ]
951 else:
952 prefix = name + "."
953 sub_errors = value.FindInitializationErrors()
954 errors += [ prefix + error for error in sub_errors ]
956 return errors
958 cls.FindInitializationErrors = FindInitializationErrors
961 def _AddMergeFromMethod(cls):
962 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
963 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
965 def MergeFrom(self, msg):
966 if not isinstance(msg, cls):
967 raise TypeError(
968 "Parameter to MergeFrom() must be instance of same class: "
969 "expected %s got %s." % (cls.__name__, type(msg).__name__))
971 assert msg is not self
972 self._Modified()
974 fields = self._fields
976 for field, value in msg._fields.iteritems():
977 if field.label == LABEL_REPEATED:
978 field_value = fields.get(field)
979 if field_value is None:
981 field_value = field._default_constructor(self)
982 fields[field] = field_value
983 field_value.MergeFrom(value)
984 elif field.cpp_type == CPPTYPE_MESSAGE:
985 if value._is_present_in_parent:
986 field_value = fields.get(field)
987 if field_value is None:
989 field_value = field._default_constructor(self)
990 fields[field] = field_value
991 field_value.MergeFrom(value)
992 else:
993 self._fields[field] = value
995 if msg._unknown_fields:
996 if not self._unknown_fields:
997 self._unknown_fields = []
998 self._unknown_fields.extend(msg._unknown_fields)
1000 cls.MergeFrom = MergeFrom
1003 def _AddWhichOneofMethod(message_descriptor, cls):
1004 def WhichOneof(self, oneof_name):
1005 """Returns the name of the currently set field inside a oneof, or None."""
1006 try:
1007 field = message_descriptor.oneofs_by_name[oneof_name]
1008 except KeyError:
1009 raise ValueError(
1010 'Protocol message has no oneof "%s" field.' % oneof_name)
1012 nested_field = self._oneofs.get(field, None)
1013 if nested_field is not None and self.HasField(nested_field.name):
1014 return nested_field.name
1015 else:
1016 return None
1018 cls.WhichOneof = WhichOneof
1021 def _AddMessageMethods(message_descriptor, cls):
1022 """Adds implementations of all Message methods to cls."""
1023 _AddListFieldsMethod(message_descriptor, cls)
1024 _AddHasFieldMethod(message_descriptor, cls)
1025 _AddClearFieldMethod(message_descriptor, cls)
1026 if message_descriptor.is_extendable:
1027 _AddClearExtensionMethod(cls)
1028 _AddHasExtensionMethod(cls)
1029 _AddClearMethod(message_descriptor, cls)
1030 _AddEqualsMethod(message_descriptor, cls)
1031 _AddStrMethod(message_descriptor, cls)
1032 _AddUnicodeMethod(message_descriptor, cls)
1033 _AddSetListenerMethod(cls)
1034 _AddByteSizeMethod(message_descriptor, cls)
1035 _AddSerializeToStringMethod(message_descriptor, cls)
1036 _AddSerializePartialToStringMethod(message_descriptor, cls)
1037 _AddMergeFromStringMethod(message_descriptor, cls)
1038 _AddIsInitializedMethod(message_descriptor, cls)
1039 _AddMergeFromMethod(cls)
1040 _AddWhichOneofMethod(message_descriptor, cls)
1042 def _AddPrivateHelperMethods(message_descriptor, cls):
1043 """Adds implementation of private helper methods to cls."""
1045 def Modified(self):
1046 """Sets the _cached_byte_size_dirty bit to true,
1047 and propagates this to our listener iff this was a state change.
1054 if not self._cached_byte_size_dirty:
1055 self._cached_byte_size_dirty = True
1056 self._listener_for_children.dirty = True
1057 self._is_present_in_parent = True
1058 self._listener.Modified()
1060 def _UpdateOneofState(self, field):
1061 """Sets field as the active field in its containing oneof.
1063 Will also delete currently active field in the oneof, if it is different
1064 from the argument. Does not mark the message as modified.
1066 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1067 if other_field is not field:
1068 del self._fields[other_field]
1069 self._oneofs[field.containing_oneof] = field
1071 cls._Modified = Modified
1072 cls.SetInParent = Modified
1073 cls._UpdateOneofState = _UpdateOneofState
1076 class _Listener(object):
1078 """MessageListener implementation that a parent message registers with its
1079 child message.
1081 In order to support semantics like:
1083 foo.bar.baz.qux = 23
1084 assert foo.HasField('bar')
1086 ...child objects must have back references to their parents.
1087 This helper class is at the heart of this support.
1090 def __init__(self, parent_message):
1091 """Args:
1092 parent_message: The message whose _Modified() method we should call when
1093 we receive Modified() messages.
1099 if isinstance(parent_message, weakref.ProxyType):
1100 self._parent_message_weakref = parent_message
1101 else:
1102 self._parent_message_weakref = weakref.proxy(parent_message)
1107 self.dirty = False
1109 def Modified(self):
1110 if self.dirty:
1111 return
1112 try:
1114 self._parent_message_weakref._Modified()
1115 except ReferenceError:
1119 pass
1122 class _OneofListener(_Listener):
1123 """Special listener implementation for setting composite oneof fields."""
1125 def __init__(self, parent_message, field):
1126 """Args:
1127 parent_message: The message whose _Modified() method we should call when
1128 we receive Modified() messages.
1129 field: The descriptor of the field being set in the parent message.
1131 super(_OneofListener, self).__init__(parent_message)
1132 self._field = field
1134 def Modified(self):
1135 """Also updates the state of the containing oneof in the parent message."""
1136 try:
1137 self._parent_message_weakref._UpdateOneofState(self._field)
1138 super(_OneofListener, self).Modified()
1139 except ReferenceError:
1140 pass
1147 class _ExtensionDict(object):
1149 """Dict-like container for supporting an indexable "Extensions"
1150 field on proto instances.
1152 Note that in all cases we expect extension handles to be
1153 FieldDescriptors.
1156 def __init__(self, extended_message):
1157 """extended_message: Message instance for which we are the Extensions dict.
1160 self._extended_message = extended_message
1162 def __getitem__(self, extension_handle):
1163 """Returns the current value of the given extension handle."""
1165 _VerifyExtensionHandle(self._extended_message, extension_handle)
1167 result = self._extended_message._fields.get(extension_handle)
1168 if result is not None:
1169 return result
1171 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1172 result = extension_handle._default_constructor(self._extended_message)
1173 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1174 result = extension_handle.message_type._concrete_class()
1175 try:
1176 result._SetListener(self._extended_message._listener_for_children)
1177 except ReferenceError:
1178 pass
1179 else:
1182 return extension_handle.default_value
1190 result = self._extended_message._fields.setdefault(
1191 extension_handle, result)
1193 return result
1195 def __eq__(self, other):
1196 if not isinstance(other, self.__class__):
1197 return False
1199 my_fields = self._extended_message.ListFields()
1200 other_fields = other._extended_message.ListFields()
1203 my_fields = [ field for field in my_fields if field.is_extension ]
1204 other_fields = [ field for field in other_fields if field.is_extension ]
1206 return my_fields == other_fields
1208 def __ne__(self, other):
1209 return not self == other
1211 def __hash__(self):
1212 raise TypeError('unhashable object')
1218 def __setitem__(self, extension_handle, value):
1219 """If extension_handle specifies a non-repeated, scalar extension
1220 field, sets the value of that field.
1223 _VerifyExtensionHandle(self._extended_message, extension_handle)
1225 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1226 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1227 raise TypeError(
1228 'Cannot assign to extension "%s" because it is a repeated or '
1229 'composite type.' % extension_handle.full_name)
1233 type_checker = type_checkers.GetTypeChecker(
1234 extension_handle)
1236 self._extended_message._fields[extension_handle] = (
1237 type_checker.CheckValue(value))
1238 self._extended_message._Modified()
1240 def _FindExtensionByName(self, name):
1241 """Tries to find a known extension with the specified name.
1243 Args:
1244 name: Extension full name.
1246 Returns:
1247 Extension field descriptor.
1249 return self._extended_message._extensions_by_name.get(name, None)