App Engine Python SDK version 1.9.3
[gae.git] / python / google / net / proto2 / python / internal / python_message.py
blob45f011023a4b8fe0ed4b7164e9bebb347d69c96f
1 #!/usr/bin/env python
3 # Copyright 2007 Google Inc.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
26 """Contains a metaclass and helper functions used to create
27 protocol message classes from Descriptor objects at runtime.
29 Recall that a metaclass is the "type" of a class.
30 (A class is to a metaclass what an instance is to a class.)
32 In this case, we use the GeneratedProtocolMessageType metaclass
33 to inject all the useful functionality into the classes
34 output by the protocol compiler at compile-time.
36 The upshot of all this is that the real implementation
37 details for ALL pure-Python protocol buffers are *here in
38 this file*.
39 """
42 import sys
43 if sys.version_info[0] < 3:
44 try:
45 from cStringIO import StringIO as BytesIO
46 except ImportError:
47 from StringIO import StringIO as BytesIO
48 import copy_reg as copyreg
49 else:
50 from io import BytesIO
51 import copyreg
52 import struct
53 import weakref
56 from google.net.proto2.python.internal import containers
57 from google.net.proto2.python.internal import decoder
58 from google.net.proto2.python.internal import encoder
59 from google.net.proto2.python.internal import enum_type_wrapper
60 from google.net.proto2.python.internal import message_listener as message_listener_mod
61 from google.net.proto2.python.internal import type_checkers
62 from google.net.proto2.python.internal import wire_format
63 from google.net.proto2.python.public import descriptor as descriptor_mod
64 from google.net.proto2.python.public import message as message_mod
65 from google.net.proto2.python.public import text_format
67 _FieldDescriptor = descriptor_mod.FieldDescriptor
70 def NewMessage(bases, descriptor, dictionary):
71 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
72 _AddSlots(descriptor, dictionary)
73 return bases
76 def InitMessage(descriptor, cls):
77 cls._decoders_by_tag = {}
78 cls._extensions_by_name = {}
79 cls._extensions_by_number = {}
80 if (descriptor.has_options and
81 descriptor.GetOptions().message_set_wire_format):
82 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
83 decoder.MessageSetItemDecoder(cls._extensions_by_number))
86 for field in descriptor.fields:
87 _AttachFieldHelpers(cls, field)
89 _AddEnumValues(descriptor, cls)
90 _AddInitMethod(descriptor, cls)
91 _AddPropertiesForFields(descriptor, cls)
92 _AddPropertiesForExtensions(descriptor, cls)
93 _AddStaticMethods(cls)
94 _AddMessageMethods(descriptor, cls)
95 _AddPrivateHelperMethods(cls)
96 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
107 def _PropertyName(proto_field_name):
108 """Returns the name of the public property attribute which
109 clients can use to get and (in some cases) set the value
110 of a protocol message field.
112 Args:
113 proto_field_name: The protocol message field name, exactly
114 as it appears (or would appear) in a .proto file.
133 return proto_field_name
136 def _VerifyExtensionHandle(message, extension_handle):
137 """Verify that the given extension handle is valid."""
139 if not isinstance(extension_handle, _FieldDescriptor):
140 raise KeyError('HasExtension() expects an extension handle, got: %s' %
141 extension_handle)
143 if not extension_handle.is_extension:
144 raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
146 if not extension_handle.containing_type:
147 raise KeyError('"%s" is missing a containing_type.'
148 % extension_handle.full_name)
150 if extension_handle.containing_type is not message.DESCRIPTOR:
151 raise KeyError('Extension "%s" extends message type "%s", but this '
152 'message is of type "%s".' %
153 (extension_handle.full_name,
154 extension_handle.containing_type.full_name,
155 message.DESCRIPTOR.full_name))
158 def _AddSlots(message_descriptor, dictionary):
159 """Adds a __slots__ entry to dictionary, containing the names of all valid
160 attributes for this message type.
162 Args:
163 message_descriptor: A Descriptor instance describing this message type.
164 dictionary: Class dictionary to which we'll add a '__slots__' entry.
166 dictionary['__slots__'] = ['_cached_byte_size',
167 '_cached_byte_size_dirty',
168 '_fields',
169 '_unknown_fields',
170 '_is_present_in_parent',
171 '_listener',
172 '_listener_for_children',
173 '__weakref__']
176 def _IsMessageSetExtension(field):
177 return (field.is_extension and
178 field.containing_type.has_options and
179 field.containing_type.GetOptions().message_set_wire_format and
180 field.type == _FieldDescriptor.TYPE_MESSAGE and
181 field.message_type == field.extension_scope and
182 field.label == _FieldDescriptor.LABEL_OPTIONAL)
185 def _AttachFieldHelpers(cls, field_descriptor):
186 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
187 is_packed = (field_descriptor.has_options and
188 field_descriptor.GetOptions().packed)
190 if _IsMessageSetExtension(field_descriptor):
191 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
192 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
193 else:
194 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
195 field_descriptor.number, is_repeated, is_packed)
196 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
197 field_descriptor.number, is_repeated, is_packed)
199 field_descriptor._encoder = field_encoder
200 field_descriptor._sizer = sizer
201 field_descriptor._default_constructor = _DefaultValueConstructorForField(
202 field_descriptor)
204 def AddDecoder(wiretype, is_packed):
205 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
206 cls._decoders_by_tag[tag_bytes] = (
207 type_checkers.TYPE_TO_DECODER[field_descriptor.type](
208 field_descriptor.number, is_repeated, is_packed,
209 field_descriptor, field_descriptor._default_constructor))
211 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
212 False)
214 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
217 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
220 def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
221 extension_dict = descriptor.extensions_by_name
222 for extension_name, extension_field in extension_dict.iteritems():
223 assert extension_name not in dictionary
224 dictionary[extension_name] = extension_field
227 def _AddEnumValues(descriptor, cls):
228 """Sets class-level attributes for all enum fields defined in this message.
230 Also exporting a class-level object that can name enum values.
232 Args:
233 descriptor: Descriptor object for this message type.
234 cls: Class we're constructing for this message type.
236 for enum_type in descriptor.enum_types:
237 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
238 for enum_value in enum_type.values:
239 setattr(cls, enum_value.name, enum_value.number)
242 def _DefaultValueConstructorForField(field):
243 """Returns a function which returns a default value for a field.
245 Args:
246 field: FieldDescriptor object for this field.
248 The returned function has one argument:
249 message: Message instance containing this field, or a weakref proxy
250 of same.
252 That function in turn returns a default value for this field. The default
253 value may refer back to |message| via a weak reference.
256 if field.label == _FieldDescriptor.LABEL_REPEATED:
257 if field.has_default_value and field.default_value != []:
258 raise ValueError('Repeated field default value not empty list: %s' % (
259 field.default_value))
260 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
263 message_type = field.message_type
264 def MakeRepeatedMessageDefault(message):
265 return containers.RepeatedCompositeFieldContainer(
266 message._listener_for_children, field.message_type)
267 return MakeRepeatedMessageDefault
268 else:
269 type_checker = type_checkers.GetTypeChecker(field)
270 def MakeRepeatedScalarDefault(message):
271 return containers.RepeatedScalarFieldContainer(
272 message._listener_for_children, type_checker)
273 return MakeRepeatedScalarDefault
275 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
277 message_type = field.message_type
278 def MakeSubMessageDefault(message):
279 result = message_type._concrete_class()
280 result._SetListener(message._listener_for_children)
281 return result
282 return MakeSubMessageDefault
284 def MakeScalarDefault(message):
287 return field.default_value
288 return MakeScalarDefault
291 def _AddInitMethod(message_descriptor, cls):
292 """Adds an __init__ method to cls."""
293 fields = message_descriptor.fields
294 def init(self, **kwargs):
295 self._cached_byte_size = 0
296 self._cached_byte_size_dirty = len(kwargs) > 0
297 self._fields = {}
300 self._unknown_fields = ()
301 self._is_present_in_parent = False
302 self._listener = message_listener_mod.NullMessageListener()
303 self._listener_for_children = _Listener(self)
304 for field_name, field_value in kwargs.iteritems():
305 field = _GetFieldByName(message_descriptor, field_name)
306 if field is None:
307 raise TypeError("%s() got an unexpected keyword argument '%s'" %
308 (message_descriptor.name, field_name))
309 if field.label == _FieldDescriptor.LABEL_REPEATED:
310 copy = field._default_constructor(self)
311 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
312 for val in field_value:
313 copy.add().MergeFrom(val)
314 else:
315 copy.extend(field_value)
316 self._fields[field] = copy
317 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
318 copy = field._default_constructor(self)
319 copy.MergeFrom(field_value)
320 self._fields[field] = copy
321 else:
322 setattr(self, field_name, field_value)
324 init.__module__ = None
325 init.__doc__ = None
326 cls.__init__ = init
329 def _GetFieldByName(message_descriptor, field_name):
330 """Returns a field descriptor by field name.
332 Args:
333 message_descriptor: A Descriptor describing all fields in message.
334 field_name: The name of the field to retrieve.
335 Returns:
336 The field descriptor associated with the field name.
338 try:
339 return message_descriptor.fields_by_name[field_name]
340 except KeyError:
341 raise ValueError('Protocol message has no "%s" field.' % field_name)
344 def _AddPropertiesForFields(descriptor, cls):
345 """Adds properties for all fields in this protocol message type."""
346 for field in descriptor.fields:
347 _AddPropertiesForField(field, cls)
349 if descriptor.is_extendable:
352 cls.Extensions = property(lambda self: _ExtensionDict(self))
355 def _AddPropertiesForField(field, cls):
356 """Adds a public property for a protocol message field.
357 Clients can use this property to get and (in the case
358 of non-repeated scalar fields) directly set the value
359 of a protocol message field.
361 Args:
362 field: A FieldDescriptor for this field.
363 cls: The class we're constructing.
367 assert _FieldDescriptor.MAX_CPPTYPE == 10
369 constant_name = field.name.upper() + "_FIELD_NUMBER"
370 setattr(cls, constant_name, field.number)
372 if field.label == _FieldDescriptor.LABEL_REPEATED:
373 _AddPropertiesForRepeatedField(field, cls)
374 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
375 _AddPropertiesForNonRepeatedCompositeField(field, cls)
376 else:
377 _AddPropertiesForNonRepeatedScalarField(field, cls)
380 def _AddPropertiesForRepeatedField(field, cls):
381 """Adds a public property for a "repeated" protocol message field. Clients
382 can use this property to get the value of the field, which will be either a
383 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
384 below).
386 Note that when clients add values to these containers, we perform
387 type-checking in the case of repeated scalar fields, and we also set any
388 necessary "has" bits as a side-effect.
390 Args:
391 field: A FieldDescriptor for this field.
392 cls: The class we're constructing.
394 proto_field_name = field.name
395 property_name = _PropertyName(proto_field_name)
397 def getter(self):
398 field_value = self._fields.get(field)
399 if field_value is None:
401 field_value = field._default_constructor(self)
409 field_value = self._fields.setdefault(field, field_value)
410 return field_value
411 getter.__module__ = None
412 getter.__doc__ = 'Getter for %s.' % proto_field_name
416 def setter(self, new_value):
417 raise AttributeError('Assignment not allowed to repeated field '
418 '"%s" in protocol message object.' % proto_field_name)
420 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
421 setattr(cls, property_name, property(getter, setter, doc=doc))
424 def _AddPropertiesForNonRepeatedScalarField(field, cls):
425 """Adds a public property for a nonrepeated, scalar protocol message field.
426 Clients can use this property to get and directly set the value of the field.
427 Note that when the client sets the value of a field by using this property,
428 all necessary "has" bits are set as a side-effect, and we also perform
429 type-checking.
431 Args:
432 field: A FieldDescriptor for this field.
433 cls: The class we're constructing.
435 proto_field_name = field.name
436 property_name = _PropertyName(proto_field_name)
437 type_checker = type_checkers.GetTypeChecker(field)
438 default_value = field.default_value
439 valid_values = set()
441 def getter(self):
444 return self._fields.get(field, default_value)
445 getter.__module__ = None
446 getter.__doc__ = 'Getter for %s.' % proto_field_name
447 def setter(self, new_value):
449 self._fields[field] = type_checker.CheckValue(new_value)
452 if not self._cached_byte_size_dirty:
453 self._Modified()
455 setter.__module__ = None
456 setter.__doc__ = 'Setter for %s.' % proto_field_name
459 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
460 setattr(cls, property_name, property(getter, setter, doc=doc))
463 def _AddPropertiesForNonRepeatedCompositeField(field, cls):
464 """Adds a public property for a nonrepeated, composite protocol message field.
465 A composite field is a "group" or "message" field.
467 Clients can use this property to get the value of the field, but cannot
468 assign to the property directly.
470 Args:
471 field: A FieldDescriptor for this field.
472 cls: The class we're constructing.
476 proto_field_name = field.name
477 property_name = _PropertyName(proto_field_name)
483 message_type = field.message_type
485 def getter(self):
486 field_value = self._fields.get(field)
487 if field_value is None:
489 field_value = message_type._concrete_class()
490 field_value._SetListener(self._listener_for_children)
498 field_value = self._fields.setdefault(field, field_value)
499 return field_value
500 getter.__module__ = None
501 getter.__doc__ = 'Getter for %s.' % proto_field_name
505 def setter(self, new_value):
506 raise AttributeError('Assignment not allowed to composite field '
507 '"%s" in protocol message object.' % proto_field_name)
510 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
511 setattr(cls, property_name, property(getter, setter, doc=doc))
514 def _AddPropertiesForExtensions(descriptor, cls):
515 """Adds properties for all fields in this protocol message type."""
516 extension_dict = descriptor.extensions_by_name
517 for extension_name, extension_field in extension_dict.iteritems():
518 constant_name = extension_name.upper() + "_FIELD_NUMBER"
519 setattr(cls, constant_name, extension_field.number)
522 def _AddStaticMethods(cls):
524 def RegisterExtension(extension_handle):
525 extension_handle.containing_type = cls.DESCRIPTOR
526 _AttachFieldHelpers(cls, extension_handle)
530 actual_handle = cls._extensions_by_number.setdefault(
531 extension_handle.number, extension_handle)
532 if actual_handle is not extension_handle:
533 raise AssertionError(
534 'Extensions "%s" and "%s" both try to extend message type "%s" with '
535 'field number %d.' %
536 (extension_handle.full_name, actual_handle.full_name,
537 cls.DESCRIPTOR.full_name, extension_handle.number))
539 cls._extensions_by_name[extension_handle.full_name] = extension_handle
541 handle = extension_handle
542 if _IsMessageSetExtension(handle):
544 cls._extensions_by_name[
545 extension_handle.message_type.full_name] = extension_handle
547 cls.RegisterExtension = staticmethod(RegisterExtension)
549 def FromString(s):
550 message = cls()
551 message.MergeFromString(s)
552 return message
553 cls.FromString = staticmethod(FromString)
556 def _IsPresent(item):
557 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
558 value should be included in the list returned by ListFields()."""
560 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
561 return bool(item[1])
562 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
563 return item[1]._is_present_in_parent
564 else:
565 return True
568 def _AddListFieldsMethod(message_descriptor, cls):
569 """Helper for _AddMessageMethods()."""
571 def ListFields(self):
572 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
573 all_fields.sort(key = lambda item: item[0].number)
574 return all_fields
576 cls.ListFields = ListFields
579 def _AddHasFieldMethod(message_descriptor, cls):
580 """Helper for _AddMessageMethods()."""
582 singular_fields = {}
583 for field in message_descriptor.fields:
584 if field.label != _FieldDescriptor.LABEL_REPEATED:
585 singular_fields[field.name] = field
587 def HasField(self, field_name):
588 try:
589 field = singular_fields[field_name]
590 except KeyError:
591 raise ValueError(
592 'Protocol message has no singular "%s" field.' % field_name)
594 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
595 value = self._fields.get(field)
596 return value is not None and value._is_present_in_parent
597 else:
598 return field in self._fields
599 cls.HasField = HasField
602 def _AddClearFieldMethod(message_descriptor, cls):
603 """Helper for _AddMessageMethods()."""
604 def ClearField(self, field_name):
605 try:
606 field = message_descriptor.fields_by_name[field_name]
607 except KeyError:
608 raise ValueError('Protocol message has no "%s" field.' % field_name)
610 if field in self._fields:
614 del self._fields[field]
619 self._Modified()
621 cls.ClearField = ClearField
624 def _AddClearExtensionMethod(cls):
625 """Helper for _AddMessageMethods()."""
626 def ClearExtension(self, extension_handle):
627 _VerifyExtensionHandle(self, extension_handle)
630 if extension_handle in self._fields:
631 del self._fields[extension_handle]
632 self._Modified()
633 cls.ClearExtension = ClearExtension
636 def _AddClearMethod(message_descriptor, cls):
637 """Helper for _AddMessageMethods()."""
638 def Clear(self):
640 self._fields = {}
641 self._unknown_fields = ()
642 self._Modified()
643 cls.Clear = Clear
646 def _AddHasExtensionMethod(cls):
647 """Helper for _AddMessageMethods()."""
648 def HasExtension(self, extension_handle):
649 _VerifyExtensionHandle(self, extension_handle)
650 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
651 raise KeyError('"%s" is repeated.' % extension_handle.full_name)
653 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
654 value = self._fields.get(extension_handle)
655 return value is not None and value._is_present_in_parent
656 else:
657 return extension_handle in self._fields
658 cls.HasExtension = HasExtension
661 def _AddEqualsMethod(message_descriptor, cls):
662 """Helper for _AddMessageMethods()."""
663 def __eq__(self, other):
664 if (not isinstance(other, message_mod.Message) or
665 other.DESCRIPTOR != self.DESCRIPTOR):
666 return False
668 if self is other:
669 return True
671 if not self.ListFields() == other.ListFields():
672 return False
675 unknown_fields = list(self._unknown_fields)
676 unknown_fields.sort()
677 other_unknown_fields = list(other._unknown_fields)
678 other_unknown_fields.sort()
680 return unknown_fields == other_unknown_fields
682 cls.__eq__ = __eq__
685 def _AddStrMethod(message_descriptor, cls):
686 """Helper for _AddMessageMethods()."""
687 def __str__(self):
688 return text_format.MessageToString(self)
689 cls.__str__ = __str__
692 def _AddUnicodeMethod(unused_message_descriptor, cls):
693 """Helper for _AddMessageMethods()."""
695 def __unicode__(self):
696 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
697 cls.__unicode__ = __unicode__
700 def _AddSetListenerMethod(cls):
701 """Helper for _AddMessageMethods()."""
702 def SetListener(self, listener):
703 if listener is None:
704 self._listener = message_listener_mod.NullMessageListener()
705 else:
706 self._listener = listener
707 cls._SetListener = SetListener
710 def _BytesForNonRepeatedElement(value, field_number, field_type):
711 """Returns the number of bytes needed to serialize a non-repeated element.
712 The returned byte count includes space for tag information and any
713 other additional space associated with serializing value.
715 Args:
716 value: Value we're serializing.
717 field_number: Field number of this value. (Since the field number
718 is stored as part of a varint-encoded tag, this has an impact
719 on the total bytes required to serialize the value).
720 field_type: The type of the field. One of the TYPE_* constants
721 within FieldDescriptor.
723 try:
724 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
725 return fn(field_number, value)
726 except KeyError:
727 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
730 def _AddByteSizeMethod(message_descriptor, cls):
731 """Helper for _AddMessageMethods()."""
733 def ByteSize(self):
734 if not self._cached_byte_size_dirty:
735 return self._cached_byte_size
737 size = 0
738 for field_descriptor, field_value in self.ListFields():
739 size += field_descriptor._sizer(field_value)
741 for tag_bytes, value_bytes in self._unknown_fields:
742 size += len(tag_bytes) + len(value_bytes)
744 self._cached_byte_size = size
745 self._cached_byte_size_dirty = False
746 self._listener_for_children.dirty = False
747 return size
749 cls.ByteSize = ByteSize
752 def _AddSerializeToStringMethod(message_descriptor, cls):
753 """Helper for _AddMessageMethods()."""
755 def SerializeToString(self):
757 errors = []
758 if not self.IsInitialized():
759 raise message_mod.EncodeError(
760 'Message %s is missing required fields: %s' % (
761 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
762 return self.SerializePartialToString()
763 cls.SerializeToString = SerializeToString
766 def _AddSerializePartialToStringMethod(message_descriptor, cls):
767 """Helper for _AddMessageMethods()."""
769 def SerializePartialToString(self):
770 out = BytesIO()
771 self._InternalSerialize(out.write)
772 return out.getvalue()
773 cls.SerializePartialToString = SerializePartialToString
775 def InternalSerialize(self, write_bytes):
776 for field_descriptor, field_value in self.ListFields():
777 field_descriptor._encoder(write_bytes, field_value)
778 for tag_bytes, value_bytes in self._unknown_fields:
779 write_bytes(tag_bytes)
780 write_bytes(value_bytes)
781 cls._InternalSerialize = InternalSerialize
784 def _AddMergeFromStringMethod(message_descriptor, cls):
785 """Helper for _AddMessageMethods()."""
786 def MergeFromString(self, serialized):
787 length = len(serialized)
788 try:
789 if self._InternalParse(serialized, 0, length) != length:
792 raise message_mod.DecodeError('Unexpected end-group tag.')
793 except (IndexError, TypeError):
795 raise message_mod.DecodeError('Truncated message.')
796 except struct.error, e:
797 raise message_mod.DecodeError(e)
798 return length
799 cls.MergeFromString = MergeFromString
801 local_ReadTag = decoder.ReadTag
802 local_SkipField = decoder.SkipField
803 decoders_by_tag = cls._decoders_by_tag
805 def InternalParse(self, buffer, pos, end):
806 self._Modified()
807 field_dict = self._fields
808 unknown_field_list = self._unknown_fields
809 while pos != end:
810 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
811 field_decoder = decoders_by_tag.get(tag_bytes)
812 if field_decoder is None:
813 value_start_pos = new_pos
814 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
815 if new_pos == -1:
816 return pos
817 if not unknown_field_list:
818 unknown_field_list = self._unknown_fields = []
819 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
820 pos = new_pos
821 else:
822 pos = field_decoder(buffer, new_pos, end, self, field_dict)
823 return pos
824 cls._InternalParse = InternalParse
827 def _AddIsInitializedMethod(message_descriptor, cls):
828 """Adds the IsInitialized and FindInitializationError methods to the
829 protocol message class."""
831 required_fields = [field for field in message_descriptor.fields
832 if field.label == _FieldDescriptor.LABEL_REQUIRED]
834 def IsInitialized(self, errors=None):
835 """Checks if all required fields of a message are set.
837 Args:
838 errors: A list which, if provided, will be populated with the field
839 paths of all missing required fields.
841 Returns:
842 True iff the specified message has all required fields set.
847 for field in required_fields:
848 if (field not in self._fields or
849 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
850 not self._fields[field]._is_present_in_parent)):
851 if errors is not None:
852 errors.extend(self.FindInitializationErrors())
853 return False
855 for field, value in list(self._fields.items()):
856 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
857 if field.label == _FieldDescriptor.LABEL_REPEATED:
858 for element in value:
859 if not element.IsInitialized():
860 if errors is not None:
861 errors.extend(self.FindInitializationErrors())
862 return False
863 elif value._is_present_in_parent and not value.IsInitialized():
864 if errors is not None:
865 errors.extend(self.FindInitializationErrors())
866 return False
868 return True
870 cls.IsInitialized = IsInitialized
872 def FindInitializationErrors(self):
873 """Finds required fields which are not initialized.
875 Returns:
876 A list of strings. Each string is a path to an uninitialized field from
877 the top-level message, e.g. "foo.bar[5].baz".
880 errors = []
882 for field in required_fields:
883 if not self.HasField(field.name):
884 errors.append(field.name)
886 for field, value in self.ListFields():
887 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
888 if field.is_extension:
889 name = "(%s)" % field.full_name
890 else:
891 name = field.name
893 if field.label == _FieldDescriptor.LABEL_REPEATED:
894 for i in xrange(len(value)):
895 element = value[i]
896 prefix = "%s[%d]." % (name, i)
897 sub_errors = element.FindInitializationErrors()
898 errors += [ prefix + error for error in sub_errors ]
899 else:
900 prefix = name + "."
901 sub_errors = value.FindInitializationErrors()
902 errors += [ prefix + error for error in sub_errors ]
904 return errors
906 cls.FindInitializationErrors = FindInitializationErrors
909 def _AddMergeFromMethod(cls):
910 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
911 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
913 def MergeFrom(self, msg):
914 if not isinstance(msg, cls):
915 raise TypeError(
916 "Parameter to MergeFrom() must be instance of same class: "
917 "expected %s got %s." % (cls.__name__, type(msg).__name__))
919 assert msg is not self
920 self._Modified()
922 fields = self._fields
924 for field, value in msg._fields.iteritems():
925 if field.label == LABEL_REPEATED:
926 field_value = fields.get(field)
927 if field_value is None:
929 field_value = field._default_constructor(self)
930 fields[field] = field_value
931 field_value.MergeFrom(value)
932 elif field.cpp_type == CPPTYPE_MESSAGE:
933 if value._is_present_in_parent:
934 field_value = fields.get(field)
935 if field_value is None:
937 field_value = field._default_constructor(self)
938 fields[field] = field_value
939 field_value.MergeFrom(value)
940 else:
941 self._fields[field] = value
943 if msg._unknown_fields:
944 if not self._unknown_fields:
945 self._unknown_fields = []
946 self._unknown_fields.extend(msg._unknown_fields)
948 cls.MergeFrom = MergeFrom
951 def _AddMessageMethods(message_descriptor, cls):
952 """Adds implementations of all Message methods to cls."""
953 _AddListFieldsMethod(message_descriptor, cls)
954 _AddHasFieldMethod(message_descriptor, cls)
955 _AddClearFieldMethod(message_descriptor, cls)
956 if message_descriptor.is_extendable:
957 _AddClearExtensionMethod(cls)
958 _AddHasExtensionMethod(cls)
959 _AddClearMethod(message_descriptor, cls)
960 _AddEqualsMethod(message_descriptor, cls)
961 _AddStrMethod(message_descriptor, cls)
962 _AddUnicodeMethod(message_descriptor, cls)
963 _AddSetListenerMethod(cls)
964 _AddByteSizeMethod(message_descriptor, cls)
965 _AddSerializeToStringMethod(message_descriptor, cls)
966 _AddSerializePartialToStringMethod(message_descriptor, cls)
967 _AddMergeFromStringMethod(message_descriptor, cls)
968 _AddIsInitializedMethod(message_descriptor, cls)
969 _AddMergeFromMethod(cls)
972 def _AddPrivateHelperMethods(cls):
973 """Adds implementation of private helper methods to cls."""
975 def Modified(self):
976 """Sets the _cached_byte_size_dirty bit to true,
977 and propagates this to our listener iff this was a state change.
984 if not self._cached_byte_size_dirty:
985 self._cached_byte_size_dirty = True
986 self._listener_for_children.dirty = True
987 self._is_present_in_parent = True
988 self._listener.Modified()
990 cls._Modified = Modified
991 cls.SetInParent = Modified
994 class _Listener(object):
996 """MessageListener implementation that a parent message registers with its
997 child message.
999 In order to support semantics like:
1001 foo.bar.baz.qux = 23
1002 assert foo.HasField('bar')
1004 ...child objects must have back references to their parents.
1005 This helper class is at the heart of this support.
1008 def __init__(self, parent_message):
1009 """Args:
1010 parent_message: The message whose _Modified() method we should call when
1011 we receive Modified() messages.
1017 if isinstance(parent_message, weakref.ProxyType):
1018 self._parent_message_weakref = parent_message
1019 else:
1020 self._parent_message_weakref = weakref.proxy(parent_message)
1025 self.dirty = False
1027 def Modified(self):
1028 if self.dirty:
1029 return
1030 try:
1032 self._parent_message_weakref._Modified()
1033 except ReferenceError:
1037 pass
1044 class _ExtensionDict(object):
1046 """Dict-like container for supporting an indexable "Extensions"
1047 field on proto instances.
1049 Note that in all cases we expect extension handles to be
1050 FieldDescriptors.
1053 def __init__(self, extended_message):
1054 """extended_message: Message instance for which we are the Extensions dict.
1057 self._extended_message = extended_message
1059 def __getitem__(self, extension_handle):
1060 """Returns the current value of the given extension handle."""
1062 _VerifyExtensionHandle(self._extended_message, extension_handle)
1064 result = self._extended_message._fields.get(extension_handle)
1065 if result is not None:
1066 return result
1068 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1069 result = extension_handle._default_constructor(self._extended_message)
1070 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1071 result = extension_handle.message_type._concrete_class()
1072 try:
1073 result._SetListener(self._extended_message._listener_for_children)
1074 except ReferenceError:
1075 pass
1076 else:
1079 return extension_handle.default_value
1087 result = self._extended_message._fields.setdefault(
1088 extension_handle, result)
1090 return result
1092 def __eq__(self, other):
1093 if not isinstance(other, self.__class__):
1094 return False
1096 my_fields = self._extended_message.ListFields()
1097 other_fields = other._extended_message.ListFields()
1100 my_fields = [ field for field in my_fields if field.is_extension ]
1101 other_fields = [ field for field in other_fields if field.is_extension ]
1103 return my_fields == other_fields
1105 def __ne__(self, other):
1106 return not self == other
1108 def __hash__(self):
1109 raise TypeError('unhashable object')
1115 def __setitem__(self, extension_handle, value):
1116 """If extension_handle specifies a non-repeated, scalar extension
1117 field, sets the value of that field.
1120 _VerifyExtensionHandle(self._extended_message, extension_handle)
1122 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1123 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1124 raise TypeError(
1125 'Cannot assign to extension "%s" because it is a repeated or '
1126 'composite type.' % extension_handle.full_name)
1130 type_checker = type_checkers.GetTypeChecker(
1131 extension_handle)
1133 self._extended_message._fields[extension_handle] = (
1134 type_checker.CheckValue(value))
1135 self._extended_message._Modified()
1137 def _FindExtensionByName(self, name):
1138 """Tries to find a known extension with the specified name.
1140 Args:
1141 name: Extension full name.
1143 Returns:
1144 Extension field descriptor.
1146 return self._extended_message._extensions_by_name.get(name, None)