# -*- Mode: Python -*-

# py3

# requirements:
# pip install pillow exifread rawpy imagehash pybktree

import os
import sys
from os.path import join, splitext

from PIL import Image, ImageChops, UnidentifiedImageError
from imagehash import dhash
import exifread
import pybktree
import warnings
import io
import collections

# TODO: needs a first pass of size,sha256 to find actually identical files.

HashItem = collections.namedtuple ('HashItem', 'hash path')

def item_distance (a, b):
    return pybktree.hamming_distance (a.hash, b.hash)

W = sys.stderr.write

raw_exts = {'.arw', '.nef', '.raw'}

# JPEG thumbnails often have black crop borders.
def trim (im):
    bg = Image.new (im.mode, im.size, (1, 1, 1))
    diff = ImageChops.difference (im, bg)
    diff = ImageChops.add (diff, diff, 2.0, -100)
    bbox = diff.getbbox()
    if bbox:
        W ('[trim]')
        return im.crop (bbox)
    else:
        return im

def get_thumbnail (path):
    tags = exifread.process_file (open (path, 'rb'))
    th = tags.get ('JPEGThumbnail', None)
    if th:
        # XXX jpeg thumbnails usually need rotating.  imagehash algs
        #   are sensitive to orientation.  We should probably try to
        #   rotate the thumbnail when the info is available.
        W ('[exifthumb]')
        return Image.open (io.BytesIO (th))
    else:
        return None

# this will preferentially extract a thumbnail/preview.
def get_raw_image (path, ext):
    import rawpy
    rimg = rawpy.imread (path)
    try:
        th = rimg.extract_thumb()
        img = Image.open (io.BytesIO (th.data))
        W ('[rawthumb]')
        return img
    except (rawpy.LibRawNoThumbnailError, rawpy.LibRawUnsupportedThumbnailError):
        data = rimg.postprocess()
        img = Image.frombytes ('RGB', (rimg.sizes.width, rimg.sizes.height), data)
        W ('[raw]')
        return img

def get_image (path, ext):
    W ('%s ' % (path,))
    try:
        # try to get a thumbnail
        img = get_thumbnail (path)
        if img is not None:
            return trim (img)
        elif ext.lower() in raw_exts:
            return get_raw_image (path, ext)
        else:
            try:
                img = Image.open (path)
                W ("[full]")
                return img
            except Image.DecompressionBombError:
                W ("[bomb]")
                return None
            except UnidentifiedImageError:
                W ("[???]")
                return None
    finally:
        W ('\n')

def search_dir (bkt, dpath, exts):
    for root, dirs, files in os.walk (dpath):
        for name in files:
            base, ext = splitext (name)
            if ext[1:].lower() in exts:
                fpath = join (root, name)
                img = get_image (fpath, ext)
                if img is not None:
                    h = int (str (dhash (img)), 16)
                    bkt.add (HashItem (h, fpath))

def find_dups (bkt, dist=2):
    for item in sorted(bkt):
        near = bkt.find (item, 2)
        if len(near) > 1:
            for d0, item0 in near:
                print ('%3d %s' % (d0, item0.path))
            print ('---')

if __name__ == '__main__':
    import argparse
    from pprint import pprint as pp
    p = argparse.ArgumentParser (description='duplicate image finder')
    p.add_argument ('-e', '--exts', help="image filename extensions", default='jpg,jpeg,tif,tiff,png,raw,arw,nef')
    p.add_argument ('dirs', help="directory to search", metavar="DIR", type=str, nargs='+')
    args = p.parse_args()
    exts = set(args.exts.split(','))
    bkt = pybktree.BKTree (item_distance)
    for dpath in args.dirs:
        search_dir (bkt, dpath, exts)
    find_dups (bkt)
