# -*- Mode: Python; tab-width: 4 -*-

# Note: this should be split into two modules, one which is unaware of
# the distinction between blocking and coroutine sockets.

# Strategies:
#   1) Simple; schedule at the socket level.
#      Just like the mysql client library, when a coroutine accesses
#        the mysql client object, it will automatically detach when
#        the socket gets EWOULDBLOCK.
#   2) Smart; schedule at the request level.
#      Use a separate coroutine to manage the mysql connection.
#      A client coroutine will resume when a response is available.
#   3) Sophisticated; schedule at the row level.
#      Allow a client coroutine to peel off rows one at a time.
#

# Going with #1 for now, for maximum compatibility with blocking socket.

import exceptions
import math
import socket
import string
import sys

class MySQLError (exceptions.Exception):
    pass

# ===========================================================================
#                           Authentication
# ===========================================================================

# Note: I've ignored the stuff to support an older version of the protocol.
#
# The code is based on the file mysql-3.21.33/client/password.c
#
# The auth scheme is challenge/response.  Upon connection the server
# sends an 8-byte challenge message.  This is hashed with the password
# to produce an 8-byte response.  The server side performs an identical
# hash to verify the password is correct.

class random_state:

    def __init__ (self, seed, seed2):
        self.max_value = 0x3FFFFFFF
        self.seed = seed % self.max_value
        self.seed2 = seed2 % self.max_value
        
    def rnd (self):
        self.seed = (self.seed * 3 + self.seed2) % self.max_value
        self.seed2 = (self.seed + self.seed2 + 33) % self.max_value
        return float(self.seed)/ float(self.max_value)

def hash_password (password):
    nr=1345345333L
    add=7
    nr2=0x12345671L
    for ch in password:
        if (ch == ' ') or (ch == '\t'):
            continue
        tmp = ord(ch)
        nr = nr ^ (((nr & 63) + add) * tmp) + (nr << 8)
        nr2 = nr2 + ((nr2 << 8) ^ nr)
        add = add + tmp
    return (
        nr & ((1L<<31)-1L),
        nr2 & ((1L<<31)-1L)
        )

def scramble (message, password):
    hash_pass = hash_password (password)
    hash_mess = hash_password (message)
    r = random_state (
        hash_pass[0] ^ hash_mess[0],
        hash_pass[1] ^ hash_mess[1]
        )
    to = []
    for ch in message:
        to.append (int (math.floor ((r.rnd() * 31) + 64)))
    extra = int (math.floor (r.rnd()*31))
    for i in range(len(to)):
        to[i] = to[i] ^ extra
    return to

# ===========================================================================
#                           Packet Protocol
# ===========================================================================

def unpacket (p):
    # 3-byte length, one-byte packet number, followed by packet data
    a,b,c,s = map (ord, p[:4])
    l = a | (b << 8) | (c << 16)
    # s is a sequence number
    return l, s

def packet (data, s=0):
    l = len(data)
    a, b, c = l & 0xff, (l>>8) & 0xff, (l>>16) & 0xff
    h = map (chr, [a,b,c,s])
    return string.join (h,'') + data

def n_byte_num (data, n):
    result = 0
    for i in range(n):
        result = result | (ord(data[i])<<(8*i))
    return result

def net_field_length (data, pos=0):
    n = ord(data[pos])
    if n < 251:
        return n, 1
    elif n == 251:
        return None, 1
    elif n == 252:
        return n_byte_num (data, 2), 3
    elif n == 253:
        return n_byte_num (data, 3), 4
    else:
        # libmysql adds 6, why?
        return n_byte_num (data, 4), 5

# used to generate the dumps below
def dump_hex (s):
    r1 = []
    r2 = []
    for ch in s:
        r1.append (' %02x' % ord(ch))
        if (ch in string.letters) or (ch in string.digits):
            r2.append ('  %c' % ch)
        else:
            r2.append ('   ')
    return string.join (r1, ''), string.join (r2, '')

class mysql_client:

    counter = 0

    def __init__ (self, socket, username, password, address=('127.0.0.1', 3306)):
        self.socket = socket
        self.username = username
        self.password = password
        self.address = address
        mysql_client.counter = mysql_client.counter + 1
        self.client_number = mysql_client.counter
        
    def connect (self):
        self.socket.connect (self.address)

    def read (self, n):
        blocks = []
        while n:
            data = self.socket.recv (n)
            if not data:
                raise MySQLError, "connection closed unexpectedly"
            else:
                blocks.append (data)
                n = n - len(data)
        return string.join (blocks, '')

    def write (self, data):
        ln = len(data)
        while data:
            n = self.socket.send (data)
            if not n:
                raise MySQLError, "error sending data"
            else:
                data = data[n:]
        return ln

    debug = 0

    def send_packet (self, data, sequence=0):
        if self.debug:
            print '--> %03d' % sequence
            a, b = dump_hex (data)
            print a
            print b
        self.write (packet (data, sequence))

    def read_packet (self):
        header = self.read (4)
        # 3-byte length, one-byte packet number, followed by packet data
        a,b,c,s = map (ord, header)
        l = a | (b << 8) | (c << 16)
        # l is length, s is a sequence number
        data = self.read (l)
        if self.debug:
            a,b = dump_hex (data)
            print '<-- %03d' % s
            print a
            print b
        return s, data

    def login (self):
        seq, data = self.read_packet()
        print 'login packet?', seq, data
        # unpack the greeting
        protocol_version = ord(data[0])
        eos = string.find (data, '\000')
        mysql_version = data[1:eos]
        thread_id = n_byte_num (data[eos+1:eos+5], 4)
        challenge = data[eos+5:eos+13]
        auth = (
            protocol_version,
            mysql_version,
            thread_id,
            challenge
            )
        lp = self.build_login_packet (challenge)
        # seems to require a sequence number of one
        self.send_packet (lp, 1)
        seq, data = self.read_packet()
        if (seq, data) != (2, '\000\000\000'):
            raise MySQLError, "login failed: %s" % repr(data)

    def build_login_packet (self, challenge):
        auth = string.join (map (chr, scramble (challenge, self.password)), '')
        # 2 bytes of client_capability
        # 3 bytes of max_allowed_packet
        # no idea what they are
        return '\005\000\000\000\020' + self.username + '\000' + auth

    # from mysql-3.21.33/include/mysql_com.h.in
    #

    cmds = [
        'sleep', 'quit', 'init_db', 'query', 'field_list', 'create_db',
        'drop_db', 'refresh', 'shutdown', 'statistics', 'process_info',
        'connect', 'process_kill', 'debug'
        ]

    d = {}
    for i in range (len (cmds)):
        d[cmds[i]] = i
    cmds = d
    del d

    def command (self, command_type, command):
        q = chr(self.cmds[command_type]) + command
        self.send_packet (q, 0)

    def cmd_use (self, database):
        self.command ('init_db', database)
        seq, data = self.read_packet()
        if data != '\000\000\000':
            raise MySQLError, repr(data)

    def cmd_query (self, query):
        self.command ('query', query)
        # read fields
        seq, data = self.read_packet()
        nfields = ord(data[0])
        fields = []
        while 1:
            seq, data = self.read_packet()
            if data == chr(0xfe):
                break
            else:
                fields.append (data)
        if len(fields) != nfields:
            raise MySQLError, "number of fields didn't match"
        # read rows
        rows = []
        while 1:
            seq, data = self.read_packet()
            if data == chr(0xfe):
                break
            else:
                rows.append (data)
        for i in range(len(rows)):
            print '%03d %03d %s' % (self.client_number, i, repr(rows[i]))

    def cmd_quit (self):
        self.command ('quit', '')
        # no reply?

def test (s):
    try:
        c = mysql_client (s, 'username', 'password', ('1.2.3.4', 3306))
        print 'connecting...'
        c.connect()
        print 'logging in...'
        c.login()
        print c
        c.cmd_use ('mysql')
        c.cmd_query ('select * from host')
        c.cmd_quit()
        coroutine.main (None)
    except:
        import traceback
        traceback.print_exc()

if __name__ == '__main__':
    try:
        import coroutine
        import corosock

        for i in range(10):
            s = corosock.coroutine_socket ()
            s.create_socket (socket.AF_INET, socket.SOCK_STREAM)
            c = coroutine.new (test)
            corosock.schedule (c, (s,))

        corosock.event_loop (30.0)
    except:
        import traceback
        traceback.print_exc()
