# $SnapHashLicense:
#
# SnapLogic - Open source data services
#
# Copyright (C) 2008-2009, SnapLogic, Inc. All rights reserved.
#
# See http://www.snaplogic.org for more information about
# the SnapLogic project.
#
# This program is free software, distributed under the terms of
# the GNU General Public License Version 2. See the LEGAL file
# at the top of the source tree.
#
# "SnapLogic" is a trademark of SnapLogic, Inc.
#
#
# $
# $Id: PostgreSQL.py 10045 2009-12-09 03:12:51Z grisha $
from time import strptime
import pgdb
from datetime import datetime,time,date,timedelta
from decimal import Decimal
from snaplogic.components.DBUtils import SnapDBAdapter
from snaplogic.common.snap_exceptions import SnapComponentError
from snaplogic.common.data_types import SnapString,SnapDateTime,SnapNumber
"""
Based on pgdb http://www.pygresql.org/pgdb.html#the-pgdb-module
"""
TYPE_CODE_TO_NATIVE_TYPE = {
'timestamp' : 'timestamp without time zone',
'timestamptz' : 'timestamp with time zone',
'time' : 'time without time zone',
'timetz' : 'time with time zone',
'varchar' : 'character varying',
'char' : 'character',
'bpchar' : 'character',
'int2' : 'smallint',
'int8' : 'bigint',
'int4' : 'integer',
'float4' : 'real',
'float8' : 'double precision',
'serial8' : 'bigserial',
'serial4' : 'serial',
'bool' : 'boolean',
}
NATIVE_TYPE_TO_SNAP_TYPE = {
'timestamp without time zone': SnapDateTime,
'timestamp with time zone' : SnapDateTime,
'time without time zone': SnapDateTime,
'time with time zone' : SnapDateTime,
'date' : SnapDateTime,
'character varying' : SnapString,
'character' : SnapString,
'text' : SnapString,
'money' : SnapNumber,
'bigint' : SnapNumber,
'integer' : SnapNumber,
'smallint' : SnapNumber,
'double precision' : SnapNumber,
'real' : SnapNumber,
'bigserial' : SnapNumber,
'serial' : SnapNumber,
'boolean' : SnapNumber,
'numeric' : SnapNumber,
}
class PostgresCursorWrapper(object):
"""
A wrapper around DB API 2.0 cursor, to handle encoding and type conversion.
Will be returned by L{SnapCx_Oracle.cursor}
"""
def __init__(self, cursor, snap_conn):
"""
Initialize.
@param cursor: DB API 2.0 cursor object, to which most requests will
be delegated.
@type: cursor
@param snap_conn: The L{SnapDBAdapter} object that generated this cursor
@type snap_conn: L{SnapDBAdapter}
"""
self._delegate = cursor
self._metadata = None
self._date_fields = None
self._num_fields = None
self._str_fields = None
self._snap_conn = snap_conn
self.execute("SET TIMEZONE TO UTC")
self.execute("SET NAMES '%s'" % snap_conn.char_set)
def execute(self, operation, params = None):
self._metadata = None
if type(operation) == unicode:
operation = operation.encode('utf-8')
if params is None:
return self._delegate.execute(operation)
return self._delegate.execute(operation, params)
def convert_row(self, row):
"""
Convert a row of data in native data types into a row of Snap types.
@param row: row returned by database
@type row: tuple
@return: row converted to Snap data types
@rtype: list
"""
if self._metadata is not None and self._date_fields is None and self._num_fields is None and self._str_fields is None:
return row
if not row:
return row
if self._metadata is None:
self._metadata = self._delegate.description
self._date_fields = None
self._num_fields = None
self._str_fields = None
i = 0
for col_metadata in self._metadata:
type_code = col_metadata[1]
# These are returned as str by pgdb...
native_type = self._snap_conn.type_code_to_native_type(type_code)
snap_type = self._snap_conn.native_type_to_snap_type(native_type)
if snap_type == SnapDateTime:
if self._date_fields is None:
self._date_fields = {}
self._date_fields[i] = type_code
elif snap_type == SnapNumber:
if self._num_fields is None:
self._num_fields = []
self._num_fields.append(i)
elif snap_type == SnapString:
if self._str_fields is None:
self._str_fields = []
self._str_fields.append(i)
else:
raise SnapComponentError("Unsupported type: %s" % native_type)
i += 1
new_row = list(row)
if self._num_fields is not None:
for idx in self._num_fields:
val = row[idx]
if val is None:
continue
t = type(val)
if t == int or t == long or t == bool:
new_row[idx] = Decimal(val)
else:
new_row[idx] = Decimal(str(val))
if self._str_fields is not None:
for idx in self._str_fields:
val = row[idx]
if val is None:
continue
new_row[idx] = val.decode('utf-8')
if self._date_fields is not None:
for idx in self._date_fields.keys():
type_code = self._date_fields[idx]
str_val = row[idx]
if str_val is None:
continue
if type_code == 'timestamp':
micros_delta = None
if '.' in str_val:
dot_idx = str_val.rindex('.')
micros = str_val[dot_idx+1:]
micros_mult = 10 ** (6 - len(micros))
micros = int(micros) * micros_mult
micros_delta = timedelta(0,0,micros)
str_val = str_val[:dot_idx]
dt = strptime(str_val,'%Y-%m-%d %H:%M:%S')
dt = dt[:6]
dt = datetime(*dt)
if micros_delta is not None:
dt += micros_delta
new_row[idx] = dt
elif type_code == 'timestamptz':
micros_delta = None
if '+' in str_val:
tz_idx = str_val.rindex('+')
delta_sign = -1
else:
tz_idx = str_val.rindex('-')
delta_sign = 1
tz_str = str_val[tz_idx+1:]
str_val = str_val[:tz_idx]
if '.' in str_val:
dot_idx = str_val.rindex('.')
micros = str_val[dot_idx+1:]
micros_mult = 10 ** (6 - len(micros))
micros = int(micros) * micros_mult
micros_delta = timedelta(0,0,micros)
str_val = str_val[:dot_idx]
dt = strptime(str_val,'%Y-%m-%d %H:%M:%S')
dt = dt[:6]
dt = datetime(*dt)
if micros_delta is not None:
dt += micros_delta
if ':' in tz_str:
(hours,mins) = tz_str.split(':')
hours = int(hours)
mins = int(mins)
else:
mins = 0
hours = int(tz_str)
delta = timedelta(hours=hours,minutes=mins)
dt = dt + delta_sign * delta
new_row[idx] = dt
elif type_code == 'date':
tt = strptime(str_val,'%Y-%m-%d')
tt = tt[:6]
new_row[idx] = datetime(*tt)
elif type_code == 'time':
micros_delta = None
if '.' in str_val:
dot_idx = str_val.rindex('.')
micros = str_val[dot_idx+1:]
micros_mult = 10 ** (6 - len(micros))
micros = int(micros) * micros_mult
micros_delta = timedelta(0,0,micros)
str_val = str_val[:dot_idx]
tt = strptime(str_val, '%H:%M:%S')
tt = tt[3:6]
t = time(*tt)
dt = datetime.combine(datetime.today(), t)
if micros_delta is not None:
dt += micros_delta
new_row[idx] = dt
elif type_code == 'timetz':
micros_delta = None
if '+' in str_val:
tz_idx = str_val.rindex('+')
delta_sign = -1
else:
tz_idx = str_val.rindex('-')
delta_sign = 1
tz_str = str_val[tz_idx+1:]
str_val = str_val[:tz_idx]
if '.' in str_val:
dot_idx = str_val.rindex('.')
micros = str_val[dot_idx+1:]
micros_mult = 10 ** (6 - len(micros))
micros = int(micros) * micros_mult
micros_delta = timedelta(0,0,micros)
str_val= str_val[:dot_idx]
tt = strptime(str_val,'%H:%M:%S')
tt = tt[3:6]
t = time(*tt)
dt = datetime.combine(datetime.today(), t)
if micros_delta is not None:
dt += micros_delta
if ':' in tz_str:
(hours,mins) = tz_str.split(':')
hours = int(hours)
mins = int(mins)
else:
mins = 0
hours = int(tz_str)
delta = timedelta(hours=hours,minutes=mins)
dt = dt + delta_sign * delta
new_row[idx] = dt
dt = new_row[idx]
return new_row
def convert_results(self, rs):
"""
Convert the result set from native data types to Snap data types.
This is similar to L{convert_row}, except it acts on the entire result
set
@param rs: Result set to convert
@type rs: list or tuple
@return: converted result set
@type: list
"""
if self._metadata is not None and self._date_fields is None and self._num_fields is None and self._str_fields is None:
return rs
if not rs:
return rs
converted_rs = []
for row in rs:
new_row = self.convert_row(row)
converted_rs.append(new_row)
return converted_rs
def fetchone(self):
"""
Same as cursor.fetchone() specified in DB API 2.0, except returning
Snap data types.
"""
row = self._delegate.fetchone()
if row is not None:
row = self.convert_row(row)
return row
def fetchmany(self, size=None):
"""
Same as cursor.fetchmany() specified in DB API 2.0, except returning
Snap data types.
"""
if size is None:
size = self._delegate.arraysize
rs = self._delegate.fetchmany(size)
rs = self.convert_results(rs)
return rs
def fetchall(self):
"""
Same as cursor.fetchall() specified in DB API 2.0, except returning
Snap data types.
"""
rs = self._delegate.fetchall()
rs = self.convert_results(rs)
return rs
def __getattr__(self, name):
"""
Used to delegate to the native cursor object those methods that are not
wrapped by this class.
"""
result = getattr(self._delegate, name)
return result
class PostgreSQL(SnapDBAdapter):
"""
Implementation of L{SnapDBAdapter} for PostgreSQL.
"""
def __init__(self, *args, **kwargs):
try:
port = kwargs['port']
kwargs.__delitem__('port')
kwargs['host'] = "%s:%s" % (kwargs['host'], port)
except KeyError:
pass
self.char_set = 'UTF8'
if 'charset' in kwargs:
self.char_set = kwargs['charset']
del kwargs['charset']
conn = pgdb.connect(**kwargs)
super(PostgreSQL, self).__init__(conn, pgdb)
def cursor(self):
"""
See L{SnapDBAdapter.cursor} and L{PostgresCursorWrapper}
"""
native_cursor = SnapDBAdapter.cursor(self)
my_cursor = PostgresCursorWrapper(native_cursor, self)
return my_cursor
def fix_bound_values(self, record):
"""
Given a record (really, a dictionary) whose values are
Python objects, returns a dictionary with the same keys
whose values are Python objects of types that the DB
expects.
@param record: record
@type record: dict
@return: a record with values converted to types DB expects.
@rtype: dict
"""
new_result = {}
for param in record.keys():
value = record[param]
value_t = type(value)
if value_t == Decimal:
# Postgres complains here:
# Native error: ("error 'do not know how to handle type <class 'decimal.Decimal'>' in 'INIT'",)
# See http://mailman.vex.net/pipermail/pygresql/2008-September/001979.html
int_val = value.__int__()
if value.__eq__(int_val):
value = int_val
else:
value = value.__float__()
new_result[param] = value
return new_result
def get_default_schema(self):
"""
See L{SnapDBAdapter.get_default_schema}. Default here is assumed
to be "public"
"""
return 'public'
def list_tables(self, schema = None):
"""
See L{SnapDBAdapter.list_tables}.
"""
cur = self.cursor()
if not schema:
schema = self.get_default_schema()
sql = "SELECT table_name FROM information_schema.tables WHERE LOWER(table_schema) = LOWER(%(schema)s)"
cur.execute(sql, {'schema' : schema})
result = cur.fetchall()
result = [row[0] for row in result]
cur.close()
return result
def limit_rows_clause(self, limit=1):
"""
See L{SnapDBAdapter.limit_rows_clause()}
"""
return "LIMIT %s" % limit
def type_code_to_native_type(self, tc):
"""
See L{SnapDBAdapter.native_typecode_to_native_type()}
"""
# In Postgres driver, they are sometimes the same.
if tc in TYPE_CODE_TO_NATIVE_TYPE:
return TYPE_CODE_TO_NATIVE_TYPE[tc]
else:
return tc
def get_snap_view_metadata(self, table_name):
"""
See L{SnapDBAdapter.get_snap_view_metadata}.
"""
view_def = {}
primary_key = []
view_def['primary_key'] = primary_key
field_defs = []
cur = self.cursor()
(schema, table_name) = self._parse_table_name(table_name)
view_def['schema'] = schema
sql = """
SELECT
c.column_name,
constraint_type,
c.column_default, c.is_nullable, c.data_type, c.character_maximum_length, c.numeric_precision,
c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, c.interval_precision,
c.collation_name, c.character_set_name
FROM
information_schema.columns c
LEFT OUTER JOIN information_schema.constraint_column_usage ccu
ON
c.table_catalog = ccu.table_catalog
AND
c.table_schema = ccu.table_schema
AND
c.table_name = ccu.table_name
AND
c.column_name = ccu.column_name
LEFT OUTER JOIN
information_schema.table_constraints tc
ON
tc.table_catalog = ccu.table_catalog
AND
tc.table_schema = ccu.table_schema
AND
tc.table_name = ccu.table_name
AND
tc.constraint_name = ccu.constraint_name
WHERE
LOWER(c.table_schema) = LOWER(%(schema)s)
AND
LOWER(c.table_name) = LOWER(%(table_name)s)
ORDER BY ordinal_position ASC;
"""
bind = {'schema' : schema, 'table_name' : table_name}
cur.execute(sql, bind)
result = cur.fetchall()
if not result:
raise SnapComponentError("Table '%s' not found in schema '%s'" % (table_name, schema))
indices = {}
for i in range(len(cur.description)):
meta = cur.description[i]
indices[meta[0]] = i
for row in result:
# These we need for actual metadata
name = row[indices['column_name']]
data_type = row[indices['data_type']].lower()
snap_type = self.native_type_to_snap_type(data_type)
constraint_type = row[indices['constraint_type']]
if constraint_type == 'PRIMARY KEY':
primary_key.append(name)
desc = []
nullable = row[indices['is_nullable']]
desc.append("Nullable: %s" % nullable)
precision = row[indices['numeric_precision']]
if precision:
desc.append("Precision: %s" % precision)
else:
precision = row[indices['datetime_precision']]
if precision:
desc.append("Precision: %s" % precision)
else:
precision = row[indices['interval_precision']]
if precision:
desc.append("Precision: %s" % precision)
scale = row[indices['numeric_scale']]
if scale:
desc.append("Scale: %s" % scale)
default = row[indices['column_default']]
if default:
desc.append("Default: %s" % default)
charset = row[indices['character_set_name']]
if charset:
desc.append("Character set: %s" % charset)
collation = row[indices['collation_name']]
if collation:
desc.append("Character set: %s" % collation)
desc = '; '.join(desc)
field_def = (name, snap_type, desc,)
field_defs.append(field_def)
cur.close()
view_def['fields'] = tuple(field_defs)
return view_def
|