# $SnapHashLicense:
#
# SnapLogic - Open source data services
#
# Copyright (C) 2008 - 2009, SnapLogic, Inc. All rights reserved.
#
# See http://www.snaplogic.org for more information about
# the SnapLogic project.
#
# This program is free software, distributed under the terms of
# the GNU General Public License Version 2. See the LEGAL file
# at the top of the source tree.
#
# "SnapLogic" is a trademark of SnapLogic, Inc.
#
#
# $
# $Id: Aggregate.py 10330 2009-12-24 22:13:38Z grisha $
"""
Aggregate module.
"""
from snaplogic.common.snap_exceptions import SnapObjTypeError
from snaplogic.common import version_info
import tempfile
import os
from decimal import Decimal
from datetime import datetime
import time
from sqlite3 import dbapi2
import sets
import snaplogic.components as components
from snaplogic.components import computils
from snaplogic.common import snap_log,sqlite_iter
from snaplogic.common.data_types import SnapNumber,SnapString,SnapDateTime
from snaplogic.cc.component_api import ComponentAPI
from snaplogic.cc import prop
from snaplogic.snapi_base import keys
from snaplogic.common.snap_exceptions import SnapComponentError
INPUT_FIELD_NAME = "Input field"
FUNCTION = "Aggregate function"
OUTPUT_FIELD_NAME = "Output field"
AGG_SPEC = "Aggregation specification"
AGG_SPECS = "Aggregation specifications"
GROUP_BY_FIELDS = "Group by fields"
DATE_FORMAT_STRING = "%Y-%m-%d %H:%M:%S"
class Aggregate(ComponentAPI):
api_version = '1.0'
component_version = '1.2'
capabilities = {
ComponentAPI.CAPABILITY_INPUT_VIEW_LOWER_LIMIT : 1,
ComponentAPI.CAPABILITY_INPUT_VIEW_UPPER_LIMIT : 1,
ComponentAPI.CAPABILITY_OUTPUT_VIEW_LOWER_LIMIT : 1,
ComponentAPI.CAPABILITY_OUTPUT_VIEW_UPPER_LIMIT : 1,
ComponentAPI.CAPABILITY_ALLOW_PASS_THROUGH : False
}
component_description = "Aggregate"
component_label = "Aggregate"
component_doc_uri = "https://www.snaplogic.org/trac/wiki/Documentation/%s/ComponentRef/Aggregate" % \
version_info.doc_uri_version
# The list of supported aggregate functions
aggregate_functions = ['count','sum','avg','min','max']
# The subset of functions that require number types as input
aggregate_number_input_functions = ['sum','avg']
# The subset of functions that require number types as output
aggregate_number_output_functions = ['count','sum','avg']
# The underlying db (sqlite) supports a maximum number of aggregates per statement.
# This dictates the maximum number of aggregate specifications in the AGG_SPECS below.
max_aggregates = 2000
def create_resource_template(self):
"""
Create Aggregate resource template.
"""
input_field_name = prop.SimpleProp(INPUT_FIELD_NAME,
"string",
"Field to aggregate on",
{'lov': [ keys.CONSTRAINT_LOV_INPUT_FIELD] },
True)
func = prop.SimpleProp(FUNCTION,
"string",
"Agggregate function",
{"lov" : self.aggregate_functions},
True)
output_field_name = prop.SimpleProp(OUTPUT_FIELD_NAME,
"string",
"What output field the result corresponds to",
{'lov': [ keys.CONSTRAINT_LOV_OUTPUT_FIELD] },
True)
agg_spec = prop.DictProp(AGG_SPEC,
input_field_name,
"Aggregation specification dictionary",
3,
3,
True,
True)
agg_spec[INPUT_FIELD_NAME] = input_field_name
agg_spec[FUNCTION] = func
agg_spec[OUTPUT_FIELD_NAME] = output_field_name
agg_specs = prop.ListProp("Aggregation specifications",
agg_spec,
"Aggregation specification properties",
1,
self.max_aggregates,
True)
self.set_property_def(AGG_SPECS, agg_specs)
group_by_fields = prop.ListProp("Group by fields",
prop.SimpleProp(GROUP_BY_FIELDS,
"string",
"Field to group by",
{'lov': [ keys.CONSTRAINT_LOV_INPUT_FIELD] }))
self.set_property_def(GROUP_BY_FIELDS, group_by_fields)
def validate(self, err_obj):
# All the properties except group by are required so we don't need to check for presence.
# We do, however, want to check that the input field is from the input view and the
# output field is from the output view.
input_views = self.list_input_view_names()
input_view = self.get_input_view_def(input_views[keys.SINGLE_VIEW])
input_view_fields = [ d[keys.FIELD_NAME] for d in input_view[keys.VIEW_FIELDS] ]
output_views = self.list_output_view_names()
output_view_name = output_views[keys.SINGLE_VIEW]
output_view = self.get_output_view_def(output_view_name)
output_view_fields = [ d[keys.FIELD_NAME] for d in output_view[keys.VIEW_FIELDS] ]
# We will build a list of agg_output_field_names for later use in validating the output view fields.
agg_output_field_names = []
agg_specs_ok = True
agg_specs = self.get_property_value(AGG_SPECS)
for i, spec in enumerate(agg_specs):
agg_func = spec[FUNCTION]
agg_input_field_name = spec[INPUT_FIELD_NAME]
agg_output_field_name = spec[OUTPUT_FIELD_NAME]
# Check that the output field hasn't been used already.
if agg_output_field_name in agg_output_field_names:
err_obj.get_property_err(AGG_SPECS)[i][OUTPUT_FIELD_NAME].set_message(
"Expression output field '%s' already specified." % agg_output_field_name)
agg_output_field_names.append(agg_output_field_name)
# Check that the input field is a number type for operators which require numbers.
if agg_func in self.aggregate_number_input_functions:
for input_field in input_view[keys.VIEW_FIELDS]:
if input_field[keys.FIELD_NAME] == agg_input_field_name and \
input_field[keys.FIELD_TYPE] != SnapNumber:
agg_specs_ok = False
err_obj.get_property_err(AGG_SPECS)[i][FUNCTION].set_message(
"Aggregation function '%s' not supported on datatype %s." % (agg_func, input_field[1]))
# Group by is optional but if present it better be a valid field name
# We build a list of these names for validating the output view.
group_by_field_names = []
group_by_ok = True
group_by_fields = self.get_property_value(GROUP_BY_FIELDS)
if group_by_fields:
for i, group_by_field in enumerate(group_by_fields):
group_by_field_names.append(group_by_field)
# All output fields can only reference the aggregation output fields or the group by
# fields, else it is an illegal syntax for single group expressions.
# Since this is an output view error, multiple errors will overwrite each other. Too bad, the
# user will have to make more than one pass through validation then.
output_view_ok = True
for i, field in enumerate(output_view_fields):
if field not in agg_output_field_names and field not in group_by_field_names:
output_view_ok = False
err_obj.get_output_view_err()[output_view_name][keys.VIEW_FIELDS][i].set_message(
"Output view field '%s' is not an aggregate output field name nor a group by field name." % field)
# One last check that can only be made when the above is looking sane...
# Make sure the output view field types correctly match the source field types where required.
mismatch_types_allowed = sets.Set(self.aggregate_number_input_functions).symmetric_difference(sets.Set(self.aggregate_number_output_functions))
if agg_specs_ok and output_view_ok:
for i, output_field in enumerate(output_view[keys.VIEW_FIELDS]):
output_field_name = output_field[keys.FIELD_NAME]
output_field_type = output_field[keys.FIELD_TYPE]
if output_field_name in agg_output_field_names:
# This is an agg output column so we must trace the source type via the agg specs.
agg_index = agg_output_field_names.index(output_field_name)
agg_input_field_name = agg_specs[agg_index][INPUT_FIELD_NAME]
agg_func = agg_specs[agg_index][FUNCTION]
# First make sure that the output type is correct for the agg function that produced it.
if agg_func in self.aggregate_number_output_functions and \
output_field_type != SnapNumber:
err_obj.get_output_view_err()[output_view_name][keys.VIEW_FIELDS][i].set_message(
"Output view field '%s' type must be 'number' for aggregate function '%s'." \
% (output_field_name, agg_func))
# Next make sure that the input type matches the output type where required.
# Skip this check if the agg function produces a different type than its input
# For example count(string) returns a number.
if agg_func in mismatch_types_allowed:
continue
for input_field in input_view[keys.VIEW_FIELDS]:
if input_field[keys.FIELD_NAME] == agg_input_field_name and \
input_field[keys.FIELD_TYPE] != output_field_type:
err_obj.get_output_view_err()[output_view_name][keys.VIEW_FIELDS][i].set_message(
"Output view field '%s' type '%s' does not match corresponding input view field '%s' type '%s'." \
% (output_field_name, output_field_type, input_field[keys.FIELD_NAME], input_field[keys.FIELD_TYPE]))
else:
# This is a group by column so it should map directly to an input view column.
for input_field in input_view[keys.VIEW_FIELDS]:
if input_field[keys.FIELD_NAME] == output_field_name and \
input_field[keys.FIELD_TYPE] != output_field_type:
err_obj.get_output_view_err()[output_view_name][keys.VIEW_FIELDS][i].set_message(
"Output view field '%s' type '%s' does not match corresponding input view field '%s' type '%s'." \
% (output_field_name, output_field_type, input_field[keys.FIELD_NAME], input_field[keys.FIELD_TYPE]))
def execute(self, input_views, output_views):
try:
self._output_view = output_views.values()[keys.SINGLE_VIEW]
except IndexError:
raise SnapComponentError("No output view connected.")
try:
self._input_view = input_views.values()[keys.SINGLE_VIEW]
except IndexError:
raise SnapComponentError("No input view connected.")
try:
self._execute()
finally:
self._cleanup()
def _execute(self):
self._db_file = None
self._db_file_name = None
self._cursor = None
self._con = None
# Create a temp file to hold the SQLite database.
# Note that mkstemp opens the file as well, which we don't need,
# so close the temp file after it's been created.
(self._db_file, self._db_file_name) = tempfile.mkstemp(".db","snapagg")
os.close(self._db_file)
self._con = sqlite.connect(self._db_file_name)
self._cursor = self._con.cursor()
sqlite.register_adapter(Decimal, float)
first = True
input_field_names = []
quoted_input_field_names = []
stmt = 'CREATE TABLE agg ('
for input_field in self._input_view.fields:
field = input_field[keys.FIELD_NAME]
field_type = input_field[keys.FIELD_TYPE]
input_field_names.append(field)
# Need to quote all strings to avoid clashes with sqlite keywords
field = "\"" + field + "\""
quoted_input_field_names.append(field)
if not first:
stmt += ", "
else:
first = False
stmt += field
if field_type == SnapString:
stmt += " VARCHAR"
elif field_type == SnapNumber:
stmt += " DECIMAL"
elif field_type == SnapDateTime:
stmt += ' DATETIME'
else:
raise SnapObjTypeError('Unknown type %s', field_type)
stmt += ")"
self._cursor.execute(stmt)
insert_stmt = "INSERT INTO agg (" + \
",".join(quoted_input_field_names) + \
") VALUES (" + \
",".join(['?' for i in quoted_input_field_names]) + \
")"
while True:
record = self._input_view.read_record()
if record is None:
break
vals = [record[field_name] for field_name in input_field_names]
self._cursor.execute(insert_stmt, vals)
agg_specs = self.get_property_value(AGG_SPECS)
output_name_to_function = {}
output_name_to_input_name = {}
for spec in agg_specs:
output_name_to_function[spec[OUTPUT_FIELD_NAME]] = spec[FUNCTION]
output_name_to_input_name[spec[OUTPUT_FIELD_NAME]] = spec[INPUT_FIELD_NAME]
group_by_fields = self.get_property_value(GROUP_BY_FIELDS)
stmt = "SELECT "
output_fields = []
sel_clause = []
for output_field in self._output_view.field_names:
output_fields.append(output_field)
if output_name_to_input_name.has_key(output_field):
input_field = output_name_to_input_name[output_field]
function = output_name_to_function[output_field]
sel = "%s(\"%s\") AS \"%s\"" % (function, input_field, output_field)
else:
if output_field not in group_by_fields:
# TODO This is an error... What to do?
pass
else:
sel = "\"" + output_field + "\""
sel_clause.append(sel)
stmt += ','.join(sel_clause)
stmt += " FROM agg"
# Add optional group by clause.
if group_by_fields:
stmt += " GROUP BY "
# Need to quote all strings to avoid clashes with sqlite keywords
group_by_fields_quoted = ["\"" + elem + "\"" for elem in group_by_fields]
stmt += ",".join(group_by_fields_quoted)
self._cursor.execute(stmt)
output_view_fields = self._output_view.fields
output_field_types = [output_field[keys.FIELD_TYPE] for output_field in output_view_fields]
for row in sqlite_iter(self._cursor):
out_rec = self._output_view.create_record()
i = 0
for field in output_fields:
if row[i] is None:
out_rec[field] = row[i]
elif output_field_types[i] == SnapNumber:
out_rec[field] = Decimal(str(row[i]))
elif output_field_types[i] == SnapDateTime:
out_rec[field] = datetime(*(time.strptime(row[i], DATE_FORMAT_STRING)[0:6]))
else:
out_rec[field] = row[i]
i += 1
self._output_view.write_record(out_rec)
self._output_view.completed()
def _cleanup(self):
"""
Clean up resources...
"""
if self._cursor:
try:
self._cursor.close()
except:
pass
if self._con:
try:
self._con.close()
except:
pass
if self._db_file_name:
try:
os.remove(self._db_file_name)
except:
pass
def upgrade_1_0_to_1_1(self):
"""
Add LOV constraints
"""
# Save property values.
# We need to recreate the properties, which resets their values.
agg_specs_value = self.get_property_value(AGG_SPECS)
group_specs_value = self.get_property_value(GROUP_BY_FIELDS)
# Delete the properties before recreating them
self.del_property_def(AGG_SPECS)
self.del_property_def(GROUP_BY_FIELDS)
# Redefine properties to include the constraint.
# This is copy and paste code from create_resource_template for version 1.0.
# The reason we cannot call create_resource_template directly here is
# that in a future version of this component the body of create_resource_template
# may change, so the invocation will no longer be correct.
# So we simply copy the code from version 1.0 create_resource_template here.
input_field_name = prop.SimpleProp(INPUT_FIELD_NAME,
"string",
"Field to aggregate on",
{'lov': [ keys.CONSTRAINT_LOV_INPUT_FIELD] },
True)
func = prop.SimpleProp(FUNCTION,
"string",
"Agggregate function",
{"lov" : self.aggregate_functions},
True)
output_field_name = prop.SimpleProp(OUTPUT_FIELD_NAME,
"string",
"What output field the result corresponds to",
{'lov': [ keys.CONSTRAINT_LOV_OUTPUT_FIELD] },
True)
agg_spec = prop.DictProp(AGG_SPEC,
input_field_name,
"Aggregation specification dictionary",
3,
3,
True,
True)
agg_spec[INPUT_FIELD_NAME] = input_field_name
agg_spec[FUNCTION] = func
agg_spec[OUTPUT_FIELD_NAME] = output_field_name
agg_specs = prop.ListProp("Aggregation specifications",
agg_spec,
"Aggregation specification properties",
1,
self.max_aggregates,
True)
self.set_property_def(AGG_SPECS, agg_specs)
group_by_fields = prop.ListProp("Group by fields",
prop.SimpleProp(GROUP_BY_FIELDS,
"string",
"Field to group by",
{'lov': [ keys.CONSTRAINT_LOV_INPUT_FIELD] }))
self.set_property_def(GROUP_BY_FIELDS, group_by_fields)
# End copy and paste (create_resource_template for version 1.0)
# Restore the value
self.set_property_value(AGG_SPECS, agg_specs_value)
self.set_property_value(GROUP_BY_FIELDS, group_specs_value)
def upgrade_1_1_to_1_2(self):
"""
No-op upgrade only to change component doc URI during the upgrade
which will be by cc_info before calling this method.
"""
pass
|