mysql.py :  » Game-2D-3D » XBMC-MythTV » xbmcmythtv » Python Open Source

Home
Python Open Source
1.3.1.2 Python
2.Ajax
3.Aspect Oriented
4.Blog
5.Build
6.Business Application
7.Chart Report
8.Content Management Systems
9.Cryptographic
10.Database
11.Development
12.Editor
13.Email
14.ERP
15.Game 2D 3D
16.GIS
17.GUI
18.IDE
19.Installer
20.IRC
21.Issue Tracker
22.Language Interface
23.Log
24.Math
25.Media Sound Audio
26.Mobile
27.Network
28.Parser
29.PDF
30.Project Management
31.RSS
32.Search
33.Security
34.Template Engines
35.Test
36.UML
37.USB Serial
38.Web Frameworks
39.Web Server
40.Web Services
41.Web Unit
42.Wiki
43.Windows
44.XML
Python Open Source » Game 2D 3D » XBMC MythTV 
XBMC MythTV » xbmcmythtv » mysql.py
"""
A MySQL client library written entirely in python.

$Id: mysql.py,v 1.5 2006/07/25 07:54:22 frooby Exp $

A pure python implementation of an interface to the MySQL database.  This
module is written entirely in python and does not require linking to any
dyanmic libraries or compiling of C driver code to connect to a MySQL
database.  This is handy for systems that have a python port but do not have
the MySQL database/libraries/dlls ported yet (e.g. Xbox Media Center).

There are no plans to implement additional features in this module.  It does
what I need it to do and that's as far as I have taken it.  If you would
like to contribute additional functionality to this, feel free to send me
any additions/changes/bug reports.

This python module is based on a pure perl implementation of an interface to
MySQL written by Hiroyuki Oyama.

Here are some simple examples of how to use this module:

import mysql

conn = mysql.Connection(
    host="myhost",
    database="mydb",
    user="myuser",
    password="mypass",
    timeout=30,
    port=3306,
    debug=0 )

if conn.executeSQL( "select * from db where hostname like '%myhost%'" ):
    iter = conn.dictRowIterator()
    for row in iter:
        for k in row.keys():
            v = row[k]
            if not v:
                print k + "=>[(null)]",
            else:
                print k + "=>[" + v + "]",
        print
    print str(iter.getRowCount()) + " rows retrieved"
else:
    print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()

if conn.executeSQL(
    "insert into person (id, first, last) values (null, 'Tom','Warkentin')" ):
    print "insert succeeded: id=" + str(conn.getInsertId())
else:
    print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()

if conn.executeSQL(
    "update settings set data = 'my new value' where value = 'my data value'" ):
    print "update succeeded"
else:
    print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()

if conn.executeSQL( "delete from db where hostname = 'myhost'" ):
    print "delete succeeded"
else:
    print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()

if conn.createDatabase( "mydb" ):
    print "database created"
else:
    print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()

if conn.dropDatabase( "mydb" ):
    print "database dropped"
else:
    print "ERROR " + str(conn.getErrorCode()) + ": " + conn.getErrorMsg()

conn.close()


Copyright (C) 2004 Tom Warkentin <tom@ixionstudios.com>

This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
or at the website: http://www.gnu.org

"""
import re, socket, string, struct

# ugly hack for problem passing Connection object to class contructor on xbox
global db
db = None

def escape( obj ):
    """
    Method to escape a string.  If the object is not defined, returns "NULL".
    This method is intended to be used to escape values sent to mysql in an
    insert or update statement.

    Example:

    s = "they're here"
    print escape( s )
    # will return: 'they\'re\ here'

    print escape( None )
    # will return: NULL
    """
    if type( obj ) is int:
        return "%d"%obj
    elif obj == None:
        return "NULL"
    elif type( obj ) is str or type( obj ) is string:
        return "'%s'"%re.escape( obj )
    else:
        return obj

class ClientException( Exception ):
    pass

class ServerException( Exception ):
    pass

class Connection:
    BUFFER_LENGTH   = 1024
    CMD_QUIT        = "\x01"
    CMD_QUERY       = "\x03"
    CMD_CREATE_DB   = "\x05"
    CMD_DROP_DB     = "\x06"

    #---------------------------------------------------------------------------
    # class private/protected methods
    #---------------------------------------------------------------------------
    def __andByChar( self, source, mask ):
        retVal = source & mask
        return retVal

    def __andByLong( self, source, mask=4294967295L ):  # 0xFFFFFFFF
        retVal = self.__cutOffToLong(source) & self.__cutOffToLong(mask)
        return retVal

    def __buildMsg( self, cmd, args, flags=0 ):
        """
        Function to build a message to send to the server.
        
        cmd     - Command to send to the server
        args    - Arguments associated with the cmd.  No checking is done to
                  verify that the args passed are valid for the cmd passed.
                  The caller is trusted to know what the valid args are.
        flags   - Some commands require additional flags.  I'm not too sure why
                  or what effect this has on the receiving end since I have
                  not looked at MySQL protocol source code.
        """
        body = cmd
        body += string.join( args, "\0" )

        msg = struct.pack( "<H", len( body ) )
        msg += struct.pack( "B", 0 )
        msg += struct.pack( "B", flags ) + body

        return msg

    def __connect( self ):
        """
        Method to create and connect a socket to a MySQL database server socket.
        Attempts to set the timeout but continues even if the timeout cannot
        be set.
        """
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM )
        s.setblocking( 1 )

        try:
            s.settimeout( self.timeout )
        except:
            self.warn( "failed to set socket timeout" )

        s.connect( (self.host, self.port) )
        self.socket = s

    def __cutOffToLong( self, source ):
        while source > 4294967295L:             # 0xFFFFFFFF
            source -= (4294967295L + 1L)
        return source

    def __debug( self, msg ):
        """
        Outputs the passed message if debug is enabled.
        """
        if self.isDebug > 0:
            print "mysql: *debug* " + msg

    def __del__( self ):
        """
        Destructor for an instance of the class.
        """
        Connection.close( self )

    def __dumpMsg( self, msg ):
        """
        Dumps the passed message if debugging is enabled.
        """
        if self.isDebug > 0:
            self.__debug( self.__sprintMsg( msg ) )

    def __executeCmd( self, cmd, args ):
        """
        Executes the specified command on the server.

        cmd     - The command to be executed.
        args    - The arguments for the command.  The caller is assumed to
                  know what the valid args are for the cmd passed.
        """
        msg = self.__buildMsg( cmd, [args] )
        self.__dumpMsg( msg )
        self.socket.send( msg )

        result = self.socket.recv( Connection.BUFFER_LENGTH )
        self.__dumpMsg( result )
        self.__resetStatus()

        if self.__isError( result ):
            return self.__setErrorByPacket( result )
        elif self.__isSelectQueryResult( result ):
            return self.__getRecordByServer( result )
        elif self.__isUpdateQueryResult( result ):
            # treat database drop as a special case - server returns 0 but we
            # return 1 to indicate success
            if cmd == Connection.CMD_DROP_DB:
                return 1
            else:
                return self.__getAffectedRowsInfoByMsg( result )
        else:
            text = "unknown response: " + self.__sprintMsg( result )
            raise ClientException, text

    def __getAffectedRowsInfoByMsg( self, msg ):
        self.affectedRowsLength = self.__getAffectedRowsLength( msg )
        self.__debug( "affectedRowsLength=%d"%self.affectedRowsLength )
        self.insertId = self.__getInsertId( msg )
        self.__debug( "insertId=%d"%self.insertId )
        self.serverMsg = ''
        return self.affectedRowsLength

    def __getAffectedRowsLength( self, msg ):
        return ord( msg[5] )

    def __getColumnLength( self, msg ):
        "Retrieves the number of columns in a server response."
        return ord( msg[4] )

    def __getErrorCode( self, msg ):
        """
        Retrieves the error code from the passed message. Raises an exception
        if the passed message is not an error message.
        """
        if not self.__isError( msg ):
            raise ClientException, "invalid error msg: " + self.__sprintMsg( msg )
        return struct.unpack( "<H", msg[5:7] )[0]

    def __getHash( self, password ):
        """
        Calculates a hash value based on the password passed in.
        """
        nr = 1345345333L
        add = 7L
        nr2 = 305419889L        # 0x12345671
        pwlen = len(password)
        for i in range(0,pwlen):
            c = password[i]
            if c == ' ' or c == '\t':
                continue
            tmp = ord( c )
            value = ((self.__andByChar(nr,63) + add) * tmp) + nr * 256
            nr = self.__xorByLong( nr, value )
            nr2 += self.__xorByLong( (nr2 * 256), nr )
            add += tmp
        return (int(self.__andByLong( nr, 2147483647L)),        # 0x7FFFFFFF
            int(self.__andByLong( nr2, 2147483647L )))

    def __getInsertId( self, msg ):
        if ord( msg[6] ) != 0xfc:
            return ord( msg[6] )
        else:
            return struct.unpack( "<H", msg[7:9] )[0]

    def __getRecordByServer( self, msg ):
        self.__getColumnLength( msg )
        while self.__hasNextMsg( msg ):
            nextMsg = self.socket.recv( Connection.BUFFER_LENGTH )
            self.__dumpMsg( nextMsg )
            msg += nextMsg
        self.selectedRecord = msg
        return 1

    def __getServerInfo( self ):
        """
        """
        msg = self.socket.recv( Connection.BUFFER_LENGTH )
        self.__dumpMsg( msg )
        if self.__isError( msg ):
            raise ServerException, msg[7:]

        # get length of message
        i = 0
        msgLen = struct.unpack("<I", msg[i:4])[0]
        i += 4
        self.__debug( "msgLen is " + str( msgLen ) )

        # get protocol version
        self.protocolVersion = ord( msg[i] )
        i += 1
        self.__debug( "protocol version is " + str(self.protocolVersion) )
        if self.protocolVersion == 10:
            self.clientCapabilities = 1

        # get server version
        end = msg.find( "\0", i )
        self.serverVersion = msg[i:end]
        self.__debug( "server version is '" + self.serverVersion  + "'" )
        i = end + 1

        # get server thread id
        self.serverThreadId = struct.unpack( "<I", msg[i:i+4] )[0]
        self.__debug( "server thread id is " + str( self.serverThreadId ) )
        i += 4

        # get salt for password hashing
        self.salt = msg[i:i+8]
        self.__debug( "salt is [" + self.salt + "]" )

    def __getServerMsg( self, msg ):
        if len( msg ) < 7:
            return ''
        return msg[7:]

    def __hasNextMsg( self, msg ):
        return msg[-1] != '\xfe'

    def __hasSelectedRecord( self ):
        return self.selectedRecord

    def __init__( \
        self, host, database, user, password, \
        port=3306, timeout=60.0, debug=0 ):
        """
        Initialize a new connection to a MySQL database.
        """
        self.isConnected = 0
        self.host = host
        self.port = port
        self.database = database
        self.user = user
        self.password = password
        self.timeout = timeout
        self.socket = None
        self.salt = None
        self.protocolVersion = None
        self.clientCapabilities = None
        self.isDebug = debug
        self.__debug( "debug output enabled" )

        self.__connect()
        self.__getServerInfo()
        self.__requestAuth()
        self.__resetStatus()

    def __isError( self, msg ):
        if len( msg ) < 4:
            return 1
        return ord( msg[4:5] ) == 255

    def __isSelectQueryResult( self, msg ):
        if self.__isError( msg ):
            return None
        return ord(msg[4]) >= 1

    def __isUpdateQueryResult( self, msg ):
        if self.__isError( msg ):
            return None
        return ord(msg[4]) == 0

    def __requestAuth( self ):
        """
        """
        self.__sendLoginMsg()
        authResult = self.socket.recv( Connection.BUFFER_LENGTH )
        self.__dumpMsg( authResult )
        if self.__isError( authResult ):
            self.close()
            if len(authResult) < 7:
                raise ClientException, "authentication timeout"
            raise ServerException, authResult[7:]
        self.__debug( "connected to database successfully" )
        self.isConnected = 1

    def __resetStatus( self ):
        self.insertId = None
        self.serverMsg = ''
        self.errorCode = None
        self.selectedRecord = None

    def __scramblePassword( self, password, salt, flags ):
        """
        """
        if len(password) == 0:
            return ''
        hsl = len( salt )
        out = []
        hashPass = self.__getHash(password)
        hashMess = self.__getHash(salt)
        maxValue = None
        seed = None
        seed2 = None
        if flags < 1:
            maxValue = 33554431L        # 0x01FFFFFF
            seed = self.__xorByLong(hashPass[0], hashMess[0]) % maxValue
            seed2 = int(seed / 2)
        else:
            maxValue = 1073741823L      # 0x3FFFFFFF
            seed = self.__xorByLong(hashPass[0], hashMess[0]) % maxValue
            seed2 = self.__xorByLong(hashPass[1], hashMess[1]) % maxValue
        dMax = maxValue
        dRes = None
        dSeed = None

        for i in range(0,hsl):
            val = seed * 3L + seed2
            seed = int(val % maxValue)
            seed2 = int((seed + seed2 + 33) % maxValue)
            dSeed = seed
            dRes = float(dSeed) / float(dMax)
            out.append( int(dRes * 31) + 64 )

        if flags == 1:
            seed = (seed * 3L + seed2) % maxValue
            seed2 = (seed + seed2 + 33L) % maxValue
            dSeed = seed

            dRes = float(dSeed) / float(dMax)
            e = int(dRes * 31)
            for i in range(0,hsl):
                out[i] ^= e
        retVal = ''
        for i in range(0,len(out)):
            retVal += chr(out[i])
        return retVal

    def __sendLoginMsg( self ):
        """
        Sends a login message to the server.
        """
        msg = self.__buildMsg(
            cmd = "\x8D\x00\x00\x00\x00",
            args = [
                self.user,
                self.__scramblePassword(
                    self.password,
                    self.salt,
                    self.clientCapabilities),
                self.database ],
            flags = 1 )
        self.__dumpMsg( msg )
        self.socket.send( msg )

    def __setErrorByPacket( self, msg ):
        self.serverMsg = self.__getServerMsg( msg )
        self.errorCode = self.__getErrorCode( msg )
        return None

    def __sprintMsg( self, msg ):
        """
        """
        charsText = []
        for i in range( 0, len( msg ) ):
            charsText.append( "%02x" % ord( msg[i] ) )
        transMsg = ''
        for i in range( 0, len( msg ) ):
            if re.match('[\d \w\!-\/\:-\@\[-\`\{-\~]', msg[i]):
                transMsg += msg[i]
            else:
                transMsg += '.'
        return string.join( charsText, ' ' ) + " (" + transMsg + ")"

    def __warn( self, msg ):
        """
        Method used for displaying warnings.
        """
        print "warning: " + msg

    def __xorByLong( self, source, mask=0L ):
        retVal = self.__cutOffToLong(source) ^ self.__cutOffToLong(mask)
        return retVal

    #---------------------------------------------------------------------------
    # public methods
    #---------------------------------------------------------------------------
    def close( self ):
        """
        Closes a connection to the MySQL server.  Can be safely called
        multiple times (although subsequent calls don't do anything useful).
        """
        if self.socket:
            if self.isConnected > 0:
                self.__debug( "sending quit command" )
                msg = self.__buildMsg( Connection.CMD_QUIT, [] )
                self.__dumpMsg( msg )
                self.socket.send( msg )
                self.isConnected = 0
            self.__debug( "closing socket" )
            self.socket.close()
            self.socket = None

    def createDatabase( self, dbName ):
        """
        Creates the database with the passed dbName.  The user must have
        sufficient priviledges to do this.
        """
        return self.__executeCmd( Connection.CMD_CREATE_DB, dbName )
        
    def dictRowIterator( self ):
        """
        Returns a dictionary row iterator for the selected data.  Returns None
        if a select statement has not been executed.  Otherwise, returns a
        RowIterator object that can be used to iterate through all the rows
        retrieved.  Each row is a dictionary where the key is of the form
        <table_name>.<column_name> and the key value is the column value
        retrieved from the database.  Null column values are returned as None.
        """
        if not self.__hasSelectedRecord():
            return None

        iter = RowIterator(
            self.selectedRecord, RowIterator.DICT_ROW, self.isDebug )
        self.selectedRecord = None
        iter.parse()
        return iter

    def dropDatabase( self, dbName ):
        """
        Drops the database named by the passed dbName.  The user must have
        sufficient priviledges to do this.
        """
        return self.__executeCmd( Connection.CMD_DROP_DB, dbName )
        
    def executeSQL( self, sqlText ):
        """
        Executes the passed SQL statement on the server.  Returns None on
        error.  On a successful select, returns 1.  All other statements return
        the number of rows affected.
        """
        return self.__executeCmd( Connection.CMD_QUERY, sqlText )

    def getErrorCode( self ):
        """
        Retrieves the error code returned by the last query to the server.  If
        no error was encountered, returns None.
        """
        return self.errorCode

    def getErrorMsg( self ):
        """
        Returns the server error message returned by the last query.  If no
        error was encountered, returns an empty string.
        """
        return self.serverMsg

    def getInsertId( self ):
        """
        Returns the id of the row that was inserted if the table being
        inserted into has a unique sequence column.  The value returned is
        meaningless if this is called after executing non-insert statements
        or tables that do not have a unique sequence column.
        """
        return self.insertId

class RowIterator:
    DICT_ROW                    = 0
    LIST_ROW                    = 1     # not implemented yet

    NULL_COLUMN                 = 251
    UNSIGNED_CHAR_COLUMN        = 251
    UNSIGNED_SHORT_COLUMN       = 252
    UNSIGNED_INT24_COLUMN       = 253
    UNSIGNED_INT32_COLUMN       = 254
    UNSIGNED_INT32_PAD_LENGTH   = 4

    #---------------------------------------------------------------------------
    # class private/protected methods
    #---------------------------------------------------------------------------
    def __init__( self, msg, rowType, debug=0 ):
        self.msg = msg
        self.position = 0
        self.columns = []
        self.isDebug = debug
        self.rowType = rowType
        self.index = 0

    def __iter__( self ):
        return self

    def __debug( self, msg ):
        if self.isDebug > 0:
            print "mysql: *debug* " + msg

    def __fetchRowAsDict( self ):
        if self.__isEndOfMsg():
            return None
        result = []
        for i in range(0, self.columnLength):
            if self.columns[i]['table']:
                column = self.columns[i]['table'] + "." + self.columns[i]['column']
            else:
                column = self.columns[i]['column']
            value = self.__getStringAndSeekPos()
            result.append( (column,value) )
        self.position += 4
        return dict(result)

    def __getColumnLength( self ):
        self.position += 4
        self.columnLength = ord( self.msg[self.position] )
        self.__debug( "column length is " + str( self.columnLength ) )
        self.position += 5

    def __getColumnNames( self ):
        for i in range( 0, self.columnLength ):
            self.columns.append(
                { 'table': self.__getStringAndSeekPos(),
                'column': self.__getStringAndSeekPos() } )
            self.position += 14
        self.position += 5
        if self.isDebug > 0:
            columns = []
            for i in self.columns:
                columns.append( str(i['table']) + "." + str(i['column']) )
            self.__debug( string.join( columns, ", " ) )

    def __getFieldLength( self ):
        self.__debug(
            "pos=" + str(self.position) + " len=" + str(len(self.msg)) )
        head = ord( self.msg[self.position] )
        self.position += 1

        if head == RowIterator.NULL_COLUMN:
            return None
        if head < RowIterator.UNSIGNED_CHAR_COLUMN:
            return head
        elif head == RowIterator.UNSIGNED_SHORT_COLUMN:
            length = struct.unpack(
                "<H", self.msg[self.position:self.position+2] )[0]
            self.position += 2
            return length
        elif head == RowIterator.UNSIGNED_INT24_COLUMN:
            int24 = self.msg[self.position:self.position+3]
            length = int(struct.unpack(
                'B', int24[0] )[0]) + int(struct.unpack(
                'B', int24[1] )[0]) << 8 + int(struct.unpack(
                'B', int24[2] )[0]) << 16
            self.position += 3
            return length
        else:
            int32 = self.msg[self.position:self.position+4]
            length = int(struct.unpack(
                'B', int32[0] )[0]) + (int(struct.unpack(
                'B', int32[1] )[0]) << 8) + (int(struct.unpack(
                'B', int32[2] )[0]) << 16) + (int(struct.unpack(
                'B', int32[3] )[0]) << 32)
            self.position += 4
            self.position += RowIterator.UNSIGNED_INT32_PAD_LENGTH
            return length

    def __getStringAndSeekPos( self ):
        length = self.__getFieldLength()
        if not length:
            return None

        tmpString = self.msg[self.position:self.position+length]
        self.__debug( "tmpString=[" + tmpString + "]" )
        self.position += length
        return tmpString

        # TODO
        #
        # Figure out a nice way to clean this HACK up so that special characters
        # are displayed correctly.  This seems to work for French but I have
        # no idea if it works for other characters.
        #
        ##retList = []
        ##i = 0
        ##while i < len(tmpString):
        ##    if ord(tmpString[i]) == ord('\xc3'):
        ##        retList.append(chr(ord(tmpString[i+1])+(ord('\xe9')-ord('\xa9'))))
        ##        i += 2
        ##    else:
        ##        retList.append(tmpString[i])
        ##        i += 1
        ##return string.join(retList, "")
        
    def __isEndOfMsg( self ):
        return len( self.msg ) <= self.position + 1

    def parse( self ):
        self.__getColumnLength()
        self.__getColumnNames()

    #---------------------------------------------------------------------------
    # class public methods
    #---------------------------------------------------------------------------
    def getRowCount( self ):
        """
        Returns the number of rows retrieved.  This value is incremented as
        each call to next() is made.  If you want the actual number of rows
        retrieved from the database, you will have to call next() until the
        StopIteration exception is raised before calling getRowCount().
        """
        return self.index

    def next( self ):
        """
        Returns the next row retrieved from the database.  Raises the
        StopIteration exception when all rows have been retrieved.
        """
        if self.__isEndOfMsg():
            raise StopIteration
        self.index = self.index + 1
        if self.rowType == RowIterator.DICT_ROW:
            return self.__fetchRowAsDict()
        else:
            raise ClientException, "unsupported row type"

www.java2java.com | Contact Us
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.