# -*- coding: utf-8 -*-

from __future__ import print_function, division, absolute_import

import collections
import codecs
import functools
import glob
import importlib
import imp
import os
import os.path as pathlib
import sys
try:
    from types import FileType  # py2
except ImportError:
    from io import IOBase as FileType  # py3

from .db import database
from .log import logger
from .helpers import (
    Color, lines_diff, print_table, parse_requirements, trim_prefix,
    trim_suffix
)
from .parser import parse_imports, parse_installed_packages
from .pypi import PKGS_URL, Downloader, Updater

from requests.exceptions import HTTPError

# FIXME: dirty workaround..
_special_packages = {
    "dogpile.cache": "dogpile.cache",
    "dogpile.core": "dogpile.core",
    "ruamel.yaml": "ruamel.yaml",
    "ruamel.ordereddict": "ruamel.ordereddict",
}


class RequirementsGenerator(object):
    def __init__(
        self,
        package_root,
        save_path,
        ignores=None,
        cmp_operator="==",
        ref_comments=False
    ):
        self._package_root = package_root
        self._save_path = save_path
        self._ignores = ignores
        self._cmp_operator = cmp_operator
        self._ref_comments = ref_comments
        self._installed_pkgs = None

    def __call__(self):
        self.generate()

    @property
    def installed_pkgs(self):
        if self._installed_pkgs is None:
            self._installed_pkgs = parse_installed_packages()
        return self._installed_pkgs

    def generate(self):
        packages, guess = parse_packages(
            self._package_root, self._ignores, self.installed_pkgs
        )

        answer = 'n'
        if guess:
            print(Color.RED('The following modules are not found yet:'))
            self._print_uncertain_modules(guess)
            sys.stdout.write(
                Color.RED(
                    (
                        'Some of them may be not installed in local '
                        'environment.\nTry to search PyPI for the '
                        'missing modules and filter'
                        ' some unnecessary modules? (y/[N]) '
                    )
                )
            )
            sys.stdout.flush()
            answer = sys.stdin.readline().strip().lower()

        in_pypi = None
        if answer in ('y', 'yes'):
            print(Color.BLUE('Checking modules on the PyPI...'))
            in_pypi = self._check_on_pypi(packages, guess)

        old = self._read_requirements()
        self._write_requirements(packages)
        new = self._read_requirements()
        self._print_diff(old, new)

        if in_pypi:
            for name in in_pypi:
                del guess[name]
        if guess and answer in ('y', 'yes'):
            print(Color.RED('These modules are not found:'))
            self._print_uncertain_modules(guess)
            print(Color.RED('Maybe or you need update database.'))

    def _check_on_pypi(self, packages, guess):
        in_pypi = set()
        for name, locs in guess.items():
            logger.info('Checking %s on the PyPI ...', name)
            downloader = Downloader()
            with database() as db:
                rows = db.query_all(name)
                pkgs = [row.package for row in rows]
                if pkgs:
                    in_pypi.add(name)
                for pkg in _best_matchs(name, pkgs):
                    try:
                        latest = downloader.download_package(pkg).version()
                        packages.add_locs(pkg, latest, locs)
                    except HTTPError as e:
                        logger.error('checking %s failed: %e', pkg, e)
        return in_pypi

    def _print_uncertain_modules(self, modules):
        for name, locs in modules.items():
            print(
                '  {0} referenced from:\n    {1}'.format(
                    Color.YELLOW(name), '\n    '.join(locs.sorted_items())
                )
            )

    def _read_requirements(self):
        if not pathlib.isfile(self._save_path):
            return
        with codecs.open(self._save_path, 'rb', 'utf-8') as f:
            return f.readlines()

    def _write_requirements(self, packages):
        print(
            Color.GREEN(
                'Writing requirements to "{0}"'.format(self._save_path)
            )
        )
        package_root_parent = pathlib.dirname(
            trim_suffix(self._package_root, "/")
        ) + "/"
        ref_comments = self._ref_comments
        cmp_operator = self._cmp_operator

        with open(self._save_path, 'w+') as f:
            f.write(
                '# Automatically generated by '
                'https://github.com/damnever/pigar.\n'
            )
            if not ref_comments:
                f.write('\n')
            for k, v in packages.sorted_items():
                if ref_comments:
                    f.write('\n')
                    f.write(
                        ''.join(
                            [
                                '# {0}\n'.format(
                                    trim_prefix(c, package_root_parent)
                                ) for c in v.comments.sorted_items()
                            ]
                        )
                    )
                if k == '-e':
                    f.write('{0} {1}\n'.format(k, v.version))
                elif v:
                    f.write('{0} {1} {2}\n'.format(k, cmp_operator, v.version))
                else:
                    f.write('{0}\n'.format(k))

    def _print_diff(self, old, new):
        if not old:
            return
        is_diff, diffs = lines_diff(old, new)
        msg = 'Requirements file has been overwritten, '
        if is_diff:
            msg += 'here is the difference:'
            print('{0}\n{1}'.format(Color.YELLOW(msg), ''.join(diffs)), end='')
        else:
            msg += 'no difference.'
            print(Color.YELLOW(msg))


def check_requirements_latest_versions(
    check_path,
    ignores=None,
    comparison_operator="==",
    ref_comments=False,
):
    logger.debug('Starting check requirements latest version ...')
    files = list()
    reqs = dict()
    pkg_versions = list()
    installed_pkgs = None
    # If no requirements file given, check in current directory.
    if pathlib.isdir(check_path):
        logger.debug('Searching file in "{0}" ...'.format(check_path))
        files.extend(glob.glob(pathlib.join(check_path, '*requirements.txt')))
        # If not found in directory, generate requirements.
        if not files:
            print(
                Color.YELLOW(
                    'Requirements file not found, '
                    'generate requirements ...'
                )
            )
            save_path = os.path.join(check_path, 'requirements.txt')
            rg = RequirementsGenerator(
                check_path, save_path, ignores, comparison_operator,
                ref_comments
            )
            rg()
            installed_pkgs = rg.installed_pkgs
            files.append(save_path)
    else:
        files.append(check_path)
    for fpath in files:
        reqs.update(parse_requirements(fpath))

    logger.debug('Checking requirements latest version ...')
    installed_pkgs = installed_pkgs or parse_installed_packages()
    installed_pkgs = {v[0]: v[1] for v in installed_pkgs.values()}
    downloader = Downloader()
    for pkg in reqs:
        current = reqs[pkg]
        # If no version specifies in requirements,
        # check in installed packages.
        if current == '' and pkg in installed_pkgs:
            current = installed_pkgs[pkg]
        logger.debug('Checking "{0}" latest version ...'.format(pkg))
        try:
            latest = downloader.download_package(pkg).version()
        except HTTPError as e:
            logger.error('checking %s failed: %e', pkg, e)
        pkg_versions.append((pkg, current, latest))

    logger.debug('Checking requirements latest version done.')
    print_table(pkg_versions)


def search_packages_by_names(names):
    """Search package information by names(`import XXX`).
    """
    downloader = Downloader()
    results = collections.defaultdict(list)
    not_found = list()

    installed_pkgs = parse_installed_packages()
    for name in names:
        logger.debug('Searching package name for "{0}" ...'.format(name))
        # If exists in local environment, do not check on the PyPI.
        if name in installed_pkgs:
            results[name].append(list(installed_pkgs[name]) + ['local'])
        # Check information on the PyPI.
        else:
            rows = None
            with database() as db:
                rows = db.query_all(name)
            if rows:
                for row in rows:
                    try:
                        version = downloader.download_package(row.package
                                                              ).version()
                        results[name].append((row.package, version, 'PyPI'))
                    except HTTPError as e:
                        logger.error('checking %s failed: %e', row.package, e)
            else:
                not_found.append(name)

    for name in results:
        print('Found package(s) for "{0}":'.format(Color.GREEN(name)))
        print_table(results[name], headers=['PACKAGE', 'VERSION', 'WHERE'])
    if not_found:
        msg = '"{0}" not found.\n'.format(Color.RED(', '.join(not_found)))
        msg += 'Maybe you need to update the database.'
        print(Color.YELLOW(msg))


def update_database():
    """Update database."""
    print(Color.GREEN('Starting update database ...'))
    print(Color.YELLOW('The process will take a long time!!!'))
    logger.info('Crawling "{0}" ...'.format(PKGS_URL))
    try:
        updater = Updater()
    except Exception:
        logger.error("Fail to fetch all packages: ", exc_info=True)
        print(Color.RED('Operation aborted'))
        return

    try:
        updater.run()
        updater.wait()
    except (KeyboardInterrupt, SystemExit):
        # FIXME(damnever): the fucking signal..
        updater.cancel()
        print(Color.BLUE('Operation canceled!'))
    else:
        print(Color.GREEN('Operation done!'))


def parse_packages(package_root, ignores=None, installed_pkgs=None):
    imported_modules, user_modules = parse_imports(package_root, ignores)
    installed_pkgs = installed_pkgs or parse_installed_packages()
    packages = _RequiredModules()
    guess = collections.defaultdict(_Locations)

    try_imports = set()
    for module in imported_modules:
        name = module.name
        if is_user_module(module, user_modules, package_root):
            logger.debug("ignore imports from user module: %s", name)
            continue
        if is_stdlib(name) or is_stdlib(name.split('.')[0]):
            logger.debug("ignore imports from stdlib: %s", name)
            continue
        names = []
        special_name = '.'.join(name.split('.')[:2])
        # Flask extension.
        if name.startswith('flask.ext.'):
            names.append('flask')
            names.append('flask_' + name.split('.')[2])
        # Special cases..
        elif special_name in _special_packages:
            names.append(_special_packages[special_name])
        # Other.
        elif '.' in name:
            names.append(name.split('.')[0])
        else:
            names.append(name)

        for name in names:
            if name in installed_pkgs:
                pkg_name, version = installed_pkgs[name]
                packages.add(pkg_name, version, module.file, module.lineno)
            else:
                guess[name].add(module.file, module.lineno)
            if module.try_:
                try_imports.add(name)

    names = []
    for name in guess:
        if name in try_imports:
            names.append(name)
    for name in names:
        del guess[name]
    return packages, guess


def _best_matchs(name, pkgs):
    # If imported name equals to package name.
    if name in pkgs:
        return [pkgs[pkgs.index(name)]]
    # If not, return all possible packages.
    return pkgs


def is_user_module(module, user_modules, package_root):
    name = module.name
    if name.startswith("."):
        return True
    parts = name.split(".")
    cur_mod_path = module.file[:-3]
    dir_path_parts = pathlib.dirname(module.file).split("/")
    nparts = len(dir_path_parts)
    for i in range(0, nparts):
        i = -i if i > 0 else nparts
        dir_path = "/".join(dir_path_parts[:i])
        if dir_path == "":
            dir_path = "/"
        if dir_path not in user_modules:
            break
        mod_paths = [pathlib.join(dir_path, "/".join(parts))]
        if len(dir_path_parts[:i]) > 0 and dir_path_parts[:i][-1] == parts[0]:
            mod_paths.append(dir_path)
        for mod_path in mod_paths:
            # FIXME(damnever): ignore the current file?
            if mod_path == cur_mod_path:
                continue
            if mod_path in user_modules:
                return True
    return False


def _checked_cache(func):
    checked = dict()

    @functools.wraps(func)
    def _wrapper(name):
        if name not in checked:
            checked[name] = func(name)
        return checked[name]

    return _wrapper


@_checked_cache
def is_stdlib(name):
    """Check whether it is stdlib module."""
    exist = True
    module_info = ('', '', '')
    try:
        module_info = imp.find_module(name)
    except ImportError:
        try:
            # __import__(name)
            importlib.import_module(name)
            module_info = imp.find_module(name)
            sys.modules.pop(name)
        except ImportError:
            exist = False
    # Testcase: ResourceWarning
    if isinstance(module_info[0], FileType):
        module_info[0].close()
    mpath = module_info[1]
    if exist and (
        mpath is not None and (
            'site-packages' in mpath or 'dist-packages' in mpath or
            ('bin/' in mpath and mpath.endswith('.py'))
        )
    ):
        exist = False
    return exist


class _RequiredModules(dict):

    _Detail = collections.namedtuple('Detail', ['version', 'comments'])

    def __init__(self):
        super(_RequiredModules, self).__init__()
        self._sorted = None

    def add_locs(self, package, version, locations):
        if package in self:
            self[package].comments.extend(locations)
        else:
            self[package] = self._Detail(version, locations)

    def add(self, package, version, file, lineno):
        if package in self:
            self[package].comments.add(file, lineno)
        else:
            loc = _Locations()
            loc.add(file, lineno)
            self[package] = self._Detail(version, loc)

    def sorted_items(self):
        if self._sorted is None:
            self._sorted = sorted(self.items())
        return self._sorted

    def remove(self, *names):
        for name in names:
            if name in self:
                self.pop(name)
        self._sorted = None


class _Locations(dict):
    """_Locations store code locations(file, linenos)."""
    def __init__(self):
        super(_Locations, self).__init__()
        self._sorted = None

    def add(self, file, lineno):
        if file in self and lineno not in self[file]:
            self[file].append(lineno)
        else:
            self[file] = [lineno]

    def extend(self, obj):
        for file, linenos in obj.items():
            for lineno in linenos:
                self.add(file, lineno)

    def sorted_items(self):
        if self._sorted is None:
            self._sorted = [
                '{0}: {1}'.format(f, ','.join([str(n) for n in sorted(ls)]))
                for f, ls in sorted(self.items())
            ]
        return self._sorted
