oracle.py :  » Web-Frameworks » Spyce » spyce-2.1 » sqlalchemy » databases » Python Open Source

Home
Python Open Source
1.3.1.2 Python
2.Ajax
3.Aspect Oriented
4.Blog
5.Build
6.Business Application
7.Chart Report
8.Content Management Systems
9.Cryptographic
10.Database
11.Development
12.Editor
13.Email
14.ERP
15.Game 2D 3D
16.GIS
17.GUI
18.IDE
19.Installer
20.IRC
21.Issue Tracker
22.Language Interface
23.Log
24.Math
25.Media Sound Audio
26.Mobile
27.Network
28.Parser
29.PDF
30.Project Management
31.RSS
32.Search
33.Security
34.Template Engines
35.Test
36.UML
37.USB Serial
38.Web Frameworks
39.Web Server
40.Web Services
41.Web Unit
42.Wiki
43.Windows
44.XML
Python Open Source » Web Frameworks » Spyce 
Spyce » spyce 2.1 » sqlalchemy » databases » oracle.py
# oracle.py
# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php


import sys, StringIO, string, re

import sqlalchemy.util as util
import sqlalchemy.sql as sql
import sqlalchemy.engine as engine
import sqlalchemy.engine.default as default
import sqlalchemy.schema as schema
import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions

try:
    import cx_Oracle
except:
    cx_Oracle = None
        
class OracleNumeric(sqltypes.Numeric):
    def get_col_spec(self):
        return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
class OracleInteger(sqltypes.Integer):
    def get_col_spec(self):
        return "INTEGER"
class OracleSmallInteger(sqltypes.Smallinteger):
    def get_col_spec(self):
        return "SMALLINT"
class OracleDateTime(sqltypes.DateTime):
    def get_col_spec(self):
        return "DATE"
# Note:
# Oracle DATE == DATETIME
# Oracle does not allow milliseconds in DATE
# Oracle does not support TIME columns

# only if cx_oracle contains TIMESTAMP
class OracleTimestamp(sqltypes.DateTime):
    def get_col_spec(self):
        return "TIMESTAMP"
    def get_dbapi_type(self, dialect):
        return dialect.TIMESTAMP
        
class OracleText(sqltypes.TEXT):
    def get_col_spec(self):
        return "CLOB"
class OracleString(sqltypes.String):
    def get_col_spec(self):
        return "VARCHAR(%(length)s)" % {'length' : self.length}
class OracleChar(sqltypes.CHAR):
    def get_col_spec(self):
        return "CHAR(%(length)s)" % {'length' : self.length}
class OracleBinary(sqltypes.Binary):
    def get_col_spec(self):
        return "BLOB"
class OracleBoolean(sqltypes.Boolean):
    def get_col_spec(self):
        return "SMALLINT"
    def convert_result_value(self, value, dialect):
        if value is None:
            return None
        return value and True or False
    def convert_bind_param(self, value, dialect):
        if value is True:
            return 1
        elif value is False:
            return 0
        elif value is None:
            return None
        else:
            return value and True or False   

        
colspecs = {
    sqltypes.Integer : OracleInteger,
    sqltypes.Smallinteger : OracleSmallInteger,
    sqltypes.Numeric : OracleNumeric,
    sqltypes.Float : OracleNumeric,
    sqltypes.DateTime : OracleDateTime,
    sqltypes.Date : OracleDateTime,
    sqltypes.String : OracleString,
    sqltypes.Binary : OracleBinary,
    sqltypes.Boolean : OracleBoolean,
    sqltypes.TEXT : OracleText,
    sqltypes.TIMESTAMP : OracleTimestamp,
    sqltypes.CHAR: OracleChar,
}

ischema_names = {
    'VARCHAR2' : OracleString,
    'DATE' : OracleDateTime,
    'DATETIME' : OracleDateTime,
    'NUMBER' : OracleNumeric,
    'BLOB' : OracleBinary,
    'CLOB' : OracleText,
    'TIMESTAMP' : OracleTimestamp
}

constraintSQL = """SELECT
  ac.constraint_name,
  ac.constraint_type,
  LOWER(loc.column_name) AS local_column,
  LOWER(rem.table_name) AS remote_table,
  LOWER(rem.column_name) AS remote_column
FROM all_constraints ac,
  all_cons_columns loc,
  all_cons_columns rem
WHERE ac.table_name = :table_name
AND ac.constraint_type IN ('R','P')
AND ac.owner = :owner
AND ac.owner = loc.owner
AND ac.constraint_name = loc.constraint_name
AND ac.r_owner = rem.owner(+)
AND ac.r_constraint_name = rem.constraint_name(+)
-- order multiple primary keys correctly
ORDER BY ac.constraint_name, loc.position, rem.position"""


def descriptor():
    return {'name':'oracle',
    'description':'Oracle',
    'arguments':[
        ('dsn', 'Data Source Name', None),
        ('user', 'Username', None),
        ('password', 'Password', None)
    ]}

class OracleExecutionContext(default.DefaultExecutionContext):
    def pre_exec(self, engine, proxy, compiled, parameters):
        super(OracleExecutionContext, self).pre_exec(engine, proxy, compiled, parameters)
        if self.dialect.auto_setinputsizes:
                self.set_input_sizes(proxy(), parameters)
        
class OracleDialect(ansisql.ANSIDialect):
    def __init__(self, use_ansi=True, auto_setinputsizes=False, module=None, threaded=True, **kwargs):
        self.use_ansi = use_ansi
        self.threaded = threaded
        if module is None:
            self.module = cx_Oracle
        else:
            self.module = module
        self.supports_timestamp = hasattr(self.module, 'TIMESTAMP' )
        self.auto_setinputsizes = auto_setinputsizes
        ansisql.ANSIDialect.__init__(self, **kwargs)

    def dbapi(self):
        return self.module

    def create_connect_args(self, url):
        if url.database:
            # if we have a database, then we have a remote host
            port = url.port
            if port:
                port = int(port)
            else:
                port = 1521
            dsn = self.module.makedsn(url.host,port,url.database)
        else:
            # we have a local tnsname
            dsn = url.host
        opts = dict(
            user=url.username,
            password=url.password,
            dsn = dsn,
            threaded = self.threaded
            )
        opts.update(url.query)
        return ([], opts)
        
    def type_descriptor(self, typeobj):
        return sqltypes.adapt_type(typeobj, colspecs)

    def oid_column_name(self):
        return "rowid"

    def create_execution_context(self):
        return OracleExecutionContext(self)

    def compiler(self, statement, bindparams, **kwargs):
        return OracleCompiler(self, statement, bindparams, **kwargs)
    def schemagenerator(self, *args, **kwargs):
        return OracleSchemaGenerator(*args, **kwargs)
    def schemadropper(self, *args, **kwargs):
        return OracleSchemaDropper(*args, **kwargs)
    def defaultrunner(self, engine, proxy):
        return OracleDefaultRunner(engine, proxy)


    def has_table(self, connection, table_name):
        cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()})
        return bool( cursor.fetchone() is not None )

    def has_sequence(self, connection, sequence_name):
        cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name""", {'name':sequence_name.upper()})
        return bool( cursor.fetchone() is not None )
        
    def reflecttable(self, connection, table):
        preparer = self.identifier_preparer
        if not preparer.should_quote(table):
            name = table.name.upper()
        else:
            name = table.name
        c = connection.execute ("select distinct OWNER from ALL_TAB_COLUMNS where TABLE_NAME = :table_name", {'table_name':name})
        rows = c.fetchall()
        if not rows :
            raise exceptions.NoSuchTableError(table.name)
        else:
            if table.owner is not None:
                if table.owner.upper() in [r[0] for r in rows]:
                    owner = table.owner.upper()
                else:
                    raise exceptions.AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name))
            else:
                if len(rows)==1:
                    owner = rows[0][0]
                else:
                    raise exceptions.AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name)

        c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS where TABLE_NAME = :table_name and OWNER = :owner", {'table_name':name, 'owner':owner})
        
        while True:
            row = c.fetchone()
            if row is None:
                break
            found_table = True

            #print "ROW:" , row
            (name, coltype, length, precision, scale, nullable, default) = (row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6])

            # INTEGER if the scale is 0 and precision is null
            # NUMBER if the scale and precision are both null
            # NUMBER(9,2) if the precision is 9 and the scale is 2
            # NUMBER(3) if the precision is 3 and scale is 0
            #length is ignored except for CHAR and VARCHAR2
            if coltype=='NUMBER' :
                if precision is None and scale is None:
                    coltype = OracleNumeric
                elif precision is None and scale == 0  :
                    coltype = OracleInteger
                else :
                    coltype = OracleNumeric(precision, scale)
            elif coltype=='CHAR' or coltype=='VARCHAR2':
                coltype = ischema_names.get(coltype, OracleString)(length)
            else:
                coltype = re.sub(r'\(\d+\)', '', coltype)
                try:
                    coltype = ischema_names[coltype]
                except KeyError:
                    raise exceptions.AssertionError("Cant get coltype for type '%s'" % coltype)
               
            colargs = []
            if default is not None:
                colargs.append(schema.PassiveDefault(sql.text(default)))
          
            # if name comes back as all upper, assume its case folded 
            if (name.upper() == name): 
                name = name.lower()
            
            table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))

       
        c = connection.execute(constraintSQL, {'table_name' : table.name.upper(), 'owner' : owner})
        fks = {}
        while True:
            row = c.fetchone()
            if row is None:
                break
            #print "ROW:" , row                
            (cons_name, cons_type, local_column, remote_table, remote_column) = row
            if cons_type == 'P':
                table.primary_key.add(table.c[local_column])
            elif cons_type == 'R':
                try:
                    fk = fks[cons_name]
                except KeyError:
                   fk = ([], [])
                   fks[cons_name] = fk
                refspec = ".".join([remote_table, remote_column])
                schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection)
                if local_column not in fk[0]:
                    fk[0].append(local_column)
                if refspec not in fk[1]:
                    fk[1].append(refspec)

        for name, value in fks.iteritems():
            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))

    def do_executemany(self, c, statement, parameters, context=None):
        rowcount = 0
        for param in parameters:
            c.execute(statement, param)
            rowcount += c.rowcount
        if context is not None:
            context._rowcount = rowcount

class OracleCompiler(ansisql.ANSICompiler):
    """oracle compiler modifies the lexical structure of Select statements to work under 
    non-ANSI configured Oracle databases, if the use_ansi flag is False."""
    
    def default_from(self):
        """called when a SELECT statement has no froms, and no FROM clause is to be appended.  
        gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
        return " FROM DUAL"

    def apply_function_parens(self, func):
        return len(func.clauses) > 0

    def visit_join(self, join):
        if self.dialect.use_ansi:
            return ansisql.ANSICompiler.visit_join(self, join)
        
        self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right)
        self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
        self.strings[join] = self.froms[join]

        if join.isouter:
            # if outer join, push on the right side table as the current "outertable"
            self._outertable = join.right

            # now re-visit the onclause, which will be used as a where clause
            # (the first visit occured via the Join object itself right before it called visit_join())
            join.onclause.accept_visitor(self)

            self._outertable = None

        self.visit_compound(self.wheres[join])
       
    def visit_alias(self, alias):
  """oracle doesnt like 'FROM table AS alias'.  is the AS standard SQL??"""
        self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name
        self.strings[alias] = self.get_str(alias.original)
 
    def visit_column(self, column):
        ansisql.ANSICompiler.visit_column(self, column)
        if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable:
            self.strings[column] = self.strings[column] + "(+)"
       
    def visit_insert(self, insert):
        """inserts are required to have the primary keys be explicitly present.
         mapper will by default not put them in the insert statement to comply
         with autoincrement fields that require they not be present.  so, 
         put them all in for all primary key columns."""
        for c in insert.table.primary_key:
            if not self.parameters.has_key(c.key):
                self.parameters[c.key] = None
        return ansisql.ANSICompiler.visit_insert(self, insert)

    def _TODO_visit_compound_select(self, select):
        """need to determine how to get LIMIT/OFFSET into a UNION for oracle"""
        if getattr(select, '_oracle_visit', False):
            # cancel out the compiled order_by on the select
            if hasattr(select, "order_by_clause"):
                self.strings[select.order_by_clause] = ""
            ansisql.ANSICompiler.visit_compound_select(self, select)
            return
            
        if select.limit is not None or select.offset is not None:
            select._oracle_visit = True
            # to use ROW_NUMBER(), an ORDER BY is required. 
            orderby = self.strings[select.order_by_clause]
            if not orderby:
                orderby = select.oid_column
                orderby.accept_visitor(self)
                orderby = self.strings[orderby]
            class SelectVisitor(sql.ClauseVisitor):
                def visit_select(self, select):
                    select.append_column(sql.column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
            select.accept_visitor(SelectVisitor())
            limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
            if select.offset is not None:
                limitselect.append_whereclause("ora_rn>%d" % select.offset)
                if select.limit is not None:
                    limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
            else:
                limitselect.append_whereclause("ora_rn<=%d" % select.limit)
            limitselect.accept_visitor(self)
            self.strings[select] = self.strings[limitselect]
            self.froms[select] = self.froms[limitselect]
        else:
            ansisql.ANSICompiler.visit_compound_select(self, select)
        
    def visit_select(self, select):
        """looks for LIMIT and OFFSET in a select statement, and if so tries to wrap it in a 
        subquery with row_number() criterion."""
        # TODO: put a real copy-container on Select and copy, or somehow make this
        # not modify the Select statement
        if getattr(select, '_oracle_visit', False):
            # cancel out the compiled order_by on the select
            if hasattr(select, "order_by_clause"):
                self.strings[select.order_by_clause] = ""
            ansisql.ANSICompiler.visit_select(self, select)
            return

        if select.limit is not None or select.offset is not None:
            select._oracle_visit = True
            # to use ROW_NUMBER(), an ORDER BY is required. 
            orderby = self.strings[select.order_by_clause]
            if not orderby:
                orderby = select.oid_column
                orderby.accept_visitor(self)
                orderby = self.strings[orderby]
            select.append_column(sql.column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
            limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
            if select.offset is not None:
                limitselect.append_whereclause("ora_rn>%d" % select.offset)
                if select.limit is not None:
                    limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
            else:
                limitselect.append_whereclause("ora_rn<=%d" % select.limit)
            limitselect.accept_visitor(self)
            self.strings[select] = self.strings[limitselect]
            self.froms[select] = self.froms[limitselect]
        else:
            ansisql.ANSICompiler.visit_select(self, select)
            
    def limit_clause(self, select):
        return ""

    def for_update_clause(self, select):
        if select.for_update=="nowait":
            return " FOR UPDATE NOWAIT"
        else:
            return super(OracleCompiler, self).for_update_clause(select)

class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
    def get_column_specification(self, column, **kwargs):
        colspec = self.preparer.format_column(column)
        colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
        default = self.get_column_default_string(column)
        if default is not None:
            colspec += " DEFAULT " + default

        if not column.nullable:
            colspec += " NOT NULL"
        return colspec

    def visit_sequence(self, sequence):
        if not self.engine.dialect.has_sequence(self.connection, sequence.name):
            self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
            self.execute()

class OracleSchemaDropper(ansisql.ANSISchemaDropper):
    def visit_sequence(self, sequence):
        if self.engine.dialect.has_sequence(self.connection, sequence.name):
            self.append("DROP SEQUENCE %s" % sequence.name)
            self.execute()

class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
    def exec_default_sql(self, default):
        c = sql.select([default.arg], from_obj=["DUAL"], engine=self.engine).compile()
        return self.proxy(str(c), c.get_params()).fetchone()[0]
    
    def visit_sequence(self, seq):
        return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0]

dialect = OracleDialect
www.java2java.com | Contact Us
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.