3 # Copyright 2007 Google Inc.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
26 """Contains a metaclass and helper functions used to create
27 protocol message classes from Descriptor objects at runtime.
29 Recall that a metaclass is the "type" of a class.
30 (A class is to a metaclass what an instance is to a class.)
32 In this case, we use the GeneratedProtocolMessageType metaclass
33 to inject all the useful functionality into the classes
34 output by the protocol compiler at compile-time.
36 The upshot of all this is that the real implementation
37 details for ALL pure-Python protocol buffers are *here in
43 if sys
.version_info
[0] < 3:
45 from cStringIO
import StringIO
as BytesIO
47 from StringIO
import StringIO
as BytesIO
48 import copy_reg
as copyreg
50 from io
import BytesIO
56 from google
.net
.proto2
.python
.internal
import containers
57 from google
.net
.proto2
.python
.internal
import decoder
58 from google
.net
.proto2
.python
.internal
import encoder
59 from google
.net
.proto2
.python
.internal
import enum_type_wrapper
60 from google
.net
.proto2
.python
.internal
import message_listener
as message_listener_mod
61 from google
.net
.proto2
.python
.internal
import type_checkers
62 from google
.net
.proto2
.python
.internal
import wire_format
63 from google
.net
.proto2
.python
.public
import descriptor
as descriptor_mod
64 from google
.net
.proto2
.python
.public
import message
as message_mod
65 from google
.net
.proto2
.python
.public
import text_format
67 _FieldDescriptor
= descriptor_mod
.FieldDescriptor
70 def NewMessage(bases
, descriptor
, dictionary
):
71 _AddClassAttributesForNestedExtensions(descriptor
, dictionary
)
72 _AddSlots(descriptor
, dictionary
)
76 def InitMessage(descriptor
, cls
):
77 cls
._decoders
_by
_tag
= {}
78 cls
._extensions
_by
_name
= {}
79 cls
._extensions
_by
_number
= {}
80 if (descriptor
.has_options
and
81 descriptor
.GetOptions().message_set_wire_format
):
82 cls
._decoders
_by
_tag
[decoder
.MESSAGE_SET_ITEM_TAG
] = (
83 decoder
.MessageSetItemDecoder(cls
._extensions
_by
_number
))
86 for field
in descriptor
.fields
:
87 _AttachFieldHelpers(cls
, field
)
89 _AddEnumValues(descriptor
, cls
)
90 _AddInitMethod(descriptor
, cls
)
91 _AddPropertiesForFields(descriptor
, cls
)
92 _AddPropertiesForExtensions(descriptor
, cls
)
93 _AddStaticMethods(cls
)
94 _AddMessageMethods(descriptor
, cls
)
95 _AddPrivateHelperMethods(descriptor
, cls
)
96 copyreg
.pickle(cls
, lambda obj
: (cls
, (), obj
.__getstate
__()))
107 def _PropertyName(proto_field_name
):
108 """Returns the name of the public property attribute which
109 clients can use to get and (in some cases) set the value
110 of a protocol message field.
113 proto_field_name: The protocol message field name, exactly
114 as it appears (or would appear) in a .proto file.
133 return proto_field_name
136 def _VerifyExtensionHandle(message
, extension_handle
):
137 """Verify that the given extension handle is valid."""
139 if not isinstance(extension_handle
, _FieldDescriptor
):
140 raise KeyError('HasExtension() expects an extension handle, got: %s' %
143 if not extension_handle
.is_extension
:
144 raise KeyError('"%s" is not an extension.' % extension_handle
.full_name
)
146 if not extension_handle
.containing_type
:
147 raise KeyError('"%s" is missing a containing_type.'
148 % extension_handle
.full_name
)
150 if extension_handle
.containing_type
is not message
.DESCRIPTOR
:
151 raise KeyError('Extension "%s" extends message type "%s", but this '
152 'message is of type "%s".' %
153 (extension_handle
.full_name
,
154 extension_handle
.containing_type
.full_name
,
155 message
.DESCRIPTOR
.full_name
))
158 def _AddSlots(message_descriptor
, dictionary
):
159 """Adds a __slots__ entry to dictionary, containing the names of all valid
160 attributes for this message type.
163 message_descriptor: A Descriptor instance describing this message type.
164 dictionary: Class dictionary to which we'll add a '__slots__' entry.
166 dictionary
['__slots__'] = ['_cached_byte_size',
167 '_cached_byte_size_dirty',
170 '_is_present_in_parent',
172 '_listener_for_children',
177 def _IsMessageSetExtension(field
):
178 return (field
.is_extension
and
179 field
.containing_type
.has_options
and
180 field
.containing_type
.GetOptions().message_set_wire_format
and
181 field
.type == _FieldDescriptor
.TYPE_MESSAGE
and
182 field
.message_type
== field
.extension_scope
and
183 field
.label
== _FieldDescriptor
.LABEL_OPTIONAL
)
186 def _AttachFieldHelpers(cls
, field_descriptor
):
187 is_repeated
= (field_descriptor
.label
== _FieldDescriptor
.LABEL_REPEATED
)
188 is_packed
= (field_descriptor
.has_options
and
189 field_descriptor
.GetOptions().packed
)
191 if _IsMessageSetExtension(field_descriptor
):
192 field_encoder
= encoder
.MessageSetItemEncoder(field_descriptor
.number
)
193 sizer
= encoder
.MessageSetItemSizer(field_descriptor
.number
)
195 field_encoder
= type_checkers
.TYPE_TO_ENCODER
[field_descriptor
.type](
196 field_descriptor
.number
, is_repeated
, is_packed
)
197 sizer
= type_checkers
.TYPE_TO_SIZER
[field_descriptor
.type](
198 field_descriptor
.number
, is_repeated
, is_packed
)
200 field_descriptor
._encoder
= field_encoder
201 field_descriptor
._sizer
= sizer
202 field_descriptor
._default
_constructor
= _DefaultValueConstructorForField(
205 def AddDecoder(wiretype
, is_packed
):
206 tag_bytes
= encoder
.TagBytes(field_descriptor
.number
, wiretype
)
207 cls
._decoders
_by
_tag
[tag_bytes
] = (
208 type_checkers
.TYPE_TO_DECODER
[field_descriptor
.type](
209 field_descriptor
.number
, is_repeated
, is_packed
,
210 field_descriptor
, field_descriptor
._default
_constructor
))
212 AddDecoder(type_checkers
.FIELD_TYPE_TO_WIRE_TYPE
[field_descriptor
.type],
215 if is_repeated
and wire_format
.IsTypePackable(field_descriptor
.type):
218 AddDecoder(wire_format
.WIRETYPE_LENGTH_DELIMITED
, True)
221 def _AddClassAttributesForNestedExtensions(descriptor
, dictionary
):
222 extension_dict
= descriptor
.extensions_by_name
223 for extension_name
, extension_field
in extension_dict
.iteritems():
224 assert extension_name
not in dictionary
225 dictionary
[extension_name
] = extension_field
228 def _AddEnumValues(descriptor
, cls
):
229 """Sets class-level attributes for all enum fields defined in this message.
231 Also exporting a class-level object that can name enum values.
234 descriptor: Descriptor object for this message type.
235 cls: Class we're constructing for this message type.
237 for enum_type
in descriptor
.enum_types
:
238 setattr(cls
, enum_type
.name
, enum_type_wrapper
.EnumTypeWrapper(enum_type
))
239 for enum_value
in enum_type
.values
:
240 setattr(cls
, enum_value
.name
, enum_value
.number
)
243 def _DefaultValueConstructorForField(field
):
244 """Returns a function which returns a default value for a field.
247 field: FieldDescriptor object for this field.
249 The returned function has one argument:
250 message: Message instance containing this field, or a weakref proxy
253 That function in turn returns a default value for this field. The default
254 value may refer back to |message| via a weak reference.
257 if field
.label
== _FieldDescriptor
.LABEL_REPEATED
:
258 if field
.has_default_value
and field
.default_value
!= []:
259 raise ValueError('Repeated field default value not empty list: %s' % (
260 field
.default_value
))
261 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
264 message_type
= field
.message_type
265 def MakeRepeatedMessageDefault(message
):
266 return containers
.RepeatedCompositeFieldContainer(
267 message
._listener
_for
_children
, field
.message_type
)
268 return MakeRepeatedMessageDefault
270 type_checker
= type_checkers
.GetTypeChecker(field
)
271 def MakeRepeatedScalarDefault(message
):
272 return containers
.RepeatedScalarFieldContainer(
273 message
._listener
_for
_children
, type_checker
)
274 return MakeRepeatedScalarDefault
276 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
278 message_type
= field
.message_type
279 def MakeSubMessageDefault(message
):
280 result
= message_type
._concrete
_class
()
281 result
._SetListener
(message
._listener
_for
_children
)
283 return MakeSubMessageDefault
285 def MakeScalarDefault(message
):
288 return field
.default_value
289 return MakeScalarDefault
292 def _ReraiseTypeErrorWithFieldName(message_name
, field_name
):
293 """Re-raise the currently-handled TypeError with the field name added."""
294 exc
= sys
.exc_info()[1]
295 if len(exc
.args
) == 1 and type(exc
) is TypeError:
297 exc
= TypeError('%s for field %s.%s' % (str(exc
), message_name
, field_name
))
300 raise type(exc
), exc
, sys
.exc_info()[2]
303 def _AddInitMethod(message_descriptor
, cls
):
304 """Adds an __init__ method to cls."""
305 fields
= message_descriptor
.fields
306 def init(self
, **kwargs
):
307 self
._cached
_byte
_size
= 0
308 self
._cached
_byte
_size
_dirty
= len(kwargs
) > 0
316 self
._unknown
_fields
= ()
317 self
._is
_present
_in
_parent
= False
318 self
._listener
= message_listener_mod
.NullMessageListener()
319 self
._listener
_for
_children
= _Listener(self
)
320 for field_name
, field_value
in kwargs
.iteritems():
321 field
= _GetFieldByName(message_descriptor
, field_name
)
323 raise TypeError("%s() got an unexpected keyword argument '%s'" %
324 (message_descriptor
.name
, field_name
))
325 if field
.label
== _FieldDescriptor
.LABEL_REPEATED
:
326 copy
= field
._default
_constructor
(self
)
327 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
328 for val
in field_value
:
329 copy
.add().MergeFrom(val
)
331 copy
.extend(field_value
)
332 self
._fields
[field
] = copy
333 elif field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
334 copy
= field
._default
_constructor
(self
)
336 copy
.MergeFrom(field_value
)
338 _ReraiseTypeErrorWithFieldName(message_descriptor
.name
, field_name
)
339 self
._fields
[field
] = copy
342 setattr(self
, field_name
, field_value
)
344 _ReraiseTypeErrorWithFieldName(message_descriptor
.name
, field_name
)
346 init
.__module
__ = None
351 def _GetFieldByName(message_descriptor
, field_name
):
352 """Returns a field descriptor by field name.
355 message_descriptor: A Descriptor describing all fields in message.
356 field_name: The name of the field to retrieve.
358 The field descriptor associated with the field name.
361 return message_descriptor
.fields_by_name
[field_name
]
363 raise ValueError('Protocol message has no "%s" field.' % field_name
)
366 def _AddPropertiesForFields(descriptor
, cls
):
367 """Adds properties for all fields in this protocol message type."""
368 for field
in descriptor
.fields
:
369 _AddPropertiesForField(field
, cls
)
371 if descriptor
.is_extendable
:
374 cls
.Extensions
= property(lambda self
: _ExtensionDict(self
))
377 def _AddPropertiesForField(field
, cls
):
378 """Adds a public property for a protocol message field.
379 Clients can use this property to get and (in the case
380 of non-repeated scalar fields) directly set the value
381 of a protocol message field.
384 field: A FieldDescriptor for this field.
385 cls: The class we're constructing.
389 assert _FieldDescriptor
.MAX_CPPTYPE
== 10
391 constant_name
= field
.name
.upper() + "_FIELD_NUMBER"
392 setattr(cls
, constant_name
, field
.number
)
394 if field
.label
== _FieldDescriptor
.LABEL_REPEATED
:
395 _AddPropertiesForRepeatedField(field
, cls
)
396 elif field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
397 _AddPropertiesForNonRepeatedCompositeField(field
, cls
)
399 _AddPropertiesForNonRepeatedScalarField(field
, cls
)
402 def _AddPropertiesForRepeatedField(field
, cls
):
403 """Adds a public property for a "repeated" protocol message field. Clients
404 can use this property to get the value of the field, which will be either a
405 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
408 Note that when clients add values to these containers, we perform
409 type-checking in the case of repeated scalar fields, and we also set any
410 necessary "has" bits as a side-effect.
413 field: A FieldDescriptor for this field.
414 cls: The class we're constructing.
416 proto_field_name
= field
.name
417 property_name
= _PropertyName(proto_field_name
)
420 field_value
= self
._fields
.get(field
)
421 if field_value
is None:
423 field_value
= field
._default
_constructor
(self
)
431 field_value
= self
._fields
.setdefault(field
, field_value
)
433 getter
.__module
__ = None
434 getter
.__doc
__ = 'Getter for %s.' % proto_field_name
438 def setter(self
, new_value
):
439 raise AttributeError('Assignment not allowed to repeated field '
440 '"%s" in protocol message object.' % proto_field_name
)
442 doc
= 'Magic attribute generated for "%s" proto field.' % proto_field_name
443 setattr(cls
, property_name
, property(getter
, setter
, doc
=doc
))
446 def _AddPropertiesForNonRepeatedScalarField(field
, cls
):
447 """Adds a public property for a nonrepeated, scalar protocol message field.
448 Clients can use this property to get and directly set the value of the field.
449 Note that when the client sets the value of a field by using this property,
450 all necessary "has" bits are set as a side-effect, and we also perform
454 field: A FieldDescriptor for this field.
455 cls: The class we're constructing.
457 proto_field_name
= field
.name
458 property_name
= _PropertyName(proto_field_name
)
459 type_checker
= type_checkers
.GetTypeChecker(field
)
460 default_value
= field
.default_value
466 return self
._fields
.get(field
, default_value
)
467 getter
.__module
__ = None
468 getter
.__doc
__ = 'Getter for %s.' % proto_field_name
469 def field_setter(self
, new_value
):
471 self
._fields
[field
] = type_checker
.CheckValue(new_value
)
474 if not self
._cached
_byte
_size
_dirty
:
477 if field
.containing_oneof
is not None:
478 def setter(self
, new_value
):
479 field_setter(self
, new_value
)
480 self
._UpdateOneofState
(field
)
482 setter
= field_setter
484 setter
.__module
__ = None
485 setter
.__doc
__ = 'Setter for %s.' % proto_field_name
488 doc
= 'Magic attribute generated for "%s" proto field.' % proto_field_name
489 setattr(cls
, property_name
, property(getter
, setter
, doc
=doc
))
492 def _AddPropertiesForNonRepeatedCompositeField(field
, cls
):
493 """Adds a public property for a nonrepeated, composite protocol message field.
494 A composite field is a "group" or "message" field.
496 Clients can use this property to get the value of the field, but cannot
497 assign to the property directly.
500 field: A FieldDescriptor for this field.
501 cls: The class we're constructing.
505 proto_field_name
= field
.name
506 property_name
= _PropertyName(proto_field_name
)
512 message_type
= field
.message_type
515 field_value
= self
._fields
.get(field
)
516 if field_value
is None:
518 field_value
= message_type
._concrete
_class
()
519 field_value
._SetListener
(
520 _OneofListener(self
, field
)
521 if field
.containing_oneof
is not None
522 else self
._listener
_for
_children
)
530 field_value
= self
._fields
.setdefault(field
, field_value
)
532 getter
.__module
__ = None
533 getter
.__doc
__ = 'Getter for %s.' % proto_field_name
537 def setter(self
, new_value
):
538 raise AttributeError('Assignment not allowed to composite field '
539 '"%s" in protocol message object.' % proto_field_name
)
542 doc
= 'Magic attribute generated for "%s" proto field.' % proto_field_name
543 setattr(cls
, property_name
, property(getter
, setter
, doc
=doc
))
546 def _AddPropertiesForExtensions(descriptor
, cls
):
547 """Adds properties for all fields in this protocol message type."""
548 extension_dict
= descriptor
.extensions_by_name
549 for extension_name
, extension_field
in extension_dict
.iteritems():
550 constant_name
= extension_name
.upper() + "_FIELD_NUMBER"
551 setattr(cls
, constant_name
, extension_field
.number
)
554 def _AddStaticMethods(cls
):
556 def RegisterExtension(extension_handle
):
557 extension_handle
.containing_type
= cls
.DESCRIPTOR
558 _AttachFieldHelpers(cls
, extension_handle
)
562 actual_handle
= cls
._extensions
_by
_number
.setdefault(
563 extension_handle
.number
, extension_handle
)
564 if actual_handle
is not extension_handle
:
565 raise AssertionError(
566 'Extensions "%s" and "%s" both try to extend message type "%s" with '
568 (extension_handle
.full_name
, actual_handle
.full_name
,
569 cls
.DESCRIPTOR
.full_name
, extension_handle
.number
))
571 cls
._extensions
_by
_name
[extension_handle
.full_name
] = extension_handle
573 handle
= extension_handle
574 if _IsMessageSetExtension(handle
):
576 cls
._extensions
_by
_name
[
577 extension_handle
.message_type
.full_name
] = extension_handle
579 cls
.RegisterExtension
= staticmethod(RegisterExtension
)
583 message
.MergeFromString(s
)
585 cls
.FromString
= staticmethod(FromString
)
588 def _IsPresent(item
):
589 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
590 value should be included in the list returned by ListFields()."""
592 if item
[0].label
== _FieldDescriptor
.LABEL_REPEATED
:
594 elif item
[0].cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
595 return item
[1]._is
_present
_in
_parent
600 def _AddListFieldsMethod(message_descriptor
, cls
):
601 """Helper for _AddMessageMethods()."""
603 def ListFields(self
):
604 all_fields
= [item
for item
in self
._fields
.iteritems() if _IsPresent(item
)]
605 all_fields
.sort(key
= lambda item
: item
[0].number
)
608 cls
.ListFields
= ListFields
611 def _AddHasFieldMethod(message_descriptor
, cls
):
612 """Helper for _AddMessageMethods()."""
615 for field
in message_descriptor
.fields
:
616 if field
.label
!= _FieldDescriptor
.LABEL_REPEATED
:
617 singular_fields
[field
.name
] = field
619 for field
in message_descriptor
.oneofs
:
620 singular_fields
[field
.name
] = field
622 def HasField(self
, field_name
):
624 field
= singular_fields
[field_name
]
627 'Protocol message has no singular "%s" field.' % field_name
)
629 if isinstance(field
, descriptor_mod
.OneofDescriptor
):
631 return HasField(self
, self
._oneofs
[field
].name
)
635 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
636 value
= self
._fields
.get(field
)
637 return value
is not None and value
._is
_present
_in
_parent
639 return field
in self
._fields
641 cls
.HasField
= HasField
644 def _AddClearFieldMethod(message_descriptor
, cls
):
645 """Helper for _AddMessageMethods()."""
646 def ClearField(self
, field_name
):
648 field
= message_descriptor
.fields_by_name
[field_name
]
651 field
= message_descriptor
.oneofs_by_name
[field_name
]
652 if field
in self
._oneofs
:
653 field
= self
._oneofs
[field
]
657 raise ValueError('Protocol message has no "%s" field.' % field_name
)
659 if field
in self
._fields
:
663 del self
._fields
[field
]
665 if self
._oneofs
.get(field
.containing_oneof
, None) is field
:
666 del self
._oneofs
[field
.containing_oneof
]
673 cls
.ClearField
= ClearField
676 def _AddClearExtensionMethod(cls
):
677 """Helper for _AddMessageMethods()."""
678 def ClearExtension(self
, extension_handle
):
679 _VerifyExtensionHandle(self
, extension_handle
)
682 if extension_handle
in self
._fields
:
683 del self
._fields
[extension_handle
]
685 cls
.ClearExtension
= ClearExtension
688 def _AddClearMethod(message_descriptor
, cls
):
689 """Helper for _AddMessageMethods()."""
693 self
._unknown
_fields
= ()
698 def _AddHasExtensionMethod(cls
):
699 """Helper for _AddMessageMethods()."""
700 def HasExtension(self
, extension_handle
):
701 _VerifyExtensionHandle(self
, extension_handle
)
702 if extension_handle
.label
== _FieldDescriptor
.LABEL_REPEATED
:
703 raise KeyError('"%s" is repeated.' % extension_handle
.full_name
)
705 if extension_handle
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
706 value
= self
._fields
.get(extension_handle
)
707 return value
is not None and value
._is
_present
_in
_parent
709 return extension_handle
in self
._fields
710 cls
.HasExtension
= HasExtension
713 def _AddEqualsMethod(message_descriptor
, cls
):
714 """Helper for _AddMessageMethods()."""
715 def __eq__(self
, other
):
716 if (not isinstance(other
, message_mod
.Message
) or
717 other
.DESCRIPTOR
!= self
.DESCRIPTOR
):
723 if not self
.ListFields() == other
.ListFields():
727 unknown_fields
= list(self
._unknown
_fields
)
728 unknown_fields
.sort()
729 other_unknown_fields
= list(other
._unknown
_fields
)
730 other_unknown_fields
.sort()
732 return unknown_fields
== other_unknown_fields
737 def _AddStrMethod(message_descriptor
, cls
):
738 """Helper for _AddMessageMethods()."""
740 return text_format
.MessageToString(self
)
741 cls
.__str
__ = __str__
744 def _AddUnicodeMethod(unused_message_descriptor
, cls
):
745 """Helper for _AddMessageMethods()."""
747 def __unicode__(self
):
748 return text_format
.MessageToString(self
, as_utf8
=True).decode('utf-8')
749 cls
.__unicode
__ = __unicode__
752 def _AddSetListenerMethod(cls
):
753 """Helper for _AddMessageMethods()."""
754 def SetListener(self
, listener
):
756 self
._listener
= message_listener_mod
.NullMessageListener()
758 self
._listener
= listener
759 cls
._SetListener
= SetListener
762 def _BytesForNonRepeatedElement(value
, field_number
, field_type
):
763 """Returns the number of bytes needed to serialize a non-repeated element.
764 The returned byte count includes space for tag information and any
765 other additional space associated with serializing value.
768 value: Value we're serializing.
769 field_number: Field number of this value. (Since the field number
770 is stored as part of a varint-encoded tag, this has an impact
771 on the total bytes required to serialize the value).
772 field_type: The type of the field. One of the TYPE_* constants
773 within FieldDescriptor.
776 fn
= type_checkers
.TYPE_TO_BYTE_SIZE_FN
[field_type
]
777 return fn(field_number
, value
)
779 raise message_mod
.EncodeError('Unrecognized field type: %d' % field_type
)
782 def _AddByteSizeMethod(message_descriptor
, cls
):
783 """Helper for _AddMessageMethods()."""
786 if not self
._cached
_byte
_size
_dirty
:
787 return self
._cached
_byte
_size
790 for field_descriptor
, field_value
in self
.ListFields():
791 size
+= field_descriptor
._sizer
(field_value
)
793 for tag_bytes
, value_bytes
in self
._unknown
_fields
:
794 size
+= len(tag_bytes
) + len(value_bytes
)
796 self
._cached
_byte
_size
= size
797 self
._cached
_byte
_size
_dirty
= False
798 self
._listener
_for
_children
.dirty
= False
801 cls
.ByteSize
= ByteSize
804 def _AddSerializeToStringMethod(message_descriptor
, cls
):
805 """Helper for _AddMessageMethods()."""
807 def SerializeToString(self
):
810 if not self
.IsInitialized():
811 raise message_mod
.EncodeError(
812 'Message %s is missing required fields: %s' % (
813 self
.DESCRIPTOR
.full_name
, ','.join(self
.FindInitializationErrors())))
814 return self
.SerializePartialToString()
815 cls
.SerializeToString
= SerializeToString
818 def _AddSerializePartialToStringMethod(message_descriptor
, cls
):
819 """Helper for _AddMessageMethods()."""
821 def SerializePartialToString(self
):
823 self
._InternalSerialize
(out
.write
)
824 return out
.getvalue()
825 cls
.SerializePartialToString
= SerializePartialToString
827 def InternalSerialize(self
, write_bytes
):
828 for field_descriptor
, field_value
in self
.ListFields():
829 field_descriptor
._encoder
(write_bytes
, field_value
)
830 for tag_bytes
, value_bytes
in self
._unknown
_fields
:
831 write_bytes(tag_bytes
)
832 write_bytes(value_bytes
)
833 cls
._InternalSerialize
= InternalSerialize
836 def _AddMergeFromStringMethod(message_descriptor
, cls
):
837 """Helper for _AddMessageMethods()."""
838 def MergeFromString(self
, serialized
):
839 length
= len(serialized
)
841 if self
._InternalParse
(serialized
, 0, length
) != length
:
844 raise message_mod
.DecodeError('Unexpected end-group tag.')
845 except (IndexError, TypeError):
847 raise message_mod
.DecodeError('Truncated message.')
848 except struct
.error
, e
:
849 raise message_mod
.DecodeError(e
)
851 cls
.MergeFromString
= MergeFromString
853 local_ReadTag
= decoder
.ReadTag
854 local_SkipField
= decoder
.SkipField
855 decoders_by_tag
= cls
._decoders
_by
_tag
857 def InternalParse(self
, buffer, pos
, end
):
859 field_dict
= self
._fields
860 unknown_field_list
= self
._unknown
_fields
862 (tag_bytes
, new_pos
) = local_ReadTag(buffer, pos
)
863 field_decoder
= decoders_by_tag
.get(tag_bytes
)
864 if field_decoder
is None:
865 value_start_pos
= new_pos
866 new_pos
= local_SkipField(buffer, new_pos
, end
, tag_bytes
)
869 if not unknown_field_list
:
870 unknown_field_list
= self
._unknown
_fields
= []
871 unknown_field_list
.append((tag_bytes
, buffer[value_start_pos
:new_pos
]))
874 pos
= field_decoder(buffer, new_pos
, end
, self
, field_dict
)
876 cls
._InternalParse
= InternalParse
879 def _AddIsInitializedMethod(message_descriptor
, cls
):
880 """Adds the IsInitialized and FindInitializationError methods to the
881 protocol message class."""
883 required_fields
= [field
for field
in message_descriptor
.fields
884 if field
.label
== _FieldDescriptor
.LABEL_REQUIRED
]
886 def IsInitialized(self
, errors
=None):
887 """Checks if all required fields of a message are set.
890 errors: A list which, if provided, will be populated with the field
891 paths of all missing required fields.
894 True iff the specified message has all required fields set.
899 for field
in required_fields
:
900 if (field
not in self
._fields
or
901 (field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
and
902 not self
._fields
[field
]._is
_present
_in
_parent
)):
903 if errors
is not None:
904 errors
.extend(self
.FindInitializationErrors())
907 for field
, value
in list(self
._fields
.items()):
908 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
909 if field
.label
== _FieldDescriptor
.LABEL_REPEATED
:
910 for element
in value
:
911 if not element
.IsInitialized():
912 if errors
is not None:
913 errors
.extend(self
.FindInitializationErrors())
915 elif value
._is
_present
_in
_parent
and not value
.IsInitialized():
916 if errors
is not None:
917 errors
.extend(self
.FindInitializationErrors())
922 cls
.IsInitialized
= IsInitialized
924 def FindInitializationErrors(self
):
925 """Finds required fields which are not initialized.
928 A list of strings. Each string is a path to an uninitialized field from
929 the top-level message, e.g. "foo.bar[5].baz".
934 for field
in required_fields
:
935 if not self
.HasField(field
.name
):
936 errors
.append(field
.name
)
938 for field
, value
in self
.ListFields():
939 if field
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
940 if field
.is_extension
:
941 name
= "(%s)" % field
.full_name
945 if field
.label
== _FieldDescriptor
.LABEL_REPEATED
:
946 for i
in xrange(len(value
)):
948 prefix
= "%s[%d]." % (name
, i
)
949 sub_errors
= element
.FindInitializationErrors()
950 errors
+= [ prefix
+ error
for error
in sub_errors
]
953 sub_errors
= value
.FindInitializationErrors()
954 errors
+= [ prefix
+ error
for error
in sub_errors
]
958 cls
.FindInitializationErrors
= FindInitializationErrors
961 def _AddMergeFromMethod(cls
):
962 LABEL_REPEATED
= _FieldDescriptor
.LABEL_REPEATED
963 CPPTYPE_MESSAGE
= _FieldDescriptor
.CPPTYPE_MESSAGE
965 def MergeFrom(self
, msg
):
966 if not isinstance(msg
, cls
):
968 "Parameter to MergeFrom() must be instance of same class: "
969 "expected %s got %s." % (cls
.__name
__, type(msg
).__name
__))
971 assert msg
is not self
974 fields
= self
._fields
976 for field
, value
in msg
._fields
.iteritems():
977 if field
.label
== LABEL_REPEATED
:
978 field_value
= fields
.get(field
)
979 if field_value
is None:
981 field_value
= field
._default
_constructor
(self
)
982 fields
[field
] = field_value
983 field_value
.MergeFrom(value
)
984 elif field
.cpp_type
== CPPTYPE_MESSAGE
:
985 if value
._is
_present
_in
_parent
:
986 field_value
= fields
.get(field
)
987 if field_value
is None:
989 field_value
= field
._default
_constructor
(self
)
990 fields
[field
] = field_value
991 field_value
.MergeFrom(value
)
993 self
._fields
[field
] = value
995 if msg
._unknown
_fields
:
996 if not self
._unknown
_fields
:
997 self
._unknown
_fields
= []
998 self
._unknown
_fields
.extend(msg
._unknown
_fields
)
1000 cls
.MergeFrom
= MergeFrom
1003 def _AddWhichOneofMethod(message_descriptor
, cls
):
1004 def WhichOneof(self
, oneof_name
):
1005 """Returns the name of the currently set field inside a oneof, or None."""
1007 field
= message_descriptor
.oneofs_by_name
[oneof_name
]
1010 'Protocol message has no oneof "%s" field.' % oneof_name
)
1012 nested_field
= self
._oneofs
.get(field
, None)
1013 if nested_field
is not None and self
.HasField(nested_field
.name
):
1014 return nested_field
.name
1018 cls
.WhichOneof
= WhichOneof
1021 def _AddMessageMethods(message_descriptor
, cls
):
1022 """Adds implementations of all Message methods to cls."""
1023 _AddListFieldsMethod(message_descriptor
, cls
)
1024 _AddHasFieldMethod(message_descriptor
, cls
)
1025 _AddClearFieldMethod(message_descriptor
, cls
)
1026 if message_descriptor
.is_extendable
:
1027 _AddClearExtensionMethod(cls
)
1028 _AddHasExtensionMethod(cls
)
1029 _AddClearMethod(message_descriptor
, cls
)
1030 _AddEqualsMethod(message_descriptor
, cls
)
1031 _AddStrMethod(message_descriptor
, cls
)
1032 _AddUnicodeMethod(message_descriptor
, cls
)
1033 _AddSetListenerMethod(cls
)
1034 _AddByteSizeMethod(message_descriptor
, cls
)
1035 _AddSerializeToStringMethod(message_descriptor
, cls
)
1036 _AddSerializePartialToStringMethod(message_descriptor
, cls
)
1037 _AddMergeFromStringMethod(message_descriptor
, cls
)
1038 _AddIsInitializedMethod(message_descriptor
, cls
)
1039 _AddMergeFromMethod(cls
)
1040 _AddWhichOneofMethod(message_descriptor
, cls
)
1042 def _AddPrivateHelperMethods(message_descriptor
, cls
):
1043 """Adds implementation of private helper methods to cls."""
1046 """Sets the _cached_byte_size_dirty bit to true,
1047 and propagates this to our listener iff this was a state change.
1054 if not self
._cached
_byte
_size
_dirty
:
1055 self
._cached
_byte
_size
_dirty
= True
1056 self
._listener
_for
_children
.dirty
= True
1057 self
._is
_present
_in
_parent
= True
1058 self
._listener
.Modified()
1060 def _UpdateOneofState(self
, field
):
1061 """Sets field as the active field in its containing oneof.
1063 Will also delete currently active field in the oneof, if it is different
1064 from the argument. Does not mark the message as modified.
1066 other_field
= self
._oneofs
.setdefault(field
.containing_oneof
, field
)
1067 if other_field
is not field
:
1068 del self
._fields
[other_field
]
1069 self
._oneofs
[field
.containing_oneof
] = field
1071 cls
._Modified
= Modified
1072 cls
.SetInParent
= Modified
1073 cls
._UpdateOneofState
= _UpdateOneofState
1076 class _Listener(object):
1078 """MessageListener implementation that a parent message registers with its
1081 In order to support semantics like:
1083 foo.bar.baz.qux = 23
1084 assert foo.HasField('bar')
1086 ...child objects must have back references to their parents.
1087 This helper class is at the heart of this support.
1090 def __init__(self
, parent_message
):
1092 parent_message: The message whose _Modified() method we should call when
1093 we receive Modified() messages.
1099 if isinstance(parent_message
, weakref
.ProxyType
):
1100 self
._parent
_message
_weakref
= parent_message
1102 self
._parent
_message
_weakref
= weakref
.proxy(parent_message
)
1114 self
._parent
_message
_weakref
._Modified
()
1115 except ReferenceError:
1122 class _OneofListener(_Listener
):
1123 """Special listener implementation for setting composite oneof fields."""
1125 def __init__(self
, parent_message
, field
):
1127 parent_message: The message whose _Modified() method we should call when
1128 we receive Modified() messages.
1129 field: The descriptor of the field being set in the parent message.
1131 super(_OneofListener
, self
).__init
__(parent_message
)
1135 """Also updates the state of the containing oneof in the parent message."""
1137 self
._parent
_message
_weakref
._UpdateOneofState
(self
._field
)
1138 super(_OneofListener
, self
).Modified()
1139 except ReferenceError:
1147 class _ExtensionDict(object):
1149 """Dict-like container for supporting an indexable "Extensions"
1150 field on proto instances.
1152 Note that in all cases we expect extension handles to be
1156 def __init__(self
, extended_message
):
1157 """extended_message: Message instance for which we are the Extensions dict.
1160 self
._extended
_message
= extended_message
1162 def __getitem__(self
, extension_handle
):
1163 """Returns the current value of the given extension handle."""
1165 _VerifyExtensionHandle(self
._extended
_message
, extension_handle
)
1167 result
= self
._extended
_message
._fields
.get(extension_handle
)
1168 if result
is not None:
1171 if extension_handle
.label
== _FieldDescriptor
.LABEL_REPEATED
:
1172 result
= extension_handle
._default
_constructor
(self
._extended
_message
)
1173 elif extension_handle
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
:
1174 result
= extension_handle
.message_type
._concrete
_class
()
1176 result
._SetListener
(self
._extended
_message
._listener
_for
_children
)
1177 except ReferenceError:
1182 return extension_handle
.default_value
1190 result
= self
._extended
_message
._fields
.setdefault(
1191 extension_handle
, result
)
1195 def __eq__(self
, other
):
1196 if not isinstance(other
, self
.__class
__):
1199 my_fields
= self
._extended
_message
.ListFields()
1200 other_fields
= other
._extended
_message
.ListFields()
1203 my_fields
= [ field
for field
in my_fields
if field
.is_extension
]
1204 other_fields
= [ field
for field
in other_fields
if field
.is_extension
]
1206 return my_fields
== other_fields
1208 def __ne__(self
, other
):
1209 return not self
== other
1212 raise TypeError('unhashable object')
1218 def __setitem__(self
, extension_handle
, value
):
1219 """If extension_handle specifies a non-repeated, scalar extension
1220 field, sets the value of that field.
1223 _VerifyExtensionHandle(self
._extended
_message
, extension_handle
)
1225 if (extension_handle
.label
== _FieldDescriptor
.LABEL_REPEATED
or
1226 extension_handle
.cpp_type
== _FieldDescriptor
.CPPTYPE_MESSAGE
):
1228 'Cannot assign to extension "%s" because it is a repeated or '
1229 'composite type.' % extension_handle
.full_name
)
1233 type_checker
= type_checkers
.GetTypeChecker(
1236 self
._extended
_message
._fields
[extension_handle
] = (
1237 type_checker
.CheckValue(value
))
1238 self
._extended
_message
._Modified
()
1240 def _FindExtensionByName(self
, name
):
1241 """Tries to find a known extension with the specified name.
1244 name: Extension full name.
1247 Extension field descriptor.
1249 return self
._extended
_message
._extensions
_by
_name
.get(name
, None)