# -*- coding: utf-8 -*-
"""
Django Extensions additional model fields

Some fields might require additional dependencies to be installed.
"""

import re
import string

try:
    import uuid
    HAS_UUID = True
except ImportError:
    HAS_UUID = False

try:
    import shortuuid
    HAS_SHORT_UUID = True
except ImportError:
    HAS_SHORT_UUID = False

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db.models import DateTimeField, CharField, SlugField, Q, UniqueConstraint
from django.db.models.constants import LOOKUP_SEP
from django.template.defaultfilters import slugify
from django.utils.crypto import get_random_string
from django.utils.encoding import force_str


MAX_UNIQUE_QUERY_ATTEMPTS = getattr(settings, 'EXTENSIONS_MAX_UNIQUE_QUERY_ATTEMPTS', 100)


class UniqueFieldMixin:

    def check_is_bool(self, attrname):
        if not isinstance(getattr(self, attrname), bool):
            raise ValueError("'{}' argument must be True or False".format(attrname))

    @staticmethod
    def _get_fields(model_cls):
        return [
            (f, f.model if f.model != model_cls else None) for f in model_cls._meta.get_fields()
            if not f.is_relation or f.one_to_one or (f.many_to_one and f.related_model)
        ]

    def get_queryset(self, model_cls, slug_field):
        for field, model in self._get_fields(model_cls):
            if model and field == slug_field:
                return model._default_manager.all()
        return model_cls._default_manager.all()

    def find_unique(self, model_instance, field, iterator, *args):
        # exclude the current model instance from the queryset used in finding
        # next valid hash
        queryset = self.get_queryset(model_instance.__class__, field)
        if model_instance.pk:
            queryset = queryset.exclude(pk=model_instance.pk)

        # form a kwarg dict used to implement any unique_together constraints
        kwargs = {}
        for params in model_instance._meta.unique_together:
            if self.attname in params:
                for param in params:
                    kwargs[param] = getattr(model_instance, param, None)

        # for support django 2.2+
        query = Q()
        constraints = getattr(model_instance._meta, 'constraints', None)
        if constraints:
            unique_constraints = filter(
                lambda c: isinstance(c, UniqueConstraint), constraints
            )
            for unique_constraint in unique_constraints:
                if self.attname in unique_constraint.fields:
                    condition = {
                        field: getattr(model_instance, field, None)
                        for field in unique_constraint.fields
                        if field != self.attname
                    }
                    query &= Q(**condition)

        new = next(iterator)
        kwargs[self.attname] = new
        while not new or queryset.filter(query, **kwargs):
            new = next(iterator)
            kwargs[self.attname] = new
        setattr(model_instance, self.attname, new)
        return new


class AutoSlugField(UniqueFieldMixin, SlugField):
    """
    AutoSlugField

    By default, sets editable=False, blank=True.

    Required arguments:

    populate_from
        Specifies which field, list of fields, or model method
        the slug will be populated from.

        populate_from can traverse a ForeignKey relationship
        by using Django ORM syntax:
            populate_from = 'related_model__field'

    Optional arguments:

    separator
        Defines the used separator (default: '-')

    overwrite
        If set to True, overwrites the slug on every save (default: False)

    slugify_function
        Defines the function which will be used to "slugify" a content
        (default: :py:func:`~django.template.defaultfilters.slugify` )

    It is possible to provide custom "slugify" function with
    the ``slugify_function`` function in a model class.

    ``slugify_function`` function in a model class takes priority over
    ``slugify_function`` given as an argument to :py:class:`~AutoSlugField`.

    Example

    .. code-block:: python

        # models.py

        from django.db import models

        from django_extensions.db.fields import AutoSlugField


        class MyModel(models.Model):
            def slugify_function(self, content):
                return content.replace('_', '-').lower()

            title = models.CharField(max_length=42)
            slug = AutoSlugField(populate_from='title')

    Inspired by SmileyChris' Unique Slugify snippet:
    https://www.djangosnippets.org/snippets/690/
    """

    def __init__(self, *args, **kwargs):
        kwargs.setdefault('blank', True)
        kwargs.setdefault('editable', False)

        populate_from = kwargs.pop('populate_from', None)
        if populate_from is None:
            raise ValueError("missing 'populate_from' argument")
        else:
            self._populate_from = populate_from

        if not callable(populate_from):
            if not isinstance(populate_from, (list, tuple)):
                populate_from = (populate_from, )

            if not all(isinstance(e, str) for e in populate_from):
                raise TypeError("'populate_from' must be str or list[str] or tuple[str], found `%s`" % populate_from)

        self.slugify_function = kwargs.pop('slugify_function', slugify)
        self.separator = kwargs.pop('separator', '-')
        self.overwrite = kwargs.pop('overwrite', False)
        self.check_is_bool('overwrite')
        self.overwrite_on_add = kwargs.pop('overwrite_on_add', True)
        self.check_is_bool('overwrite_on_add')
        self.allow_duplicates = kwargs.pop('allow_duplicates', False)
        self.check_is_bool('allow_duplicates')
        self.max_unique_query_attempts = kwargs.pop('max_unique_query_attempts', MAX_UNIQUE_QUERY_ATTEMPTS)
        super().__init__(*args, **kwargs)

    def _slug_strip(self, value):
        """
        Clean up a slug by removing slug separator characters that occur at
        the beginning or end of a slug.

        If an alternate separator is used, it will also replace any instances
        of the default '-' separator with the new separator.
        """
        re_sep = '(?:-|%s)' % re.escape(self.separator)
        value = re.sub('%s+' % re_sep, self.separator, value)
        return re.sub(r'^%s+|%s+$' % (re_sep, re_sep), '', value)

    @staticmethod
    def slugify_func(content, slugify_function):
        if content:
            return slugify_function(content)
        return ''

    def slug_generator(self, original_slug, start):
        yield original_slug
        for i in range(start, self.max_unique_query_attempts):
            slug = original_slug
            end = '%s%s' % (self.separator, i)
            end_len = len(end)
            if self.slug_len and len(slug) + end_len > self.slug_len:
                slug = slug[:self.slug_len - end_len]
                slug = self._slug_strip(slug)
            slug = '%s%s' % (slug, end)
            yield slug
        raise RuntimeError('max slug attempts for %s exceeded (%s)' % (original_slug, self.max_unique_query_attempts))

    def create_slug(self, model_instance, add):
        slug = getattr(model_instance, self.attname)
        use_existing_slug = False
        if slug and not self.overwrite:
            # Existing slug and not configured to overwrite - Short-circuit
            # here to prevent slug generation when not required.
            use_existing_slug = True

        if self.overwrite_on_add and add:
            use_existing_slug = False

        if use_existing_slug:
            return slug

        # get fields to populate from and slug field to set
        populate_from = self._populate_from
        if not isinstance(populate_from, (list, tuple)):
            populate_from = (populate_from, )

        slug_field = model_instance._meta.get_field(self.attname)
        slugify_function = getattr(model_instance, 'slugify_function', self.slugify_function)

        # slugify the original field content and set next step to 2
        slug_for_field = lambda lookup_value: self.slugify_func(
            self.get_slug_fields(model_instance, lookup_value),
            slugify_function=slugify_function
        )
        slug = self.separator.join(map(slug_for_field, populate_from))
        start = 2

        # strip slug depending on max_length attribute of the slug field
        # and clean-up
        self.slug_len = slug_field.max_length
        if self.slug_len:
            slug = slug[:self.slug_len]
        slug = self._slug_strip(slug)
        original_slug = slug

        if self.allow_duplicates:
            setattr(model_instance, self.attname, slug)
            return slug

        return self.find_unique(
            model_instance, slug_field, self.slug_generator(original_slug, start))

    def get_slug_fields(self, model_instance, lookup_value):
        if callable(lookup_value):
            # A function has been provided
            return "%s" % lookup_value(model_instance)

        lookup_value_path = lookup_value.split(LOOKUP_SEP)
        attr = model_instance
        for elem in lookup_value_path:
            try:
                attr = getattr(attr, elem)
            except AttributeError:
                raise AttributeError(
                    "value {} in AutoSlugField's 'populate_from' argument {} returned an error - {} has no attribute {}".format(
                        elem, lookup_value, attr, elem))

        if callable(attr):
            return "%s" % attr()

        return attr

    def pre_save(self, model_instance, add):
        value = force_str(self.create_slug(model_instance, add))
        return value

    def get_internal_type(self):
        return "SlugField"

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        kwargs['populate_from'] = self._populate_from
        if not self.separator == '-':
            kwargs['separator'] = self.separator
        if self.overwrite is not False:
            kwargs['overwrite'] = True
        if self.allow_duplicates is not False:
            kwargs['allow_duplicates'] = True
        return name, path, args, kwargs


class RandomCharField(UniqueFieldMixin, CharField):
    """
    RandomCharField

    By default, sets editable=False, blank=True, unique=False.

    Required arguments:

    length
        Specifies the length of the field

    Optional arguments:

    unique
        If set to True, duplicate entries are not allowed (default: False)

    lowercase
        If set to True, lowercase the alpha characters (default: False)

    uppercase
        If set to True, uppercase the alpha characters (default: False)

    include_alpha
        If set to True, include alpha characters (default: True)

    include_digits
        If set to True, include digit characters (default: True)

    include_punctuation
        If set to True, include punctuation characters (default: False)

    keep_default
        If set to True, keeps the default initialization value (default: False)
    """

    def __init__(self, *args, **kwargs):
        kwargs.setdefault('blank', True)
        kwargs.setdefault('editable', False)

        self.length = kwargs.pop('length', None)
        if self.length is None:
            raise ValueError("missing 'length' argument")
        kwargs['max_length'] = self.length

        self.lowercase = kwargs.pop('lowercase', False)
        self.check_is_bool('lowercase')
        self.uppercase = kwargs.pop('uppercase', False)
        self.check_is_bool('uppercase')
        if self.uppercase and self.lowercase:
            raise ValueError("the 'lowercase' and 'uppercase' arguments are mutually exclusive")
        self.include_digits = kwargs.pop('include_digits', True)
        self.check_is_bool('include_digits')
        self.include_alpha = kwargs.pop('include_alpha', True)
        self.check_is_bool('include_alpha')
        self.include_punctuation = kwargs.pop('include_punctuation', False)
        self.keep_default = kwargs.pop('keep_default', False)
        self.check_is_bool('include_punctuation')
        self.max_unique_query_attempts = kwargs.pop('max_unique_query_attempts', MAX_UNIQUE_QUERY_ATTEMPTS)

        # Set unique=False unless it's been set manually.
        if 'unique' not in kwargs:
            kwargs['unique'] = False

        super().__init__(*args, **kwargs)

    def random_char_generator(self, chars):
        for i in range(self.max_unique_query_attempts):
            yield ''.join(get_random_string(self.length, chars))
        raise RuntimeError('max random character attempts exceeded (%s)' % self.max_unique_query_attempts)

    def in_unique_together(self, model_instance):
        for params in model_instance._meta.unique_together:
            if self.attname in params:
                return True
        return False

    def pre_save(self, model_instance, add):
        if (not add or self.keep_default) and getattr(model_instance, self.attname) != '':
            return getattr(model_instance, self.attname)

        population = ''
        if self.include_alpha:
            if self.lowercase:
                population += string.ascii_lowercase
            elif self.uppercase:
                population += string.ascii_uppercase
            else:
                population += string.ascii_letters

        if self.include_digits:
            population += string.digits

        if self.include_punctuation:
            population += string.punctuation

        random_chars = self.random_char_generator(population)
        if not self.unique and not self.in_unique_together(model_instance):
            new = next(random_chars)
            setattr(model_instance, self.attname, new)
            return new

        return self.find_unique(
            model_instance,
            model_instance._meta.get_field(self.attname),
            random_chars,
        )

    def internal_type(self):
        return "CharField"

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        kwargs['length'] = self.length
        del kwargs['max_length']
        if self.lowercase is True:
            kwargs['lowercase'] = self.lowercase
        if self.uppercase is True:
            kwargs['uppercase'] = self.uppercase
        if self.include_alpha is False:
            kwargs['include_alpha'] = self.include_alpha
        if self.include_digits is False:
            kwargs['include_digits'] = self.include_digits
        if self.include_punctuation is True:
            kwargs['include_punctuation'] = self.include_punctuation
        if self.unique is True:
            kwargs['unique'] = self.unique
        return name, path, args, kwargs


class CreationDateTimeField(DateTimeField):
    """
    CreationDateTimeField

    By default, sets editable=False, blank=True, auto_now_add=True
    """

    def __init__(self, *args, **kwargs):
        kwargs.setdefault('editable', False)
        kwargs.setdefault('blank', True)
        kwargs.setdefault('auto_now_add', True)
        DateTimeField.__init__(self, *args, **kwargs)

    def get_internal_type(self):
        return "DateTimeField"

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        if self.editable is not False:
            kwargs['editable'] = True
        if self.blank is not True:
            kwargs['blank'] = False
        if self.auto_now_add is not False:
            kwargs['auto_now_add'] = True
        return name, path, args, kwargs


class ModificationDateTimeField(CreationDateTimeField):
    """
    ModificationDateTimeField

    By default, sets editable=False, blank=True, auto_now=True

    Sets value to now every time the object is saved.
    """

    def __init__(self, *args, **kwargs):
        kwargs.setdefault('auto_now', True)
        DateTimeField.__init__(self, *args, **kwargs)

    def get_internal_type(self):
        return "DateTimeField"

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        if self.auto_now is not False:
            kwargs['auto_now'] = True
        return name, path, args, kwargs

    def pre_save(self, model_instance, add):
        if not getattr(model_instance, 'update_modified', True):
            return getattr(model_instance, self.attname)
        return super().pre_save(model_instance, add)


class UUIDVersionError(Exception):
    pass


class UUIDFieldMixin:
    """
    UUIDFieldMixin

    By default uses UUID version 4 (randomly generated UUID).

    The field support all uuid versions which are natively supported by the uuid python module, except version 2.
    For more information see: https://docs.python.org/lib/module-uuid.html
    """

    DEFAULT_MAX_LENGTH = 36

    def __init__(self, verbose_name=None, name=None, auto=True, version=4,
                 node=None, clock_seq=None, namespace=None, uuid_name=None, *args,
                 **kwargs):
        if not HAS_UUID:
            raise ImproperlyConfigured("'uuid' module is required for UUIDField. (Do you have Python 2.5 or higher installed ?)")

        kwargs.setdefault('max_length', self.DEFAULT_MAX_LENGTH)

        if auto:
            self.empty_strings_allowed = False
            kwargs['blank'] = True
            kwargs.setdefault('editable', False)

        self.auto = auto
        self.version = version
        self.node = node
        self.clock_seq = clock_seq
        self.namespace = namespace
        self.uuid_name = uuid_name or name

        super().__init__(verbose_name=verbose_name, *args, **kwargs)

    def create_uuid(self):
        if not self.version or self.version == 4:
            return uuid.uuid4()
        elif self.version == 1:
            return uuid.uuid1(self.node, self.clock_seq)
        elif self.version == 2:
            raise UUIDVersionError("UUID version 2 is not supported.")
        elif self.version == 3:
            return uuid.uuid3(self.namespace, self.uuid_name)
        elif self.version == 5:
            return uuid.uuid5(self.namespace, self.uuid_name)
        else:
            raise UUIDVersionError("UUID version %s is not valid." % self.version)

    def pre_save(self, model_instance, add):
        value = super().pre_save(model_instance, add)

        if self.auto and add and value is None:
            value = force_str(self.create_uuid())
            setattr(model_instance, self.attname, value)
            return value
        else:
            if self.auto and not value:
                value = force_str(self.create_uuid())
                setattr(model_instance, self.attname, value)

        return value

    def formfield(self, **kwargs):
        if self.auto:
            return None
        return super().formfield(**kwargs)

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()

        if kwargs.get('max_length', None) == self.DEFAULT_MAX_LENGTH:
            del kwargs['max_length']
        if self.auto is not True:
            kwargs['auto'] = self.auto
        if self.version != 4:
            kwargs['version'] = self.version
        if self.node is not None:
            kwargs['node'] = self.node
        if self.clock_seq is not None:
            kwargs['clock_seq'] = self.clock_seq
        if self.namespace is not None:
            kwargs['namespace'] = self.namespace
        if self.uuid_name is not None:
            kwargs['uuid_name'] = self.name

        return name, path, args, kwargs


class ShortUUIDField(UUIDFieldMixin, CharField):
    """
    ShortUUIDFied

    Generates concise (22 characters instead of 36), unambiguous, URL-safe UUIDs.

    Based on `shortuuid`: https://github.com/stochastic-technologies/shortuuid
    """

    DEFAULT_MAX_LENGTH = 22

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not HAS_SHORT_UUID:
            raise ImproperlyConfigured("'shortuuid' module is required for ShortUUIDField. (Do you have Python 2.5 or higher installed ?)")
        kwargs.setdefault('max_length', self.DEFAULT_MAX_LENGTH)

    def create_uuid(self):
        if not self.version or self.version == 4:
            return shortuuid.uuid()
        elif self.version == 1:
            return shortuuid.uuid()
        elif self.version == 2:
            raise UUIDVersionError("UUID version 2 is not supported.")
        elif self.version == 3:
            raise UUIDVersionError("UUID version 3 is not supported.")
        elif self.version == 5:
            return shortuuid.uuid(name=self.namespace)
        else:
            raise UUIDVersionError("UUID version %s is not valid." % self.version)
