# maxdb.py
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Support for the MaxDB database.
This dialect is *not* ported to SQLAlchemy 0.6.
This dialect is *not* tested on SQLAlchemy 0.6.
Overview
--------
The ``maxdb`` dialect is **experimental** and has only been tested on 7.6.03.007
and 7.6.00.037. Of these, **only 7.6.03.007 will work** with SQLAlchemy's ORM.
The earlier version has severe ``LEFT JOIN`` limitations and will return
incorrect results from even very simple ORM queries.
Only the native Python DB-API is currently supported. ODBC driver support
is a future enhancement.
Connecting
----------
The username is case-sensitive. If you usually connect to the
database with sqlcli and other tools in lower case, you likely need to
use upper case for DB-API.
Implementation Notes
--------------------
Also check the DatabaseNotes page on the wiki for detailed information.
With the 7.6.00.37 driver and Python 2.5, it seems that all DB-API
generated exceptions are broken and can cause Python to crash.
For 'somecol.in_([])' to work, the IN operator's generation must be changed
to cast 'NULL' to a numeric, i.e. NUM(NULL). The DB-API doesn't accept a
bind parameter there, so that particular generation must inline the NULL value,
which depends on [ticket:807].
The DB-API is very picky about where bind params may be used in queries.
Bind params for some functions (e.g. MOD) need type information supplied.
The dialect does not yet do this automatically.
Max will occasionally throw up 'bad sql, compile again' exceptions for
perfectly valid SQL. The dialect does not currently handle these, more
research is needed.
MaxDB 7.5 and Sap DB <= 7.4 reportedly do not support schemas. A very
slightly different version of this dialect would be required to support
those versions, and can easily be added if there is demand. Some other
required components such as an Max-aware 'old oracle style' join compiler
(thetas with (+) outer indicators) are already done and available for
integration- email the devel list if you're interested in working on
this.
"""
import datetime, itertools, re
from sqlalchemy import exc,schema,sql,util,processors
from sqlalchemy.sql import operators
from sqlalchemy.sql import compiler,visitors
from sqlalchemy.engine import base
from sqlalchemy import types
class _StringType(sqltypes.String):
_type = None
def __init__(self, length=None, encoding=None, **kw):
super(_StringType, self).__init__(length=length, **kw)
self.encoding = encoding
def bind_processor(self, dialect):
if self.encoding == 'unicode':
return None
else:
def process(value):
if isinstance(value, unicode):
return value.encode(dialect.encoding)
else:
return value
return process
def result_processor(self, dialect, coltype):
#XXX: this code is probably very slow and one should try (if at all
# possible) to determine the correct code path on a per-connection
# basis (ie, here in result_processor, instead of inside the processor
# function itself) and probably also use a few generic
# processors, or possibly per query (though there is no mechanism
# for that yet).
def process(value):
while True:
if value is None:
return None
elif isinstance(value, unicode):
return value
elif isinstance(value, str):
if self.convert_unicode or dialect.convert_unicode:
return value.decode(dialect.encoding)
else:
return value
elif hasattr(value, 'read'):
# some sort of LONG, snarf and retry
value = value.read(value.remainingLength())
continue
else:
# unexpected type, return as-is
return value
return process
class MaxString(_StringType):
_type = 'VARCHAR'
def __init__(self, *a, **kw):
super(MaxString, self).__init__(*a, **kw)
class MaxUnicode(_StringType):
_type = 'VARCHAR'
def __init__(self, length=None, **kw):
super(MaxUnicode, self).__init__(length=length, encoding='unicode')
class MaxChar(_StringType):
_type = 'CHAR'
class MaxText(_StringType):
_type = 'LONG'
def __init__(self, *a, **kw):
super(MaxText, self).__init__(*a, **kw)
def get_col_spec(self):
spec = 'LONG'
if self.encoding is not None:
spec = ' '.join((spec, self.encoding))
elif self.convert_unicode:
spec = ' '.join((spec, 'UNICODE'))
return spec
class MaxNumeric(sqltypes.Numeric):
"""The FIXED (also NUMERIC, DECIMAL) data type."""
def __init__(self, precision=None, scale=None, **kw):
kw.setdefault('asdecimal', True)
super(MaxNumeric, self).__init__(scale=scale, precision=precision,
**kw)
def bind_processor(self, dialect):
return None
class MaxTimestamp(sqltypes.DateTime):
def bind_processor(self, dialect):
def process(value):
if value is None:
return None
elif isinstance(value, basestring):
return value
elif dialect.datetimeformat == 'internal':
ms = getattr(value, 'microsecond', 0)
return value.strftime("%Y%m%d%H%M%S" + ("%06u" % ms))
elif dialect.datetimeformat == 'iso':
ms = getattr(value, 'microsecond', 0)
return value.strftime("%Y-%m-%d %H:%M:%S." + ("%06u" % ms))
else:
raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." % (
dialect.datetimeformat,))
return process
def result_processor(self, dialect, coltype):
if dialect.datetimeformat == 'internal':
def process(value):
if value is None:
return None
else:
return datetime.datetime(
*[int(v)
for v in (value[0:4], value[4:6], value[6:8],
value[8:10], value[10:12], value[12:14],
value[14:])])
elif dialect.datetimeformat == 'iso':
def process(value):
if value is None:
return None
else:
return datetime.datetime(
*[int(v)
for v in (value[0:4], value[5:7], value[8:10],
value[11:13], value[14:16], value[17:19],
value[20:])])
else:
raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." %
dialect.datetimeformat)
return process
class MaxDate(sqltypes.Date):
def bind_processor(self, dialect):
def process(value):
if value is None:
return None
elif isinstance(value, basestring):
return value
elif dialect.datetimeformat == 'internal':
return value.strftime("%Y%m%d")
elif dialect.datetimeformat == 'iso':
return value.strftime("%Y-%m-%d")
else:
raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." % (
dialect.datetimeformat,))
return process
def result_processor(self, dialect, coltype):
if dialect.datetimeformat == 'internal':
def process(value):
if value is None:
return None
else:
return datetime.date(int(value[0:4]), int(value[4:6]),
int(value[6:8]))
elif dialect.datetimeformat == 'iso':
def process(value):
if value is None:
return None
else:
return datetime.date(int(value[0:4]), int(value[5:7]),
int(value[8:10]))
else:
raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." %
dialect.datetimeformat)
return process
class MaxTime(sqltypes.Time):
def bind_processor(self, dialect):
def process(value):
if value is None:
return None
elif isinstance(value, basestring):
return value
elif dialect.datetimeformat == 'internal':
return value.strftime("%H%M%S")
elif dialect.datetimeformat == 'iso':
return value.strftime("%H-%M-%S")
else:
raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." % (
dialect.datetimeformat,))
return process
def result_processor(self, dialect, coltype):
if dialect.datetimeformat == 'internal':
def process(value):
if value is None:
return None
else:
return datetime.time(int(value[0:4]), int(value[4:6]),
int(value[6:8]))
elif dialect.datetimeformat == 'iso':
def process(value):
if value is None:
return None
else:
return datetime.time(int(value[0:4]), int(value[5:7]),
int(value[8:10]))
else:
raise exc.InvalidRequestError(
"datetimeformat '%s' is not supported." %
dialect.datetimeformat)
return process
class MaxBlob(sqltypes.LargeBinary):
def bind_processor(self, dialect):
return processors.to_str
def result_processor(self, dialect, coltype):
def process(value):
if value is None:
return None
else:
return value.read(value.remainingLength())
return process
class MaxDBTypeCompiler(compiler.GenericTypeCompiler):
def _string_spec(self, string_spec, type_):
if type_.length is None:
spec = 'LONG'
else:
spec = '%s(%s)' % (string_spec, type_.length)
if getattr(type_, 'encoding'):
spec = ' '.join([spec, getattr(type_, 'encoding').upper()])
return spec
def visit_text(self, type_):
spec = 'LONG'
if getattr(type_, 'encoding', None):
spec = ' '.join((spec, type_.encoding))
elif type_.convert_unicode:
spec = ' '.join((spec, 'UNICODE'))
return spec
def visit_char(self, type_):
return self._string_spec("CHAR", type_)
def visit_string(self, type_):
return self._string_spec("VARCHAR", type_)
def visit_large_binary(self, type_):
return "LONG BYTE"
def visit_numeric(self, type_):
if type_.scale and type_.precision:
return 'FIXED(%s, %s)' % (type_.precision, type_.scale)
elif type_.precision:
return 'FIXED(%s)' % type_.precision
else:
return 'INTEGER'
def visit_BOOLEAN(self, type_):
return "BOOLEAN"
colspecs = {
sqltypes.Numeric: MaxNumeric,
sqltypes.DateTime: MaxTimestamp,
sqltypes.Date: MaxDate,
sqltypes.Time: MaxTime,
sqltypes.String: MaxString,
sqltypes.Unicode:MaxUnicode,
sqltypes.LargeBinary: MaxBlob,
sqltypes.Text: MaxText,
sqltypes.CHAR: MaxChar,
sqltypes.TIMESTAMP: MaxTimestamp,
sqltypes.BLOB: MaxBlob,
sqltypes.Unicode: MaxUnicode,
}
ischema_names = {
'boolean': sqltypes.BOOLEAN,
'char': sqltypes.CHAR,
'character': sqltypes.CHAR,
'date': sqltypes.DATE,
'fixed': sqltypes.Numeric,
'float': sqltypes.FLOAT,
'int': sqltypes.INT,
'integer': sqltypes.INT,
'long binary': sqltypes.BLOB,
'long unicode': sqltypes.Text,
'long': sqltypes.Text,
'long': sqltypes.Text,
'smallint': sqltypes.SmallInteger,
'time': sqltypes.Time,
'timestamp': sqltypes.TIMESTAMP,
'varchar': sqltypes.VARCHAR,
}
# TODO: migrate this to sapdb.py
class MaxDBExecutionContext(default.DefaultExecutionContext):
def post_exec(self):
# DB-API bug: if there were any functions as values,
# then do another select and pull CURRVAL from the
# autoincrement column's implicit sequence... ugh
if self.compiled.isinsert and not self.executemany:
table = self.compiled.statement.table
index, serial_col = _autoserial_column(table)
if serial_col and (not self.compiled._safeserial or
not(self._last_inserted_ids) or
self._last_inserted_ids[index] in (None, 0)):
if table.schema:
sql = "SELECT %s.CURRVAL FROM DUAL" % (
self.compiled.preparer.format_table(table))
else:
sql = "SELECT CURRENT_SCHEMA.%s.CURRVAL FROM DUAL" % (
self.compiled.preparer.format_table(table))
rs = self.cursor.execute(sql)
id = rs.fetchone()[0]
if not self._last_inserted_ids:
# This shouldn't ever be > 1? Right?
self._last_inserted_ids = \
[None] * len(table.primary_key.columns)
self._last_inserted_ids[index] = id
super(MaxDBExecutionContext, self).post_exec()
def get_result_proxy(self):
if self.cursor.description is not None:
for column in self.cursor.description:
if column[1] in ('Long Binary', 'Long', 'Long Unicode'):
return MaxDBResultProxy(self)
return engine_base.ResultProxy(self)
@property
def rowcount(self):
if hasattr(self, '_rowcount'):
return self._rowcount
else:
return self.cursor.rowcount
def fire_sequence(self, seq):
if seq.optional:
return None
return self._execute_scalar("SELECT %s.NEXTVAL FROM DUAL" % (
self.dialect.identifier_preparer.format_sequence(seq)))
class MaxDBCachedColumnRow(engine_base.RowProxy):
"""A RowProxy that only runs result_processors once per column."""
def __init__(self, parent, row):
super(MaxDBCachedColumnRow, self).__init__(parent, row)
self.columns = {}
self._row = row
self._parent = parent
def _get_col(self, key):
if key not in self.columns:
self.columns[key] = self._parent._get_col(self._row, key)
return self.columns[key]
def __iter__(self):
for i in xrange(len(self._row)):
yield self._get_col(i)
def __repr__(self):
return repr(list(self))
def __eq__(self, other):
return ((other is self) or
(other == tuple([self._get_col(key)
for key in xrange(len(self._row))])))
def __getitem__(self, key):
if isinstance(key, slice):
indices = key.indices(len(self._row))
return tuple([self._get_col(i) for i in xrange(*indices)])
else:
return self._get_col(key)
def __getattr__(self, name):
try:
return self._get_col(name)
except KeyError:
raise AttributeError(name)
class MaxDBResultProxy(engine_base.ResultProxy):
_process_row = MaxDBCachedColumnRow
class MaxDBCompiler(compiler.SQLCompiler):
function_conversion = {
'CURRENT_DATE': 'DATE',
'CURRENT_TIME': 'TIME',
'CURRENT_TIMESTAMP': 'TIMESTAMP',
}
# These functions must be written without parens when called with no
# parameters. e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL'
bare_functions = set([
'CURRENT_SCHEMA', 'DATE', 'FALSE', 'SYSDBA', 'TIME', 'TIMESTAMP',
'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP',
'UTCDATE', 'UTCDIFF'])
def visit_mod(self, binary, **kw):
return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
def default_from(self):
return ' FROM DUAL'
def for_update_clause(self, select):
clause = select.for_update
if clause is True:
return " WITH LOCK EXCLUSIVE"
elif clause is None:
return ""
elif clause == "read":
return " WITH LOCK"
elif clause == "ignore":
return " WITH LOCK (IGNORE) EXCLUSIVE"
elif clause == "nowait":
return " WITH LOCK (NOWAIT) EXCLUSIVE"
elif isinstance(clause, basestring):
return " WITH LOCK %s" % clause.upper()
elif not clause:
return ""
else:
return " WITH LOCK EXCLUSIVE"
def function_argspec(self, fn, **kw):
if fn.name.upper() in self.bare_functions:
return ""
elif len(fn.clauses) > 0:
return compiler.SQLCompiler.function_argspec(self, fn, **kw)
else:
return ""
def visit_function(self, fn, **kw):
transform = self.function_conversion.get(fn.name.upper(), None)
if transform:
fn = fn._clone()
fn.name = transform
return super(MaxDBCompiler, self).visit_function(fn, **kw)
def visit_cast(self, cast, **kwargs):
# MaxDB only supports casts * to NUMERIC, * to VARCHAR or
# date/time to VARCHAR. Casts of LONGs will fail.
if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)):
return "NUM(%s)" % self.process(cast.clause)
elif isinstance(cast.type, sqltypes.String):
return "CHR(%s)" % self.process(cast.clause)
else:
return self.process(cast.clause)
def visit_sequence(self, sequence):
if sequence.optional:
return None
else:
return (self.dialect.identifier_preparer.format_sequence(sequence) +
".NEXTVAL")
class ColumnSnagger(visitors.ClauseVisitor):
def __init__(self):
self.count = 0
self.column = None
def visit_column(self, column):
self.column = column
self.count += 1
def _find_labeled_columns(self, columns, use_labels=False):
labels = {}
for column in columns:
if isinstance(column, basestring):
continue
snagger = self.ColumnSnagger()
snagger.traverse(column)
if snagger.count == 1:
if isinstance(column, sql_expr._Label):
labels[unicode(snagger.column)] = column.name
elif use_labels:
labels[unicode(snagger.column)] = column._label
return labels
def order_by_clause(self, select, **kw):
order_by = self.process(select._order_by_clause, **kw)
# ORDER BY clauses in DISTINCT queries must reference aliased
# inner columns by alias name, not true column name.
if order_by and getattr(select, '_distinct', False):
labels = self._find_labeled_columns(select.inner_columns,
select.use_labels)
if labels:
for needs_alias in labels.keys():
r = re.compile(r'(^| )(%s)(,| |$)' %
re.escape(needs_alias))
order_by = r.sub((r'\1%s\3' % labels[needs_alias]),
order_by)
# No ORDER BY in subqueries.
if order_by:
if self.is_subquery():
# It's safe to simply drop the ORDER BY if there is no
# LIMIT. Right? Other dialects seem to get away with
# dropping order.
if select._limit:
raise exc.InvalidRequestError(
"MaxDB does not support ORDER BY in subqueries")
else:
return ""
return " ORDER BY " + order_by
else:
return ""
def get_select_precolumns(self, select):
# Convert a subquery's LIMIT to TOP
sql = select._distinct and 'DISTINCT ' or ''
if self.is_subquery() and select._limit:
if select._offset:
raise exc.InvalidRequestError(
'MaxDB does not support LIMIT with an offset.')
sql += 'TOP %s ' % select._limit
return sql
def limit_clause(self, select):
# The docs say offsets are supported with LIMIT. But they're not.
# TODO: maybe emulate by adding a ROWNO/ROWNUM predicate?
if self.is_subquery():
# sub queries need TOP
return ''
elif select._offset:
raise exc.InvalidRequestError(
'MaxDB does not support LIMIT with an offset.')
else:
return ' \n LIMIT %s' % (select._limit,)
def visit_insert(self, insert):
self.isinsert = True
self._safeserial = True
colparams = self._get_colparams(insert)
for value in (insert.parameters or {}).itervalues():
if isinstance(value, sql_expr.Function):
self._safeserial = False
break
return ''.join(('INSERT INTO ',
self.preparer.format_table(insert.table),
' (',
', '.join([self.preparer.format_column(c[0])
for c in colparams]),
') VALUES (',
', '.join([c[1] for c in colparams]),
')'))
class MaxDBIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = set([
'abs', 'absolute', 'acos', 'adddate', 'addtime', 'all', 'alpha',
'alter', 'any', 'ascii', 'asin', 'atan', 'atan2', 'avg', 'binary',
'bit', 'boolean', 'byte', 'case', 'ceil', 'ceiling', 'char',
'character', 'check', 'chr', 'column', 'concat', 'constraint', 'cos',
'cosh', 'cot', 'count', 'cross', 'curdate', 'current', 'curtime',
'database', 'date', 'datediff', 'day', 'dayname', 'dayofmonth',
'dayofweek', 'dayofyear', 'dec', 'decimal', 'decode', 'default',
'degrees', 'delete', 'digits', 'distinct', 'double', 'except',
'exists', 'exp', 'expand', 'first', 'fixed', 'float', 'floor', 'for',
'from', 'full', 'get_objectname', 'get_schema', 'graphic', 'greatest',
'group', 'having', 'hex', 'hextoraw', 'hour', 'ifnull', 'ignore',
'index', 'initcap', 'inner', 'insert', 'int', 'integer', 'internal',
'intersect', 'into', 'join', 'key', 'last', 'lcase', 'least', 'left',
'length', 'lfill', 'list', 'ln', 'locate', 'log', 'log10', 'long',
'longfile', 'lower', 'lpad', 'ltrim', 'makedate', 'maketime',
'mapchar', 'max', 'mbcs', 'microsecond', 'min', 'minute', 'mod',
'month', 'monthname', 'natural', 'nchar', 'next', 'no', 'noround',
'not', 'now', 'null', 'num', 'numeric', 'object', 'of', 'on',
'order', 'packed', 'pi', 'power', 'prev', 'primary', 'radians',
'real', 'reject', 'relative', 'replace', 'rfill', 'right', 'round',
'rowid', 'rowno', 'rpad', 'rtrim', 'second', 'select', 'selupd',
'serial', 'set', 'show', 'sign', 'sin', 'sinh', 'smallint', 'some',
'soundex', 'space', 'sqrt', 'stamp', 'statistics', 'stddev',
'subdate', 'substr', 'substring', 'subtime', 'sum', 'sysdba',
'table', 'tan', 'tanh', 'time', 'timediff', 'timestamp', 'timezone',
'to', 'toidentifier', 'transaction', 'translate', 'trim', 'trunc',
'truncate', 'ucase', 'uid', 'unicode', 'union', 'update', 'upper',
'user', 'usergroup', 'using', 'utcdate', 'utcdiff', 'value', 'values',
'varchar', 'vargraphic', 'variance', 'week', 'weekofyear', 'when',
'where', 'with', 'year', 'zoned' ])
def _normalize_name(self, name):
if name is None:
return None
if name.isupper():
lc_name = name.lower()
if not self._requires_quotes(lc_name):
return lc_name
return name
def _denormalize_name(self, name):
if name is None:
return None
elif (name.islower() and
not self._requires_quotes(name)):
return name.upper()
else:
return name
def _maybe_quote_identifier(self, name):
if self._requires_quotes(name):
return self.quote_identifier(name)
else:
return name
class MaxDBDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kw):
colspec = [self.preparer.format_column(column),
self.dialect.type_compiler.process(column.type)]
if not column.nullable:
colspec.append('NOT NULL')
default = column.default
default_str = self.get_column_default_string(column)
# No DDL default for columns specified with non-optional sequence-
# this defaulting behavior is entirely client-side. (And as a
# consequence, non-reflectable.)
if (default and isinstance(default, schema.Sequence) and
not default.optional):
pass
# Regular default
elif default_str is not None:
colspec.append('DEFAULT %s' % default_str)
# Assign DEFAULT SERIAL heuristically
elif column.primary_key and column.autoincrement:
# For SERIAL on a non-primary key member, use
# DefaultClause(text('SERIAL'))
try:
first = [c for c in column.table.primary_key.columns
if (c.autoincrement and
(isinstance(c.type, sqltypes.Integer) or
(isinstance(c.type, MaxNumeric) and
c.type.precision)) and
not c.foreign_keys)].pop(0)
if column is first:
colspec.append('DEFAULT SERIAL')
except IndexError:
pass
return ' '.join(colspec)
def get_column_default_string(self, column):
if isinstance(column.server_default, schema.DefaultClause):
if isinstance(column.default.arg, basestring):
if isinstance(column.type, sqltypes.Integer):
return str(column.default.arg)
else:
return "'%s'" % column.default.arg
else:
return unicode(self._compile(column.default.arg, None))
else:
return None
def visit_create_sequence(self, create):
"""Creates a SEQUENCE.
TODO: move to module doc?
start
With an integer value, set the START WITH option.
increment
An integer value to increment by. Default is the database default.
maxdb_minvalue
maxdb_maxvalue
With an integer value, sets the corresponding sequence option.
maxdb_no_minvalue
maxdb_no_maxvalue
Defaults to False. If true, sets the corresponding sequence option.
maxdb_cycle
Defaults to False. If true, sets the CYCLE option.
maxdb_cache
With an integer value, sets the CACHE option.
maxdb_no_cache
Defaults to False. If true, sets NOCACHE.
"""
sequence = create.element
if (not sequence.optional and
(not self.checkfirst or
not self.dialect.has_sequence(self.connection, sequence.name))):
ddl = ['CREATE SEQUENCE',
self.preparer.format_sequence(sequence)]
sequence.increment = 1
if sequence.increment is not None:
ddl.extend(('INCREMENT BY', str(sequence.increment)))
if sequence.start is not None:
ddl.extend(('START WITH', str(sequence.start)))
opts = dict([(pair[0][6:].lower(), pair[1])
for pair in sequence.kwargs.items()
if pair[0].startswith('maxdb_')])
if 'maxvalue' in opts:
ddl.extend(('MAXVALUE', str(opts['maxvalue'])))
elif opts.get('no_maxvalue', False):
ddl.append('NOMAXVALUE')
if 'minvalue' in opts:
ddl.extend(('MINVALUE', str(opts['minvalue'])))
elif opts.get('no_minvalue', False):
ddl.append('NOMINVALUE')
if opts.get('cycle', False):
ddl.append('CYCLE')
if 'cache' in opts:
ddl.extend(('CACHE', str(opts['cache'])))
elif opts.get('no_cache', False):
ddl.append('NOCACHE')
return ' '.join(ddl)
class MaxDBDialect(default.DefaultDialect):
name = 'maxdb'
supports_alter = True
supports_unicode_statements = True
max_identifier_length = 32
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
preparer = MaxDBIdentifierPreparer
statement_compiler = MaxDBCompiler
ddl_compiler = MaxDBDDLCompiler
execution_ctx_cls = MaxDBExecutionContext
ported_sqla_06 = False
colspecs = colspecs
ischema_names = ischema_names
# MaxDB-specific
datetimeformat = 'internal'
def __init__(self, _raise_known_sql_errors=False, **kw):
super(MaxDBDialect, self).__init__(**kw)
self._raise_known = _raise_known_sql_errors
if self.dbapi is None:
self.dbapi_type_map = {}
else:
self.dbapi_type_map = {
'Long Binary': MaxBlob(),
'Long byte_t': MaxBlob(),
'Long Unicode': MaxText(),
'Timestamp': MaxTimestamp(),
'Date': MaxDate(),
'Time': MaxTime(),
datetime.datetime: MaxTimestamp(),
datetime.date: MaxDate(),
datetime.time: MaxTime(),
}
def do_execute(self, cursor, statement, parameters, context=None):
res = cursor.execute(statement, parameters)
if isinstance(res, int) and context is not None:
context._rowcount = res
def do_release_savepoint(self, connection, name):
# Does MaxDB truly support RELEASE SAVEPOINT <id>? All my attempts
# produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS
# BEGIN SQLSTATE: I7065"
# Note that ROLLBACK TO works fine. In theory, a RELEASE should
# just free up some transactional resources early, before the overall
# COMMIT/ROLLBACK so omitting it should be relatively ok.
pass
def _get_default_schema_name(self, connection):
return self.identifier_preparer._normalize_name(
connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar())
def has_table(self, connection, table_name, schema=None):
denormalize = self.identifier_preparer._denormalize_name
bind = [denormalize(table_name)]
if schema is None:
sql = ("SELECT tablename FROM TABLES "
"WHERE TABLES.TABLENAME=? AND"
" TABLES.SCHEMANAME=CURRENT_SCHEMA ")
else:
sql = ("SELECT tablename FROM TABLES "
"WHERE TABLES.TABLENAME = ? AND"
" TABLES.SCHEMANAME=? ")
bind.append(denormalize(schema))
rp = connection.execute(sql, bind)
return bool(rp.first())
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
if schema is None:
sql = (" SELECT TABLENAME FROM TABLES WHERE "
" SCHEMANAME=CURRENT_SCHEMA ")
rs = connection.execute(sql)
else:
sql = (" SELECT TABLENAME FROM TABLES WHERE "
" SCHEMANAME=? ")
matchname = self.identifier_preparer._denormalize_name(schema)
rs = connection.execute(sql, matchname)
normalize = self.identifier_preparer._normalize_name
return [normalize(row[0]) for row in rs]
def reflecttable(self, connection, table, include_columns):
denormalize = self.identifier_preparer._denormalize_name
normalize = self.identifier_preparer._normalize_name
st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, '
' NULLABLE, "DEFAULT", DEFAULTFUNCTION '
'FROM COLUMNS '
'WHERE TABLENAME=? AND SCHEMANAME=%s '
'ORDER BY POS')
fk = ('SELECT COLUMNNAME, FKEYNAME, '
' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, '
' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA '
' THEN 1 ELSE 0 END) AS in_schema '
'FROM FOREIGNKEYCOLUMNS '
'WHERE TABLENAME=? AND SCHEMANAME=%s '
'ORDER BY FKEYNAME ')
params = [denormalize(table.name)]
if not table.schema:
st = st % 'CURRENT_SCHEMA'
fk = fk % 'CURRENT_SCHEMA'
else:
st = st % '?'
fk = fk % '?'
params.append(denormalize(table.schema))
rows = connection.execute(st, params).fetchall()
if not rows:
raise exc.NoSuchTableError(table.fullname)
include_columns = set(include_columns or [])
for row in rows:
(name, mode, col_type, encoding, length, scale,
nullable, constant_def, func_def) = row
name = normalize(name)
if include_columns and name not in include_columns:
continue
type_args, type_kw = [], {}
if col_type == 'FIXED':
type_args = length, scale
# Convert FIXED(10) DEFAULT SERIAL to our Integer
if (scale == 0 and
func_def is not None and func_def.startswith('SERIAL')):
col_type = 'INTEGER'
type_args = length,
elif col_type in 'FLOAT':
type_args = length,
elif col_type in ('CHAR', 'VARCHAR'):
type_args = length,
type_kw['encoding'] = encoding
elif col_type == 'LONG':
type_kw['encoding'] = encoding
try:
type_cls = ischema_names[col_type.lower()]
type_instance = type_cls(*type_args, **type_kw)
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" %
(col_type, name))
type_instance = sqltypes.NullType
col_kw = {'autoincrement': False}
col_kw['nullable'] = (nullable == 'YES')
col_kw['primary_key'] = (mode == 'KEY')
if func_def is not None:
if func_def.startswith('SERIAL'):
if col_kw['primary_key']:
# No special default- let the standard autoincrement
# support handle SERIAL pk columns.
col_kw['autoincrement'] = True
else:
# strip current numbering
col_kw['server_default'] = schema.DefaultClause(
sql.text('SERIAL'))
col_kw['autoincrement'] = True
else:
col_kw['server_default'] = schema.DefaultClause(
sql.text(func_def))
elif constant_def is not None:
col_kw['server_default'] = schema.DefaultClause(sql.text(
"'%s'" % constant_def.replace("'", "''")))
table.append_column(schema.Column(name, type_instance, **col_kw))
fk_sets = itertools.groupby(connection.execute(fk, params),
lambda row: row.FKEYNAME)
for fkeyname, fkey in fk_sets:
fkey = list(fkey)
if include_columns:
key_cols = set([r.COLUMNNAME for r in fkey])
if key_cols != include_columns:
continue
columns, referants = [], []
quote = self.identifier_preparer._maybe_quote_identifier
for row in fkey:
columns.append(normalize(row.COLUMNNAME))
if table.schema or not row.in_schema:
referants.append('.'.join(
[quote(normalize(row[c]))
for c in ('REFSCHEMANAME', 'REFTABLENAME',
'REFCOLUMNNAME')]))
else:
referants.append('.'.join(
[quote(normalize(row[c]))
for c in ('REFTABLENAME', 'REFCOLUMNNAME')]))
constraint_kw = {'name': fkeyname.lower()}
if fkey[0].RULE is not None:
rule = fkey[0].RULE
if rule.startswith('DELETE '):
rule = rule[7:]
constraint_kw['ondelete'] = rule
table_kw = {}
if table.schema or not row.in_schema:
table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME)
ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME),
table_kw.get('schema'))
if ref_key not in table.metadata.tables:
schema.Table(normalize(fkey[0].REFTABLENAME),
table.metadata,
autoload=True, autoload_with=connection,
**table_kw)
constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True,
**constraint_kw)
table.append_constraint(constraint)
def has_sequence(self, connection, name):
# [ticket:726] makes this schema-aware.
denormalize = self.identifier_preparer._denormalize_name
sql = ("SELECT sequence_name FROM SEQUENCES "
"WHERE SEQUENCE_NAME=? ")
rp = connection.execute(sql, denormalize(name))
return bool(rp.first())
def _autoserial_column(table):
"""Finds the effective DEFAULT SERIAL column of a Table, if any."""
for index, col in enumerate(table.primary_key.columns):
if (isinstance(col.type, (sqltypes.Integer, sqltypes.Numeric)) and
col.autoincrement):
if isinstance(col.default, schema.Sequence):
if col.default.optional:
return index, col
elif (col.default is None or
(not isinstance(col.server_default, schema.DefaultClause))):
return index, col
return None, None
|