#!/usr/bin/python

import serial
import time
import argparse
import os
import re
import difflib
import stat
import sys

from mpy_utils.replcontrol import ReplControl, ReplIOFileHandle, ReplIOSerial

parser = argparse.ArgumentParser(
    description="upload files to a device using only the REPL"
)
parser.add_argument("--port", default="/dev/ttyUSB0", help="serial port device")
parser.add_argument("--baud", default=115200, type=int, help="port speed in baud")
parser.add_argument(
    "--pipe", action="store_true", help="use stdin/stdout instead of serial port"
)
parser.add_argument(
    "--delay", default=100.0, type=float, help="delay between lines (ms)"
)
parser.add_argument(
    "--reset", action="store_true", help="send soft reset (control-D) after upload"
)
parser.add_argument(
    "--delete", default=False, action="store_true", help="delete files not in source"
)
parser.add_argument(
    "--download", default=False, action="store_true", help="download files first"
)
parser.add_argument(
    "--debug", default=False, action="store_true", help="print debugging info"
)
parser.add_argument("source_dir", type=str, default=".")
parser.add_argument("dest_dir", type=str, default="/", nargs="?")
args = parser.parse_args()

if args.pipe:
    repl_io = ReplIOFileHandle(infh=sys.stdin, outfh=sys.stdout)
else:
    repl_io = ReplIOSerial(port=args.port, baud=args.baud, delay=args.delay)

repl_control = ReplControl(
    io=repl_io, delay=args.delay, debug=args.debug, reset=args.reset,
)


path_re = re.compile(r"(/[^./][^/]*)*\.(py|html?|js|css|png|gif|jpe?g)$")


def make_file_list(source_dir, dest_dir=None):
    if dest_dir is None:
        dest_dir = args.dest_dir
    for source in os.listdir(source_dir):
        source_path = os.path.join(source_dir, source)
        dest_path = os.path.join(dest_dir, source)
        if os.path.isdir(source_path):
            for x in make_file_list(source_path, dest_path):
                yield x
        elif path_re.search(dest_path):
            yield (source_path, dest_path)


def ensure_dir_exists(dest_path):
    parent_dir, _ = os.path.split(dest_path)
    if parent_dir != "/":
        ensure_dir_exists(parent_dir)
        repl_control.statement("os.mkdir", parent_dir)


temporary_file_name = os.path.join(args.dest_dir, ".temporary")


def sync_files(file_list, file_cache={}):

    initialized = False

    if args.delete:
        extra_set = set(file_cache.keys()) - set(x[1] for x in file_list)
        if extra_set:
            log("initializing ...")
            repl_control.initialize()
            repl_control.command("import os")
            initialized = True

            for dest_path in extra_set:
                log("deleting %s" % repr(dest_path))
                repl_control.statement("os.remove", dest_path)
                del file_cache[dest_path]

    for source_path, dest_path in file_list:
        source_mtime = os.stat(source_path).st_mtime
        if dest_path in file_cache and file_cache[dest_path][0] == source_mtime:
            continue

        source_text = open(source_path, "rb").read()
        if dest_path in file_cache and file_cache[dest_path][1] == source_text:
            file_cache[dest_path][0] = source_mtime
            continue

        if not initialized:
            log("initializing ...")
            repl_control.initialize()
            repl_control.command("import os")
            initialized = True

        dest_temp_fh = repl_control.variable("open", temporary_file_name, "wb")

        if dest_path in file_cache:
            cache_text = file_cache[dest_path][1]

            remote_fh = repl_control.variable("open", dest_path, "rb")
            sequence_matcher = difflib.SequenceMatcher(None, cache_text, source_text)
            log(
                "patching %s => %s (%d bytes, %f ratio)"
                % (
                    repr(source_path),
                    repr(dest_path),
                    len(source_text),
                    sequence_matcher.ratio(),
                )
            )

            for tag, i1, i2, j1, j2 in sequence_matcher.get_opcodes():
                if tag == "replace" or tag == "insert":
                    for x in range(j1, j2, 50):
                        dest_temp_fh.method("write", source_text[x : min(x + 50, j2)])
                        time.sleep(args.delay / 1000.0)
                elif tag == "equal":
                    repl_control.command(
                        "%s.seek(%d); %s.write(%s.read(%d))"
                        % (
                            remote_fh.get_name(),
                            i1,
                            dest_temp_fh.get_name(),
                            remote_fh.get_name(),
                            j2 - j1,
                        )
                    )
                    time.sleep(args.delay / 1000.0)
            remote_fh.method("close")
            del remote_fh

        else:
            log(
                "copying %s => %s (%d bytes)"
                % (repr(source_path), repr(dest_path), len(source_text))
            )

            for x in range(0, len(source_text), 50):
                dest_temp_fh.method("write", source_text[x : x + 50])
                time.sleep(args.delay / 1000.0)

        dest_temp_fh.method("flush")
        dest_temp_fh.method("close")
        del dest_temp_fh

        ensure_dir_exists(dest_path)

        repl_control.statement("os.remove", dest_path)
        repl_control.statement("os.rename", temporary_file_name, dest_path)

        file_cache[dest_path] = [source_mtime, source_text]

    if initialized and args.reset:
        log("resetting ...")
        repl_control.reset()


def delete_files(file_list, dest_dir="/"):
    file_set = set(x[1] for x in file_list)
    for dest_file in repl_control.function("os.listdir", dest_dir):
        dest_path = os.path.join(dest_dir, dest_file)
        dest_stat = repl_control.function("os.stat", dest_path)
        if type(dest_stat) is bytes:
            continue
        elif stat.S_ISDIR(dest_stat[0]):
            delete_files(file_list, dest_path)
        elif path_re.match(dest_path) and dest_path not in file_set:
            log("deleting %s" % dest_path)


def log(msg):
    print(msg, file=sys.stderr)


if args.download:
    download_files(args.source_dir)

while True:
    files = list(make_file_list(args.source_dir))
    sync_files(files)

    time.sleep(1)
