# $SnapHashLicense:
#
# SnapLogic - Open source data services
#
# Copyright (C) 2008-2009, SnapLogic, Inc. All rights reserved.
#
# See http://www.snaplogic.org for more information about
# the SnapLogic project.
#
# This program is free software, distributed under the terms of
# the GNU General Public License Version 2. See the LEGAL file
# at the top of the source tree.
#
# "SnapLogic" is a trademark of SnapLogic, Inc.
#
#
# $
# $Id: __init__.py 8268 2009-07-17 17:42:24Z dmitri $
"""
This package contains utilities related to DB access. It loads the specified
DB-API based package to carry out operations. These packages are 3rd party
packages and tend to differ a bit, even though they follow the DB-API 2.0 spec.
"""
__docformat__ = "epytext en"
import re
import threading
import sys
from snaplogic.common.snap_exceptions import SnapException,SnapComponentError
from snaplogic import snapi
from snaplogic.snapi_base import keys,exceptions
from snaplogic import cc
from snaplogic.cc import prop
from snaplogic.common import snap_log,data_types
# Module Locking support
_lock = None
_db_classes_loaded = {}
"""
In the original Connection component, the DB is identified as the first element
of the connection string (ConnectString), which corresponds to the DB API
driver module. In the new set of Connection components, DB is identified
by name (e.g., "MySQL", "Oracle"). This dictionary maps the DB identifiers
from old-style to new-style, in order to be able to work with old-style
Connection resources.
"""
_OLD_STYLE_MAP = {
'MySQLdb' : 'MySQL',
'cx_Oracle' : 'Oracle',
'pgdb' : 'PostgreSQL',
'pymssql' : 'SQLServer',
'sqlite' : 'SQLite'
}
def _acquireLock():
"""
Acquire the module-level lock for accessing the _db_packages_loaded dictionary.
The lock should be released with _releaseLock().
"""
global _lock
if not _lock:
_lock = threading.RLock()
if _lock:
_lock.acquire()
def _releaseLock():
"""Release the module-level lock acquired from _acquireLock()."""
if _lock:
_lock.release()
def _get_snap_db_connection_class(db_type):
"""
Gets the instance of SnapLogic L{SnapDBAdapter} by the specified
database type.
@param db_type: Name of 3rd party package.
@type db_type: string
@return: A subclass of L{SnapDBAdapter} class
@rtype: class
"""
_acquireLock()
try:
if db_type in _db_classes_loaded:
return _db_classes_loaded[db_type]
pkg_name = 'snaplogic.components.DBUtils.%s' % db_type
try:
snap_db_mod = __import__(pkg_name, globals(), locals(), [''])
except ImportError, e:
raise SnapComponentError("Unknown database type: %s (%s)" % (db_type, e))
snap_db_cls = getattr(snap_db_mod, db_type)
_db_classes_loaded[db_type] = snap_db_cls
return snap_db_cls
finally:
_releaseLock()
def get_db_record_set(cursor):
"""
Record Set generator
Could use while cursor.rownumber < cursor.rowcount ,
but not supported by some database wrappers. So :
do the following with one extra fetchmany()
"""
while True:
try:
recset = cursor.fetchmany()
if not recset:
break
for rec in recset:
yield rec
except:
# TODO: handle db error in more sophisticated way
raise
def get_connection(db_type, kwargs):
con_cls = _get_snap_db_connection_class(db_type)
con = con_cls(**kwargs)
return con
def get_connection_from_resdef(connect_resdef):
"""
Return an appropriate L{SnapDBAdapter} object based on the
appropriate Connection resource..
@param connect_resdef: Resdef of a Connection resource.
@type connect_resdef: L{ResDef}
@return: An instance of L{SnapDBAdapter}
@rtype: L{SnapDBAdapter}
"""
comp_name = connect_resdef.get_component_name()
if comp_name == 'snaplogic.components.Connection':
# Old style connection
db_type = ''
conn_str = connect_resdef.get_property_value('ConnectString')
conn_str_good = True
if not conn_str:
conn_str_good = False
else:
conn_list = conn_str.split(':')
if len(conn_list) < 2:
conn_str_good = False
else:
driver = conn_list[0]
arg_str = conn_list[1]
if _OLD_STYLE_MAP.has_key(driver):
db_type = _OLD_STYLE_MAP[driver]
else:
conn_str_good = False
if db_type:
cc.log(snap_log.LEVEL_WARN, "Use of Connection component is deprecated, please use Connection%s" % db_type)
else:
cc.log(snap_log.LEVEL_WARN, "Use of Connection component is deprecated, please use Connection<DBType>")
if not conn_str_good:
raise SnapComponentError('Malformed connection string (%s) in connection resdef' % conn_str)
nvps = arg_str.split(',')
kwargs = {}
for nvp in nvps:
try:
(name,value) = nvp.split("=")
except ValueError:
raise SnapComponentError('Malformed connection string (%s) in connection resdef' % conn_str)
kwargs[str(name)] = str(value)
else:
# Given component name such as snaplogic.components.ConnectionMySQL
# or snaplogic.components.ConnectionMySQL_1_0
# extract "MySQL" (database name) from it.
m = re.match('.*Connection([a-zA-Z0-9]*)(_[0-9]_[0-9])*', comp_name)
if m is not None:
db_type = m.groups()[0]
else:
raise SnapComponentError("Resource definition is not a DB Connection resource")
kwargs = {}
for prop_name in connect_resdef.list_property_names():
if prop_name.startswith("db_"):
db_api_name = prop_name[3:]
val = connect_resdef.get_property_value(prop_name)
if type(val) == unicode:
val = str(val)
# Only set the param if it's not None.
# Because parameters are generally optional in DBAPI
# we don't have to explicitly set them to None if they are None:
# instead we can just not set them.
# And setting params explicitly to None causes problems
# in some DBAPI drivers, such as MySQL driver: it raises this error:
# "connect() argument 3 must be string, not None".
# However not setting parameter 3 (password) instead doesn't cause errors.
if val is not None:
kwargs[str(db_api_name)] = val
if not kwargs:
raise SnapComponentError("Resource definition is not a DB Connection resource")
try:
con = get_connection(db_type, kwargs)
except SnapException:
raise
except Exception, e:
snap_exc = SnapComponentError("Cannot create %s connection from resource definition: %s" % (db_type, str(e)))
snap_exc.append(e)
raise snap_exc
return con
class SnapDBAdapter(object):
"""
Abstract class representing a SnapLogic wrapper around a DB API connection
object. This wrapper is intended to appropriately convert types and encodings
between native DB and SnapLogic data types. To this end, some subclasses may
choose to return native cursor objects, while others may choose to create
wrappers.
Each DB is represented by its own wrapper, with the naming convention of
Snap<DB-API-driver-name>.
"""
def __init__(self, conn, driver):
"""
Initialize. Intended to be only called by subclasses.
@param conn: Native DB API 2.0 connection object
@type conn: Connection
@param driver: DB API 2.0 module (e.g., cx_Oracle)
@type driver: module
"""
super(SnapDBAdapter, self).__init__()
self._conn = conn
self._driver = driver
paramstyle = self._driver.paramstyle
# Sets the type of container for bind variables, depending on the
# paramstyle as specified in http://www.python.org/dev/peps/pep-0249/
if (paramstyle == 'qmark') or (paramstyle == 'numeric') or (paramstyle == 'format'):
self.bind_container_type = list
elif (paramstyle == 'named') or (paramstyle == 'pyformat'):
self.bind_container_type = dict
else:
raise SnapComponentError('Unexpected paramstyle: ' + paramstyle)
try:
bind_var_list_method_name = '_bind_var_list_%s' % paramstyle
self.bindVariableList = getattr(self, bind_var_list_method_name)
except AttributeError:
raise SnapComponentError('Unexpected paramstyle: ' + paramstyle)
def close(self):
"""
Close the connection as specified by http://www.python.org/dev/peps/pep-0249/
"""
return self._conn.close()
def commit(self):
"""
Commit as specified by http://www.python.org/dev/peps/pep-0249/
"""
return self._conn.commit()
def rollback(self):
"""
Roll back as specified by http://www.python.org/dev/peps/pep-0249/
"""
return self._conn.rollback()
def cursor(self):
"""
Get cursor as specified by http://www.python.org/dev/peps/pep-0249/
"""
cursor = self._conn.cursor()
return cursor
def bindVariableList(self, field_names):
"""
Create a list bind variables suitable for embedding into an
SQL statement, using the appropriate param style for this
connection.
@param field_names: A list of field names to use for bind variables.
@type field_names: list of strings
@return: List of bind variables directly corresponding with the elements in
field_names.
@rtype: list of strings
"""
pass
def upsert(self, table, row, keys, table_meta):
"""
Provides upsert (insert if not exist, else update) functionality.
This provides a generic upsert functionality, but subclasses are
encouraged to override with a more efficient DB-specific version
of this functionality.
This is not guaranteed to be atomic. It is the responsibility of the
caller to wrap this in a transaction if necessary.
@param table: DB table on which to perform upsert
@type table: str
@param row: Values to update or insert, as a dict, whose keys are column
names and whose values are the values to store in dB
@type row: dict
@param keys: A list of column names that serve as keys. If the values
for these column in the database are equal to the values
from the row argument, then an update is performed, otherwise,
an insert
@type keys: list or tuple
@param table_metadata
"""
sql = "SELECT " + ",".join(keys) + " FROM " + table + " WHERE "
bound_keys = self.bindVariableList(keys)
where_clause = ["%s = %s" % (key, bound_keys[i]) for i, key in enumerate(keys)]
where_clause = ' AND '.join(where_clause)
sql += where_clause
cur = self.cursor()
if self.bind_container_type == dict:
bind_container = {}
[bind_container.__setitem__(key, row[key]) for key in keys]
else:
bind_container = [row[key] for key in keys]
if hasattr(self, 'fix_bound_values'):
bind_container = self.fix_bound_values(bind_container)
cur.execute(sql, bind_container)
rs = cur.fetchall()
if len(rs) > 1:
explicit_where_clause = ' AND '.join(["%s = %s" % (key, row[key]) for key in keys])
raise SnapComponentError("Too many rows match 'WHERE %s'" % explicit_where_clause)
if rs:
# Update
sql = "UPDATE " + table + " SET "
fields_to_set = list(set(row.keys()) - set(keys))
bound_fields = self.bindVariableList(fields_to_set)
update_clause = ["%s = %s" % (field, bound_fields[i]) for i, field in enumerate(fields_to_set)]
update_clause = ', '.join(update_clause)
sql += update_clause
sql += " WHERE " + where_clause
if self.bind_container_type == list:
bind_container = [row[field] for field in fields_to_set]
bind_container += [row[key] for key in keys]
else:
bind_container = row
if hasattr(self, 'fix_bound_values'):
bind_container = self.fix_bound_values(bind_container)
cur.execute(sql, bind_container)
else:
# Insert
fields_to_insert = row.keys()
sql = "INSERT INTO " + table + "("
sql += ",".join(fields_to_insert)
sql += ") VALUES ("
bound_fields = self.bindVariableList(row.keys())
sql += ','.join(bound_fields) + ")"
if self.bind_container_type == list:
bind_container = [row[field] for field in fields_to_insert]
else:
bind_container = row
if hasattr(self, 'fix_bound_values'):
bind_container = self.fix_bound_values(bind_container)
cur.execute(sql, bind_container)
def _bind_var_list_qmark(self, field_names):
"""Implementation of bindVariableList for qmark (?,?,...) paramstyle."""
return ['?' for field in field_names]
def _bind_var_list_numeric(self, field_names):
"""Implementation of bindVariableList for numeric (:1,:2,...) paramstyle."""
return [':' + str(i + 1) for i in range(0, len(field_names))]
def _bind_var_list_named(self, field_names):
"""Implementation of bindVariableList for named (:field1, :field2,...) paramstyle."""
return [':' + field for field in field_names]
def _bind_var_list_format(self, field_names):
"""Implementation of bindVariableList for format (%s,%s,...) paramstyle."""
return ['%s' for field in field_names]
def _bind_var_list_pyformat(self, field_names):
"""Implementation of bindVariableList for pyformat (%(field1),%(field2),...) paramstyle."""
return ['%(' + field + ')s' for field in field_names]
def bindValueContainer(self, record):
"""
Create a proper container of values from the input fields of record.
Encodes the fields present within record appropriately for the given paramstyle.
@param record: A data record.
@type record: snaplogic.common.DataTypes.Record
@return: The values contained within the record appropriately encoded for the given
paramstyle
@rtype: dictionary or list
"""
# TODO this could further be made
# more "OO" like bindVariableList above.
paramstyle = self._driver.paramstyle
if self.bind_container_type == list:
result = [record[name] for name in record.field_names]
elif self.bind_container_type == dict:
values = {}
for field_name in record.field_names:
values[field_name] = record[field_name]
result = values
else:
raise SnapComponentError('Unexpected paramstyle: ' + paramstyle)
# If a subclass has the "fix_bound_values" method, apply it.
# The method will fix up the default set of bound values
# in a DB-specific manner (for example, converting datetime objects
# to strings of appropriate format; encoding unicode objects appropriately,
# etc.)
if hasattr(self, 'fix_bound_values'):
result = self.fix_bound_values(result)
return result
def type_code_to_native_type(self, tc):
"""
Given the native type code, converts it to native type
as a string (e.g., 22 to 'VARCHAR').
@param tc: native typecode
@type tc: int
@return: native type
@rtype: str
@raise SnapComponentError: if the type code provided is unknown
"""
mod = sys.modules[self.__module__]
d = mod.TYPE_CODE_TO_NATIVE_TYPE
if tc in d:
native_type = d[tc]
else:
raise SnapComponentError("Unknown type code: %s" % tc)
return native_type
def native_type_to_snap_type(self, nt):
"""
Given the native type (e.g., 'VARCHAR'),
converts it to Snap type (e.g., 'string').
@param nt: native type
@type nt: str
@return: Snap type, or None if unsupported.
@rtype: str
"""
nt = nt.lower()
mod = sys.modules[self.__module__]
d = mod.NATIVE_TYPE_TO_SNAP_TYPE
if nt in d:
return d[nt]
else:
raise SnapComponentError("Unsupported type: %s" % nt)
def limit_rows_clause(self, limit=0):
"""
Subclasses must override. DB-specific clause to limit number
of returned rows (useful for discovery, etc.)
@param limit: Limit number of returned rows
@type limit: int
@return: DB-specific clause to limit number of returned rows
given the specified limit.
@rtype: str
"""
raise SnapComponentError("%s must override" % self.__class__)
def _parse_table_name(self, table_name):
"""
Parses a table name with an optional schema qualifier into a tuple
of (schema, table_name). If schema not present, L{get_default_schema}
is used.
@param table_name: table name of the format [schema.]table_name.
@type table_name: str
@return: a tuple of (schema, table_name)
@rtype: tuple of str
"""
db_objects = table_name.split('.')
l = len(db_objects)
if l == 1:
schema = self.get_default_schema()
elif l == 2:
schema = db_objects[0]
table_name = db_objects[1]
else:
raise SnapComponentError("Malformed table name: %s; expected [schema.]table")
return (schema, table_name)
def get_no_row_query(self, select_sql):
"""
Creates a query to get 0 rows from the underlying SELECT statement.
@param select_sql: An arbitrary SELECT statement
@type select_sql: string
"""
limit_clause = self.limit_rows_clause()
sql = "SELECT * FROM (%s) dummy WHERE 1=0 %s " % (select_sql, limit_clause)
return sql
def get_snap_view_metadata_from_select(self, select_sql):
"""
Retrieves the metadata in the same format as L{get_snap_view_metadata} except
for an arbitrary SQL SELECT (or CALL) statement.
@param select_sql: SQL statement
@type select_sql: string
@return: See L{get_snap_view_metadata}
"""
cur = self.cursor()
sql = self.get_no_row_query(select_sql)
cur.execute(sql)
view_def = {}
view_def['primary_key'] = []
view_def['schema'] = self.get_default_schema()
view_def['native_types'] = {}
field_defs = []
metas = cur.description
for meta in metas:
field_name = meta[0]
type_code = meta[1]
native_type = self.type_code_to_native_type(type_code)
desc = ["Native type: %s" % native_type]
nullable = meta[6]
if nullable:
desc.append("Nullable: %s" % ("Yes" if nullable else "No"))
display_size = meta[2]
if display_size:
desc.append("Display size: %s" % display_size)
internal_size = meta[3]
if internal_size:
desc.append("Internal size: %s" % internal_size)
precision = meta[4]
if precision:
desc.append("Precision: %s" % precision)
scale = meta[5]
if scale:
desc.append("Scale: %s" % scale)
desc = '; '.join(desc)
snap_type = self.native_type_to_snap_type(native_type)
view_def['native_types'][field_name] = native_type
field_def = (field_name, snap_type, desc, )
field_defs.append(field_def)
cur.close()
view_def['fields'] = tuple(field_defs)
return view_def
def get_default_schema(self):
"""
Get the default schema for currently connected user.
@return: schema
@rtype: str
"""
# This is a no-op, subclasses must override
pass
def list_tables(self, schema=None):
"""
Lists all tables available for the connected user. If schema is not specified,
it is assumed to be the default schema (see L{get_default_schema}) for the user.
@param schema: Schema to use (optional).
@type schema: str
@return: List of tables
@rtype: list
"""
# This is a no-op, subclasses must override
pass
def get_snap_view_metadata(self, table_name):
"""
Returns a dict representing metadata needed to create a SnapLogic view.
@param table_name: Name of the table for which to gather the metadata.
May be optionally preceded by schema qualifier, e.g.,
"emp" or "scott.emp". If no schema qualifier is provided,
default schema (see L{get_default_schema()} for the connected
user will be used.
@type table_name: str
@return: view metadata as a dict. This dict has the following keys:
- 'primary_key' the value is a list of columns that comprise the primary key
in the given table
- 'schema' schema the table provided belongs to.
- 'native_types' (optional) a dict whose keys are column names and whose values
are native SQL types of that column
- 'fields' the value is a view definition (a tuple of field definitions,
each of which is a tuple of ('field_name', 'field_type', 'field_description)
format.
@rtype: dict
"""
# This is a no-op, subclasses must override
pass
|