# $SnapHashLicense:
#
# SnapLogic - Open source data services
#
# Copyright (C) 2008-2009, SnapLogic, Inc. All rights reserved.
#
# See http://www.snaplogic.org for more information about
# the SnapLogic project.
#
# This program is free software, distributed under the terms of
# the GNU General Public License Version 2. See the LEGAL file
# at the top of the source tree.
#
# "SnapLogic" is a trademark of SnapLogic, Inc.
#
#
# $
# $Id: SQLServer.py 9280 2009-10-15 16:56:01Z grisha $
"""
Based on pymssql v0.8.0 (http://pymssql.sourceforge.net/).
"""
import pymssql
import _mssql
import os
import types
import codecs
from decimal import Decimal
from datetime import timedelta,datetime
from distutils.version import LooseVersion
from snaplogic.common.snap_exceptions import SnapComponentError
from snaplogic.common.snap_exceptions import SnapException
from snaplogic.components.DBUtils import SnapDBAdapter
from snaplogic.common.data_types import SnapString,SnapDateTime,SnapNumber
TYPE_CODE_TO_NATIVE_TYPE = {
_mssql.NUMBER : 'decimal',
_mssql.DECIMAL : 'decimal',
_mssql.STRING : 'varchar',
_mssql.BINARY : 'binary',
_mssql.DATETIME : 'datetime'
}
NATIVE_TYPE_TO_SNAP_TYPE = {
'smalldatetime' : SnapDateTime,
'datetime' : SnapDateTime,
'varchar' : SnapString,
'char' : SnapString,
'nvarchar' : SnapString,
'nchar' : SnapString,
'text' : SnapString,
'bigint' : SnapNumber,
'int' : SnapNumber,
'smallint' : SnapNumber,
'tinyint' : SnapNumber,
'float' : SnapNumber,
'decimal' : SnapNumber,
}
class SQLServerCursorWrapper(object):
"""
A wrapper around DB API 2.0 cursor, to handle encoding and type conversion.
Will be returned by L{pymssql.cursor}
"""
def __init__(self, cursor, snap_conn):
"""
Initialize.
@param cursor: DB API 2.0 cursor object, to which most requests will
be delegated.
@type: cursor
@param conn: The L{SnapDBAdapter} object that generated this cursor
@type conn: L{SnapDBAdapter}
"""
self._snap_conn = snap_conn
self._delegate = cursor
if os.name != 'nt':
self._delegate.numbersAsStrings = 1
self._metadata = None
self._string_fields = None
self._num_fields = None
def convert_row(self, row):
"""
Convert a row of data in native data types into a row of Snap types.
@param row: row returned by database
@type row: tuple
@return: row converted to Snap data types
@rtype: list
"""
if self._metadata is not None and self._string_fields is None and self._num_fields is None and self._datetime_fields is None:
return row
if not row:
return row
if self._metadata is None:
self._metadata = self._delegate.description
self._string_fields = None
self._datetime_fields = None
self._num_fields = None
i = 0
for col_metadata in self._metadata:
type_code = col_metadata[1]
native_type = self._snap_conn.type_code_to_native_type(type_code)
snap_type = self._snap_conn.native_type_to_snap_type(native_type)
if snap_type == SnapNumber:
if self._num_fields is None:
self._num_fields = []
self._num_fields.append(i)
elif snap_type == SnapString:
if self._string_fields is None:
self._string_fields = []
self._string_fields.append(i)
elif snap_type == SnapDateTime:
if self._datetime_fields is None:
self._datetime_fields = []
self._datetime_fields.append(i)
i += 1
new_row = list(row)
if self._string_fields is not None:
for i in self._string_fields:
if new_row[i] is not None:
new_val = new_row[i].decode(self._snap_conn._charset)
new_row[i] = new_val
if self._num_fields is not None:
for i in self._num_fields:
val = new_row[i]
if val is not None:
val = str(val)
val = Decimal(val)
new_row[i] = val
if self._datetime_fields is not None:
for i in self._datetime_fields:
val = new_row[i]
if val is not None:
val -= self._snap_conn._utc_offset
new_row[i] = val
return new_row
def convert_results(self, rs):
"""
Convert the result set from native data types to Snap data types.
This is similar to L{convert_row}, except it acts on the entire result
set
@param rs: Result set to convert
@type rs: list or tuple
@return: converted result set
@type: list
"""
if self._metadata is not None and self._string_fields is None and self._num_fields is None and self._datetime_fields is None:
return rs
if not rs:
return rs
converted_rs = []
for row in rs:
new_row = self.convert_row(row)
converted_rs.append(new_row)
return converted_rs
def execute(self, operation, params = ()):
"""
Delegates to PyMSSQL's execute method
"""
self._metadata = None
if type(operation) == unicode:
try:
operation = operation.encode('utf-8')
except UnicodeEncodeError, e:
snap_exc = SnapException("Cannot encode SQL command (%s) into the UTF-8 encoding" % operation)
snap_exc.append(e)
raise snap_exc
return self._delegate.execute(operation, params)
def fetchone(self):
"""
Same as cursor.fetchone() specified in DB API 2.0, except returning
Snap data types.
"""
row = self._delegate.fetchone()
if row is not None:
row = self.convert_row(row)
return row
def fetchmany(self, size=None):
"""
Same as cursor.fetchmany() specified in DB API 2.0, except returning
Snap data types.
"""
if size is None:
rs = self._delegate.fetchmany()
else:
rs = self._delegate.fetchmany(size)
rs = self.convert_results(rs)
return rs
def fetchall(self):
"""
Same as cursor.fetchall() specified in DB API 2.0, except returning
Snap data types.
"""
rs = self._delegate.fetchall()
rs = self.convert_results(rs)
return rs
def __getattr__(self, name):
"""
Used to delegate to the native cursor object those methods that are not
wrapped by this class.
"""
result = getattr(self._delegate, name)
return result
class SQLServer(SnapDBAdapter):
"""
Implementation of L{SnapDBAdapter} for SQLServer.
"""
def __init__(self, *args, **kwargs):
host = kwargs['host']
try:
port = kwargs['port']
del kwargs['port']
except KeyError, e:
# For old-style connection.
if not ':' in host:
port = 1433
# See bug #1988
# pymssql prior to version 1.0 used comma rather than colon to
# separate host and port on Windows
# More details at http://pymssql.sourceforge.net/faq.php
sep = ':'
if os.name == 'nt' and LooseVersion(pymssql.__version__) < LooseVersion('1.0'):
sep = ','
kwargs['host'] = "%s%s%s" % (host, sep, port)
conn = pymssql.connect(**kwargs)
super(SQLServer, self).__init__(conn, pymssql)
self._charset = 'utf-8'
cur = self.cursor()
cur.execute("SELECT DATEDIFF(second,GETUTCDATE(),GETDATE())")
rs = cur.fetchone()
secs = int(rs[0])
self._utc_offset = timedelta(seconds=secs)
try:
cur.execute("SELECT @@VERSION AS version")
rs = cur.fetchone()
self._server_version = rs[0]
self._server_version = self._server_version[len("Microsoft SQL Server"):].strip()
if self._server_version.startswith("2000"):
q = "SELECT default_character_set_name FROM information_schema.schemata WHERE catalog_name = %(db)s"
cur.execute(q, {'db' : kwargs['database']})
rs = cur.fetchone()
charset = rs[0]
if charset == 'iso_1' or charset is None:
charset = 'latin1'
try:
codec = codecs.lookup(charset)
self._charset = codec.name
except LookupError, le:
exc = SnapComponentError("Unknown encoding %s" % charset)
exc.append(le)
raise exc
finally:
cur.close()
def upsert(self, table, row, keys, table_metadata=None):
"""
SQL Server-specific implementation of L{SnapDBAdapter.upsert()
by using MERGE.
"""
if not self._server_version.startswith("2008"):
return SnapDBAdapter.upsert(self, table, row, keys, table_metadata)
field_names = row.keys()
inner_select_clause = ['%%(%s)s AS %s' % (f, f) for f in field_names]
sql = "MERGE INTO " + \
table + \
" t1 USING (SELECT " + \
', '.join(inner_select_clause) + \
" ) t2 ON (";
set_clause = ["t1.%s = t2.%s" % (key, key) for key in keys]
sql += ' AND '.join(set_clause)
sql += ")"
sql += " WHEN MATCHED THEN UPDATE SET "
fields_to_set = list(set(field_names) - set(keys))
update_clause = ["t1.%s = t2.%s" % (f, f) for f in fields_to_set]
sql += ",".join(update_clause)
sql += " WHEN NOT MATCHED THEN INSERT ("
sql += ",".join(field_names)
sql += ") VALUES ("
bound_field_names = self.bindVariableList(field_names)
sql += ",".join(bound_field_names)
sql += ");"
cur = self.cursor()
bind_container = self.fix_bound_values(row)
cur.execute(sql, bind_container)
def cursor(self):
"""
See L{SnapDBAdapter.cursor} and L{SqlServerCursorWrapper}
"""
native_cursor = SnapDBAdapter.cursor(self)
my_cursor = SQLServerCursorWrapper(native_cursor, self)
return my_cursor
def fix_bound_values(self, record):
"""
Given a record (really, a dictionary) whose values are
Python objects, returns a dictionary with the same keys
whose values are Python objects of types that the DB
expects.
@param record: record
@type record: dict
@return: a record with values converted to types DB expects.
@rtype: dict
"""
new_result = {}
for param in record.keys():
value = record[param]
value_t = type(value)
if value_t == Decimal:
value = str(value)
elif value_t == datetime:
if value.tzinfo:
value += value.tzinfo.utcoffset(value)
value += self._utc_offset
new_result[param] = value
return new_result
def get_default_schema(self):
"""
See L{SnapDBAdapter.get_default_schema}. Default here is assumed
to be 'dbo' schema.
"""
return 'dbo'
def list_tables(self, schema = None):
"""
See L{SnapDBAdapter.list_tables}.
"""
if not schema:
schema = self.get_default_schema()
cur = self.cursor()
sql = "SELECT table_name FROM information_schema.tables WHERE LOWER(table_schema) = LOWER(%(schema)s)"
cur.execute(sql, {'schema' : schema})
result = cur.fetchall()
result = [row[0] for row in result]
cur.close()
return result
def get_no_row_query(self, select_sql):
"""
Creates a query to get 0 rows from the underlying SELECT statement.
@param select_sql: An arbitrary SELECT statement
@type select_sql: string
"""
select_sql = select_sql.strip()
if select_sql.upper().startswith("SELECT "):
# Make sure we use SELECT TOP. This is because if select_sql contains
# "ORDER BY" we will get the following error:
# "The ORDER BY clause is invalid in views, inline functions,
# derived tables,and subqueries, unless TOP is also specified."
# See ticket #1989
select_sql = select_sql[len("SELECT "):].strip()
if select_sql.upper().startswith("TOP "):
select_sql = "SELECT %s" % select_sql
else:
select_sql = "SELECT TOP 1 %s" % select_sql
sql = "SELECT TOP 1 * FROM (%s) dummy WHERE 1=0" % select_sql
return sql
def get_snap_view_metadata(self, table_name):
view_def = {}
primary_key = []
view_def['primary_key'] = primary_key
(schema, table_name) = self._parse_table_name(table_name)
view_def['schema'] = schema
field_defs = []
cur = self.cursor()
sql = """
SELECT * FROM
information_schema.columns c LEFT OUTER JOIN
information_schema.key_column_usage k ON
c.table_catalog = k.table_catalog
AND
c.table_schema = k.table_schema
AND
c.table_name = k.table_name
AND
c.column_name = k.column_name
WHERE
LOWER(c.table_schema) = LOWER(%(schema)s)
AND
LOWER(c.table_name) = LOWER(%(table_name)s)
ORDER BY
c.ordinal_position
"""
bind = {'schema' : schema, 'table_name' : table_name}
cur.execute(sql, bind)
result = cur.fetchall()
if not result:
raise SnapComponentError("Table '%s' not found in schema '%s'" % (table_name, schema))
indices = {}
for i in range(len(cur.description)):
meta = cur.description[i]
col_name = meta[0]
if not indices.has_key(col_name):
indices[col_name] = i
for row in result:
# These we need for actual metadata
name = row[indices['COLUMN_NAME']]
data_type = row[indices['DATA_TYPE']].lower()
snap_type = self.native_type_to_snap_type(data_type)
constraint = row[indices['CONSTRAINT_NAME']]
if constraint and constraint.startswith("PK__"):
primary_key.append(name)
desc = []
nullable = row[indices['IS_NULLABLE']]
desc.append("Nullable: %s" % nullable)
precision = row[indices['NUMERIC_PRECISION']]
if precision:
desc.append("Precision: %s" % precision)
else:
precision = row[indices['DATETIME_PRECISION']]
if precision:
desc.append("Precision: %s" % precision)
precision_radix = row[indices['NUMERIC_PRECISION_RADIX']]
if precision_radix:
desc.append("Precision radix: %s" % precision_radix)
scale = row[indices['NUMERIC_SCALE']]
if scale:
desc.append("Scale: %s" % scale)
default = row[indices['COLUMN_DEFAULT']]
if default:
desc.append("Default: %s" % default)
charset = row[indices['CHARACTER_SET_NAME']]
if charset:
desc.append("Character set: %s" % charset)
collation = row[indices['COLLATION_NAME']]
if collation:
desc.append("Collation: %s" % collation)
desc = '; '.join(desc)
field_def = (name, snap_type, desc,)
field_defs.append(field_def)
cur.close()
view_def['fields'] = tuple(field_defs)
return view_def
|