# -*- Mode: Python; tab-width: 4 -*- # Note: this should be split into two modules, one which is unaware of # the distinction between blocking and coroutine sockets. # Strategies: # 1) Simple; schedule at the socket level. # Just like the mysql client library, when a coroutine accesses # the mysql client object, it will automatically detach when # the socket gets EWOULDBLOCK. # 2) Smart; schedule at the request level. # Use a separate coroutine to manage the mysql connection. # A client coroutine will resume when a response is available. # 3) Sophisticated; schedule at the row level. # Allow a client coroutine to peel off rows one at a time. # # Going with #1 for now, for maximum compatibility with blocking socket. import exceptions import math import socket import string import sys class MySQLError (exceptions.Exception): pass # =========================================================================== # Authentication # =========================================================================== # Note: I've ignored the stuff to support an older version of the protocol. # # The code is based on the file mysql-3.21.33/client/password.c # # The auth scheme is challenge/response. Upon connection the server # sends an 8-byte challenge message. This is hashed with the password # to produce an 8-byte response. The server side performs an identical # hash to verify the password is correct. class random_state: def __init__ (self, seed, seed2): self.max_value = 0x3FFFFFFF self.seed = seed % self.max_value self.seed2 = seed2 % self.max_value def rnd (self): self.seed = (self.seed * 3 + self.seed2) % self.max_value self.seed2 = (self.seed + self.seed2 + 33) % self.max_value return float(self.seed)/ float(self.max_value) def hash_password (password): nr=1345345333L add=7 nr2=0x12345671L for ch in password: if (ch == ' ') or (ch == '\t'): continue tmp = ord(ch) nr = nr ^ (((nr & 63) + add) * tmp) + (nr << 8) nr2 = nr2 + ((nr2 << 8) ^ nr) add = add + tmp return ( nr & ((1L<<31)-1L), nr2 & ((1L<<31)-1L) ) def scramble (message, password): hash_pass = hash_password (password) hash_mess = hash_password (message) r = random_state ( hash_pass[0] ^ hash_mess[0], hash_pass[1] ^ hash_mess[1] ) to = [] for ch in message: to.append (int (math.floor ((r.rnd() * 31) + 64))) extra = int (math.floor (r.rnd()*31)) for i in range(len(to)): to[i] = to[i] ^ extra return to # =========================================================================== # Packet Protocol # =========================================================================== def unpacket (p): # 3-byte length, one-byte packet number, followed by packet data a,b,c,s = map (ord, p[:4]) l = a | (b << 8) | (c << 16) # s is a sequence number return l, s def packet (data, s=0): l = len(data) a, b, c = l & 0xff, (l>>8) & 0xff, (l>>16) & 0xff h = map (chr, [a,b,c,s]) return string.join (h,'') + data def n_byte_num (data, n): result = 0 for i in range(n): result = result | (ord(data[i])<<(8*i)) return result def net_field_length (data, pos=0): n = ord(data[pos]) if n < 251: return n, 1 elif n == 251: return None, 1 elif n == 252: return n_byte_num (data, 2), 3 elif n == 253: return n_byte_num (data, 3), 4 else: # libmysql adds 6, why? return n_byte_num (data, 4), 5 # used to generate the dumps below def dump_hex (s): r1 = [] r2 = [] for ch in s: r1.append (' %02x' % ord(ch)) if (ch in string.letters) or (ch in string.digits): r2.append (' %c' % ch) else: r2.append (' ') return string.join (r1, ''), string.join (r2, '') class mysql_client: counter = 0 def __init__ (self, socket, username, password, address=('127.0.0.1', 3306)): self.socket = socket self.username = username self.password = password self.address = address mysql_client.counter = mysql_client.counter + 1 self.client_number = mysql_client.counter def connect (self): self.socket.connect (self.address) def read (self, n): blocks = [] while n: data = self.socket.recv (n) if not data: raise MySQLError, "connection closed unexpectedly" else: blocks.append (data) n = n - len(data) return string.join (blocks, '') def write (self, data): ln = len(data) while data: n = self.socket.send (data) if not n: raise MySQLError, "error sending data" else: data = data[n:] return ln debug = 0 def send_packet (self, data, sequence=0): if self.debug: print '--> %03d' % sequence a, b = dump_hex (data) print a print b self.write (packet (data, sequence)) def read_packet (self): header = self.read (4) # 3-byte length, one-byte packet number, followed by packet data a,b,c,s = map (ord, header) l = a | (b << 8) | (c << 16) # l is length, s is a sequence number data = self.read (l) if self.debug: a,b = dump_hex (data) print '<-- %03d' % s print a print b return s, data def login (self): seq, data = self.read_packet() print 'login packet?', seq, data # unpack the greeting protocol_version = ord(data[0]) eos = string.find (data, '\000') mysql_version = data[1:eos] thread_id = n_byte_num (data[eos+1:eos+5], 4) challenge = data[eos+5:eos+13] auth = ( protocol_version, mysql_version, thread_id, challenge ) lp = self.build_login_packet (challenge) # seems to require a sequence number of one self.send_packet (lp, 1) seq, data = self.read_packet() if (seq, data) != (2, '\000\000\000'): raise MySQLError, "login failed: %s" % repr(data) def build_login_packet (self, challenge): auth = string.join (map (chr, scramble (challenge, self.password)), '') # 2 bytes of client_capability # 3 bytes of max_allowed_packet # no idea what they are return '\005\000\000\000\020' + self.username + '\000' + auth # from mysql-3.21.33/include/mysql_com.h.in # cmds = [ 'sleep', 'quit', 'init_db', 'query', 'field_list', 'create_db', 'drop_db', 'refresh', 'shutdown', 'statistics', 'process_info', 'connect', 'process_kill', 'debug' ] d = {} for i in range (len (cmds)): d[cmds[i]] = i cmds = d del d def command (self, command_type, command): q = chr(self.cmds[command_type]) + command self.send_packet (q, 0) def cmd_use (self, database): self.command ('init_db', database) seq, data = self.read_packet() if data != '\000\000\000': raise MySQLError, repr(data) def cmd_query (self, query): self.command ('query', query) # read fields seq, data = self.read_packet() nfields = ord(data[0]) fields = [] while 1: seq, data = self.read_packet() if data == chr(0xfe): break else: fields.append (data) if len(fields) != nfields: raise MySQLError, "number of fields didn't match" # read rows rows = [] while 1: seq, data = self.read_packet() if data == chr(0xfe): break else: rows.append (data) for i in range(len(rows)): print '%03d %03d %s' % (self.client_number, i, repr(rows[i])) def cmd_quit (self): self.command ('quit', '') # no reply? def test (s): try: c = mysql_client (s, 'rushing', 'fnord', ('192.168.200.6', 3306)) print 'connecting...' c.connect() print 'logging in...' c.login() print c c.cmd_use ('mysql') c.cmd_query ('select * from host') c.cmd_quit() coroutine.main (None) except: import traceback traceback.print_exc() if __name__ == '__main__': try: import coroutine import corosock for i in range(10): s = corosock.coroutine_socket () s.create_socket (socket.AF_INET, socket.SOCK_STREAM) c = coroutine.new (test) corosock.schedule (c, (s,)) corosock.event_loop (30.0) except: import traceback traceback.print_exc()