#!/usr/bin/env python
import fnmatch
import optparse
import os
import re
import sys
import textwrap
import time
import warnings
try:
from paste.deploy import appconfig
except ImportError:
appconfig = None
import sqlobject
from sqlobject import col
from sqlobject.classregistry import findClass
from sqlobject.declarative import DeclarativeMeta
from sqlobject.util import moduleloader
# It's not very unsafe to use tempnam like we are doing:
warnings.filterwarnings(
'ignore', 'tempnam is a potential security risk.*',
RuntimeWarning, '.*command', 28)
def nowarning_tempnam(*args, **kw):
return os.tempnam(*args, **kw)
class SQLObjectVersionTable(sqlobject.SQLObject):
"""
This table is used to store information about the database and
its version (used with record and update commands).
"""
class sqlmeta:
table = 'sqlobject_db_version'
version = col.StringCol()
updated = col.DateTimeCol(default=col.DateTimeCol.now)
def db_differences(soClass, conn):
"""
Returns the differences between a class and the table in a
connection. Returns [] if no differences are found. This
function does the best it can; it can miss many differences.
"""
# @@: Repeats a lot from CommandStatus.command, but it's hard
# to actually factor out the display logic. Or I'm too lazy
# to do so.
diffs = []
if not conn.tableExists(soClass.sqlmeta.table):
if soClass.sqlmeta.columns:
diffs.append('Does not exist in database')
else:
try:
columns = conn.columnsFromSchema(soClass.sqlmeta.table,
soClass)
except AttributeError:
# Database does not support reading columns
pass
else:
existing = {}
for col in columns:
col = col.withClass(soClass)
existing[col.dbName] = col
missing = {}
for col in soClass.sqlmeta.columnList:
if existing.has_key(col.dbName):
del existing[col.dbName]
else:
missing[col.dbName] = col
for col in existing.values():
diffs.append('Database has extra column: %s'
% col.dbName)
for col in missing.values():
diffs.append('Database missing column: %s' % col.dbName)
return diffs
class CommandRunner(object):
def __init__(self):
self.commands = {}
self.command_aliases = {}
def run(self, argv):
invoked_as = argv[0]
args = argv[1:]
for i in range(len(args)):
if not args[i].startswith('-'):
# this must be a command
command = args[i].lower()
del args[i]
break
else:
# no command found
self.invalid('No COMMAND given (try "%s help")'
% os.path.basename(invoked_as))
real_command = self.command_aliases.get(command, command)
if real_command not in self.commands.keys():
self.invalid('COMMAND %s unknown' % command)
runner = self.commands[real_command](
invoked_as, command, args, self)
runner.run()
def register(self, command):
name = command.name
self.commands[name] = command
for alias in command.aliases:
self.command_aliases[alias] = name
def invalid(self, msg, code=2):
print msg
sys.exit(code)
the_runner = CommandRunner()
register = the_runner.register
def standard_parser(connection=True, simulate=True,
interactive=False, find_modules=True):
parser = optparse.OptionParser()
parser.add_option('-v', '--verbose',
help='Be verbose (multiple times for more verbosity)',
action='count',
dest='verbose',
default=0)
if simulate:
parser.add_option('-n', '--simulate',
help="Don't actually do anything (implies -v)",
action='store_true',
dest='simulate')
if connection:
parser.add_option('-c', '--connection',
help="The database connection URI",
metavar='URI',
dest='connection_uri')
parser.add_option('-f', '--config-file',
help="The Paste config file that contains the database URI (in the database key)",
metavar="FILE",
dest="config_file")
if find_modules:
parser.add_option('-m', '--module',
help="Module in which to find SQLObject classes",
action='append',
metavar='MODULE',
dest='modules',
default=[])
parser.add_option('-p', '--package',
help="Package to search for SQLObject classes",
action="append",
metavar="PACKAGE",
dest="packages",
default=[])
parser.add_option('--class',
help="Select only named classes (wildcards allowed)",
action="append",
metavar="NAME",
dest="class_matchers",
default=[])
if interactive:
parser.add_option('-i', '--interactive',
help="Ask before doing anything (use twice to be more careful)",
action="count",
dest="interactive",
default=0)
parser.add_option('--egg',
help="Select modules from the given Egg, using sqlobject.txt",
action="append",
metavar="EGG_SPEC",
dest="eggs",
default=[])
return parser
class Command(object):
__metaclass__ = DeclarativeMeta
min_args = 0
min_args_error = 'You must provide at least %(min_args)s arguments'
max_args = 0
max_args_error = 'You must provide no more than %(max_args)s arguments'
aliases = ()
required_args = []
description = None
help = ''
def orderClassesByDependencyLevel(self, classes):
"""
Return classes ordered by their depth in the class dependency
tree (this is *not* the inheritance tree), from the
top level (independant) classes to the deepest level.
The dependency tree is defined by the foreign key relations.
"""
# @@: written as a self-contained function for now, to prevent
# having to modify any core SQLObject component and namespace
# contamination.
# yemartin - 2006-08-08
class SQLObjectCircularReferenceError(Exception): pass
def findReverseDependencies(cls):
"""
Return a list of classes that cls depends on. Note that
"depends on" here mean "has a foreign key pointing to".
"""
depended = []
for col in cls.sqlmeta.columnList:
if col.foreignKey:
other = findClass(col.foreignKey,
col.soClass.sqlmeta.registry)
if other not in depended:
depended.append(other)
return depended
# Cache to save already calculated dependency levels.
dependency_levels = {}
def calculateDependencyLevel(cls, dependency_stack=[]):
"""
Recursively calculate the dependency level of cls, while
using the dependency_stack to detect any circular reference.
"""
# Return value from the cache if already calculated
if dependency_levels.has_key(cls):
return dependency_levels[cls]
# Check for circular references
if cls in dependency_stack:
dependency_stack.append(cls)
raise SQLObjectCircularReferenceError, (
"Found a circular reference: %s " %
(' --> '.join([x.__name__
for x in dependency_stack])))
dependency_stack.append(cls)
# Recursively inspect dependent classes.
depended = findReverseDependencies(cls)
if depended:
level = max([calculateDependencyLevel(x, dependency_stack)
for x in depended]) + 1
else:
level = 0
dependency_levels[cls] = level
return level
# Now simply calculate and sort by dependency levels:
try:
sorter = []
for cls in classes:
level = calculateDependencyLevel(cls)
sorter.append((level, cls))
sorter.sort()
ordered_classes = [cls for level, cls in sorter]
except SQLObjectCircularReferenceError, msg:
# Failsafe: return the classes as-is if a circular reference
# prevented the dependency levels to be calculated.
print ("Warning: a circular reference was detected in the "
"model. Unable to sort the classes by dependency: they "
"will be treated in alphabetic order. This may or may "
"not work depending on your database backend. "
"The error was:\n%s" % msg)
return classes
return ordered_classes
def __classinit__(cls, new_args):
if cls.__bases__ == (object,):
# This abstract base class
return
register(cls)
def __init__(self, invoked_as, command_name, args, runner):
self.invoked_as = invoked_as
self.command_name = command_name
self.raw_args = args
self.runner = runner
def run(self):
self.parser.usage = "%%prog [options]\n%s" % self.summary
if self.help:
help = textwrap.fill(
self.help, int(os.environ.get('COLUMNS', 80))-4)
self.parser.usage += '\n' + help
self.parser.prog = '%s %s' % (
os.path.basename(self.invoked_as),
self.command_name)
if self.description:
self.parser.description = description
self.options, self.args = self.parser.parse_args(self.raw_args)
if (getattr(self.options, 'simulate', False)
and not self.options.verbose):
self.options.verbose = 1
if self.min_args is not None and len(self.args) < self.min_args:
self.runner.invalid(
self.min_args_error % {'min_args': self.min_args,
'actual_args': len(self.args)})
if self.max_args is not None and len(self.args) > self.max_args:
self.runner.invalid(
self.max_args_error % {'max_args': self.max_args,
'actual_args': len(self.args)})
for var_name, option_name in self.required_args:
if not getattr(self.options, var_name, None):
self.runner.invalid(
'You must provide the option %s' % option_name)
conf = self.config()
if conf and conf.get('sys_path'):
update_sys_path(conf['sys_path'], self.options.verbose)
if conf and conf.get('database'):
conn = sqlobject.connectionForURI(conf['database'])
sqlobject.sqlhub.processConnection = conn
for egg_spec in getattr(self.options, 'eggs', []):
self.load_options_from_egg(egg_spec)
self.command()
def classes(self, require_connection=True,
require_some=False):
all = []
conf = self.config()
for module_name in self.options.modules:
all.extend(self.classes_from_module(
moduleloader.load_module(module_name)))
for package_name in self.options.packages:
all.extend(self.classes_from_package(package_name))
for egg_spec in self.options.eggs:
all.extend(self.classes_from_egg(egg_spec))
if self.options.class_matchers:
filtered = []
for soClass in all:
name = soClass.__name__
for matcher in self.options.class_matchers:
if fnmatch.fnmatch(name, matcher):
filtered.append(soClass)
break
all = filtered
conn = self.connection()
if conn:
for soClass in all:
soClass._connection = conn
else:
missing = []
for soClass in all:
try:
if not soClass._connection:
missing.append(soClass)
except AttributeError:
missing.append(soClass)
if missing and require_connection:
self.runner.invalid(
'These classes do not have connections set:\n * %s\n'
'You must indicate --connection=URI'
% '\n * '.join([soClass.__name__
for soClass in missing]))
if require_some and not all:
print 'No classes found!'
if self.options.modules:
print 'Looked in modules: %s' % ', '.join(self.options.modules)
else:
print 'No modules specified'
if self.options.packages:
print 'Looked in packages: %s' % ', '.join(self.options.packages)
else:
print 'No packages specified'
if self.options.class_matchers:
print 'Matching class pattern: %s' % self.options.class_matches
if self.options.eggs:
print 'Looked in eggs: %s' % ', '.join(self.options.eggs)
else:
print 'No eggs specified'
sys.exit(1)
return self.orderClassesByDependencyLevel(all)
def classes_from_module(self, module):
all = []
if hasattr(module, 'soClasses'):
for name_or_class in module.soClasses:
if isinstance(name_or_class, str):
name_or_class = getattr(module, name_or_class)
all.append(name_or_class)
else:
for name in dir(module):
value = getattr(module, name)
if (isinstance(value, type)
and issubclass(value, sqlobject.SQLObject)
and value.__module__ == module.__name__):
all.append(value)
return all
def connection(self):
config = self.config()
if config is not None:
assert config.get('database'), (
"No database variable found in config file %s"
% self.options.config_file)
return sqlobject.connectionForURI(config['database'])
elif getattr(self.options, 'connection_uri', None):
return sqlobject.connectionForURI(self.options.connection_uri)
else:
return None
def config(self):
if not getattr(self.options, 'config_file', None):
return None
config_file = self.options.config_file
if appconfig:
if (not config_file.startswith('egg:')
and not config_file.startswith('config:')):
config_file = 'config:' + config_file
return appconfig(config_file,
relative_to=os.getcwd())
else:
return self.ini_config(config_file)
def ini_config(self, conf_fn):
conf_section = 'main'
if '#' in conf_fn:
conf_fn, conf_section = conf_fn.split('#', 1)
from ConfigParser import ConfigParser
p = ConfigParser()
# Case-sensitive:
p.optionxform = str
if not os.path.exists(conf_fn):
# Stupid RawConfigParser doesn't give an error for
# non-existant files:
raise OSError(
"Config file %s does not exist" % self.options.config_file)
p.read([conf_fn])
p._defaults.setdefault(
'here', os.path.dirname(os.path.abspath(conf_fn)))
possible_sections = []
for section in p.sections():
name = section.strip().lower()
if (conf_section == name or
(conf_section == name.split(':')[-1]
and name.split(':')[0] in ('app', 'application'))):
possible_sections.append(section)
if not possible_sections:
raise OSError(
"Config file %s does not have a section [%s] or [*:%s]"
% (conf_fn, conf_section, conf_section))
if len(possible_sections) > 1:
raise OSError(
"Config file %s has multiple sections matching %s: %s"
% (conf_fn, conf_section, ', '.join(possible_sections)))
config = {}
for op in p.options(possible_sections[0]):
config[op] = p.get(possible_sections[0], op)
return config
def classes_from_package(self, package_name):
all = []
package = moduleloader.load_module(package_name)
package_dir = os.path.dirname(package.__file__)
def find_classes_in_file(arg, dir_name, filenames):
if dir_name.startswith('.svn'):
return
filenames = filter(lambda fname: fname.endswith('.py') and fname != '__init__.py',
filenames)
for fname in filenames:
module_name = os.path.join(dir_name, fname)
module_name = module_name[module_name.find(package_name):]
module_name = module_name.replace(os.path.sep,'.')[:-3]
try:
module = moduleloader.load_module(module_name)
except ImportError, err:
if self.options.verbose:
print 'Could not import module "%s". Error was : "%s"' % (module_name, err)
continue
except Exception, exc:
if self.options.verbose:
print 'Unknown exception while processing module "%s" : "%s"' % (module_name, exc)
continue
classes = self.classes_from_module(module)
all.extend(classes)
os.path.walk(package_dir, find_classes_in_file, None)
return all
def classes_from_egg(self, egg_spec):
modules = []
dist, conf = self.config_from_egg(egg_spec, warn_no_sqlobject=True)
for mod in conf.get('db_module', '').split(','):
mod = mod.strip()
if not mod:
continue
if self.options.verbose:
print 'Looking in module %s' % mod
modules.extend(self.classes_from_module(
moduleloader.load_module(mod)))
return modules
def load_options_from_egg(self, egg_spec):
dist, conf = self.config_from_egg(egg_spec)
if (hasattr(self.options, 'output_dir')
and not self.options.output_dir
and conf.get('history_dir')):
dir = conf['history_dir']
dir = dir.replace('$base', dist.location)
self.options.output_dir = dir
def config_from_egg(self, egg_spec, warn_no_sqlobject=True):
import pkg_resources
dist = pkg_resources.get_distribution(egg_spec)
if not dist.has_metadata('sqlobject.txt'):
if warn_no_sqlobject:
print 'No sqlobject.txt in %s egg info' % egg_spec
return None, {}
result = {}
for line in dist.get_metadata_lines('sqlobject.txt'):
line = line.strip()
if not line or line.startswith('#'):
continue
name, value = line.split('=', 1)
name = name.strip().lower()
if name in result:
print 'Warning: %s appears more than once in sqlobject.txt' % name
result[name.strip().lower()] = value.strip()
return dist, result
def command(self):
raise NotImplementedError
def _get_prog_name(self):
return os.path.basename(self.invoked_as)
prog_name = property(_get_prog_name)
def ask(self, prompt, safe=False, default=True):
if self.options.interactive >= 2:
default = safe
if default:
prompt += ' [Y/n]? '
else:
prompt += ' [y/N]? '
while 1:
response = raw_input(prompt).strip()
if not response.strip():
return default
if response and response[0].lower() in ('y', 'n'):
return response[0].lower() == 'y'
print 'Y or N please'
def shorten_filename(self, fn):
"""
Shortens a filename to make it relative to the current
directory (if it can). For display purposes.
"""
if fn.startswith(os.getcwd() + '/'):
fn = fn[len(os.getcwd())+1:]
return fn
def open_editor(self, pretext, breaker=None, extension='.txt'):
"""
Open an editor with the given text. Return the new text,
or None if no edits were made. If given, everything after
`breaker` will be ignored.
"""
fn = nowarning_tempnam() + extension
f = open(fn, 'w')
f.write(pretext)
f.close()
print '$EDITOR %s' % fn
os.system('$EDITOR %s' % fn)
f = open(fn, 'r')
content = f.read()
f.close()
if breaker:
content = content.split(breaker)[0]
pretext = pretext.split(breaker)[0]
if content == pretext or not content.strip():
return None
return content
class CommandSQL(Command):
name = 'sql'
summary = 'Show SQL CREATE statements'
parser = standard_parser(simulate=False)
def command(self):
classes = self.classes()
allConstraints = []
for cls in classes:
if self.options.verbose >= 1:
print '-- %s from %s' % (
cls.__name__, cls.__module__)
createSql, constraints = cls.createTableSQL()
print createSql.strip() + ';\n'
allConstraints.append(constraints)
for constraints in allConstraints:
if constraints:
for constraint in constraints:
if constraint:
print constraint.strip() + ';\n'
class CommandList(Command):
name = 'list'
summary = 'Show all SQLObject classes found'
parser = standard_parser(simulate=False, connection=False)
def command(self):
if self.options.verbose >= 1:
print 'Classes found:'
classes = self.classes(require_connection=False)
for soClass in classes:
print '%s.%s' % (soClass.__module__, soClass.__name__)
if self.options.verbose >= 1:
print ' Table: %s' % soClass.sqlmeta.table
class CommandCreate(Command):
name = 'create'
summary = 'Create tables'
parser = standard_parser(interactive=True)
parser.add_option('--create-db',
action='store_true',
dest='create_db',
help="Create the database")
def command(self):
v = self.options.verbose
created = 0
existing = 0
dbs_created = []
constraints = {}
for soClass in self.classes(require_some=True):
if (self.options.create_db
and soClass._connection not in dbs_created):
if not self.options.simulate:
try:
soClass._connection.createEmptyDatabase()
except soClass._connection.module.ProgrammingError, e:
if str(e).find('already exists') != -1:
print 'Database already exists'
else:
raise
else:
print '(simulating; cannot create database)'
dbs_created.append(soClass._connection)
if soClass._connection not in constraints.keys():
constraints[soClass._connection] = []
exists = soClass._connection.tableExists(soClass.sqlmeta.table)
if v >= 1:
if exists:
existing += 1
print '%s already exists.' % soClass.__name__
else:
print 'Creating %s' % soClass.__name__
if v >= 2:
sql, extra = soClass.createTableSQL()
print sql
if (not self.options.simulate
and not exists):
if self.options.interactive:
if self.ask('Create %s' % soClass.__name__):
created += 1
tableConstraints = soClass.createTable(applyConstraints=False)
if tableConstraints:
constraints[soClass._connection].append(tableConstraints)
else:
print 'Cancelled'
else:
created += 1
tableConstraints = soClass.createTable(applyConstraints=False)
if tableConstraints:
constraints[soClass._connection].append(tableConstraints)
for connection in constraints.keys():
if v >= 2:
print 'Creating constraints'
for constraintList in constraints[connection]:
for constraint in constraintList:
if constraint:
connection.query(constraint)
if v >= 1:
print '%i tables created (%i already exist)' % (
created, existing)
class CommandDrop(Command):
name = 'drop'
summary = 'Drop tables'
parser = standard_parser(interactive=True)
def command(self):
v = self.options.verbose
dropped = 0
not_existing = 0
for soClass in reversed(self.classes()):
exists = soClass._connection.tableExists(soClass.sqlmeta.table)
if v >= 1:
if exists:
print 'Dropping %s' % soClass.__name__
else:
not_existing += 1
print '%s does not exist.' % soClass.__name__
if (not self.options.simulate
and exists):
if self.options.interactive:
if self.ask('Drop %s' % soClass.__name__):
dropped += 1
soClass.dropTable()
else:
print 'Cancelled'
else:
dropped += 1
soClass.dropTable()
if v >= 1:
print '%i tables dropped (%i didn\'t exist)' % (
dropped, not_existing)
class CommandStatus(Command):
name = 'status'
summary = 'Show status of classes vs. database'
help = ('This command checks the SQLObject definition and checks if '
'the tables in the database match. It can always test for '
'missing tables, and on some databases can test for the '
'existance of other tables. Column types are not currently '
'checked.')
parser = standard_parser(simulate=False)
def print_class(self, soClass):
if self.printed:
return
self.printed = True
print 'Checking %s...' % soClass.__name__
def command(self):
good = 0
bad = 0
missing_tables = 0
columnsFromSchema_warning = False
for soClass in self.classes(require_some=True):
conn = soClass._connection
self.printed = False
if self.options.verbose:
self.print_class(soClass)
if not conn.tableExists(soClass.sqlmeta.table):
self.print_class(soClass)
print ' Does not exist in database'
missing_tables += 1
continue
try:
columns = conn.columnsFromSchema(soClass.sqlmeta.table,
soClass)
except AttributeError:
if not columnsFromSchema_warning:
print 'Database does not support reading columns'
columnsFromSchema_warning = True
good += 1
continue
except AssertionError, e:
print 'Cannot read db table %s: %s' % (
soClass.sqlmeta.table, e)
continue
existing = {}
for col in columns:
col = col.withClass(soClass)
existing[col.dbName] = col
missing = {}
for col in soClass.sqlmeta.columnList:
if existing.has_key(col.dbName):
del existing[col.dbName]
else:
missing[col.dbName] = col
if existing:
self.print_class(soClass)
for col in existing.values():
print ' Database has extra column: %s' % col.dbName
if missing:
self.print_class(soClass)
for col in missing.values():
print ' Database missing column: %s' % col.dbName
if existing or missing:
bad += 1
else:
good += 1
if self.options.verbose:
print '%i in sync; %i out of sync; %i not in database' % (
good, bad, missing_tables)
class CommandHelp(Command):
name = 'help'
summary = 'Show help'
parser = optparse.OptionParser()
max_args = 1
def command(self):
if self.args:
the_runner.run([self.invoked_as, self.args[0], '-h'])
else:
print 'Available commands:'
print ' (use "%s help COMMAND" or "%s COMMAND -h" ' % (
self.prog_name, self.prog_name)
print ' for more information)'
items = the_runner.commands.items()
items.sort()
max_len = max([len(cn) for cn, c in items])
for command_name, command in items:
print '%s:%s %s' % (command_name,
' '*(max_len-len(command_name)),
command.summary)
if command.aliases:
print '%s (Aliases: %s)' % (
' '*max_len, ', '.join(command.aliases))
class CommandExecute(Command):
name = 'execute'
summary = 'Execute SQL statements'
help = ('Runs SQL statements directly in the database, with no '
'intervention. Useful when used with a configuration file. '
'Each argument is executed as an individual statement.')
parser = standard_parser(find_modules=False)
parser.add_option('--stdin',
help="Read SQL from stdin (normally takes SQL from the command line)",
dest="use_stdin",
action="store_true")
max_args = None
def command(self):
args = self.args
if self.options.use_stdin:
if self.options.verbose:
print "Reading additional SQL from stdin (Ctrl-D or Ctrl-Z to finish)..."
args.append(sys.stdin.read())
self.conn = self.connection().getConnection()
self.cursor = self.conn.cursor()
for sql in args:
self.execute_sql(sql)
def execute_sql(self, sql):
if self.options.verbose:
print sql
try:
self.cursor.execute(sql)
except Exception, e:
if not self.options.verbose:
print sql
print "****Error:"
print ' ', e
return
desc = self.cursor.description
rows = self.cursor.fetchall()
if self.options.verbose:
if not self.cursor.rowcount:
print "No rows accessed"
else:
print "%i rows accessed" % self.cursor.rowcount
if desc:
for name, type_code, display_size, internal_size, precision, scale, null_ok in desc:
sys.stdout.write("%s\t" % name)
sys.stdout.write("\n")
for row in rows:
for col in row:
sys.stdout.write("%r\t" % col)
sys.stdout.write("\n")
print
class CommandRecord(Command):
name = 'record'
summary = 'Record historical information about the database status'
help = ('Record state of table definitions. The state of each '
'table is written out to a separate file in a directory, '
'and that directory forms a "version". A table is also '
'added to your database (%s) that reflects the version the '
'database is currently at. Use the upgrade command to '
'sync databases with code.'
% SQLObjectVersionTable.sqlmeta.table)
parser = standard_parser()
parser.add_option('--output-dir',
help="Base directory for recorded definitions",
dest="output_dir",
metavar="DIR",
default=None)
parser.add_option('--no-db-record',
help="Don't record version to database",
dest="db_record",
action="store_false",
default=True)
parser.add_option('--force-create',
help="Create a new version even if appears to be "
"identical to the last version",
action="store_true",
dest="force_create")
parser.add_option('--name',
help="The name to append to the version. The "
"version should sort after previous versions (so "
"any versions from the same day should come "
"alphabetically before this version).",
dest="version_name",
metavar="NAME")
parser.add_option('--force-db-version',
help="Update the database version, and include no "
"database information. This is for databases that "
"were developed without any interaction with "
"this tool, to create a 'beginning' revision.",
metavar="VERSION_NAME",
dest="force_db_version")
parser.add_option('--edit',
help="Open an editor for the upgrader in the last "
"version (using $EDITOR).",
action="store_true",
dest="open_editor")
version_regex = re.compile(r'^\d\d\d\d-\d\d-\d\d')
def command(self):
if self.options.force_db_version:
self.command_force_db_version()
return
v = self.options.verbose
sim = self.options.simulate
classes = self.classes()
if not classes:
print "No classes found!"
return
output_dir = self.find_output_dir()
version = os.path.basename(output_dir)
print "Creating version %s" % version
conns = []
files = {}
for cls in self.classes():
dbName = cls._connection.dbName
if cls._connection not in conns:
conns.append(cls._connection)
fn = os.path.join(cls.__name__
+ '_' + dbName + '.sql')
if sim:
continue
create, constraints = cls.createTableSQL()
if constraints:
constraints = '\n-- Constraints:\n%s\n' % (
'\n'.join(constraints))
else:
constraints = ''
files[fn] = ''.join([
'-- Exported definition from %s\n'
% time.strftime('%Y-%m-%dT%H:%M:%S'),
'-- Class %s.%s\n'
% (cls.__module__, cls.__name__),
'-- Database: %s\n'
% dbName,
create.strip(),
'\n',
constraints])
last_version_dir = self.find_last_version()
if last_version_dir and not self.options.force_create:
if v > 1:
print "Checking %s to see if it is current" % last_version_dir
files_copy = files.copy()
for fn in os.listdir(last_version_dir):
if not fn.endswith('.sql'):
continue
if not files_copy.has_key(fn):
if v > 1:
print "Missing file %s" % fn
break
f = open(os.path.join(last_version_dir, fn), 'r')
content = f.read()
f.close()
if (self.strip_comments(files_copy[fn])
!= self.strip_comments(content)):
if v > 1:
print "Content does not match: %s" % fn
break
del files_copy[fn]
else:
# No differences so far
if not files_copy:
# Used up all files
print ("Current status matches version %s"
% os.path.basename(last_version_dir))
return
if v > 1:
print "Extra files: %s" % ', '.join(files_copy.keys())
if v:
print ("Current state does not match %s"
% os.path.basename(last_version_dir))
if v > 1 and not last_version_dir:
print "No last version to check"
if not sim:
os.mkdir(output_dir)
if v:
print 'Making directory %s' % self.shorten_filename(output_dir)
files = files.items()
files.sort()
for fn, content in files:
if v:
print ' Writing %s' % self.shorten_filename(fn)
if not sim:
f = open(os.path.join(output_dir, fn), 'w')
f.write(content)
f.close()
if self.options.db_record:
all_diffs = []
for cls in self.classes():
for conn in conns:
diffs = db_differences(cls, conn)
for diff in diffs:
if len(conns) > 1:
diff = ' (%s).%s: %s' % (
conn.uri(), cls.sqlmeta.table, diff)
else:
diff = ' %s: %s' % (cls.sqlmeta.table, diff)
all_diffs.append(diff)
if all_diffs:
print 'Database does not match schema:'
print '\n'.join(all_diffs)
for conn in conns:
self.update_db(version, conn)
else:
all_diffs = []
if self.options.open_editor:
if not last_version_dir:
print ("Cannot edit upgrader because there is no "
"previous version")
else:
breaker = ('-'*20 + ' lines below this will be ignored '
+ '-'*20)
pre_text = breaker + '\n' + '\n'.join(all_diffs)
text = self.open_editor('\n\n' + pre_text, breaker=breaker,
extension='.sql')
if text is not None:
fn = os.path.join(last_version_dir,
'upgrade_%s_%s.sql' %
(dbName, version))
f = open(fn, 'w')
f.write(text)
f.close()
print 'Wrote to %s' % fn
def update_db(self, version, conn):
v = self.options.verbose
if not conn.tableExists(SQLObjectVersionTable.sqlmeta.table):
if v:
print ('Creating table %s'
% SQLObjectVersionTable.sqlmeta.table)
sql = SQLObjectVersionTable.createTableSQL(connection=conn)
if v > 1:
print sql
if not self.options.simulate:
SQLObjectVersionTable.createTable(connection=conn)
if not self.options.simulate:
SQLObjectVersionTable.clearTable(connection=conn)
SQLObjectVersionTable(
version=version,
connection=conn)
def strip_comments(self, sql):
lines = [l for l in sql.splitlines()
if not l.strip().startswith('--')]
return '\n'.join(lines)
def base_dir(self):
base = self.options.output_dir
if base is None:
base = CONFIG.get('sqlobject_history_dir', '.')
if not os.path.exists(base):
print 'Creating history directory %s' % self.shorten_filename(base)
if not self.options.simulate:
os.makedirs(base)
return base
def find_output_dir(self):
today = time.strftime('%Y-%m-%d', time.localtime())
if self.options.version_name:
dir = os.path.join(self.base_dir(), today + '-' +
self.options.version_name)
if os.path.exists(dir):
print ("Error, directory already exists: %s"
% dir)
sys.exit(1)
return dir
extra = ''
while 1:
dir = os.path.join(self.base_dir(), today + extra)
if not os.path.exists(dir):
return dir
if not extra:
extra = 'a'
else:
extra = chr(ord(extra)+1)
def find_last_version(self):
names = []
for fn in os.listdir(self.base_dir()):
if not self.version_regex.search(fn):
continue
names.append(fn)
if not names:
return None
names.sort()
return os.path.join(self.base_dir(), names[-1])
def command_force_db_version(self):
v = self.options.verbose
sim = self.options.simulate
version = self.options.force_db_version
if not self.version_regex.search(version):
print "Versions must be in the format YYYY-MM-DD..."
print "You version %s does not fit this" % version
return
version_dir = os.path.join(self.base_dir(), version)
if not os.path.exists(version_dir):
if v:
print 'Creating %s' % self.shorten_filename(version_dir)
if not sim:
os.mkdir(version_dir)
elif v:
print ('Directory %s exists'
% self.shorten_filename(version_dir))
if self.options.db_record:
self.update_db(version, self.connection())
class CommandUpgrade(CommandRecord):
name = 'upgrade'
summary = 'Update the database to a new version (as created by record)'
help = ('This command runs scripts (that you write by hand) to '
'upgrade a database. The database\'s current version is in '
'the sqlobject_version table (use record --force-db-version '
'if a database does not have a sqlobject_version table), '
'and upgrade scripts are in the version directory you are '
'upgrading FROM, named upgrade_DBNAME_VERSION.sql, like '
'"upgrade_mysql_2004-12-01b.sql".')
parser = standard_parser(find_modules=False)
parser.add_option('--upgrade-to',
help="Upgrade to the given version (default: newest version)",
dest="upgrade_to",
metavar="VERSION")
parser.add_option('--output-dir',
help="Base directory for recorded definitions",
dest="output_dir",
metavar="DIR",
default=None)
upgrade_regex = re.compile(r'^upgrade_([a-z]*)_([^.]*)\.sql$', re.I)
def command(self):
v = self.options.verbose
sim = self.options.simulate
if self.options.upgrade_to:
version_to = self.options.upgrade_to
else:
fname = self.find_last_version()
if fname is None:
print "No version exists, use 'record' command to create one"
return
version_to = os.path.basename(fname)
current = self.current_version()
if v:
print 'Current version: %s' % current
version_list = self.make_plan(current, version_to)
if not version_list:
print 'Database up to date'
return
if v:
print 'Plan:'
for next_version, upgrader in version_list:
print ' Use %s to upgrade to %s' % (
self.shorten_filename(upgrader), next_version)
conn = self.connection()
for next_version, upgrader in version_list:
f = open(upgrader)
sql = f.read()
f.close()
if v:
print "Running:"
print sql
print '-'*60
if not sim:
try:
conn.query(sql)
except:
print "Error in script: %s" % upgrader
raise
self.update_db(next_version, conn)
print 'Done.'
def current_version(self):
conn = self.connection()
if not conn.tableExists(SQLObjectVersionTable.sqlmeta.table):
print 'No sqlobject_version table!'
sys.exit(1)
versions = list(SQLObjectVersionTable.select(connection=conn))
if not versions:
print 'No rows in sqlobject_version!'
sys.exit(1)
if len(versions) > 1:
print 'Ambiguous sqlobject_version_table'
sys.exit(1)
return versions[0].version
def make_plan(self, current, dest):
if current == dest:
return []
dbname = self.connection().dbName
next_version, upgrader = self.best_upgrade(current, dest, dbname)
if not upgrader:
print 'No way to upgrade from %s to %s' % (current, dest)
print ('(you need a %s/upgrade_%s_%s.sql script)'
% (current, dbname, dest))
sys.exit(1)
plan = [(next_version, upgrader)]
if next_version == dest:
return plan
else:
return plan + self.make_plan(next_version, dest)
def best_upgrade(self, current, dest, target_dbname):
current_dir = os.path.join(self.base_dir(), current)
if self.options.verbose > 1:
print ('Looking in %s for upgraders'
% self.shorten_filename(current_dir))
upgraders = []
for fn in os.listdir(current_dir):
match = self.upgrade_regex.search(fn)
if not match:
if self.options.verbose > 1:
print 'Not an upgrade script: %s' % fn
continue
dbname = match.group(1)
version = match.group(2)
if dbname != target_dbname:
if self.options.verbose > 1:
print 'Not for this database: %s (want %s)' % (
dbname, target_dbname)
continue
if version > dest:
if self.options.verbose > 1:
print 'Version too new: %s (only want %s)' % (
version, dest)
upgraders.append((version, os.path.join(current_dir, fn)))
if not upgraders:
if self.options.verbose > 1:
print 'No upgraders found in %s' % current_dir
return None, None
upgraders.sort()
return upgraders[-1]
def update_sys_path(paths, verbose):
if isinstance(paths, basestring):
paths = [paths]
for path in paths:
path = os.path.abspath(path)
if path not in sys.path:
if verbose > 1:
print 'Adding %s to path' % path
sys.path.insert(0, path)
if __name__ == '__main__':
the_runner.run(sys.argv)
|