"""Revision management for Reversion."""
try:
set
except NameError:
from sets import Set
try:
from threading import local
except ImportError:
from django.utils._threading_local import local# Python 2.3 fallback.
try:
from functools import wraps
except ImportError:
from django.utils.functional import wraps# Python 2.3, 2.4 fallback.
import sys
from django.contrib.contenttypes.models import ContentType
from django.core import serializers
from django.db import models
from django.db.models.query import QuerySet
from django.db.models.signals import post_save
from reversion.models import Revision,Version
from reversion.storage import VersionFileStorageWrapper
class RevisionManagementError(Exception):
"""
Exception that is thrown when something goes wrong with revision managment.
"""
class RegistrationError(Exception):
"""Exception thrown when registration with Reversion goes wrong."""
class RegistrationInfo(object):
"""Stored registration information about a model."""
__slots__ = "fields", "file_fields", "follow", "format",
def __init__(self, fields, file_fields, follow, format):
"""Initializes the registration info."""
self.fields = fields
self.file_fields = file_fields
self.follow = follow
self.format = format
class RevisionState(local):
"""Manages the state of the current revision."""
def __init__(self):
"""Initializes the revision state."""
self.clear()
def clear(self):
"""Puts the revision manager back into its default state."""
self.objects = set()
self.user = None
self.comment = ""
self.depth = 0
self.is_invalid = False
self.meta = []
DEFAULT_SERIALIZATION_FORMAT = "json"
class RevisionManager(object):
"""Manages the configuration and creation of revisions."""
__slots__ = "__weakref__", "_registry", "_state",
def __init__(self):
"""Initializes the revision manager."""
self._registry = {}
self._state = RevisionState()
# Registration methods.
def is_registered(self, model_class):
"""
Checks whether the given model has been registered with this revision
manager.
"""
return model_class in self._registry
def register(self, model_class, fields=None, follow=(), format=DEFAULT_SERIALIZATION_FORMAT):
"""Registers a model with this revision manager."""
# Prevent multiple registration.
if self.is_registered(model_class):
raise RegistrationError, "%r has already been registered with Reversion." % model_class
# Ensure the parent model of proxy models is registered.
if model_class._meta.proxy and not self.is_registered(model_class._meta.parents.keys()[0]):
raise RegistrationError, "%r is a proxy model, and its parent has not been registered with Reversion." % model_class
# Calculate serializable model fields.
opts = model_class._meta
local_fields = opts.local_fields + opts.local_many_to_many
if fields is None:
fields = [field.name for field in local_fields]
fields = tuple(fields)
# Calculate serializable model file fields.
file_fields = []
for field in local_fields:
if isinstance(field, models.FileField) and field.name in fields:
field.storage = VersionFileStorageWrapper(field.storage)
file_fields.append(field)
file_fields = tuple(file_fields)
# Register the generated registration information.
follow = tuple(follow)
registration_info = RegistrationInfo(fields, file_fields, follow, format)
self._registry[model_class] = registration_info
# Connect to the post save signal of the model.
post_save.connect(self.post_save_receiver, model_class)
def get_registration_info(self, model_class):
"""Returns the registration information for the given model class."""
try:
registration_info = self._registry[model_class]
except KeyError:
raise RegistrationError, "%r has not been registered with Reversion." % model_class
else:
return registration_info
def unregister(self, model_class):
"""Removes a model from version control."""
try:
registration_info = self._registry.pop(model_class)
except KeyError:
raise RegistrationError, "%r has not been registered with Reversion." % model_class
else:
for field in registration_info.file_fields:
field.storage = field.storage.wrapped_storage
post_save.disconnect(self.post_save_receiver, model_class)
# Low-level revision management methods.
def start(self):
"""
Begins a revision for this thread.
This MUST be balanced by a call to `end`. It is recommended that you
leave these methods alone and instead use the revision context manager
or the `create_on_success` decorator.
"""
self._state.depth += 1
def is_active(self):
"""Returns whether there is an active revision for this thread."""
return self._state.depth > 0
def assert_active(self):
"""Checks for an active revision, throwning an exception if none."""
if not self.is_active():
raise RevisionManagementError, "There is no active revision for this thread."
def add(self, obj):
"""Adds an object to the current revision."""
self.assert_active()
self._state.objects.add(obj)
def set_user(self, user):
"""Sets the user for the current revision"""
self.assert_active()
self._state.user = user
def get_user(self):
"""Gets the user for the current revision."""
self.assert_active()
return self._state.user
user = property(get_user,
set_user,
doc="The user for the current revision.")
def set_comment(self, comment):
"""Sets the comment for the current revision"""
self.assert_active()
self._state.comment = comment
def get_comment(self):
"""Gets the comment for the current revision."""
self.assert_active()
return self._state.comment
comment = property(get_comment,
set_comment,
doc="The comment for the current revision.")
def add_meta(self, cls, **kwargs):
"""Adds a class of meta information to the current revision."""
self.assert_active()
self._state.meta.append((cls, kwargs))
def invalidate(self):
"""Marks this revision as broken, so should not be commited."""
self.assert_active()
self._state.is_invalid = True
def is_invalid(self):
"""Checks whether this revision is invalid."""
return self._state.is_invalid
def follow_relationships(self, object_set):
"""
Follows all the registered relationships in the given set of models to
yield a set containing the original models plus all their related
models.
"""
result_set = set()
def _follow_relationships(obj):
# Prevent recursion.
if obj in result_set:
return
result_set.add(obj)
# Follow relations.
registration_info = self.get_registration_info(obj.__class__)
for relationship in registration_info.follow:
# Clear foreign key cache.
try:
related_field = obj._meta.get_field(relationship)
except models.FieldDoesNotExist:
pass
else:
if isinstance(related_field, models.ForeignKey):
if hasattr(obj, related_field.get_cache_name()):
delattr(obj, related_field.get_cache_name())
# Get the references obj(s).
related = getattr(obj, relationship, None)
if isinstance(related, models.Model):
_follow_relationships(related)
elif isinstance(related, (models.Manager, QuerySet)):
for related_obj in related.all():
_follow_relationships(related_obj)
elif related is not None:
raise TypeError, "Cannot follow the relationship %r. Expected a model or QuerySet, found %r." % (relationship, related)
# If a proxy model's parent is registered, add it.
if obj._meta.proxy:
parent_cls = obj._meta.parents.keys()[0]
if self.is_registered(parent_cls):
parent_obj = parent_cls.objects.get(pk=obj.pk)
_follow_relationships(parent_obj)
map(_follow_relationships, object_set)
return result_set
def end(self):
"""Ends a revision."""
self.assert_active()
self._state.depth -= 1
# Handle end of revision conditions here.
if self._state.depth == 0:
models = self._state.objects
try:
if models and not self.is_invalid():
# Save a new revision.
revision = Revision.objects.create(user=self._state.user,
comment=self._state.comment)
# Follow relationships.
revision_set = self.follow_relationships(self._state.objects)
# Because we might have uncomitted data in models, we need to
# replace the models in revision_set which might have come from the
# db, with the actual models sent to reversion.
diff = revision_set.difference(models)
revision_set = models.union(diff)
# Save version models.
for obj in revision_set:
# Proxy models should not actually be saved to the revision set.
if obj._meta.proxy:
continue
registration_info = self.get_registration_info(obj.__class__)
object_id = unicode(obj.pk)
content_type = ContentType.objects.get_for_model(obj)
serialized_data = serializers.serialize(registration_info.format, [obj], fields=registration_info.fields)
Version.objects.create(revision=revision,
object_id=object_id,
content_type=content_type,
format=registration_info.format,
serialized_data=serialized_data,
object_repr=unicode(obj))
for cls, kwargs in self._state.meta:
cls._default_manager.create(revision=revision, **kwargs)
finally:
self._state.clear()
# Signal receivers.
def post_save_receiver(self, instance, sender, **kwargs):
"""Adds registered models to the current revision, if any."""
if self.is_active():
self.add(instance)
# High-level revision management methods.
def __enter__(self):
"""Enters a block of revision management."""
self.start()
def __exit__(self, exc_type, exc_value, traceback):
"""Leaves a block of revision management."""
if exc_type is not None:
self.invalidate()
self.end()
return False
def create_on_success(self, func):
"""Creates a revision when the given function exist successfully."""
def _create_on_success(*args, **kwargs):
self.start()
try:
try:
result = func(*args, **kwargs)
except:
self.invalidate()
raise
finally:
self.end()
return result
return wraps(func)(_create_on_success)
# A thread-safe shared revision manager.
revision = RevisionManager()
|