assertsql.py :  » Database » SQLAlchemy » SQLAlchemy-0.6.0 » lib » sqlalchemy » test » Python Open Source

Home
Python Open Source
1.3.1.2 Python
2.Ajax
3.Aspect Oriented
4.Blog
5.Build
6.Business Application
7.Chart Report
8.Content Management Systems
9.Cryptographic
10.Database
11.Development
12.Editor
13.Email
14.ERP
15.Game 2D 3D
16.GIS
17.GUI
18.IDE
19.Installer
20.IRC
21.Issue Tracker
22.Language Interface
23.Log
24.Math
25.Media Sound Audio
26.Mobile
27.Network
28.Parser
29.PDF
30.Project Management
31.RSS
32.Search
33.Security
34.Template Engines
35.Test
36.UML
37.USB Serial
38.Web Frameworks
39.Web Server
40.Web Services
41.Web Unit
42.Wiki
43.Windows
44.XML
Python Open Source » Database » SQLAlchemy 
SQLAlchemy » SQLAlchemy 0.6.0 » lib » sqlalchemy » test » assertsql.py

from sqlalchemy.interfaces import ConnectionProxy
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.base import Connection
from sqlalchemy import util
import re

class AssertRule(object):
    def process_execute(self, clauseelement, *multiparams, **params):
        pass

    def process_cursor_execute(self, statement, parameters, context, executemany):
        pass
        
    def is_consumed(self):
        """Return True if this rule has been consumed, False if not.
        
        Should raise an AssertionError if this rule's condition has definitely failed.
        
        """
        raise NotImplementedError()
    
    def rule_passed(self):
        """Return True if the last test of this rule passed, False if failed, None if no test was applied."""
        
        raise NotImplementedError()
        
    def consume_final(self):
        """Return True if this rule has been consumed.
        
        Should raise an AssertionError if this rule's condition has not been consumed or has failed.
        
        """
        
        if self._result is None:
            assert False, "Rule has not been consumed"
            
        return self.is_consumed()

class SQLMatchRule(AssertRule):
    def __init__(self):
        self._result = None
        self._errmsg = ""
    
    def rule_passed(self):
        return self._result
        
    def is_consumed(self):
        if self._result is None:
            return False
            
        assert self._result, self._errmsg
        
        return True
    
class ExactSQL(SQLMatchRule):
    def __init__(self, sql, params=None):
        SQLMatchRule.__init__(self)
        self.sql = sql
        self.params = params
    
    def process_cursor_execute(self, statement, parameters, context, executemany):
        if not context:
            return
            
        _received_statement = _process_engine_statement(context.unicode_statement, context)
        _received_parameters = context.compiled_parameters
        
        # TODO: remove this step once all unit tests
        # are migrated, as ExactSQL should really be *exact* SQL 
        sql = _process_assertion_statement(self.sql, context)
        
        equivalent = _received_statement == sql
        if self.params:
            if util.callable(self.params):
                params = self.params(context)
            else:
                params = self.params

            if not isinstance(params, list):
                params = [params]
            equivalent = equivalent and params == context.compiled_parameters
        else:
            params = {}
        
        
        self._result = equivalent
        if not self._result:
            self._errmsg = "Testing for exact statement %r exact params %r, " \
                "received %r with params %r" % (sql, params, _received_statement, _received_parameters)
    

class RegexSQL(SQLMatchRule):
    def __init__(self, regex, params=None):
        SQLMatchRule.__init__(self)
        self.regex = re.compile(regex)
        self.orig_regex = regex
        self.params = params

    def process_cursor_execute(self, statement, parameters, context, executemany):
        if not context:
            return

        _received_statement = _process_engine_statement(context.unicode_statement, context)
        _received_parameters = context.compiled_parameters

        equivalent = bool(self.regex.match(_received_statement))
        if self.params:
            if util.callable(self.params):
                params = self.params(context)
            else:
                params = self.params

            if not isinstance(params, list):
                params = [params]
            
            # do a positive compare only
            for param, received in zip(params, _received_parameters):
                for k, v in param.iteritems():
                    if k not in received or received[k] != v:
                        equivalent = False
                        break
        else:
            params = {}

        self._result = equivalent
        if not self._result:
            self._errmsg = "Testing for regex %r partial params %r, "\
                "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters)

class CompiledSQL(SQLMatchRule):
    def __init__(self, statement, params):
        SQLMatchRule.__init__(self)
        self.statement = statement
        self.params = params

    def process_cursor_execute(self, statement, parameters, context, executemany):
        if not context:
            return

        _received_parameters = context.compiled_parameters
        
        # recompile from the context, using the default dialect
        compiled = context.compiled.statement.\
                compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys)
                
        _received_statement = re.sub(r'\n', '', str(compiled))
        
        equivalent = self.statement == _received_statement
        if self.params:
            if util.callable(self.params):
                params = self.params(context)
            else:
                params = self.params

            if not isinstance(params, list):
                params = [params]
            
            all_params = list(params)
            all_received = list(_received_parameters)
            while params:
                param = dict(params.pop(0))
                for k, v in context.compiled.params.iteritems():
                    param.setdefault(k, v)
                    
                if param not in _received_parameters:
                    equivalent = False
                    break
                else:
                    _received_parameters.remove(param)
            if _received_parameters:
                equivalent = False
        else:
            params = {}

        self._result = equivalent
        if not self._result:
            self._errmsg = "Testing for compiled statement %r partial params %r, " \
                    "received %r with params %r" % \
                    (self.statement, all_params, _received_statement, all_received)
    
        
class CountStatements(AssertRule):
    def __init__(self, count):
        self.count = count
        self._statement_count = 0
        
    def process_execute(self, clauseelement, *multiparams, **params):
        self._statement_count += 1

    def process_cursor_execute(self, statement, parameters, context, executemany):
        pass

    def is_consumed(self):
        return False
    
    def consume_final(self):
        assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count)
        return True
        
class AllOf(AssertRule):
    def __init__(self, *rules):
        self.rules = set(rules)
        
    def process_execute(self, clauseelement, *multiparams, **params):
        for rule in self.rules:
            rule.process_execute(clauseelement, *multiparams, **params)

    def process_cursor_execute(self, statement, parameters, context, executemany):
        for rule in self.rules:
            rule.process_cursor_execute(statement, parameters, context, executemany)

    def is_consumed(self):
        if not self.rules:
            return True
        
        for rule in list(self.rules):
            if rule.rule_passed(): # a rule passed, move on
                self.rules.remove(rule)
                return len(self.rules) == 0

        assert False, "No assertion rules were satisfied for statement"
    
    def consume_final(self):
        return len(self.rules) == 0
        
def _process_engine_statement(query, context):
    if util.jython:
        # oracle+zxjdbc passes a PyStatement when returning into
        query = unicode(query)
    if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'):
        query = query[:-25]
    
    query = re.sub(r'\n', '', query)
    
    return query
    
def _process_assertion_statement(query, context):
    paramstyle = context.dialect.paramstyle
    if paramstyle == 'named':
        pass
    elif paramstyle =='pyformat':
        query = re.sub(r':([\w_]+)', r"%(\1)s", query)
    else:
        # positional params
        repl = None
        if paramstyle=='qmark':
            repl = "?"
        elif paramstyle=='format':
            repl = r"%s"
        elif paramstyle=='numeric':
            repl = None
        query = re.sub(r':([\w_]+)', repl, query)

    return query

class SQLAssert(ConnectionProxy):
    rules = None
    
    def add_rules(self, rules):
        self.rules = list(rules)
    
    def statement_complete(self):
        for rule in self.rules:
            if not rule.consume_final():
                assert False, "All statements are complete, but pending assertion rules remain"

    def clear_rules(self):
        del self.rules
        
    def execute(self, conn, execute, clauseelement, *multiparams, **params):
        result = execute(clauseelement, *multiparams, **params)

        if self.rules is not None:
            if not self.rules:
                assert False, "All rules have been exhausted, but further statements remain"
            rule = self.rules[0]
            rule.process_execute(clauseelement, *multiparams, **params)
            if rule.is_consumed():
                self.rules.pop(0)
            
        return result
        
    def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
        result = execute(cursor, statement, parameters, context)
        
        if self.rules:
            rule = self.rules[0]
            rule.process_cursor_execute(statement, parameters, context, executemany)

        return result

asserter = SQLAssert()
    
www.java2java.com | Contact Us
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.