# -*- Mode: Python -*-

import coro
import struct
import hashlib
import os

# why: using scp to copy files upstream via my cable modem (comcast) is *horribly* slow.
#  --- example timings to upload a 42MB file ---
#      scp: 284s
# pyftplib: 141s
# udp_xfer:  39s
#
# note: there is *zero* security in this implementation.
#
# usage:
#
#  -- server side --
# python udp_xfer.py -s
#
#  -- client side --
# python udp_xfer.py -c <host-ip> <file-to-send>
#

W = coro.write_stderr

# goal: reliable packet transfer
# non-goal: ordered delivery
# 
# we want to be able to send an arbitrarily-sized object.
# [let's put that in the next layer up]
#
# I think we want at least two layers: one layer is responsible for MTU-sized packets,
#   and manages retransmits and acks.  The second layer builds a reliable transfer on
#   top of the first layer.
#
# we want to push a certain number of packets ahead before
#  waiting for an ack.

#default_mtu = 1500
# this will scan the range 1400-1600 in steps of 20
# ping -D -g 1400 -G 1600 -h 20 <server-host>
default_mtu = 1460
default_addr = ('', 14583)

class PacketTooBig (Exception):
    pass

def strip_path (path):
    _, path = os.path.split (path)
    return path

class server:

    def __init__ (self, addr=default_addr, mtu=default_mtu):
        self.mtu = mtu
        self.sock = coro.udp_sock()
        self.sock.bind (addr)
        self.pending_acks = coro.fifo()
        self.count = 0
        self.total = 0
        coro.spawn (self.ack_thread)
        coro.spawn (self.recv_thread)
        self.seq_rcvd = []
        self.acks_sent = []

    # the protocol is simple - data goes in one direction, acks in the other
    #  acks can be collected together [though a timer here for delayed ack might help?]
    #  an ack packet is just a set of 32-bit sequence numbers.

    def recv_thread (self):
        while 1:
            what, who = self.sock.recvfrom (self.mtu * 2)
            #W (what[4])
            #W ('{%d}' % (len(what),))
            self.count += 1
            if self.count % 100 == 0:
                W ('.')
            # assumes only one <who> right now
            self.who = who
            seq, = struct.unpack ('>L', what[:4])
            self.seq_rcvd.append (seq)
            self.handle_packet (what[4:])
            self.pending_acks.push (seq)

    def ack_thread (self):
        while 1:
            # wait to gather up a few to ack
            coro.sleep_relative (0.02)
            acks = self.pending_acks.pop_all()
            self.acks_sent.extend (acks)
            #W ('<ack: %r>' % (acks,))
            packet = ''.join ([struct.pack ('>L', x) for x in acks])
            if len (packet) > self.mtu:
                W ('warning: packet larger than MTU\n')
            self.sock.sendto (packet, self.who)

    def handle_packet (self, data):
        pass

class file_blocks:
    def __init__ (self, name):
        self.name = strip_path (name)
        self.file = open (name, 'wb')
        self.blocks = {}
        self.bytes = 0
        self.nblocks = None
        self.start = coro.now
        self.next = 0
        self.h = hashlib.sha256()
        self.wrote = 0

    def block (self, index, data):
        if not self.blocks.has_key (index):
            self.blocks[index] = data
            self.bytes += len (data)
            #W ('[%d/%d]' % (index, len(data),))
            self.maybe_write()
        else:
            # duplicate
            #W ('<%d>' % (index,))
            pass

    def maybe_write (self):
        while self.blocks.has_key (self.next):
            block = self.blocks[self.next]
            self.file.write (block)
            self.wrote += 1
            self.h.update (block)
            self.blocks[self.next] = None
            self.next += 1
            if self.nblocks is not None and self.nblocks == self.next:
                self.finish()

    def last_block (self, index):
        self.nblocks = index

    def finish (self):
        elapsed = coro.now - self.start
        elapsed_sec = float (elapsed) / coro.ticks_per_sec
        rate = self.bytes / elapsed_sec
        W ('\ndone with file %r (%s) elapsed=%.2f %d bytes/sec\n' % (
            self.name, self.bytes, elapsed_sec, int(rate)
            ))
        W ('sha256: %s\n' % (self.h.hexdigest(),))
        W ('wrote %d blocks\n' % (self.wrote,))
        self.file.close()

class file_receiver (server):

    def __init__ (self, addr=default_addr):
        server.__init__ (self, addr)
        self.files = {}

    def handle_packet (self, data):
        if data[0] == 'N':
            # new file
            file_id, name_len = struct.unpack ('>HH', data[1:5])
            name = data[5:5+name_len]
            self.files[file_id] = file_blocks (name)
            W ('new file id=%d name=%r who=%r\n' % (file_id, name, self.who))
        elif data[0] == 'b':
            # small-index data block
            file_id, index, = struct.unpack ('>HH', data[1:5])
            self.handle_block (file_id, index, data[5:])
        elif data[0] == 'B':
            # big-index data block
            file_id, index, = struct.unpack ('>HL', data[1:7])
            self.handle_block (file_id, index, data[7:])
        elif data[0] == 'd':
            file_id, index = struct.unpack ('>HL', data[1:7])
            self.handle_last (file_id, index)
        elif data[0] == 'D':
            self.handle_done()
        else:
            raise ValueError ("unknown packet type", data)

    def handle_block (self, file_id, index, data):
        probe = self.files[file_id]
        if probe:
            probe.block (index, data)
        else:
            W ('orphaned file_id? id=%d index=%d\n' % (file_id, index))
            
    def handle_last (self, file_id, index):
        probe = self.files[file_id]
        if probe:
            probe.last_block (index)
        else:
            W ('orphaned file_id? id=%d index=%d\n' % (file_id, index))

    def handle_done (self):
        # for now, we exit after one file.
        coro.spawn (coro.set_exit)

class XferStats:

    def __init__ (self):
        self.sent_p = 0
        self.sent_b = 0
        self.rxmit = 0
        self.last = None

    def sample (self):
        r = XferStats()
        r.sent_p = self.sent_p
        r.sent_b = self.sent_b
        r.rxmit = self.rxmit
        return r

    def delta (self, other):
        return (
            self.sent_p - other.sent_p,
            self.sent_b - other.sent_b,
            self.rxmit - other.rxmit
            )

class client:

    # how long do we wait before retransmitting?
    retransmit_limit = coro.ticks_per_sec / 5

    def __init__ (self, addr=('localhost', 14583), window_size=10, mtu=default_mtu):
        self.addr = addr
        self.mtu = mtu
        self.sock = coro.udp_sock()
        self.sock.connect (addr)
        self.window = coro.semaphore (window_size)
        self.pending = []
        self.tosend = coro.fifo()
        self.resends = 0
        self.late_acks = []
        self.done = False
        self.stats = XferStats()
        coro.spawn (self.ack_thread)
        coro.spawn (self.send_thread)
        coro.spawn (self.timer_thread)
        coro.spawn (self.stats_thread)

    def ack_thread (self):
        try:
            while 1:
                packet = self.sock.recv (self.mtu)
                nack = len (packet) / 4
                acks = set()
                for i in range (0, nack):
                    ack, = struct.unpack ('>L', packet[i*4:(i+1)*4])
                    acks.add (ack)
                n = 0
                i = 0
                #W ('{%d}' % (nack,))
                while i < len (self.pending):
                    t, seq, p = self.pending[i]
                    if seq in acks:
                        del self.pending[i]
                        acks.remove (seq)
                        n += 1
                    else:
                        i += 1
                #if len(acks):
                #    W ('[dup/late acks? %r]' % acks)
                self.window.release (n)
                # tell the server to shut down, it's all good here.
                if self.done and len(self.pending) == 0:
                    self.send (struct.pack ('>c', 'D'))
                    coro.spawn (coro.set_exit)
        except OSError:
            if not self.done:
                raise

    def retransmit (self, packet):
        self.tosend.push (packet)
        self.resends += 1
        self.stats.rxmit += 1

    def send_thread (self):
        seq = 0
        try:
            while 1:
                packet = self.tosend.pop()
                self.pending.append ((coro.get_now(), seq, packet))
                p0 = struct.pack ('>L', seq) + packet
                self.sock.send (p0)
                self.stats.sent_p += 1
                self.stats.sent_b += len (p0)
                seq += 1
                #W ('=> %d %d bytes\n' % (seq, len(packet)))
        except OSError:
            if not self.done:
                raise

    def timer_thread (self):
        while 1:
            coro.sleep_relative (0.1)
            now = coro.now
            n = 0
            while len (self.pending):
                t, seq, p = self.pending[0]
                if now - t > self.retransmit_limit:
                    self.late_acks.append (seq)
                    del self.pending[0]
                    self.retransmit (p)
                    n += 1
                else:
                    break
            #if n > 0:
            #    W ('retransmitted: %d\n' % (n,))

    def stats_thread (self):
        last = self.stats.sample()
        now0 = coro.now
        lines = 0
        while 1:
            coro.sleep_relative (1.0)
            now = coro.now
            dtime = float (now - now0) / coro.ticks_per_sec
            sent_p, sent_b, rxmit = self.stats.delta (last)
            rate = (sent_b * 8) / 1000000.0
            if (lines % 30) == 0:
                W ('%10s %10s %10s %10s\n' % ('packets', 'bytes', 'rxmit', 'rate(Mb/s)'))
            W ('%10d %10d %10d %8.2f\n' % (sent_p, sent_b, rxmit, rate))
            last = self.stats.sample()
            lines += 1
            now0 = now

    def send (self, packet):
        if len(packet) > self.mtu:
            raise PacketTooBig
        self.window.acquire (1)
        self.tosend.push (packet)

file_id_counter = 0

def send_file (client, name, file):
    global file_id_counter
    mtu = client.mtu
    name = strip_path (name)
    file_id = file_id_counter
    file_id_counter += 1
    # start a new file transfer
    client.send (struct.pack ('>cHH', 'N', file_id, len(name)) + name)
    index = 0
    h = hashlib.sha256()
    nbytes = 0
    while 1:
        # 28 bytes for IP/UDP, 4 bytes for tag, 5/7 for protocol
        if index < 16384:
            block_overhead = 28 + 4 + 5
            head = struct.pack ('>cHH', 'b', file_id, index)
        else:
            block_overhead = 28 + 4 + 7
            head = struct.pack ('>cHL', 'B', file_id, index)
        size = mtu - block_overhead
        block = file.read (size)
        h.update (block)
        if not block:
            break
        else:
            client.send (head + block)
            nbytes += len(block)
            index += 1
    client.send (struct.pack ('>cHL', 'd', file_id, index))
    client.done = True
    W ('\nsent %d bytes in %d blocks\n' % (nbytes, index))
    W ('sha256: %s\n' % (h.hexdigest(),))

if __name__ == '__main__':
    import sys
    if '-s' in sys.argv:
        s = file_receiver()
        bd = '/tmp/xs.bd'
    elif '-c' in sys.argv:
        ip = sys.argv[2]
        path = sys.argv[3]
        if len(sys.argv) > 4:
            mtu = int (sys.argv[4])
        else:
            mtu = default_mtu
        c = client ((ip, 14583), 50, mtu)
        coro.spawn (send_file, c, path, open (path, 'rb'))
        bd = '/tmp/xc.bd'
    coro.set_selfishness (0)
    #import coro.backdoor
    #coro.spawn (coro.backdoor.serve, unix_path=bd)
    coro.event_loop()
