[UserSettings] refactor, add tests
[mygpo.git] / mygpo / maintenance / merge.py
blobe3df02bed46f3a4f2f7ff7c5db83a27e96942ea6
1 import collections
3 from django.db import transaction, IntegrityError
4 from django.contrib.contenttypes.models import ContentType
5 from django.db.models import get_models, Model
6 from django.contrib.contenttypes.generic import GenericForeignKey
8 from mygpo.podcasts.models import (MergedUUID, ScopedModel, OrderedModel, Slug,
9 Tag, URL, MergedUUID, Podcast, Episode)
10 from mygpo import utils
11 from mygpo.history.models import HistoryEntry, EpisodeHistoryEntry
12 from mygpo.publisher.models import PublishedPodcast
13 from mygpo.subscriptions.models import Subscription
15 import logging
16 logger = logging.getLogger(__name__)
19 PG_UNIQUE_VIOLATION = 23505
22 class IncorrectMergeException(Exception):
23 pass
26 class PodcastMerger(object):
27 """ Merges podcasts and their related objects """
29 def __init__(self, podcasts, actions, groups):
30 """ Prepares to merge podcasts[1:] into podcasts[0] """
32 for n, podcast1 in enumerate(podcasts):
33 for m, podcast2 in enumerate(podcasts):
34 if podcast1 == podcast2 and n != m:
35 raise IncorrectMergeException(
36 "can't merge podcast %s into itself %s" %
37 (podcast1.get_id(), podcast2.get_id()))
39 self.podcasts = podcasts
40 self.actions = actions
41 self.groups = groups
43 def merge(self):
44 """ Carries out the actual merging """
46 logger.info('Start merging of podcasts: %r', self.podcasts)
48 podcast1 = self.podcasts.pop(0)
49 logger.info('Merge target: %r', podcast1)
51 self.merge_episodes()
52 merge_model_objects(podcast1, self.podcasts)
54 return podcast1
56 def merge_episodes(self):
57 """ Merges the episodes according to the groups """
59 for n, episodes in self.groups:
60 if not episodes:
61 continue
63 episode = episodes.pop(0)
64 logger.info('Merging %d episodes', len(episodes))
65 merge_model_objects(episode, episodes)
68 def reassign_urls(obj1, obj2):
69 # Reassign all URLs of obj2 to obj1
70 max_order = max([0] + [u.order for u in obj1.urls.all()])
72 for n, url in enumerate(obj2.urls.all(), max_order+1):
73 url.content_object = obj1
74 url.order = n
75 url.scope = obj1.scope
76 try:
77 url.save()
78 except IntegrityError as ie:
79 logger.warn('Moving URL failed: %s. Deleting.', str(ie))
80 url.delete()
83 def reassign_merged_uuids(obj1, obj2):
84 # Reassign all IDs of obj2 to obj1
85 MergedUUID.objects.create(uuid=obj2.id, content_object=obj1)
86 for m in obj2.merged_uuids.all():
87 m.content_object = obj1
88 m.save()
91 def reassign_slugs(obj1, obj2):
92 # Reassign all Slugs of obj2 to obj1
93 max_order = max([0] + [s.order for s in obj1.slugs.all()])
94 for n, slug in enumerate(obj2.slugs.all(), max_order+1):
95 slug.content_object = obj1
96 slug.order = n
97 slug.scope = obj1.scope
98 try:
99 slug.save()
100 except IntegrityError as ie:
101 logger.warn('Moving Slug failed: %s. Deleting', str(ie))
102 slug.delete()
105 # based on https://djangosnippets.org/snippets/2283/
106 @transaction.commit_on_success
107 def merge_model_objects(primary_object, alias_objects=[], keep_old=False):
109 Use this function to merge model objects (i.e. Users, Organizations, Polls,
110 etc.) and migrate all of the related fields from the alias objects to the
111 primary object.
113 Usage:
114 from django.contrib.auth.models import User
115 primary_user = User.objects.get(email='good_email@example.com')
116 duplicate_user = User.objects.get(email='good_email+duplicate@example.com')
117 merge_model_objects(primary_user, duplicate_user)
119 if not isinstance(alias_objects, list):
120 alias_objects = [alias_objects]
122 # check that all aliases are the same class as primary one and that
123 # they are subclass of model
124 primary_class = primary_object.__class__
126 if not issubclass(primary_class, Model):
127 raise TypeError('Only django.db.models.Model subclasses can be merged')
129 for alias_object in alias_objects:
130 if not isinstance(alias_object, primary_class):
131 raise TypeError('Only models of same class can be merged')
133 # Get a list of all GenericForeignKeys in all models
134 # TODO: this is a bit of a hack, since the generics framework should
135 # provide a similar method to the ForeignKey field for accessing the
136 # generic related fields.
137 generic_fields = []
138 for model in get_models():
139 fields = filter(lambda x: isinstance(x[1], GenericForeignKey),
140 model.__dict__.iteritems())
141 for field_name, field in fields:
142 generic_fields.append(field)
144 blank_local_fields = set(
145 [field.attname for field
146 in primary_object._meta.local_fields
147 if getattr(primary_object, field.attname) in [None, '']])
149 # Loop through all alias objects and migrate their data to
150 # the primary object.
151 for alias_object in alias_objects:
152 # Migrate all foreign key references from alias object to
153 # primary object.
154 for related_object in alias_object._meta.get_all_related_objects():
155 # The variable name on the alias_object model.
156 alias_varname = related_object.get_accessor_name()
157 # The variable name on the related model.
158 obj_varname = related_object.field.name
159 related_objects = getattr(alias_object, alias_varname)
160 for obj in related_objects.all():
161 setattr(obj, obj_varname, primary_object)
162 reassigned(obj, primary_object)
163 obj.save()
165 # Migrate all many to many references from alias object to
166 # primary object.
167 related = alias_object._meta.get_all_related_many_to_many_objects()
168 for related_many_object in related:
169 alias_varname = related_many_object.get_accessor_name()
170 obj_varname = related_many_object.field.name
172 if alias_varname is not None:
173 # standard case
174 related_many_objects = getattr(alias_object,
175 alias_varname).all()
176 else:
177 # special case, symmetrical relation, no reverse accessor
178 related_many_objects = getattr(alias_object,
179 obj_varname).all()
180 for obj in related_many_objects.all():
181 getattr(obj, obj_varname).remove(alias_object)
182 reassigned(obj, primary_object)
183 getattr(obj, obj_varname).add(primary_object)
185 # Migrate all generic foreign key references from alias
186 # object to primary object.
187 for field in generic_fields:
188 filter_kwargs = {}
189 filter_kwargs[field.fk_field] = alias_object._get_pk_val()
190 filter_kwargs[field.ct_field] = field.get_content_type(
191 alias_object)
192 related = field.model.objects.filter(**filter_kwargs)
193 for generic_related_object in related:
194 setattr(generic_related_object, field.name, primary_object)
195 reassigned(generic_related_object, primary_object)
196 try:
197 # execute save in a savepoint, so we can resume in the
198 # transaction
199 with transaction.atomic():
200 generic_related_object.save()
201 except IntegrityError as ie:
202 if ie.__cause__.pgcode == PG_UNIQUE_VIOLATION:
203 merge(generic_related_object, primary_object)
205 # Try to fill all missing values in primary object by
206 # values of duplicates
207 filled_up = set()
208 for field_name in blank_local_fields:
209 val = getattr(alias_object, field_name)
210 if val not in [None, '']:
211 setattr(primary_object, field_name, val)
212 filled_up.add(field_name)
213 blank_local_fields -= filled_up
215 if not keep_old:
216 before_delete(alias_object, primary_object)
217 alias_object.delete()
218 primary_object.save()
219 return primary_object
222 def reassigned(obj, new):
223 if isinstance(obj, URL):
224 # a URL has its parent's scope
225 obj.scope = new.scope
227 existing_urls = new.urls.all()
228 max_order = max([-1] + [u.order for u in existing_urls])
229 obj.order = max_order+1
231 elif isinstance(obj, Episode):
232 # obj is an Episode, new is a podcast
233 for url in obj.urls.all():
234 url.scope = new.as_scope
235 url.save()
237 elif isinstance(obj, Subscription):
238 pass
240 elif isinstance(obj, EpisodeHistoryEntry):
241 pass
243 elif isinstance(obj, HistoryEntry):
244 pass
246 else:
247 raise TypeError('unknown type for reassigning: {objtype}'.format(
248 objtype=type(obj)))
251 def before_delete(old, new):
253 if isinstance(old, Episode):
254 MergedUUID.objects.create(
255 content_type=ContentType.objects.get_for_model(new),
256 object_id=new.pk,
257 uuid=old.pk,
260 elif isinstance(old, Podcast):
261 MergedUUID.objects.create(
262 content_type=ContentType.objects.get_for_model(new),
263 object_id=new.pk,
264 uuid=old.pk,
267 else:
268 raise TypeError('unknown type for deleting: {objtype}'.format(
269 objtype=type(old)))
272 def merge(moved_obj, new_target):
273 if isinstance(moved_obj, URL):
274 # if we have two conflicting URLs, don't save the second one
275 # URLs don't have any interesting properties (except the URL) that
276 # we could merge
277 pass
279 else:
280 raise TypeError('unknown type for merging: {objtype}'.format(
281 objtype=type(old)))