# -*- Mode: Python -*-

# See ISO/IEC 14496-12

# the overall structure is very simple, a (usually) 32-bit size field,
#   followed by a (usually) four-byte 'atom' or 'box' identifier,
#   followed by the data, which may contain more atoms/boxes. (i.e., a
#   tree is defined)
#
# next step: find the reference for the 'avc1' box type.

import struct
import sys

def as_binary (n, width):
    r = []
    bit = 1 << width
    while width:
        width -= 1
        bit >>= 1
        if n & bit:
            r.append ('1')
        else:
            r.append ('0')
    return ''.join (r)

def get_bits (x, start, nbits):
    r = 0
    for i in range (nbits):
        if x & (1<<(start+i)):
            r |= 1<<i
    return r

class atom:
    def __init__ (self, name):
        self._name = name

def dump_box (b, depth):
    attrs = dir (b)
    for attr in dir (b):
        if attr.startswith ('_'):
            pass
        else:
            print '  ' * depth, attr, getattr (b, attr)

class parser:

    def __init__ (self, file):
        self.file = file
        self.context = []
        self.current_handler_type = None

    def get (self, nbytes):
        r = self.file.read (nbytes)
        if not r:
            raise EOFError
        else:
            return r

    def get16 (self):
        return struct.unpack ('>H', self.get (2))[0]

    def get_16_16 (self):
        return self.get16(), self.get16()

    def get32 (self):
        return struct.unpack ('>L', self.get (4))[0]

    def get64 (self):
        return struct.unpack ('>Q', self.get (8))[0]

    # having trouble guessing what their string datatype is
    #  sometimes it appears to be a pascal string, other times
    #  nul-terminated...
    def get_string (self):
        r = []
        while 1:
            ch = self.get (1)
            if ch == '\x00':
                break
            else:
                r.append (ch)
        return ''.join (r)

    def get_pascal_string (self):
        sl = ord (self.get (1))
        if sl:
            return self.get (sl)
        else:
            return ''

    def skip (self, nbytes):
        #sys.stderr.write ('skip')
        while nbytes:
            block = self.file.read (min (nbytes, 16384))
            if not block:
                sys.stderr.write (' *** unexpected end of file! ***\n')
                raise EOFError
            nbytes -= len (block)
            #sys.stderr.write ('.')
        #sys.stderr.write ('\n')

    def dump (self, depth, s):
        print '%s%s' % ('  ' * depth, s)

    def go (self):
        while 1:
            try:
                self.read()
            except EOFError:
                print 'eof'
                break

    def read (self, depth=0):
        size = self.get32()
        if size == 1:
            # large size
            raise NotImplemented
        else:
            # read the type now
            type = self.get (4)
            # uncount the size and type fields themselves...
            size -= 8
            if type == 'uuid':
                raise NotImplemented
            elif size:
                method = 'parse_%s' % (type,)
                probe = getattr (self, method, None)
                if probe is not None:
                    print '%s[%s %d]' % ('  ' * depth, type, size)
                    self.context.append (type)
                    b = atom (type)
                    probe (b, size, depth)
                    dump_box (b, depth+1)
                    self.context.pop()
                else:
                    if size <= 100:
                        bytes = self.get (size)
                        self.dump (depth, '[%s %d] %r' % (type, size, bytes))
                    else:
                        self.dump (depth, '[%s %d]' % (type, size))
                        self.skip (size)

    def read_all (self, box, size, depth):
        #box.children = []
        left = size
        while left:
            start = self.file.tell()
            #box.children.append (self.read (depth))
            self.read (depth)
            stop = self.file.tell()
            left -= (stop - start)
        
    def parse_ftyp (self, box, size, depth):
        major = self.get (4)
        minor, = struct.unpack ('>L', self.get (4))
        box.type = [major, minor]
        size -= 8
        while size:
            box.type.append (self.get (4))
            size -= 4

    def parse_mdat (self, box, size, depth):
        self.skip (size)

    def get_version (self):
        version = ord (self.get (1))
        flags = ord(self.get (1))
        flags <<= 8 | ord(self.get (1))
        flags <<= 8 | ord(self.get (1))
        return version, flags

    def parse_moov (self, box, size, depth):
        # full of other boxes
        self.read_all (box, size, depth+1)

    def parse_mvhd (self, box, size, depth):
        box.version, box.flags = self.get_version()
        if box.version == 0:
            get = self.get32
        elif box.version == 1:
            get = self.get64
        else:
            raise ValueError ("unknown version number %d" % (box.version,))
        box.ctime = get()
        box.mtime = get()
        box.timescale = self.get32()
        box.duration = get()
        box.rate = self.get32()
        box.volume = self.get16()
        self.get16()
        self.get32()
        self.get32()
        box.matrix = [self.get32() for x in range (9)]
        for i in range (6):
            self.get32()
        box.next_track_ID = self.get32()

    def parse_trak (self, box, size, depth):
        self.read_all (box, size, depth+1)

    def parse_tkhd (self, box, size, depth):
        box.version, box.flags = self.get_version()
        if box.version == 0:
            get = self.get32
        elif box.version == 1:
            get = self.get64
        else:
            raise ValueError ("unknown version number %d" % (version,))
        box.ctime = get()
        box.mtime = get()
        box.track_ID = self.get32()
        self.get32()
        box.duration = get()
        # reserved[2]
        self.get32()
        self.get32()
        box.layer = self.get16()
        box.alternate_group = self.get16()
        box.volume = self.get16()
        # reserved
        self.get16()
        box.matrix = [self.get32() for x in range (9)]
        box.width = self.get_16_16()
        box.height = self.get_16_16()

    def parse_mdia (self, box, size, depth):
        self.read_all (box, size, depth+1)
        
    def parse_mdhd (self, box, size, depth):
        box.version, box.flags = self.get_version()
        if box.version == 0:
            get = self.get32
        elif box.version == 1:
            get = self.get64
        else:
            raise ValueError ("unknown version number %d" % (box.version,))
        d = depth+1
        box.ctime = get()
        box.mtime = get()
        box.timescape = self.get32()
        box.duration = get()
        box.language = self.get16()
        self.get16()

    def parse_hdlr (self, box, size, depth):
        start = self.file.tell()
        box.version, box.flags = self.get_version()
        # pre_defined
        self.get32()
        d = depth+1
        box.handler_type = self.get (4)
        # hack
        self.current_handler_type = box.handler_type
        # reserved
        self.get32(); self.get32(); self.get32()
        left = size - (self.file.tell() - start)
        # this apparently can't be relied on...
        #box.name = self.get_pascal_string()
        if left:
            box.name = self.get (left)

    def parse_minf (self, box, size, depth):
        self.read_all (box, size, depth+1)

    def parse_vmhd (self, box, size, depth):
        box.version, box.flags = self.get_version()
        box.graphicsmode = self.get16()
        box.opcolor = [ self.get16() for x in range (3) ]

    def parse_smhd (self, box, size, depth):
        box.version, box.flags = self.get_version()
        box.balance = self.get16()
        box.reserved = self.get16()

    def parse_stbl (self, box, size, depth):
        self.read_all (box, size, depth+1)

    def parse_stsd (self, box, size, depth):
        version, flags = self.get_version()
        entry_count = self.get32()
        for i in range (entry_count):
            if self.current_handler_type == 'vide':
                #self.parse_visual_sample_entry (depth)
                self.read (depth+1)
            elif self.current_handler_type == 'soun':
                #self.parse_audio_sample_entry (depth)
                self.read (depth+1)
            else:
                raise NotImplemented

    def parse_sample_entry (self, depth):
        ignore = self.get (6)
        data_reference_index = self.get16()
        self.dump (depth+1, 'data_reference_index %r' % (data_reference_index,))

    def parse_visual_sample_entry (self, box, size, depth):
        start = self.file.tell()
        self.parse_sample_entry (depth)
        self.get16()
        self.get16()
        self.get32(); self.get32(); self.get32()
        box.width = self.get16()
        box.height = self.get16()
        box.horizresolution = self.get32()
        box.vertresolution  = self.get32()
        self.get32()
        box.frame_count = self.get16()
        box.compressorname = self.get (32)
        box.depth = self.get16()
        self.get16()
        while 1:
            left = size - (self.file.tell() - start)
            # how do I know if the optional stuff is here???
            if left:
                # XXX can I just use read()?
                #box.optional_stuff = self.get (left)
                self.read (depth+1)
            else:
                break

    def parse_avc1 (self, box, size, depth):
        return self.parse_visual_sample_entry (box, size, depth)

    def parse_clap (self, box, size, depth):
        box.cleanApertureWidthN = self.get32()
        box.cleanApertureWidthD = self.get32()        
        box.cleanApertureHeightN = self.get32()
        box.cleanApertureHeightD = self.get32()
        box.horizOffN = self.get32()
        box.horizOffD = self.get32()
        box.vertOffN = self.get32()
        box.vertOffD = self.get32()

    def parse_pasp (self, box, size, depth):
        box.hSpacing = self.get32()
        box.vSpacing = self.get32()

    def parse_udta (self, box, size, depth):
        self.read_all (box, size, depth+1)        

    def parse_meta (self, box, size, depth):
        box.version, box.flags = self.get_version()
        self.read_all (box, size, depth+1)        

    def parse_ilst (self, box, size, depth):
        self.read_all (box, size, depth+1)        

    parse_dinf = parse_udta

    def parse_dref (self, box, size, depth):
        box.version, box.flags = self.get_version()
        entry_count = self.get32()
        box.entries = []
        for i in range (entry_count):
            box.entries.append (self.read (depth+1))

if __name__ == '__main__':
    p = parser (open (sys.argv[1], 'rb'))
    p.go()
