from collections import Counter
from cqi.models.corpora import Corpus as CQiCorpus
from cqi.models.subcorpora import Subcorpus as CQiSubcorpus
from cqi.status import StatusOk as CQiStatusOk
from flask import current_app
import gzip
import json
import math
from app import db
from app.models import Corpus
from .utils import CQiOverSocketIOSessionManager


CQI_EXTENSION_FUNCTION_NAMES = [
    'ext_corpus_update_db',
    'ext_corpus_static_data',
    'ext_corpus_paginate_corpus',
    'ext_cqp_paginate_subcorpus',
    'ext_cqp_partial_export_subcorpus',
    'ext_cqp_export_subcorpus',
]


def ext_corpus_update_db(corpus: str) -> CQiStatusOk:
    corpus_id = CQiOverSocketIOSessionManager.get_corpus_id()
    cqi_client = CQiOverSocketIOSessionManager.get_cqi_client()
    db_corpus = Corpus.query.get(corpus_id)
    cqi_corpus = cqi_client.corpora.get(corpus)
    db_corpus.num_tokens = cqi_corpus.size
    db.session.commit()
    return CQiStatusOk()


def ext_corpus_static_data(corpus: str) -> dict:
    corpus_id = CQiOverSocketIOSessionManager.get_corpus_id()
    db_corpus = Corpus.query.get(corpus_id)

    static_data_file_path = db_corpus.path / 'cwb' / 'static.json.gz'
    if static_data_file_path.exists():
        with static_data_file_path.open('rb') as f:
            return f.read()

    cqi_client = CQiOverSocketIOSessionManager.get_cqi_client()
    cqi_corpus = cqi_client.corpora.get(corpus)
    cqi_p_attrs = cqi_corpus.positional_attributes.list()
    cqi_s_attrs = cqi_corpus.structural_attributes.list()

    static_data = {
        'corpus': {
            'bounds': [0, cqi_corpus.size - 1],
            'freqs': {}
        },
        'p_attrs': {},
        's_attrs': {},
        'values': {'p_attrs': {}, 's_attrs': {}}
    }

    for p_attr in cqi_p_attrs:
        current_app.logger.info(f'corpus.freqs.{p_attr.name}')
        static_data['corpus']['freqs'][p_attr.name] = []
        p_attr_id_list = list(range(p_attr.lexicon_size))
        static_data['corpus']['freqs'][p_attr.name].extend(p_attr.freqs_by_ids(p_attr_id_list))
        del p_attr_id_list

        current_app.logger.info(f'p_attrs.{p_attr.name}')
        static_data['p_attrs'][p_attr.name] = []
        cpos_list = list(range(cqi_corpus.size))
        static_data['p_attrs'][p_attr.name].extend(p_attr.ids_by_cpos(cpos_list))
        del cpos_list

        current_app.logger.info(f'values.p_attrs.{p_attr.name}')
        static_data['values']['p_attrs'][p_attr.name] = []
        p_attr_id_list = list(range(p_attr.lexicon_size))
        static_data['values']['p_attrs'][p_attr.name].extend(p_attr.values_by_ids(p_attr_id_list))
        del p_attr_id_list

    for s_attr in cqi_s_attrs:
        if s_attr.has_values:
            continue

        static_data['s_attrs'][s_attr.name] = {'lexicon': [], 'values': None}

        if s_attr.name in ['s', 'ent']:
            ##############################################################
            # A faster way to get cpos boundaries for smaller s_attrs    #
            # Note: Needs more testing, don't use it in production       #
            ##############################################################
            cqi_corpus.query('Last', f'<{s_attr.name}> []* </{s_attr.name}>;')
            cqi_subcorpus = cqi_corpus.subcorpora.get('Last')
            first_match = 0
            last_match = cqi_subcorpus.size - 1
            match_boundaries = zip(
                range(first_match, last_match + 1),
                cqi_subcorpus.dump(
                    cqi_subcorpus.fields['match'],
                    first_match,
                    last_match
                ),
                cqi_subcorpus.dump(
                    cqi_subcorpus.fields['matchend'],
                    first_match,
                    last_match
                )
            )
            cqi_subcorpus.drop()
            del cqi_subcorpus, first_match, last_match
            for id, lbound, rbound in match_boundaries:
                static_data['s_attrs'][s_attr.name]['lexicon'].append({})
                current_app.logger.info(f's_attrs.{s_attr.name}.lexicon.{id}.bounds')
                static_data['s_attrs'][s_attr.name]['lexicon'][id]['bounds'] = [lbound, rbound]
            del match_boundaries

        if s_attr.name != 'text':
            continue

        for id in range(0, s_attr.size):
            static_data['s_attrs'][s_attr.name]['lexicon'].append({})
            # This is a very slow operation, thats why we only use it for
            # the text attribute
            lbound, rbound = s_attr.cpos_by_id(id)
            current_app.logger.info(f's_attrs.{s_attr.name}.lexicon.{id}.bounds')
            static_data['s_attrs'][s_attr.name]['lexicon'][id]['bounds'] = [lbound, rbound]
            static_data['s_attrs'][s_attr.name]['lexicon'][id]['freqs'] = {}
            cpos_list = list(range(lbound, rbound + 1))
            for p_attr in cqi_p_attrs:
                p_attr_ids = []
                p_attr_ids.extend(p_attr.ids_by_cpos(cpos_list))
                current_app.logger.info(f's_attrs.{s_attr.name}.lexicon.{id}.freqs.{p_attr.name}')
                static_data['s_attrs'][s_attr.name]['lexicon'][id]['freqs'][p_attr.name] = dict(Counter(p_attr_ids))
                del p_attr_ids
            del cpos_list

        sub_s_attrs = cqi_corpus.structural_attributes.list(filters={'part_of': s_attr})
        current_app.logger.info(f's_attrs.{s_attr.name}.values')
        static_data['s_attrs'][s_attr.name]['values'] = [
            sub_s_attr.name[(len(s_attr.name) + 1):]
            for sub_s_attr in sub_s_attrs
        ]
        s_attr_id_list = list(range(s_attr.size))
        sub_s_attr_values = []
        for sub_s_attr in sub_s_attrs:
            tmp = []
            tmp.extend(sub_s_attr.values_by_ids(s_attr_id_list))
            sub_s_attr_values.append(tmp)
            del tmp
        del s_attr_id_list
        current_app.logger.info(f'values.s_attrs.{s_attr.name}')
        static_data['values']['s_attrs'][s_attr.name] = [
            {
                s_attr_value_name: sub_s_attr_values[s_attr_value_name_idx][s_attr_id]
                for s_attr_value_name_idx, s_attr_value_name in enumerate(
                    static_data['s_attrs'][s_attr.name]['values']
                )
            } for s_attr_id in range(0, s_attr.size)
        ]
        del sub_s_attr_values
    current_app.logger.info('Saving static data to file')
    with gzip.open(static_data_file_path, 'wt') as f:
        json.dump(static_data, f)
    del static_data
    current_app.logger.info('Sending static data to client')
    with open(static_data_file_path, 'rb') as f:
        return f.read()


def ext_corpus_paginate_corpus(
    corpus: str,
    page: int = 1,
    per_page: int = 20
) -> dict:
    cqi_client = CQiOverSocketIOSessionManager.get_cqi_client()
    cqi_corpus = cqi_client.corpora.get(corpus)
    # Sanity checks
    if (
        per_page < 1
        or page < 1
        or (
            cqi_corpus.size > 0
            and page > math.ceil(cqi_corpus.size / per_page)
        )
    ):
        return {'code': 416, 'msg': 'Range Not Satisfiable'}
    first_cpos = (page - 1) * per_page
    last_cpos = min(cqi_corpus.size, first_cpos + per_page)
    cpos_list = [*range(first_cpos, last_cpos)]
    lookups = _lookups_by_cpos(cqi_corpus, cpos_list)
    payload = {}
    # the items for the current page
    payload['items'] = [cpos_list]
    # the lookups for the items
    payload['lookups'] = lookups
    # the total number of items matching the query
    payload['total'] = cqi_corpus.size
    # the number of items to be displayed on a page.
    payload['per_page'] = per_page
    # The total number of pages
    payload['pages'] = math.ceil(payload['total'] / payload['per_page'])
    # the current page number (1 indexed)
    payload['page'] = page if payload['pages'] > 0 else None
    # True if a previous page exists
    payload['has_prev'] = payload['page'] > 1 if payload['page'] else False
    # True if a next page exists.
    payload['has_next'] = payload['page'] < payload['pages'] if payload['page'] else False  # noqa
    # Number of the previous page.
    payload['prev_num'] = payload['page'] - 1 if payload['has_prev'] else None
    # Number of the next page
    payload['next_num'] = payload['page'] + 1 if payload['has_next'] else None
    return payload


def ext_cqp_paginate_subcorpus(
    subcorpus: str,
    context: int = 50,
    page: int = 1,
    per_page: int = 20
) -> dict:
    corpus_name, subcorpus_name = subcorpus.split(':', 1)
    cqi_client = CQiOverSocketIOSessionManager.get_cqi_client()
    cqi_corpus = cqi_client.corpora.get(corpus_name)
    cqi_subcorpus = cqi_corpus.subcorpora.get(subcorpus_name)
    # Sanity checks
    if (
        per_page < 1
        or page < 1
        or (
            cqi_subcorpus.size > 0
            and page > math.ceil(cqi_subcorpus.size / per_page)
        )
    ):
        return {'code': 416, 'msg': 'Range Not Satisfiable'}
    offset = (page - 1) * per_page
    cutoff = per_page
    cqi_results_export = _export_subcorpus(
        cqi_subcorpus, context=context, cutoff=cutoff, offset=offset)
    payload = {}
    # the items for the current page
    payload['items'] = cqi_results_export.pop('matches')
    # the lookups for the items
    payload['lookups'] = cqi_results_export
    # the total number of items matching the query
    payload['total'] = cqi_subcorpus.size
    # the number of items to be displayed on a page.
    payload['per_page'] = per_page
    # The total number of pages
    payload['pages'] = math.ceil(payload['total'] / payload['per_page'])
    # the current page number (1 indexed)
    payload['page'] = page if payload['pages'] > 0 else None
    # True if a previous page exists
    payload['has_prev'] = payload['page'] > 1 if payload['page'] else False
    # True if a next page exists.
    payload['has_next'] = payload['page'] < payload['pages'] if payload['page'] else False  # noqa
    # Number of the previous page.
    payload['prev_num'] = payload['page'] - 1 if payload['has_prev'] else None
    # Number of the next page
    payload['next_num'] = payload['page'] + 1 if payload['has_next'] else None
    return payload


def ext_cqp_partial_export_subcorpus(
    subcorpus: str,
    match_id_list: list,
    context: int = 50
) -> dict:
    corpus_name, subcorpus_name = subcorpus.split(':', 1)
    cqi_client = CQiOverSocketIOSessionManager.get_cqi_client()
    cqi_corpus = cqi_client.corpora.get(corpus_name)
    cqi_subcorpus = cqi_corpus.subcorpora.get(subcorpus_name)
    cqi_subcorpus_partial_export = _partial_export_subcorpus(cqi_subcorpus, match_id_list, context=context)
    return cqi_subcorpus_partial_export


def ext_cqp_export_subcorpus(subcorpus: str, context: int = 50) -> dict:
    corpus_name, subcorpus_name = subcorpus.split(':', 1)
    cqi_client = CQiOverSocketIOSessionManager.get_cqi_client()
    cqi_corpus = cqi_client.corpora.get(corpus_name)
    cqi_subcorpus = cqi_corpus.subcorpora.get(subcorpus_name)
    cqi_subcorpus_export = _export_subcorpus(cqi_subcorpus, context=context)
    return cqi_subcorpus_export


def _lookups_by_cpos(corpus: CQiCorpus, cpos_list: list[int]) -> dict:
    lookups = {}
    lookups['cpos_lookup'] = {cpos: {} for cpos in cpos_list}
    for attr in corpus.positional_attributes.list():
        cpos_attr_values = attr.values_by_cpos(cpos_list)
        for i, cpos in enumerate(cpos_list):
            lookups['cpos_lookup'][cpos][attr.name] = cpos_attr_values[i]
    for attr in corpus.structural_attributes.list():
        # We only want to iterate over non subattributes, identifiable by
        # attr.has_values == False
        if attr.has_values:
            continue
        cpos_attr_ids = attr.ids_by_cpos(cpos_list)
        for i, cpos in enumerate(cpos_list):
            if cpos_attr_ids[i] == -1:
                continue
            lookups['cpos_lookup'][cpos][attr.name] = cpos_attr_ids[i]
        occured_attr_ids = [x for x in set(cpos_attr_ids) if x != -1]
        if len(occured_attr_ids) == 0:
            continue
        subattrs = corpus.structural_attributes.list(filters={'part_of': attr})
        if len(subattrs) == 0:
            continue
        lookup_name = f'{attr.name}_lookup'
        lookups[lookup_name] = {}
        for attr_id in occured_attr_ids:
            lookups[lookup_name][attr_id] = {}
        for subattr in subattrs:
            subattr_name = subattr.name[(len(attr.name) + 1):]  # noqa
            for i, subattr_value in enumerate(subattr.values_by_ids(occured_attr_ids)):  # noqa
                lookups[lookup_name][occured_attr_ids[i]][subattr_name] = subattr_value  # noqa
    return lookups


def _partial_export_subcorpus(
    subcorpus: CQiSubcorpus,
    match_id_list: list[int],
    context: int = 25
) -> dict:
    if subcorpus.size == 0:
        return {'matches': []}
    match_boundaries = []
    for match_id in match_id_list:
        if match_id < 0 or match_id >= subcorpus.size:
            continue
        match_boundaries.append(
            (
                match_id,
                subcorpus.dump(subcorpus.fields['match'], match_id, match_id)[0],
                subcorpus.dump(subcorpus.fields['matchend'], match_id, match_id)[0]
            )
        )
    cpos_set = set()
    matches = []
    for match_boundary in match_boundaries:
        match_num, match_start, match_end = match_boundary
        c = (match_start, match_end)
        if match_start == 0 or context == 0:
            lc = None
            cpos_list_lbound = match_start
        else:
            lc_lbound = max(0, (match_start - context))
            lc_rbound = match_start - 1
            lc = (lc_lbound, lc_rbound)
            cpos_list_lbound = lc_lbound
        if match_end == (subcorpus.collection.corpus.size - 1) or context == 0:
            rc = None
            cpos_list_rbound = match_end
        else:
            rc_lbound = match_end + 1
            rc_rbound = min(
                (match_end + context),
                (subcorpus.collection.corpus.size - 1)
            )
            rc = (rc_lbound, rc_rbound)
            cpos_list_rbound = rc_rbound
        match = {'num': match_num, 'lc': lc, 'c': c, 'rc': rc}
        matches.append(match)
        cpos_set.update(range(cpos_list_lbound, cpos_list_rbound + 1))
    lookups = _lookups_by_cpos(subcorpus.collection.corpus, list(cpos_set))
    return {'matches': matches, **lookups}


def _export_subcorpus(
    subcorpus: CQiSubcorpus,
    context: int = 25,
    cutoff: float = float('inf'),
    offset: int = 0
) -> dict:
    if subcorpus.size == 0:
        return {'matches': []}
    first_match = max(0, offset)
    last_match = min((offset + cutoff - 1), (subcorpus.size - 1))
    match_boundaries = zip(
        range(first_match, last_match + 1),
        subcorpus.dump(subcorpus.fields['match'], first_match, last_match),
        subcorpus.dump(subcorpus.fields['matchend'], first_match, last_match)
    )
    cpos_set = set()
    matches = []
    for match_num, match_start, match_end in match_boundaries:
        c = (match_start, match_end)
        if match_start == 0 or context == 0:
            lc = None
            cpos_list_lbound = match_start
        else:
            lc_lbound = max(0, (match_start - context))
            lc_rbound = match_start - 1
            lc = (lc_lbound, lc_rbound)
            cpos_list_lbound = lc_lbound
        if match_end == (subcorpus.collection.corpus.size - 1) or context == 0:
            rc = None
            cpos_list_rbound = match_end
        else:
            rc_lbound = match_end + 1
            rc_rbound = min(
                (match_end + context),
                (subcorpus.collection.corpus.size - 1)
            )
            rc = (rc_lbound, rc_rbound)
            cpos_list_rbound = rc_rbound
        match = {'num': match_num, 'lc': lc, 'c': c, 'rc': rc}
        matches.append(match)
        cpos_set.update(range(cpos_list_lbound, cpos_list_rbound + 1))
    lookups = _lookups_by_cpos(subcorpus.collection.corpus, list(cpos_set))
    return {'matches': matches, **lookups}