3 # Copyright 2007 Google Inc.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
26 """Contains a metaclass and helper functions used to create
27 protocol message classes from Descriptor objects at runtime.
29 Recall that a metaclass is the "type" of a class.
30 (A class is to a metaclass what an instance is to a class.)
32 In this case, we use the GeneratedProtocolMessageType metaclass
33 to inject all the useful functionality into the classes
34 output by the protocol compiler at compile-time.
36 The upshot of all this is that the real implementation
37 details for ALL pure-Python protocol buffers are *here in
43 if sys
.version_info
[0] < 3:
45 from cStringIO
import StringIO
as BytesIO
47 from StringIO
import StringIO
as BytesIO
48 import copy_reg
as copyreg
50 from io
import BytesIO
56 from google
.net
.proto2
.python
.internal
import containers
57 from google
.net
.proto2
.python
.internal
import decoder
58 from google
.net
.proto2
.python
.internal
import encoder
59 from google
.net
.proto2
.python
.internal
import enum_type_wrapper
60 from google
.net
.proto2
.python
.internal
import message_listener
as message_listener_mod
61 from google
.net
.proto2
.python
.internal
import type_checkers
62 from google
.net
.proto2
.python
.internal
import wire_format
63 from google
.net
.proto2
.python
.public
import descriptor
as descriptor_mod
64 from google
.net
.proto2
.python
.public
import message
as message_mod
65 from google
.net
.proto2
.python
.public
import text_format
67 _FieldDescriptor
= descriptor_mod
.FieldDescriptor
70 def NewMessage(bases
, descriptor
, dictionary
):
71 _AddClassAttributesForNestedExtensions(descriptor
, dictionary
)
72 _AddSlots(descriptor
, dictionary
)
76 def InitMessage(descriptor
, cls
):
77 cls
._decoders
_by
_tag
= {}
78 cls
._extensions
_by
_name
= {}
79 cls
._extensions
_by
_number
= {}
80 if (descriptor
.has_options
and
81 descriptor
.GetOptions().message_set_wire_format
):
82 cls
._decoders
_by
_tag
[decoder
.MESSAGE_SET_ITEM_TAG
] = (
83 decoder
.MessageSetItemDecoder(cls
._extensions
_by
_number
))
86 for field
in descriptor
.fields
:
87 _AttachFieldHelpers(cls
, field
)
89 _AddEnumValues(descriptor
, cls
)
90 _AddInitMethod(descriptor
, cls
)
91 _AddPropertiesForFields(descriptor
, cls
)
92 _AddPropertiesForExtensions(descriptor
, cls
)
93 _AddStaticMethods(cls
)
94 _AddMessageMethods(descriptor
, cls
)
95 _AddPrivateHelperMethods(cls
)
96 copyreg
.pickle(cls
, lambda obj
: (cls
, (), obj
.__getstate
__()))
107 def _PropertyName(proto_field_name
):
108 """Returns the name of the public property attribute which
109 clients can use to get and (in some cases) set the value
110 of a protocol message field.
113 proto_field_name: The protocol message field name, exactly
114 as it appears (or would appear) in a .proto file.
133 return proto_field_name
136 def _VerifyExtensionHandle(message
, extension_handle
):
137 """Verify that the given extension handle is valid."""
139 if not isinstance(extension_handle
, _FieldDescriptor
):
140 raise KeyError('HasExtension() expects an extension handle, got: %s' %
143 if not extension_handle
.is_extension
:
144 raise KeyError('"%s" is not an extension.' % extension_handle
.full_name
)
146 if not extension_handle
.containing_type
:
147 raise KeyError('"%s" is missing a containing_type.'
148 % extension_handle
.full_name
)
150 if extension_handle
.containing_type
is not message
.DESCRIPTOR
:
151 raise KeyError('Extension "%s" extends message type "%s", but this '
152 'message is of type "%s".' %
153 (extension_handle
.full_name
,
154 extension_handle
.containing_type
.full_name
,
155 message
.DESCRIPTOR
.full_name
))
158 def _AddSlots(message_descriptor
, dictionary
):
159 """Adds a __slots__ entry to dictionary, containing the names of all valid
160 attributes for this message type.
163 message_descriptor: A Descriptor instance describing this message type.
164 dictionary: Class dictionary to which we'll add a '__slots__' entry.
166 dictionary
['__slots__'] = ['_cached_byte_size',
167 '_cached_byte_size_dirty',
170 '_is_present_in_parent',
172 '_listener_for_children',
176 def _IsMessageSetExtension(field
):
177 return (field
.is_extension
and
178 field
.containing_type
.has_options
and
179 field
.containing_type
.GetOptions().message_set_wire_format
and
180 field
.type == _FieldDescriptor
.TYPE_MESSAGE
and
181 field
.message_type
== field
.extension_scope
and
182 field
.label
== _FieldDescriptor
.LABEL_OPTIONAL
)
185 def _AttachFieldHelpers(cls
, field_descriptor
):
186 is_repeated
= (field_descriptor
.label
== _FieldDescriptor
.LABEL_REPEATED
)
187 is_packed
= (field_descriptor
.has_options
and
188 field_descriptor
.GetOptions().packed
)
190 if _IsMessageSetExtension(field_descriptor
):
191 field_encoder
= encoder
.MessageSetItemEncoder(field_descriptor
.number
)
192 sizer
= encoder
.MessageSetItemSizer(field_descriptor
.number
)
194 field_encoder
= type_checkers
.TYPE_TO_ENCODER
[field_descriptor
.type](
195 field_descriptor
.number
, is_repeated
, is_packed
)
196 sizer
= type_checkers
.TYPE_TO_SIZER
[field_descriptor
.type](
197 field_descriptor
.number
, is_repeated
, is_packed
)
199 field_descriptor
._encoder
= field_encoder
200 field_descriptor
._sizer
= sizer
201 field_descriptor
._default
_constructor
= _DefaultValueConstructorForField(
204 def AddDecoder(wiretype
, is_packed
):
205 tag_bytes
= encoder
.TagBytes(field_descriptor
.number
, wiretype
)
206 cls
._decoders
_by
_tag
[tag_bytes
] = (
207 type_checkers
.TYPE_TO_DECODER
[field_descriptor
.type](
208 field_descriptor
.number
, is_repeated
, is_packed
,
209 field_descriptor
, field_descriptor
._default
_constructor
))
211 AddDecoder(type_checkers
.FIELD_TYPE_TO_WIRE_TYPE
[field_descriptor
.type],
214 if is_repeated
and wire_format
.IsTypePackable(field_descriptor
.type):
217 AddDecoder(wire_format
.WIRETYPE_LENGTH_DELIMITED
, True)
220 def _AddClassAttributesForNestedExtensions(descriptor
, dictionary
):
221 extension_dict
= descriptor
.extensions_by_name
222 for extension_name
, extension_field
in extension_dict
.iteritems():
223 assert extension_name
not in dictionary
224 dictionary
[extension_name
] = extension_field
227 def _AddEnumValues(descriptor
, cls
):
228 """Sets class-level attributes for all enum fields defined in this message.
230 Also exporting a class-level object that can name enum values.
233 descriptor: Descriptor object for this message type.
234 cls: Class we're constructing for this message type.
236 for enum_type
in descriptor
.enum_types
:
237 setattr(cls
, enum_type
.name
, enum_type_wrapper
.EnumTypeWrapper(enum_type
))
238 for enum_value
in enum_type
.values
:
239 setattr(cls
, enum_value
.name
, enum_value
.number
)
242 def _DefaultValueConstructorForField(field
):
243 """Returns a function which returns a default value for a field.
246 field: FieldDescriptor object for this field.
248 The returned function has one argument:
249 message: Message instance containing this field, or a weakref proxy
252 That function in turn returns a default value for this field. The default
253 value may refer back to |message| via a weak reference.
256 if field
.label
== _FieldDescriptor
.LABEL_REPEATED
:
257 if field
.has_default_value
and field
.default_value
!= []:
258 raise ValueError('Repeated field default value not empty list: %s' % (
259 field
.default_value
))
260 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
263 message_type
= field
.message_type
264 def MakeRepeatedMessageDefault(message
):
265 return containers
.RepeatedCompositeFieldContainer(
266 message
._listener
_for
_children
, field
.message_type
)
267 return MakeRepeatedMessageDefault
269 type_checker
= type_checkers
.GetTypeChecker(field
)
270 def MakeRepeatedScalarDefault(message
):
271 return containers
.RepeatedScalarFieldContainer(
272 message
._listener
_for
_children
, type_checker
)
273 return MakeRepeatedScalarDefault
275 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
277 message_type
= field
.message_type
278 def MakeSubMessageDefault(message
):
279 result
= message_type
._concrete
_class
()
280 result
._SetListener
(message
._listener
_for
_children
)
282 return MakeSubMessageDefault
284 def MakeScalarDefault(message
):
287 return field
.default_value
288 return MakeScalarDefault
291 def _AddInitMethod(message_descriptor
, cls
):
292 """Adds an __init__ method to cls."""
293 fields
= message_descriptor
.fields
294 def init(self
, **kwargs
):
295 self
._cached
_byte
_size
= 0
296 self
._cached
_byte
_size
_dirty
= len(kwargs
) > 0
300 self
._unknown
_fields
= ()
301 self
._is
_present
_in
_parent
= False
302 self
._listener
= message_listener_mod
.NullMessageListener()
303 self
._listener
_for
_children
= _Listener(self
)
304 for field_name
, field_value
in kwargs
.iteritems():
305 field
= _GetFieldByName(message_descriptor
, field_name
)
307 raise TypeError("%s() got an unexpected keyword argument '%s'" %
308 (message_descriptor
.name
, field_name
))
309 if field
.label
== _FieldDescriptor
.LABEL_REPEATED
:
310 copy
= field
._default
_constructor
(self
)
311 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
312 for val
in field_value
:
313 copy
.add().MergeFrom(val
)
315 copy
.extend(field_value
)
316 self
._fields
[field
] = copy
317 elif field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
318 copy
= field
._default
_constructor
(self
)
319 copy
.MergeFrom(field_value
)
320 self
._fields
[field
] = copy
322 setattr(self
, field_name
, field_value
)
324 init
.__module
__ = None
329 def _GetFieldByName(message_descriptor
, field_name
):
330 """Returns a field descriptor by field name.
333 message_descriptor: A Descriptor describing all fields in message.
334 field_name: The name of the field to retrieve.
336 The field descriptor associated with the field name.
339 return message_descriptor
.fields_by_name
[field_name
]
341 raise ValueError('Protocol message has no "%s" field.' % field_name
)
344 def _AddPropertiesForFields(descriptor
, cls
):
345 """Adds properties for all fields in this protocol message type."""
346 for field
in descriptor
.fields
:
347 _AddPropertiesForField(field
, cls
)
349 if descriptor
.is_extendable
:
352 cls
.Extensions
= property(lambda self
: _ExtensionDict(self
))
355 def _AddPropertiesForField(field
, cls
):
356 """Adds a public property for a protocol message field.
357 Clients can use this property to get and (in the case
358 of non-repeated scalar fields) directly set the value
359 of a protocol message field.
362 field: A FieldDescriptor for this field.
363 cls: The class we're constructing.
367 assert _FieldDescriptor
.MAX_CPPTYPE
== 10
369 constant_name
= field
.name
.upper() + "_FIELD_NUMBER"
370 setattr(cls
, constant_name
, field
.number
)
372 if field
.label
== _FieldDescriptor
.LABEL_REPEATED
:
373 _AddPropertiesForRepeatedField(field
, cls
)
374 elif field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
375 _AddPropertiesForNonRepeatedCompositeField(field
, cls
)
377 _AddPropertiesForNonRepeatedScalarField(field
, cls
)
380 def _AddPropertiesForRepeatedField(field
, cls
):
381 """Adds a public property for a "repeated" protocol message field. Clients
382 can use this property to get the value of the field, which will be either a
383 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
386 Note that when clients add values to these containers, we perform
387 type-checking in the case of repeated scalar fields, and we also set any
388 necessary "has" bits as a side-effect.
391 field: A FieldDescriptor for this field.
392 cls: The class we're constructing.
394 proto_field_name
= field
.name
395 property_name
= _PropertyName(proto_field_name
)
398 field_value
= self
._fields
.get(field
)
399 if field_value
is None:
401 field_value
= field
._default
_constructor
(self
)
409 field_value
= self
._fields
.setdefault(field
, field_value
)
411 getter
.__module
__ = None
412 getter
.__doc
__ = 'Getter for %s.' % proto_field_name
416 def setter(self
, new_value
):
417 raise AttributeError('Assignment not allowed to repeated field '
418 '"%s" in protocol message object.' % proto_field_name
)
420 doc
= 'Magic attribute generated for "%s" proto field.' % proto_field_name
421 setattr(cls
, property_name
, property(getter
, setter
, doc
=doc
))
424 def _AddPropertiesForNonRepeatedScalarField(field
, cls
):
425 """Adds a public property for a nonrepeated, scalar protocol message field.
426 Clients can use this property to get and directly set the value of the field.
427 Note that when the client sets the value of a field by using this property,
428 all necessary "has" bits are set as a side-effect, and we also perform
432 field: A FieldDescriptor for this field.
433 cls: The class we're constructing.
435 proto_field_name
= field
.name
436 property_name
= _PropertyName(proto_field_name
)
437 type_checker
= type_checkers
.GetTypeChecker(field
)
438 default_value
= field
.default_value
444 return self
._fields
.get(field
, default_value
)
445 getter
.__module
__ = None
446 getter
.__doc
__ = 'Getter for %s.' % proto_field_name
447 def setter(self
, new_value
):
449 self
._fields
[field
] = type_checker
.CheckValue(new_value
)
452 if not self
._cached
_byte
_size
_dirty
:
455 setter
.__module
__ = None
456 setter
.__doc
__ = 'Setter for %s.' % proto_field_name
459 doc
= 'Magic attribute generated for "%s" proto field.' % proto_field_name
460 setattr(cls
, property_name
, property(getter
, setter
, doc
=doc
))
463 def _AddPropertiesForNonRepeatedCompositeField(field
, cls
):
464 """Adds a public property for a nonrepeated, composite protocol message field.
465 A composite field is a "group" or "message" field.
467 Clients can use this property to get the value of the field, but cannot
468 assign to the property directly.
471 field: A FieldDescriptor for this field.
472 cls: The class we're constructing.
476 proto_field_name
= field
.name
477 property_name
= _PropertyName(proto_field_name
)
483 message_type
= field
.message_type
486 field_value
= self
._fields
.get(field
)
487 if field_value
is None:
489 field_value
= message_type
._concrete
_class
()
490 field_value
._SetListener
(self
._listener
_for
_children
)
498 field_value
= self
._fields
.setdefault(field
, field_value
)
500 getter
.__module
__ = None
501 getter
.__doc
__ = 'Getter for %s.' % proto_field_name
505 def setter(self
, new_value
):
506 raise AttributeError('Assignment not allowed to composite field '
507 '"%s" in protocol message object.' % proto_field_name
)
510 doc
= 'Magic attribute generated for "%s" proto field.' % proto_field_name
511 setattr(cls
, property_name
, property(getter
, setter
, doc
=doc
))
514 def _AddPropertiesForExtensions(descriptor
, cls
):
515 """Adds properties for all fields in this protocol message type."""
516 extension_dict
= descriptor
.extensions_by_name
517 for extension_name
, extension_field
in extension_dict
.iteritems():
518 constant_name
= extension_name
.upper() + "_FIELD_NUMBER"
519 setattr(cls
, constant_name
, extension_field
.number
)
522 def _AddStaticMethods(cls
):
524 def RegisterExtension(extension_handle
):
525 extension_handle
.containing_type
= cls
.DESCRIPTOR
526 _AttachFieldHelpers(cls
, extension_handle
)
530 actual_handle
= cls
._extensions
_by
_number
.setdefault(
531 extension_handle
.number
, extension_handle
)
532 if actual_handle
is not extension_handle
:
533 raise AssertionError(
534 'Extensions "%s" and "%s" both try to extend message type "%s" with '
536 (extension_handle
.full_name
, actual_handle
.full_name
,
537 cls
.DESCRIPTOR
.full_name
, extension_handle
.number
))
539 cls
._extensions
_by
_name
[extension_handle
.full_name
] = extension_handle
541 handle
= extension_handle
542 if _IsMessageSetExtension(handle
):
544 cls
._extensions
_by
_name
[
545 extension_handle
.message_type
.full_name
] = extension_handle
547 cls
.RegisterExtension
= staticmethod(RegisterExtension
)
551 message
.MergeFromString(s
)
553 cls
.FromString
= staticmethod(FromString
)
556 def _IsPresent(item
):
557 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
558 value should be included in the list returned by ListFields()."""
560 if item
[0].label
== _FieldDescriptor
.LABEL_REPEATED
:
562 elif item
[0].cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
563 return item
[1]._is
_present
_in
_parent
568 def _AddListFieldsMethod(message_descriptor
, cls
):
569 """Helper for _AddMessageMethods()."""
571 def ListFields(self
):
572 all_fields
= [item
for item
in self
._fields
.iteritems() if _IsPresent(item
)]
573 all_fields
.sort(key
= lambda item
: item
[0].number
)
576 cls
.ListFields
= ListFields
579 def _AddHasFieldMethod(message_descriptor
, cls
):
580 """Helper for _AddMessageMethods()."""
583 for field
in message_descriptor
.fields
:
584 if field
.label
!= _FieldDescriptor
.LABEL_REPEATED
:
585 singular_fields
[field
.name
] = field
587 def HasField(self
, field_name
):
589 field
= singular_fields
[field_name
]
592 'Protocol message has no singular "%s" field.' % field_name
)
594 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
595 value
= self
._fields
.get(field
)
596 return value
is not None and value
._is
_present
_in
_parent
598 return field
in self
._fields
599 cls
.HasField
= HasField
602 def _AddClearFieldMethod(message_descriptor
, cls
):
603 """Helper for _AddMessageMethods()."""
604 def ClearField(self
, field_name
):
606 field
= message_descriptor
.fields_by_name
[field_name
]
608 raise ValueError('Protocol message has no "%s" field.' % field_name
)
610 if field
in self
._fields
:
614 del self
._fields
[field
]
621 cls
.ClearField
= ClearField
624 def _AddClearExtensionMethod(cls
):
625 """Helper for _AddMessageMethods()."""
626 def ClearExtension(self
, extension_handle
):
627 _VerifyExtensionHandle(self
, extension_handle
)
630 if extension_handle
in self
._fields
:
631 del self
._fields
[extension_handle
]
633 cls
.ClearExtension
= ClearExtension
636 def _AddClearMethod(message_descriptor
, cls
):
637 """Helper for _AddMessageMethods()."""
641 self
._unknown
_fields
= ()
646 def _AddHasExtensionMethod(cls
):
647 """Helper for _AddMessageMethods()."""
648 def HasExtension(self
, extension_handle
):
649 _VerifyExtensionHandle(self
, extension_handle
)
650 if extension_handle
.label
== _FieldDescriptor
.LABEL_REPEATED
:
651 raise KeyError('"%s" is repeated.' % extension_handle
.full_name
)
653 if extension_handle
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
654 value
= self
._fields
.get(extension_handle
)
655 return value
is not None and value
._is
_present
_in
_parent
657 return extension_handle
in self
._fields
658 cls
.HasExtension
= HasExtension
661 def _AddEqualsMethod(message_descriptor
, cls
):
662 """Helper for _AddMessageMethods()."""
663 def __eq__(self
, other
):
664 if (not isinstance(other
, message_mod
.Message
) or
665 other
.DESCRIPTOR
!= self
.DESCRIPTOR
):
671 if not self
.ListFields() == other
.ListFields():
675 unknown_fields
= list(self
._unknown
_fields
)
676 unknown_fields
.sort()
677 other_unknown_fields
= list(other
._unknown
_fields
)
678 other_unknown_fields
.sort()
680 return unknown_fields
== other_unknown_fields
685 def _AddStrMethod(message_descriptor
, cls
):
686 """Helper for _AddMessageMethods()."""
688 return text_format
.MessageToString(self
)
689 cls
.__str
__ = __str__
692 def _AddUnicodeMethod(unused_message_descriptor
, cls
):
693 """Helper for _AddMessageMethods()."""
695 def __unicode__(self
):
696 return text_format
.MessageToString(self
, as_utf8
=True).decode('utf-8')
697 cls
.__unicode
__ = __unicode__
700 def _AddSetListenerMethod(cls
):
701 """Helper for _AddMessageMethods()."""
702 def SetListener(self
, listener
):
704 self
._listener
= message_listener_mod
.NullMessageListener()
706 self
._listener
= listener
707 cls
._SetListener
= SetListener
710 def _BytesForNonRepeatedElement(value
, field_number
, field_type
):
711 """Returns the number of bytes needed to serialize a non-repeated element.
712 The returned byte count includes space for tag information and any
713 other additional space associated with serializing value.
716 value: Value we're serializing.
717 field_number: Field number of this value. (Since the field number
718 is stored as part of a varint-encoded tag, this has an impact
719 on the total bytes required to serialize the value).
720 field_type: The type of the field. One of the TYPE_* constants
721 within FieldDescriptor.
724 fn
= type_checkers
.TYPE_TO_BYTE_SIZE_FN
[field_type
]
725 return fn(field_number
, value
)
727 raise message_mod
.EncodeError('Unrecognized field type: %d' % field_type
)
730 def _AddByteSizeMethod(message_descriptor
, cls
):
731 """Helper for _AddMessageMethods()."""
734 if not self
._cached
_byte
_size
_dirty
:
735 return self
._cached
_byte
_size
738 for field_descriptor
, field_value
in self
.ListFields():
739 size
+= field_descriptor
._sizer
(field_value
)
741 for tag_bytes
, value_bytes
in self
._unknown
_fields
:
742 size
+= len(tag_bytes
) + len(value_bytes
)
744 self
._cached
_byte
_size
= size
745 self
._cached
_byte
_size
_dirty
= False
746 self
._listener
_for
_children
.dirty
= False
749 cls
.ByteSize
= ByteSize
752 def _AddSerializeToStringMethod(message_descriptor
, cls
):
753 """Helper for _AddMessageMethods()."""
755 def SerializeToString(self
):
758 if not self
.IsInitialized():
759 raise message_mod
.EncodeError(
760 'Message %s is missing required fields: %s' % (
761 self
.DESCRIPTOR
.full_name
, ','.join(self
.FindInitializationErrors())))
762 return self
.SerializePartialToString()
763 cls
.SerializeToString
= SerializeToString
766 def _AddSerializePartialToStringMethod(message_descriptor
, cls
):
767 """Helper for _AddMessageMethods()."""
769 def SerializePartialToString(self
):
771 self
._InternalSerialize
(out
.write
)
772 return out
.getvalue()
773 cls
.SerializePartialToString
= SerializePartialToString
775 def InternalSerialize(self
, write_bytes
):
776 for field_descriptor
, field_value
in self
.ListFields():
777 field_descriptor
._encoder
(write_bytes
, field_value
)
778 for tag_bytes
, value_bytes
in self
._unknown
_fields
:
779 write_bytes(tag_bytes
)
780 write_bytes(value_bytes
)
781 cls
._InternalSerialize
= InternalSerialize
784 def _AddMergeFromStringMethod(message_descriptor
, cls
):
785 """Helper for _AddMessageMethods()."""
786 def MergeFromString(self
, serialized
):
787 length
= len(serialized
)
789 if self
._InternalParse
(serialized
, 0, length
) != length
:
792 raise message_mod
.DecodeError('Unexpected end-group tag.')
793 except (IndexError, TypeError):
795 raise message_mod
.DecodeError('Truncated message.')
796 except struct
.error
, e
:
797 raise message_mod
.DecodeError(e
)
799 cls
.MergeFromString
= MergeFromString
801 local_ReadTag
= decoder
.ReadTag
802 local_SkipField
= decoder
.SkipField
803 decoders_by_tag
= cls
._decoders
_by
_tag
805 def InternalParse(self
, buffer, pos
, end
):
807 field_dict
= self
._fields
808 unknown_field_list
= self
._unknown
_fields
810 (tag_bytes
, new_pos
) = local_ReadTag(buffer, pos
)
811 field_decoder
= decoders_by_tag
.get(tag_bytes
)
812 if field_decoder
is None:
813 value_start_pos
= new_pos
814 new_pos
= local_SkipField(buffer, new_pos
, end
, tag_bytes
)
817 if not unknown_field_list
:
818 unknown_field_list
= self
._unknown
_fields
= []
819 unknown_field_list
.append((tag_bytes
, buffer[value_start_pos
:new_pos
]))
822 pos
= field_decoder(buffer, new_pos
, end
, self
, field_dict
)
824 cls
._InternalParse
= InternalParse
827 def _AddIsInitializedMethod(message_descriptor
, cls
):
828 """Adds the IsInitialized and FindInitializationError methods to the
829 protocol message class."""
831 required_fields
= [field
for field
in message_descriptor
.fields
832 if field
.label
== _FieldDescriptor
.LABEL_REQUIRED
]
834 def IsInitialized(self
, errors
=None):
835 """Checks if all required fields of a message are set.
838 errors: A list which, if provided, will be populated with the field
839 paths of all missing required fields.
842 True iff the specified message has all required fields set.
847 for field
in required_fields
:
848 if (field
not in self
._fields
or
849 (field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
and
850 not self
._fields
[field
]._is
_present
_in
_parent
)):
851 if errors
is not None:
852 errors
.extend(self
.FindInitializationErrors())
855 for field
, value
in list(self
._fields
.items()):
856 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
857 if field
.label
== _FieldDescriptor
.LABEL_REPEATED
:
858 for element
in value
:
859 if not element
.IsInitialized():
860 if errors
is not None:
861 errors
.extend(self
.FindInitializationErrors())
863 elif value
._is
_present
_in
_parent
and not value
.IsInitialized():
864 if errors
is not None:
865 errors
.extend(self
.FindInitializationErrors())
870 cls
.IsInitialized
= IsInitialized
872 def FindInitializationErrors(self
):
873 """Finds required fields which are not initialized.
876 A list of strings. Each string is a path to an uninitialized field from
877 the top-level message, e.g. "foo.bar[5].baz".
882 for field
in required_fields
:
883 if not self
.HasField(field
.name
):
884 errors
.append(field
.name
)
886 for field
, value
in self
.ListFields():
887 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
888 if field
.is_extension
:
889 name
= "(%s)" % field
.full_name
893 if field
.label
== _FieldDescriptor
.LABEL_REPEATED
:
894 for i
in xrange(len(value
)):
896 prefix
= "%s[%d]." % (name
, i
)
897 sub_errors
= element
.FindInitializationErrors()
898 errors
+= [ prefix
+ error
for error
in sub_errors
]
901 sub_errors
= value
.FindInitializationErrors()
902 errors
+= [ prefix
+ error
for error
in sub_errors
]
906 cls
.FindInitializationErrors
= FindInitializationErrors
909 def _AddMergeFromMethod(cls
):
910 LABEL_REPEATED
= _FieldDescriptor
.LABEL_REPEATED
911 CPPTYPE_MESSAGE
= _FieldDescriptor
.CPPTYPE_MESSAGE
913 def MergeFrom(self
, msg
):
914 if not isinstance(msg
, cls
):
916 "Parameter to MergeFrom() must be instance of same class: "
917 "expected %s got %s." % (cls
.__name
__, type(msg
).__name
__))
919 assert msg
is not self
922 fields
= self
._fields
924 for field
, value
in msg
._fields
.iteritems():
925 if field
.label
== LABEL_REPEATED
:
926 field_value
= fields
.get(field
)
927 if field_value
is None:
929 field_value
= field
._default
_constructor
(self
)
930 fields
[field
] = field_value
931 field_value
.MergeFrom(value
)
932 elif field
.cpp_type
== CPPTYPE_MESSAGE
:
933 if value
._is
_present
_in
_parent
:
934 field_value
= fields
.get(field
)
935 if field_value
is None:
937 field_value
= field
._default
_constructor
(self
)
938 fields
[field
] = field_value
939 field_value
.MergeFrom(value
)
941 self
._fields
[field
] = value
943 if msg
._unknown
_fields
:
944 if not self
._unknown
_fields
:
945 self
._unknown
_fields
= []
946 self
._unknown
_fields
.extend(msg
._unknown
_fields
)
948 cls
.MergeFrom
= MergeFrom
951 def _AddMessageMethods(message_descriptor
, cls
):
952 """Adds implementations of all Message methods to cls."""
953 _AddListFieldsMethod(message_descriptor
, cls
)
954 _AddHasFieldMethod(message_descriptor
, cls
)
955 _AddClearFieldMethod(message_descriptor
, cls
)
956 if message_descriptor
.is_extendable
:
957 _AddClearExtensionMethod(cls
)
958 _AddHasExtensionMethod(cls
)
959 _AddClearMethod(message_descriptor
, cls
)
960 _AddEqualsMethod(message_descriptor
, cls
)
961 _AddStrMethod(message_descriptor
, cls
)
962 _AddUnicodeMethod(message_descriptor
, cls
)
963 _AddSetListenerMethod(cls
)
964 _AddByteSizeMethod(message_descriptor
, cls
)
965 _AddSerializeToStringMethod(message_descriptor
, cls
)
966 _AddSerializePartialToStringMethod(message_descriptor
, cls
)
967 _AddMergeFromStringMethod(message_descriptor
, cls
)
968 _AddIsInitializedMethod(message_descriptor
, cls
)
969 _AddMergeFromMethod(cls
)
972 def _AddPrivateHelperMethods(cls
):
973 """Adds implementation of private helper methods to cls."""
976 """Sets the _cached_byte_size_dirty bit to true,
977 and propagates this to our listener iff this was a state change.
984 if not self
._cached
_byte
_size
_dirty
:
985 self
._cached
_byte
_size
_dirty
= True
986 self
._listener
_for
_children
.dirty
= True
987 self
._is
_present
_in
_parent
= True
988 self
._listener
.Modified()
990 cls
._Modified
= Modified
991 cls
.SetInParent
= Modified
994 class _Listener(object):
996 """MessageListener implementation that a parent message registers with its
999 In order to support semantics like:
1001 foo.bar.baz.qux = 23
1002 assert foo.HasField('bar')
1004 ...child objects must have back references to their parents.
1005 This helper class is at the heart of this support.
1008 def __init__(self
, parent_message
):
1010 parent_message: The message whose _Modified() method we should call when
1011 we receive Modified() messages.
1017 if isinstance(parent_message
, weakref
.ProxyType
):
1018 self
._parent
_message
_weakref
= parent_message
1020 self
._parent
_message
_weakref
= weakref
.proxy(parent_message
)
1032 self
._parent
_message
_weakref
._Modified
()
1033 except ReferenceError:
1044 class _ExtensionDict(object):
1046 """Dict-like container for supporting an indexable "Extensions"
1047 field on proto instances.
1049 Note that in all cases we expect extension handles to be
1053 def __init__(self
, extended_message
):
1054 """extended_message: Message instance for which we are the Extensions dict.
1057 self
._extended
_message
= extended_message
1059 def __getitem__(self
, extension_handle
):
1060 """Returns the current value of the given extension handle."""
1062 _VerifyExtensionHandle(self
._extended
_message
, extension_handle
)
1064 result
= self
._extended
_message
._fields
.get(extension_handle
)
1065 if result
is not None:
1068 if extension_handle
.label
== _FieldDescriptor
.LABEL_REPEATED
:
1069 result
= extension_handle
._default
_constructor
(self
._extended
_message
)
1070 elif extension_handle
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
1071 result
= extension_handle
.message_type
._concrete
_class
()
1073 result
._SetListener
(self
._extended
_message
._listener
_for
_children
)
1074 except ReferenceError:
1079 return extension_handle
.default_value
1087 result
= self
._extended
_message
._fields
.setdefault(
1088 extension_handle
, result
)
1092 def __eq__(self
, other
):
1093 if not isinstance(other
, self
.__class
__):
1096 my_fields
= self
._extended
_message
.ListFields()
1097 other_fields
= other
._extended
_message
.ListFields()
1100 my_fields
= [ field
for field
in my_fields
if field
.is_extension
]
1101 other_fields
= [ field
for field
in other_fields
if field
.is_extension
]
1103 return my_fields
== other_fields
1105 def __ne__(self
, other
):
1106 return not self
== other
1109 raise TypeError('unhashable object')
1115 def __setitem__(self
, extension_handle
, value
):
1116 """If extension_handle specifies a non-repeated, scalar extension
1117 field, sets the value of that field.
1120 _VerifyExtensionHandle(self
._extended
_message
, extension_handle
)
1122 if (extension_handle
.label
== _FieldDescriptor
.LABEL_REPEATED
or
1123 extension_handle
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
):
1125 'Cannot assign to extension "%s" because it is a repeated or '
1126 'composite type.' % extension_handle
.full_name
)
1130 type_checker
= type_checkers
.GetTypeChecker(
1133 self
._extended
_message
._fields
[extension_handle
] = (
1134 type_checker
.CheckValue(value
))
1135 self
._extended
_message
._Modified
()
1137 def _FindExtensionByName(self
, name
):
1138 """Tries to find a known extension with the specified name.
1141 name: Extension full name.
1144 Extension field descriptor.
1146 return self
._extended
_message
._extensions
_by
_name
.get(name
, None)