"""
A MySQL client library written entirely in python.
$Id: mysql.py,v 1.5 2006/07/25 07:54:22 frooby Exp $
A pure python implementation of an interface to the MySQL database. This
module is written entirely in python and does not require linking to any
dyanmic libraries or compiling of C driver code to connect to a MySQL
database. This is handy for systems that have a python port but do not have
the MySQL database/libraries/dlls ported yet (e.g. Xbox Media Center).
There are no plans to implement additional features in this module. It does
what I need it to do and that's as far as I have taken it. If you would
like to contribute additional functionality to this, feel free to send me
any additions/changes/bug reports.
This python module is based on a pure perl implementation of an interface to
MySQL written by Hiroyuki Oyama.
Here are some simple examples of how to use this module:
import mysql
conn = mysql.Connection(
host="myhost",
database="mydb",
user="myuser",
password="mypass",
timeout=30,
port=3306,
debug=0 )
if conn.executeSQL( "select * from db where hostname like '%myhost%'" ):
iter = conn.dictRowIterator()
for row in iter:
for k in row.keys():
v = row[k]
if not v:
print k + "=>[(null)]",
else:
print k + "=>[" + v + "]",
print
print str(iter.getRowCount()) + " rows retrieved"
else:
print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()
if conn.executeSQL(
"insert into person (id, first, last) values (null, 'Tom','Warkentin')" ):
print "insert succeeded: id=" + str(conn.getInsertId())
else:
print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()
if conn.executeSQL(
"update settings set data = 'my new value' where value = 'my data value'" ):
print "update succeeded"
else:
print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()
if conn.executeSQL( "delete from db where hostname = 'myhost'" ):
print "delete succeeded"
else:
print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()
if conn.createDatabase( "mydb" ):
print "database created"
else:
print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()
if conn.dropDatabase( "mydb" ):
print "database dropped"
else:
print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()
conn.close()
Copyright (C) 2004 Tom Warkentin <tom@ixionstudios.com>
This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
or at the website: http://www.gnu.org
"""
import re, socket, string, struct
# ugly hack for problem passing Connection object to class contructor on xbox
global db
db = None
def escape( obj ):
"""
Method to escape a string. If the object is not defined, returns "NULL".
This method is intended to be used to escape values sent to mysql in an
insert or update statement.
Example:
s = "they're here"
print escape( s )
# will return: 'they\'re\ here'
print escape( None )
# will return: NULL
"""
if type( obj ) is int:
return "%d"%obj
elif obj == None:
return "NULL"
elif type( obj ) is str or type( obj ) is string:
return "'%s'"%re.escape( obj )
else:
return obj
class ClientException( Exception ):
pass
class ServerException( Exception ):
pass
class Connection:
BUFFER_LENGTH = 1024
CMD_QUIT = "\x01"
CMD_QUERY = "\x03"
CMD_CREATE_DB = "\x05"
CMD_DROP_DB = "\x06"
#---------------------------------------------------------------------------
# class private/protected methods
#---------------------------------------------------------------------------
def __andByChar( self, source, mask ):
retVal = source & mask
return retVal
def __andByLong( self, source, mask=4294967295L ): # 0xFFFFFFFF
retVal = self.__cutOffToLong(source) & self.__cutOffToLong(mask)
return retVal
def __buildMsg( self, cmd, args, flags=0 ):
"""
Function to build a message to send to the server.
cmd - Command to send to the server
args - Arguments associated with the cmd. No checking is done to
verify that the args passed are valid for the cmd passed.
The caller is trusted to know what the valid args are.
flags - Some commands require additional flags. I'm not too sure why
or what effect this has on the receiving end since I have
not looked at MySQL protocol source code.
"""
body = cmd
body += string.join( args, "\0" )
msg = struct.pack( "<H", len( body ) )
msg += struct.pack( "B", 0 )
msg += struct.pack( "B", flags ) + body
return msg
def __connect( self ):
"""
Method to create and connect a socket to a MySQL database server socket.
Attempts to set the timeout but continues even if the timeout cannot
be set.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM )
s.setblocking( 1 )
try:
s.settimeout( self.timeout )
except:
self.warn( "failed to set socket timeout" )
s.connect( (self.host, self.port) )
self.socket = s
def __cutOffToLong( self, source ):
while source > 4294967295L: # 0xFFFFFFFF
source -= (4294967295L + 1L)
return source
def __debug( self, msg ):
"""
Outputs the passed message if debug is enabled.
"""
if self.isDebug > 0:
print "mysql: *debug* " + msg
def __del__( self ):
"""
Destructor for an instance of the class.
"""
Connection.close( self )
def __dumpMsg( self, msg ):
"""
Dumps the passed message if debugging is enabled.
"""
if self.isDebug > 0:
self.__debug( self.__sprintMsg( msg ) )
def __executeCmd( self, cmd, args ):
"""
Executes the specified command on the server.
cmd - The command to be executed.
args - The arguments for the command. The caller is assumed to
know what the valid args are for the cmd passed.
"""
msg = self.__buildMsg( cmd, [args] )
self.__dumpMsg( msg )
self.socket.send( msg )
result = self.socket.recv( Connection.BUFFER_LENGTH )
self.__dumpMsg( result )
self.__resetStatus()
if self.__isError( result ):
return self.__setErrorByPacket( result )
elif self.__isSelectQueryResult( result ):
return self.__getRecordByServer( result )
elif self.__isUpdateQueryResult( result ):
# treat database drop as a special case - server returns 0 but we
# return 1 to indicate success
if cmd == Connection.CMD_DROP_DB:
return 1
else:
return self.__getAffectedRowsInfoByMsg( result )
else:
text = "unknown response: " + self.__sprintMsg( result )
raise ClientException, text
def __getAffectedRowsInfoByMsg( self, msg ):
self.affectedRowsLength = self.__getAffectedRowsLength( msg )
self.__debug( "affectedRowsLength=%d"%self.affectedRowsLength )
self.insertId = self.__getInsertId( msg )
self.__debug( "insertId=%d"%self.insertId )
self.serverMsg = ''
return self.affectedRowsLength
def __getAffectedRowsLength( self, msg ):
return ord( msg[5] )
def __getColumnLength( self, msg ):
"Retrieves the number of columns in a server response."
return ord( msg[4] )
def __getErrorCode( self, msg ):
"""
Retrieves the error code from the passed message. Raises an exception
if the passed message is not an error message.
"""
if not self.__isError( msg ):
raise ClientException, "invalid error msg: " + self.__sprintMsg( msg )
return struct.unpack( "<H", msg[5:7] )[0]
def __getHash( self, password ):
"""
Calculates a hash value based on the password passed in.
"""
nr = 1345345333L
add = 7L
nr2 = 305419889L # 0x12345671
pwlen = len(password)
for i in range(0,pwlen):
c = password[i]
if c == ' ' or c == '\t':
continue
tmp = ord( c )
value = ((self.__andByChar(nr,63) + add) * tmp) + nr * 256
nr = self.__xorByLong( nr, value )
nr2 += self.__xorByLong( (nr2 * 256), nr )
add += tmp
return (int(self.__andByLong( nr, 2147483647L)), # 0x7FFFFFFF
int(self.__andByLong( nr2, 2147483647L )))
def __getInsertId( self, msg ):
if ord( msg[6] ) != 0xfc:
return ord( msg[6] )
else:
return struct.unpack( "<H", msg[7:9] )[0]
def __getRecordByServer( self, msg ):
self.__getColumnLength( msg )
while self.__hasNextMsg( msg ):
nextMsg = self.socket.recv( Connection.BUFFER_LENGTH )
self.__dumpMsg( nextMsg )
msg += nextMsg
self.selectedRecord = msg
return 1
def __getServerInfo( self ):
"""
"""
msg = self.socket.recv( Connection.BUFFER_LENGTH )
self.__dumpMsg( msg )
if self.__isError( msg ):
raise ServerException, msg[7:]
# get length of message
i = 0
msgLen = struct.unpack("<I", msg[i:4])[0]
i += 4
self.__debug( "msgLen is " + str( msgLen ) )
# get protocol version
self.protocolVersion = ord( msg[i] )
i += 1
self.__debug( "protocol version is " + str(self.protocolVersion) )
if self.protocolVersion == 10:
self.clientCapabilities = 1
# get server version
end = msg.find( "\0", i )
self.serverVersion = msg[i:end]
self.__debug( "server version is '" + self.serverVersion + "'" )
i = end + 1
# get server thread id
self.serverThreadId = struct.unpack( "<I", msg[i:i+4] )[0]
self.__debug( "server thread id is " + str( self.serverThreadId ) )
i += 4
# get salt for password hashing
self.salt = msg[i:i+8]
self.__debug( "salt is [" + self.salt + "]" )
def __getServerMsg( self, msg ):
if len( msg ) < 7:
return ''
return msg[7:]
def __hasNextMsg( self, msg ):
return msg[-1] != '\xfe'
def __hasSelectedRecord( self ):
return self.selectedRecord
def __init__( \
self, host, database, user, password, \
port=3306, timeout=60.0, debug=0 ):
"""
Initialize a new connection to a MySQL database.
"""
self.isConnected = 0
self.host = host
self.port = port
self.database = database
self.user = user
self.password = password
self.timeout = timeout
self.socket = None
self.salt = None
self.protocolVersion = None
self.clientCapabilities = None
self.isDebug = debug
self.__debug( "debug output enabled" )
self.__connect()
self.__getServerInfo()
self.__requestAuth()
self.__resetStatus()
def __isError( self, msg ):
if len( msg ) < 4:
return 1
return ord( msg[4:5] ) == 255
def __isSelectQueryResult( self, msg ):
if self.__isError( msg ):
return None
return ord(msg[4]) >= 1
def __isUpdateQueryResult( self, msg ):
if self.__isError( msg ):
return None
return ord(msg[4]) == 0
def __requestAuth( self ):
"""
"""
self.__sendLoginMsg()
authResult = self.socket.recv( Connection.BUFFER_LENGTH )
self.__dumpMsg( authResult )
if self.__isError( authResult ):
self.close()
if len(authResult) < 7:
raise ClientException, "authentication timeout"
raise ServerException, authResult[7:]
self.__debug( "connected to database successfully" )
self.isConnected = 1
def __resetStatus( self ):
self.insertId = None
self.serverMsg = ''
self.errorCode = None
self.selectedRecord = None
def __scramblePassword( self, password, salt, flags ):
"""
"""
if len(password) == 0:
return ''
hsl = len( salt )
out = []
hashPass = self.__getHash(password)
hashMess = self.__getHash(salt)
maxValue = None
seed = None
seed2 = None
if flags < 1:
maxValue = 33554431L # 0x01FFFFFF
seed = self.__xorByLong(hashPass[0], hashMess[0]) % maxValue
seed2 = int(seed / 2)
else:
maxValue = 1073741823L # 0x3FFFFFFF
seed = self.__xorByLong(hashPass[0], hashMess[0]) % maxValue
seed2 = self.__xorByLong(hashPass[1], hashMess[1]) % maxValue
dMax = maxValue
dRes = None
dSeed = None
for i in range(0,hsl):
val = seed * 3L + seed2
seed = int(val % maxValue)
seed2 = int((seed + seed2 + 33) % maxValue)
dSeed = seed
dRes = float(dSeed) / float(dMax)
out.append( int(dRes * 31) + 64 )
if flags == 1:
seed = (seed * 3L + seed2) % maxValue
seed2 = (seed + seed2 + 33L) % maxValue
dSeed = seed
dRes = float(dSeed) / float(dMax)
e = int(dRes * 31)
for i in range(0,hsl):
out[i] ^= e
retVal = ''
for i in range(0,len(out)):
retVal += chr(out[i])
return retVal
def __sendLoginMsg( self ):
"""
Sends a login message to the server.
"""
msg = self.__buildMsg(
cmd = "\x8D\x00\x00\x00\x00",
args = [
self.user,
self.__scramblePassword(
self.password,
self.salt,
self.clientCapabilities),
self.database ],
flags = 1 )
self.__dumpMsg( msg )
self.socket.send( msg )
def __setErrorByPacket( self, msg ):
self.serverMsg = self.__getServerMsg( msg )
self.errorCode = self.__getErrorCode( msg )
return None
def __sprintMsg( self, msg ):
"""
"""
charsText = []
for i in range( 0, len( msg ) ):
charsText.append( "%02x" % ord( msg[i] ) )
transMsg = ''
for i in range( 0, len( msg ) ):
if re.match('[\d \w\!-\/\:-\@\[-\`\{-\~]', msg[i]):
transMsg += msg[i]
else:
transMsg += '.'
return string.join( charsText, ' ' ) + " (" + transMsg + ")"
def __warn( self, msg ):
"""
Method used for displaying warnings.
"""
print "warning: " + msg
def __xorByLong( self, source, mask=0L ):
retVal = self.__cutOffToLong(source) ^ self.__cutOffToLong(mask)
return retVal
#---------------------------------------------------------------------------
# public methods
#---------------------------------------------------------------------------
def close( self ):
"""
Closes a connection to the MySQL server. Can be safely called
multiple times (although subsequent calls don't do anything useful).
"""
if self.socket:
if self.isConnected > 0:
self.__debug( "sending quit command" )
msg = self.__buildMsg( Connection.CMD_QUIT, [] )
self.__dumpMsg( msg )
self.socket.send( msg )
self.isConnected = 0
self.__debug( "closing socket" )
self.socket.close()
self.socket = None
def createDatabase( self, dbName ):
"""
Creates the database with the passed dbName. The user must have
sufficient priviledges to do this.
"""
return self.__executeCmd( Connection.CMD_CREATE_DB, dbName )
def dictRowIterator( self ):
"""
Returns a dictionary row iterator for the selected data. Returns None
if a select statement has not been executed. Otherwise, returns a
RowIterator object that can be used to iterate through all the rows
retrieved. Each row is a dictionary where the key is of the form
<table_name>.<column_name> and the key value is the column value
retrieved from the database. Null column values are returned as None.
"""
if not self.__hasSelectedRecord():
return None
iter = RowIterator(
self.selectedRecord, RowIterator.DICT_ROW, self.isDebug )
self.selectedRecord = None
iter.parse()
return iter
def dropDatabase( self, dbName ):
"""
Drops the database named by the passed dbName. The user must have
sufficient priviledges to do this.
"""
return self.__executeCmd( Connection.CMD_DROP_DB, dbName )
def executeSQL( self, sqlText ):
"""
Executes the passed SQL statement on the server. Returns None on
error. On a successful select, returns 1. All other statements return
the number of rows affected.
"""
return self.__executeCmd( Connection.CMD_QUERY, sqlText )
def getErrorCode( self ):
"""
Retrieves the error code returned by the last query to the server. If
no error was encountered, returns None.
"""
return self.errorCode
def getErrorMsg( self ):
"""
Returns the server error message returned by the last query. If no
error was encountered, returns an empty string.
"""
return self.serverMsg
def getInsertId( self ):
"""
Returns the id of the row that was inserted if the table being
inserted into has a unique sequence column. The value returned is
meaningless if this is called after executing non-insert statements
or tables that do not have a unique sequence column.
"""
return self.insertId
class RowIterator:
DICT_ROW = 0
LIST_ROW = 1 # not implemented yet
NULL_COLUMN = 251
UNSIGNED_CHAR_COLUMN = 251
UNSIGNED_SHORT_COLUMN = 252
UNSIGNED_INT24_COLUMN = 253
UNSIGNED_INT32_COLUMN = 254
UNSIGNED_INT32_PAD_LENGTH = 4
#---------------------------------------------------------------------------
# class private/protected methods
#---------------------------------------------------------------------------
def __init__( self, msg, rowType, debug=0 ):
self.msg = msg
self.position = 0
self.columns = []
self.isDebug = debug
self.rowType = rowType
self.index = 0
def __iter__( self ):
return self
def __debug( self, msg ):
if self.isDebug > 0:
print "mysql: *debug* " + msg
def __fetchRowAsDict( self ):
if self.__isEndOfMsg():
return None
result = []
for i in range(0, self.columnLength):
if self.columns[i]['table']:
column = self.columns[i]['table'] + "." + self.columns[i]['column']
else:
column = self.columns[i]['column']
value = self.__getStringAndSeekPos()
result.append( (column,value) )
self.position += 4
return dict(result)
def __getColumnLength( self ):
self.position += 4
self.columnLength = ord( self.msg[self.position] )
self.__debug( "column length is " + str( self.columnLength ) )
self.position += 5
def __getColumnNames( self ):
for i in range( 0, self.columnLength ):
self.columns.append(
{ 'table': self.__getStringAndSeekPos(),
'column': self.__getStringAndSeekPos() } )
self.position += 14
self.position += 5
if self.isDebug > 0:
columns = []
for i in self.columns:
columns.append( str(i['table']) + "." + str(i['column']) )
self.__debug( string.join( columns, ", " ) )
def __getFieldLength( self ):
self.__debug(
"pos=" + str(self.position) + " len=" + str(len(self.msg)) )
head = ord( self.msg[self.position] )
self.position += 1
if head == RowIterator.NULL_COLUMN:
return None
if head < RowIterator.UNSIGNED_CHAR_COLUMN:
return head
elif head == RowIterator.UNSIGNED_SHORT_COLUMN:
length = struct.unpack(
"<H", self.msg[self.position:self.position+2] )[0]
self.position += 2
return length
elif head == RowIterator.UNSIGNED_INT24_COLUMN:
int24 = self.msg[self.position:self.position+3]
length = int(struct.unpack(
'B', int24[0] )[0]) + int(struct.unpack(
'B', int24[1] )[0]) << 8 + int(struct.unpack(
'B', int24[2] )[0]) << 16
self.position += 3
return length
else:
int32 = self.msg[self.position:self.position+4]
length = int(struct.unpack(
'B', int32[0] )[0]) + (int(struct.unpack(
'B', int32[1] )[0]) << 8) + (int(struct.unpack(
'B', int32[2] )[0]) << 16) + (int(struct.unpack(
'B', int32[3] )[0]) << 32)
self.position += 4
self.position += RowIterator.UNSIGNED_INT32_PAD_LENGTH
return length
def __getStringAndSeekPos( self ):
length = self.__getFieldLength()
if not length:
return None
tmpString = self.msg[self.position:self.position+length]
self.__debug( "tmpString=[" + tmpString + "]" )
self.position += length
return tmpString
# TODO
#
# Figure out a nice way to clean this HACK up so that special characters
# are displayed correctly. This seems to work for French but I have
# no idea if it works for other characters.
#
##retList = []
##i = 0
##while i < len(tmpString):
## if ord(tmpString[i]) == ord('\xc3'):
## retList.append(chr(ord(tmpString[i+1])+(ord('\xe9')-ord('\xa9'))))
## i += 2
## else:
## retList.append(tmpString[i])
## i += 1
##return string.join(retList, "")
def __isEndOfMsg( self ):
return len( self.msg ) <= self.position + 1
def parse( self ):
self.__getColumnLength()
self.__getColumnNames()
#---------------------------------------------------------------------------
# class public methods
#---------------------------------------------------------------------------
def getRowCount( self ):
"""
Returns the number of rows retrieved. This value is incremented as
each call to next() is made. If you want the actual number of rows
retrieved from the database, you will have to call next() until the
StopIteration exception is raised before calling getRowCount().
"""
return self.index
def next( self ):
"""
Returns the next row retrieved from the database. Raises the
StopIteration exception when all rows have been retrieved.
"""
if self.__isEndOfMsg():
raise StopIteration
self.index = self.index + 1
if self.rowType == RowIterator.DICT_ROW:
return self.__fetchRowAsDict()
else:
raise ClientException, "unsupported row type"
|