# objectstore.py
# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy import util,exceptions,sql
from sqlalchemy.orm import unitofwork,query
from sqlalchemy.orm.mapper import object_mapper
from sqlalchemy.orm.mapper import class_mapper
import weakref
import sqlalchemy
class SessionTransaction(object):
"""represents a Session-level Transaction. This corresponds to one or
more sqlalchemy.engine.Transaction instances behind the scenes, with one
Transaction per Engine in use.
the SessionTransaction object is **not** threadsafe."""
def __init__(self, session, parent=None, autoflush=True):
self.session = session
self.connections = {}
self.parent = parent
self.autoflush = autoflush
def connection(self, mapper_or_class, entity_name=None):
if isinstance(mapper_or_class, type):
mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
if self.parent is not None:
return self.parent.connection(mapper_or_class)
engine = self.session.get_bind(mapper_or_class)
return self.get_or_add(engine)
def _begin(self):
return SessionTransaction(self.session, self)
def add(self, connectable):
if self.connections.has_key(connectable.engine):
raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
return self.get_or_add(connectable)
def get_or_add(self, connectable):
# we reference the 'engine' attribute on the given object, which in the case of
# Connection, ProxyEngine, Engine, whatever, should return the original
# "Engine" object that is handling the connection.
if self.connections.has_key(connectable.engine):
return self.connections[connectable.engine][0]
e = connectable.engine
c = connectable.contextual_connect()
if not self.connections.has_key(e):
self.connections[e] = (c, c.begin(), c is not connectable)
return self.connections[e][0]
def commit(self):
if self.parent is not None:
return
if self.autoflush:
self.session.flush()
for t in self.connections.values():
t[1].commit()
self.close()
def rollback(self):
if self.parent is not None:
self.parent.rollback()
return
for k, t in self.connections.iteritems():
t[1].rollback()
self.close()
def close(self):
if self.parent is not None:
return
for t in self.connections.values():
if t[2]:
t[0].close()
self.session.transaction = None
class Session(object):
"""encapsulates a set of objects being operated upon within an object-relational operation.
The Session object is **not** threadsafe. For thread-management of Sessions, see the
sqlalchemy.ext.sessioncontext module."""
def __init__(self, bind_to=None, hash_key=None, import_session=None, echo_uow=False):
if import_session is not None:
self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map)
else:
self.uow = unitofwork.UnitOfWork()
self.bind_to = bind_to
self.binds = {}
self.echo_uow = echo_uow
self.transaction = None
if hash_key is None:
self.hash_key = id(self)
else:
self.hash_key = hash_key
_sessions[self.hash_key] = self
def _get_echo_uow(self):
return self.uow.echo
def _set_echo_uow(self, value):
self.uow.echo = value
echo_uow = property(_get_echo_uow,_set_echo_uow)
def create_transaction(self, **kwargs):
"""returns a new SessionTransaction corresponding to an existing or new transaction.
if the transaction is new, the returned SessionTransaction will have commit control
over the underlying transaction, else will have rollback control only."""
if self.transaction is not None:
return self.transaction._begin()
else:
self.transaction = SessionTransaction(self, **kwargs)
return self.transaction
def connect(self, mapper=None, **kwargs):
"""returns a unique connection corresponding to the given mapper. this connection
will not be part of any pre-existing transactional context."""
return self.get_bind(mapper).connect(**kwargs)
def connection(self, mapper, **kwargs):
"""returns a Connection corresponding to the given mapper. used by the execute()
method which performs select operations for Mapper and Query.
if this Session is transactional,
the connection will be in the context of this session's transaction. otherwise, the connection
is returned by the contextual_connect method, which some Engines override to return a thread-local
connection, and will have close_with_result set to True.
the given **kwargs will be sent to the engine's contextual_connect() method, if no transaction is in progress."""
if self.transaction is not None:
return self.transaction.connection(mapper)
else:
return self.get_bind(mapper).contextual_connect(**kwargs)
def execute(self, mapper, clause, params, **kwargs):
"""using the given mapper to identify the appropriate Engine or Connection to be used for statement execution,
executes the given ClauseElement using the provided parameter dictionary. Returns a ResultProxy corresponding
to the execution's results. If this method allocates a new Connection for the operation, then the ResultProxy's close()
method will release the resources of the underlying Connection, otherwise its a no-op.
"""
return self.connection(mapper, close_with_result=True).execute(clause, params, **kwargs)
def scalar(self, mapper, clause, params, **kwargs):
"""works like execute() but returns a scalar result."""
return self.connection(mapper, close_with_result=True).scalar(clause, params, **kwargs)
def close(self):
"""closes this Session.
"""
self.clear()
if self.transaction is not None:
self.transaction.close()
def clear(self):
"""removes all object instances from this Session. this is equivalent to calling expunge() for all
objects in this Session."""
for instance in self:
self._unattach(instance)
echo = self.uow.echo
self.uow = unitofwork.UnitOfWork()
self.uow.echo = echo
def mapper(self, class_, entity_name=None):
"""given an Class, return the primary Mapper responsible for persisting it"""
return _class_mapper(class_, entity_name = entity_name)
def bind_mapper(self, mapper, bindto):
"""bind the given Mapper to the given Engine or Connection.
All subsequent operations involving this Mapper will use the given bindto."""
self.binds[mapper] = bindto
def bind_table(self, table, bindto):
"""bind the given Table to the given Engine or Connection.
All subsequent operations involving this Table will use the given bindto."""
self.binds[table] = bindto
def get_bind(self, mapper):
"""return the Engine or Connection which is used to execute statements on behalf of the given Mapper.
Calling connect() on the return result will always result in a Connection object. This method
disregards any SessionTransaction that may be in progress.
The order of searching is as follows:
if an Engine or Connection was bound to this Mapper specifically within this Session, returns that
Engine or Connection.
if an Engine or Connection was bound to this Mapper's underlying Table within this Session
(i.e. not to the Table directly), returns that Engine or Conneciton.
if an Engine or Connection was bound to this Session, returns that Engine or Connection.
finally, returns the Engine which was bound directly to the Table's MetaData object.
If no Engine is bound to the Table, an exception is raised.
"""
if mapper is None:
return self.bind_to
elif self.binds.has_key(mapper):
return self.binds[mapper]
elif self.binds.has_key(mapper.mapped_table):
return self.binds[mapper.mapped_table]
elif self.bind_to is not None:
return self.bind_to
else:
e = mapper.mapped_table.engine
if e is None:
raise exceptions.InvalidRequestError("Could not locate any Engine bound to mapper '%s'" % str(mapper))
return e
def query(self, mapper_or_class, entity_name=None, **kwargs):
"""return a new Query object corresponding to this Session and the mapper, or the classes' primary mapper."""
if isinstance(mapper_or_class, type):
return query.Query(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
else:
return query.Query(mapper_or_class, self, **kwargs)
def _sql(self):
class SQLProxy(object):
def __getattr__(self, key):
def call(*args, **kwargs):
kwargs[engine] = self.engine
return getattr(sql, key)(*args, **kwargs)
sql = property(_sql)
def flush(self, objects=None):
"""flush all the object modifications present in this session to the database.
'objects' is a list or tuple of objects specifically to be flushed; if None, all
new and modified objects are flushed."""
self.uow.flush(self, objects)
def get(self, class_, ident, **kwargs):
"""return an instance of the object based on the given identifier, or None if not found.
The ident argument is a scalar or tuple of primary key column values in the order of the
table def's primary key columns.
the entity_name keyword argument may also be specified which further qualifies the underlying
Mapper used to perform the query."""
entity_name = kwargs.get('entity_name', None)
return self.query(class_, entity_name=entity_name).get(ident)
def load(self, class_, ident, **kwargs):
"""return an instance of the object based on the given identifier.
If not found, raises an exception. The method will *remove all pending changes* to the object
already existing in the Session. The ident argument is a scalar or tuple of
primary key columns in the order of the table def's primary key columns.
the entity_name keyword argument may also be specified which further qualifies the underlying
Mapper used to perform the query."""
entity_name = kwargs.get('entity_name', None)
return self.query(class_, entity_name=entity_name).load(ident)
def refresh(self, obj):
"""reload the attributes for the given object from the database, clear any changes made."""
self._validate_persistent(obj)
if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None:
raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % repr(obj))
def expire(self, obj):
"""mark the given object as expired.
this will add an instrumentation to all mapped attributes on the instance such that when
an attribute is next accessed, the session will reload all attributes on the instance
from the database.
"""
self._validate_persistent(obj)
def exp():
if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None:
raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % repr(obj))
attribute_manager.trigger_history(obj, exp)
def is_expired(self, obj, unexpire=False):
"""return True if the given object has been marked as expired."""
ret = attribute_manager.has_trigger(obj)
if ret and unexpire:
attribute_manager.untrigger_history(obj)
return ret
def expunge(self, object):
"""remove the given object from this Session.
this will free all internal references to the object. cascading will be applied according to the
'expunge' cascade rule."""
for c in [object] + list(_object_mapper(object).cascade_iterator('expunge', object)):
self.uow._remove_deleted(c)
self._unattach(c)
def save(self, object, entity_name=None):
"""
Add a transient (unsaved) instance to this Session.
This operation cascades the "save_or_update" method to associated instances if the
relation is mapped with cascade="save-update".
The 'entity_name' keyword argument will further qualify the specific Mapper used to handle this
instance.
"""
self._save_impl(object, entity_name=entity_name)
_object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
def update(self, object, entity_name=None):
"""Bring the given detached (saved) instance into this Session.
If there is a persistent instance with the same identifier already associated
with this Session, an exception is thrown.
This operation cascades the "save_or_update" method to associated instances if the relation is mapped
with cascade="save-update"."""
self._update_impl(object, entity_name=entity_name)
_object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
def save_or_update(self, object, entity_name=None):
"""save or update the given object into this Session.
The presence of an '_instance_key' attribute on the instance determines whether to
save() or update() the instance."""
self._save_or_update_impl(object, entity_name=entity_name)
_object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
def _save_or_update_impl(self, object, entity_name=None):
key = getattr(object, '_instance_key', None)
if key is None:
self._save_impl(object, entity_name=entity_name)
else:
self._update_impl(object, entity_name=entity_name)
def delete(self, object, entity_name=None):
"""mark the given instance as deleted.
the delete operation occurs upon flush()."""
for c in [object] + list(_object_mapper(object).cascade_iterator('delete', object)):
self.uow.register_deleted(c)
def merge(self, object, entity_name=None):
"""merge the object into a newly loaded or existing instance from this Session.
note: this method is currently not completely implemented."""
instance = None
for obj in [object] + list(_object_mapper(object).cascade_iterator('merge', object)):
key = getattr(obj, '_instance_key', None)
if key is None:
mapper = _object_mapper(object)
ident = mapper.identity(object)
for k in ident:
if k is None:
raise exceptions.InvalidRequestError("Instance '%s' does not have a full set of identity values, and does not represent a saved entity in the database. Use the add() method to add unsaved instances to this Session." % repr(obj))
key = mapper.identity_key(ident)
u = self.uow
if u.identity_map.has_key(key):
# TODO: copy the state of the given object into this one. tricky !
inst = u.identity_map[key]
else:
inst = self.get(object.__class__, key[1])
if obj is object:
instance = inst
return instance
def _save_impl(self, object, **kwargs):
if hasattr(object, '_instance_key'):
if not self.identity_map.has_key(object._instance_key):
raise exceptions.InvalidRequestError("Instance '%s' is a detached instance or is already persistent in a different Session" % repr(object))
else:
m = _class_mapper(object.__class__, entity_name=kwargs.get('entity_name', None))
# this would be a nice exception to raise...however this is incompatible with a contextual
# session which puts all objects into the session upon construction.
#if m._is_orphan(object):
# raise exceptions.InvalidRequestError("Instance '%s' is an orphan, and must be attached to a parent object to be saved" % (repr(object)))
m._assign_entity_name(object)
self._register_pending(object)
def _update_impl(self, object, **kwargs):
if self._is_attached(object) and object not in self.deleted:
return
if not hasattr(object, '_instance_key'):
raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % repr(object))
self._register_persistent(object)
def _register_pending(self, obj):
self._attach(obj)
self.uow.register_new(obj)
def _register_persistent(self, obj):
self._attach(obj)
self.uow.register_clean(obj)
def _register_deleted(self, obj):
self._attach(obj)
self.uow.register_deleted(obj)
def _attach(self, obj):
"""Attach the given object to this Session."""
if getattr(obj, '_sa_session_id', None) != self.hash_key:
old = getattr(obj, '_sa_session_id', None)
if old is not None and _sessions.has_key(old):
raise exceptions.InvalidRequestError("Object '%s' is already attached to session '%s' (this is '%s')" % (repr(obj), old, id(self)))
# auto-removal from the old session is disabled. but if we decide to
# turn it back on, do it as below: gingerly since _sessions is a WeakValueDict
# and it might be affected by other threads
#try:
# sess = _sessions[old]
#except KeyError:
# sess = None
#if sess is not None:
# sess.expunge(old)
key = getattr(obj, '_instance_key', None)
if key is not None:
self.identity_map[key] = obj
obj._sa_session_id = self.hash_key
def _unattach(self, obj):
self._validate_attached(obj)
del obj._sa_session_id
def _validate_attached(self, obj):
"""validate that the given object is either pending or persistent within this Session."""
if not self._is_attached(obj):
raise exceptions.InvalidRequestError("Instance '%s' not attached to this Session" % repr(obj))
def _validate_persistent(self, obj):
"""validate that the given object is persistent within this Session."""
self.uow._validate_obj(obj)
def _is_attached(self, obj):
return getattr(obj, '_sa_session_id', None) == self.hash_key
def __contains__(self, obj):
return self._is_attached(obj) and (obj in self.uow.new or self.identity_map.has_key(obj._instance_key))
def __iter__(self):
return iter(list(self.uow.new) + self.uow.identity_map.values())
def _get(self, key):
return self.identity_map[key]
def has_key(self, key):
return self.identity_map.has_key(key)
dirty = property(lambda s:s.uow.locate_dirty(), doc="a Set of all objects marked as 'dirty' within this Session")
deleted = property(lambda s:s.uow.deleted, doc="a Set of all objects marked as 'deleted' within this Session")
new = property(lambda s:s.uow.new, doc="a Set of all objects marked as 'new' within this Session.")
identity_map = property(lambda s:s.uow.identity_map, doc="a WeakValueDictionary consisting of all objects within this Session keyed to their _instance_key value.")
def import_instance(self, *args, **kwargs):
"""deprecated; a synynom for merge()"""
return self.merge(*args, **kwargs)
# this is the AttributeManager instance used to provide attribute behavior on objects.
# to all the "global variable police" out there: its a stateless object.
attribute_manager = unitofwork.attribute_manager
# this dictionary maps the hash key of a Session to the Session itself, and
# acts as a Registry with which to locate Sessions. this is to enable
# object instances to be associated with Sessions without having to attach the
# actual Session object directly to the object instance.
_sessions = weakref.WeakValueDictionary()
def object_session(obj):
"""return the Session to which the given object is bound, or None if none."""
hashkey = getattr(obj, '_sa_session_id', None)
if hashkey is not None:
return _sessions.get(hashkey)
return None
unitofwork.object_session = object_session
from sqlalchemy.orm import mapper
mapper.attribute_manager = attribute_manager
|