import re
import numpy as np


NEWLINE_WHITESPACE_RE = re.compile(r'\n\s*\n')
# this was (r'^([\d]+[,\.]*)+$')
# but the runtime on that can explode exponentially
# for example, on 111111111111111111111111a
NUMERIC_RE = re.compile(r'^[\d]+([,\.]+[\d]+)*[,\.]*$')
WHITESPACE_RE = re.compile(r'\s')

def filter_consecutive_whitespaces(para):
    filtered = []
    for i, (char, label) in enumerate(para):
        if i > 0:
            if char == ' ' and para[i-1][0] == ' ':
                continue

        filtered.append((char, label))

    return filtered

class TokenizationDataset:
    def __init__(self, config, text=None, vocab=None, dictionary=None, *args, **kwargs):
        super().__init__(*args, **kwargs)  # forwards all unused arguments
        self.args = config
        self.dictionary = dictionary
        self.vocab = vocab

        text_chunks = NEWLINE_WHITESPACE_RE.split(text)
        text_chunks = [pt.rstrip() for pt in text_chunks]
        text_chunks = [pt for pt in text_chunks if pt]
        labels = [[0 for _ in pt] for pt in text_chunks]

        skip_newline = self.args.get('skip_newline', False)
        self.data = [[(WHITESPACE_RE.sub(' ', char), label) # substitute special whitespaces
                      for char, label in zip(pt, pc) if not (skip_newline and char == '\n')] # check if newline needs to be eaten
                     for pt, pc in zip(text_chunks, labels)]

        # remove consecutive whitespaces
        self.data = [filter_consecutive_whitespaces(x) for x in self.data]

    def labels(self):
        """
        Returns a list of the labels for all of the sentences in this DataLoader

        Used at eval time to compare to the results, for example
        """
        return [np.array(list(x[1] for x in sent)) for sent in self.data]

    def extract_dict_feat(self, para, idx):
        """
        This function is to extract dictionary features for each character
        """
        length = len(para)

        dict_forward_feats = [0 for i in range(self.args['num_dict_feat'])]
        dict_backward_feats = [0 for i in range(self.args['num_dict_feat'])]
        forward_word = para[idx][0]
        backward_word = para[idx][0]
        prefix = True
        suffix = True
        for window in range(1,self.args['num_dict_feat']+1):
            # concatenate each character and check if words found in dict not, stop if prefix not found
            #check if idx+t is out of bound and if the prefix is already not found
            if (idx + window) <= length-1 and prefix:
                forward_word += para[idx+window][0].lower()
                #check in json file if the word is present as prefix or word or None.
                feat = 1 if forward_word in self.dictionary["words"] else 0
                #if the return value is not 2 or 3 then the checking word is not a valid word in dict.
                dict_forward_feats[window-1] = feat
                #if the dict return 0 means no prefixes found, thus, stop looking for forward.
                if forward_word not in self.dictionary["prefixes"]:
                    prefix = False
            #backward check: similar to forward
            if (idx - window) >= 0 and suffix:
                backward_word = para[idx-window][0].lower() + backward_word
                feat = 1 if backward_word in self.dictionary["words"] else 0
                dict_backward_feats[window-1] = feat
                if backward_word not in self.dictionary["suffixes"]:
                    suffix = False
            #if cannot find both prefix and suffix, then exit the loop
            if not prefix and not suffix:
                break

        return dict_forward_feats + dict_backward_feats

    def para_to_sentences(self, para):
        """ Convert a paragraph to a list of processed sentences. """
        res = []
        funcs = []
        for feat_func in self.args['feat_funcs']:
            if feat_func == 'end_of_para' or feat_func == 'start_of_para':
                # skip for position-dependent features
                continue
            if feat_func == 'space_before':
                func = lambda x: 1 if x.startswith(' ') else 0
            elif feat_func == 'capitalized':
                func = lambda x: 1 if x[0].isupper() else 0
            elif feat_func == 'numeric':
                func = lambda x: 1 if (NUMERIC_RE.match(x) is not None) else 0
            else:
                raise ValueError('Feature function "{}" is undefined.'.format(feat_func))

            funcs.append(func)

        # stacking all featurize functions
        composite_func = lambda x: [f(x) for f in funcs]

        def process_sentence(sent_units, sent_labels, sent_feats):
            return (np.array([self.vocab.unit2id(y) for y in sent_units]),
                    np.array(sent_labels),
                    np.array(sent_feats),
                    list(sent_units))

        use_end_of_para = 'end_of_para' in self.args['feat_funcs']
        use_start_of_para = 'start_of_para' in self.args['feat_funcs']
        use_dictionary = self.args['use_dictionary']
        current_units = []
        current_labels = []
        current_feats = []
        for i, (unit, label) in enumerate(para):
            feats = composite_func(unit)
            # position-dependent features
            if use_end_of_para:
                f = 1 if i == len(para)-1 else 0
                feats.append(f)
            if use_start_of_para:
                f = 1 if i == 0 else 0
                feats.append(f)

            #if dictionary feature is selected
            if use_dictionary:
                dict_feats = self.extract_dict_feat(para, i)
                feats = feats + dict_feats

            current_units.append(unit)
            current_labels.append(label)
            current_feats.append(feats)

        if len(current_units) > 0:
            res.append(process_sentence(current_units, current_labels, current_feats))

        return res

    def advance_old_batch(self, eval_offsets, old_batch):
        """
        Advance to a new position in a batch where we have partially processed the batch

        If we have previously built a batch of data and made predictions on them, then when we are trying to make
        prediction on later characters in those paragraphs, we can avoid rebuilding the converted data from scratch
        and just (essentially) advance the indices/offsets from where we read converted data in this old batch.
        In this case, eval_offsets index within the old_batch to advance the strings to process.
        """
        unkid = self.vocab.unit2id('<UNK>')
        padid = self.vocab.unit2id('<PAD>')

        ounits, olabels, ofeatures, oraw = old_batch
        feat_size = ofeatures.shape[-1]
        lens = (ounits != padid).sum(1).tolist()
        pad_len = max(l-i for i, l in zip(eval_offsets, lens))

        units = np.full((len(ounits), pad_len), padid, dtype=np.int64)
        labels = np.full((len(ounits), pad_len), -1, dtype=np.int32)
        features = np.zeros((len(ounits), pad_len, feat_size), dtype=np.float32)
        raw_units = []

        for i in range(len(ounits)):
            eval_offsets[i] = min(eval_offsets[i], lens[i])
            units[i, :(lens[i] - eval_offsets[i])] = ounits[i, eval_offsets[i]:lens[i]]
            labels[i, :(lens[i] - eval_offsets[i])] = olabels[i, eval_offsets[i]:lens[i]]
            features[i, :(lens[i] - eval_offsets[i])] = ofeatures[i, eval_offsets[i]:lens[i]]
            raw_units.append(oraw[i][eval_offsets[i]:lens[i]] + ['<PAD>'] * (pad_len - lens[i] + eval_offsets[i]))

        return units, labels, features, raw_units



def sort_with_indices(data, key=None, reverse=False):
    """
    Sort data and return both the data and the original indices.

    One useful application is to sort by length, which can be done with key=len
    Returns the data as a sorted list, then the indices of the original list.
    """
    if not data:
        return [], []
    if key:
        ordered = sorted(enumerate(data), key=lambda x: key(x[1]), reverse=reverse)
    else:
        ordered = sorted(enumerate(data), key=lambda x: x[1], reverse=reverse)

    result = tuple(zip(*ordered))
    return result[1], result[0]

def unsort(sorted_list, oidx):
    """
    Unsort a sorted list, based on the original idx.
    """
    assert len(sorted_list) == len(oidx), "Number of list elements must match with original indices."
    if len(sorted_list) == 0:
        return []
    _, unsorted = [list(t) for t in zip(*sorted(zip(oidx, sorted_list)))]
    return unsorted


class SortedDataset:
    """
    Holds a TokenizationDataset for use in a torch DataLoader

    The torch DataLoader is different from the DataLoader defined here
    and allows for cpu & gpu parallelism.  Updating output_predictions
    to use this class as a wrapper to a TokenizationDataset means the
    calculation of features can happen in parallel, saving quite a
    bit of time.
    """
    def __init__(self, dataset):
        self.dataset = dataset
        self.data, self.indices = sort_with_indices(self.dataset.data, key=len)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.dataset.para_to_sentences(self.data[index])

    def unsort(self, arr):
        return unsort(arr, self.indices)

    def collate(self, samples):
        if any(len(x) > 1 for x in samples):
            raise ValueError("Expected all paragraphs to have no preset sentence splits!")
        feat_size = samples[0][0][2].shape[-1]
        padid = self.dataset.vocab.unit2id('<PAD>')

        # +1 so that all samples end with at least one pad
        pad_len = max(len(x[0][3]) for x in samples) + 1

        units = np.full((len(samples), pad_len), padid, dtype=np.int64)
        labels = np.full((len(samples), pad_len), -1, dtype=np.int32)
        features = np.zeros((len(samples), pad_len, feat_size), dtype=np.float32)
        raw_units = []
        for i, sample in enumerate(samples):
            u_, l_, f_, r_ = sample[0]
            units[i, :len(u_)] = u_
            labels[i, :len(l_)] = l_
            features[i, :len(f_), :] = f_
            raw_units.append(r_ + ['<PAD>'] * (pad_len - len(r_)))

        return units, labels, features, raw_units


