/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec.wordstore;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.HuffmanNode;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyWord;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VocabularyHolder
implements Serializable {
    private final Map<String, VocabularyWord> vocabulary = new ConcurrentHashMap<String, VocabularyWord>();
    private transient Map<Integer, VocabularyWord> idxMap = new ConcurrentHashMap<Integer, VocabularyWord>();
    private int minWordFrequency = 0;
    private boolean hugeModelExpected = false;
    private int retentionDelay = 3;
    private VocabCache vocabCache;
    private int scavengerThreshold = 2000000;
    private long totalWordOccurrences = 0L;
    private transient AtomicLong hiddenWordsCounter = new AtomicLong(0L);
    private AtomicInteger totalWordCount = new AtomicInteger(0);
    private Logger logger = LoggerFactory.getLogger(VocabularyHolder.class);
    private static final int MAX_CODE_LENGTH = 40;

    protected VocabularyHolder() {
    }

    protected VocabularyHolder(@NonNull VocabCache<? extends SequenceElement> cache, boolean markAsSpecial) {
        if (cache == null) {
            throw new NullPointerException("cache is marked @NonNull but is null");
        }
        this.vocabCache = cache;
        for (SequenceElement sequenceElement : cache.tokens()) {
            VocabularyWord vw = new VocabularyWord(sequenceElement.getLabel());
            vw.setCount((int)sequenceElement.getElementFrequency());
            vw.setSpecial(markAsSpecial);
            if (sequenceElement.getPoints() != null && !sequenceElement.getPoints().isEmpty()) {
                vw.setHuffmanNode(VocabularyHolder.buildNode(sequenceElement.getCodes(), sequenceElement.getPoints(), sequenceElement.getCodeLength(), sequenceElement.getIndex()));
            }
            this.vocabulary.put(vw.getWord(), vw);
        }
        if (this.numWords() > 1) {
            this.updateHuffmanCodes();
        }
        this.logger.info("Init from VocabCache is complete. " + this.numWords() + " word(s) were transferred.");
    }

    public static HuffmanNode buildNode(List<Byte> codes, List<Integer> points, int codeLen, int index) {
        return new HuffmanNode(VocabularyHolder.listToArray(codes), VocabularyHolder.listToArray(points, 40), index, (byte)codeLen);
    }

    public void transferBackToVocabCache() {
        this.transferBackToVocabCache(this.vocabCache, true);
    }

    public void transferBackToVocabCache(VocabCache cache) {
        this.transferBackToVocabCache(cache, true);
    }

    public void transferBackToVocabCache(VocabCache cache, boolean emptyHolder) {
        if (!(cache instanceof InMemoryLookupCache)) {
            throw new IllegalStateException("Sorry, only InMemoryLookupCache use implemented.");
        }
        List<VocabularyWord> words = this.words();
        for (VocabularyWord word : words) {
            if (word.getWord().isEmpty()) continue;
            VocabWord vocabWord = new VocabWord(1.0, word.getWord());
            if (word.getHistoricalGradient() != null) {
                INDArray gradient = Nd4j.create((double[])word.getHistoricalGradient());
                vocabWord.setHistoricalGradient(gradient);
            }
            ((InMemoryLookupCache)cache).getVocabs().put(word.getWord(), vocabWord);
            ((InMemoryLookupCache)cache).getTokens().put(word.getWord(), vocabWord);
            if (word.getHuffmanNode() != null) {
                vocabWord.setIndex(word.getHuffmanNode().getIdx());
                vocabWord.setCodeLength(word.getHuffmanNode().getLength());
                vocabWord.setPoints(VocabularyHolder.arrayToList(word.getHuffmanNode().getPoint(), (int)word.getHuffmanNode().getLength()));
                vocabWord.setCodes(VocabularyHolder.arrayToList(word.getHuffmanNode().getCode(), (int)word.getHuffmanNode().getLength()));
                cache.addWordToIndex(word.getHuffmanNode().getIdx(), word.getWord());
            }
            if (word.getCount() <= 1) continue;
            cache.incrementWordCount(word.getWord(), word.getCount() - 1);
        }
        if (emptyHolder) {
            this.idxMap.clear();
            this.vocabulary.clear();
        }
    }

    protected void setScavengerActivationThreshold(int threshold) {
        this.scavengerThreshold = threshold;
    }

    public static List<Byte> arrayToList(byte[] array, int codeLen) {
        ArrayList<Byte> result = new ArrayList<Byte>();
        for (int x = 0; x < codeLen; ++x) {
            result.add(array[x]);
        }
        return result;
    }

    public static byte[] listToArray(List<Byte> code) {
        byte[] array = new byte[40];
        for (int x = 0; x < code.size(); ++x) {
            array[x] = code.get(x);
        }
        return array;
    }

    public static int[] listToArray(List<Integer> points, int codeLen) {
        int[] array = new int[points.size()];
        for (int x = 0; x < points.size(); ++x) {
            array[x] = points.get(x);
        }
        return array;
    }

    public static List<Integer> arrayToList(int[] array, int codeLen) {
        ArrayList<Integer> result = new ArrayList<Integer>();
        for (int x = 0; x < codeLen; ++x) {
            result.add(array[x]);
        }
        return result;
    }

    public Collection<VocabularyWord> getVocabulary() {
        return this.vocabulary.values();
    }

    public VocabularyWord getVocabularyWordByString(String word) {
        return this.vocabulary.get(word);
    }

    public VocabularyWord getVocabularyWordByIdx(Integer id) {
        return this.idxMap.get(id);
    }

    public boolean containsWord(String word) {
        return this.vocabulary.containsKey(word);
    }

    public void incrementWordCounter(String word) {
        if (this.vocabulary.containsKey(word)) {
            this.vocabulary.get(word).incrementCount();
        }
    }

    public void addWord(String word) {
        if (!this.vocabulary.containsKey(word)) {
            VocabularyWord vw = new VocabularyWord(word);
            if (this.hugeModelExpected) {
                vw.setFrequencyShift(new byte[this.retentionDelay]);
            }
            this.vocabulary.put(word, vw);
            if (this.hugeModelExpected && this.minWordFrequency > 1 && this.hiddenWordsCounter.incrementAndGet() % (long)this.scavengerThreshold == 0L) {
                this.activateScavenger();
            }
            return;
        }
    }

    public void addWord(VocabularyWord word) {
        this.vocabulary.put(word.getWord(), word);
    }

    public void consumeVocabulary(VocabularyHolder holder) {
        for (VocabularyWord word : holder.getVocabulary()) {
            if (!this.containsWord(word.getWord())) {
                this.addWord(word);
                continue;
            }
            holder.incrementWordCounter(word.getWord());
        }
    }

    protected synchronized void activateScavenger() {
        int initialSize = this.vocabulary.size();
        ArrayList<VocabularyWord> words = new ArrayList<VocabularyWord>(this.vocabulary.values());
        for (VocabularyWord word : words) {
            if (word.isSpecial() || word.getCount() >= this.minWordFrequency || word.getFrequencyShift() == null) {
                word.setFrequencyShift(null);
                continue;
            }
            word.getFrequencyShift()[word.getRetentionStep()] = (byte)word.getCount();
            int activation = Math.max(this.minWordFrequency / 5, 2);
            this.logger.debug("Current state> Activation: [" + activation + "], retention info: " + Arrays.toString(word.getFrequencyShift()));
            if (word.getCount() <= activation && word.getFrequencyShift()[this.retentionDelay - 1] > 0 && word.getFrequencyShift()[this.retentionDelay - 1] <= activation && word.getFrequencyShift()[this.retentionDelay - 1] == word.getFrequencyShift()[0]) {
                this.vocabulary.remove(word.getWord());
            }
            if (word.getRetentionStep() < this.retentionDelay - 1) {
                word.incrementRetentionStep();
                continue;
            }
            for (int x = 1; x < this.retentionDelay; ++x) {
                word.getFrequencyShift()[x - 1] = word.getFrequencyShift()[x];
            }
        }
        this.logger.info("Scavenger was activated. Vocab size before: [" + initialSize + "],  after: [" + this.vocabulary.size() + "]");
    }

    public void resetWordCounters() {
        for (VocabularyWord word : this.getVocabulary()) {
            word.setHuffmanNode(null);
            word.setFrequencyShift(null);
            word.setCount(0);
        }
    }

    public int numWords() {
        return this.vocabulary.size();
    }

    public void truncateVocabulary() {
        this.truncateVocabulary(this.minWordFrequency);
    }

    public void truncateVocabulary(int threshold) {
        this.logger.debug("Truncating vocabulary to minWordFrequency: [" + threshold + "]");
        Set<String> keyset = this.vocabulary.keySet();
        for (String word : keyset) {
            VocabularyWord vw = this.vocabulary.get(word);
            if (vw.isSpecial() || vw.getCount() >= threshold) continue;
            this.vocabulary.remove(word);
            if (vw.getHuffmanNode() == null) continue;
            this.idxMap.remove(vw.getHuffmanNode().getIdx());
        }
    }

    public List<VocabularyWord> updateHuffmanCodes() {
        int a;
        List<VocabularyWord> vocab = this.words();
        int[] count = new int[vocab.size() * 2 + 1];
        int[] parent_node = new int[vocab.size() * 2 + 1];
        byte[] binary = new byte[vocab.size() * 2 + 1];
        for (a = 0; a < vocab.size(); ++a) {
            count[a] = vocab.get(a).getCount();
        }
        for (a = vocab.size(); a < vocab.size() * 2; ++a) {
            count[a] = Integer.MAX_VALUE;
        }
        int pos1 = vocab.size() - 1;
        int pos2 = vocab.size();
        for (int a2 = 0; a2 < vocab.size(); ++a2) {
            int min1i = pos1 >= 0 ? (count[pos1] < count[pos2] ? pos1-- : pos2++) : pos2++;
            int min2i = pos1 >= 0 ? (count[pos1] < count[pos2] ? pos1-- : pos2++) : pos2++;
            count[vocab.size() + a2] = count[min1i] + count[min2i];
            parent_node[min1i] = vocab.size() + a2;
            parent_node[min2i] = vocab.size() + a2;
            binary[min2i] = 1;
        }
        byte[] code = new byte[40];
        int[] point = new int[40];
        for (int a3 = 0; a3 < vocab.size(); ++a3) {
            int b = a3;
            int i = 0;
            byte[] lcode = new byte[40];
            int[] lpoint = new int[40];
            do {
                code[i] = binary[b];
                point[i] = b;
                ++i;
            } while ((b = parent_node[b]) != vocab.size() * 2 - 2);
            lpoint[0] = vocab.size() - 2;
            for (b = 0; b < i; ++b) {
                lcode[i - b - 1] = code[b];
                lpoint[i - b] = point[b] - vocab.size();
            }
            vocab.get(a3).setHuffmanNode(new HuffmanNode(lcode, lpoint, a3, (byte)i));
        }
        this.idxMap.clear();
        for (VocabularyWord word : vocab) {
            this.idxMap.put(word.getHuffmanNode().getIdx(), word);
        }
        return vocab;
    }

    public int indexOf(String word) {
        if (this.vocabulary.containsKey(word)) {
            return this.vocabulary.get(word).getHuffmanNode().getIdx();
        }
        return -1;
    }

    public List<VocabularyWord> words() {
        ArrayList<VocabularyWord> vocab = new ArrayList<VocabularyWord>(this.vocabulary.values());
        Collections.sort(vocab, new Comparator<VocabularyWord>(){

            @Override
            public int compare(VocabularyWord o1, VocabularyWord o2) {
                return Integer.compare(o2.getCount(), o1.getCount());
            }
        });
        return vocab;
    }

    public long totalWordsBeyondLimit() {
        if (this.totalWordOccurrences == 0L) {
            for (VocabularyWord word : this.vocabulary.values()) {
                this.totalWordOccurrences += (long)word.getCount();
            }
            return this.totalWordOccurrences;
        }
        return this.totalWordOccurrences;
    }

    public static class Builder {
        private VocabCache cache = null;
        private int minWordFrequency = 0;
        private boolean hugeModelExpected = false;
        private int scavengerThreshold = 2000000;
        private int retentionDelay = 3;

        public Builder externalCache(@NonNull VocabCache cache) {
            if (cache == null) {
                throw new NullPointerException("cache is marked @NonNull but is null");
            }
            this.cache = cache;
            return this;
        }

        public Builder minWordFrequency(int threshold) {
            this.minWordFrequency = threshold;
            return this;
        }

        public Builder hugeModelExpected(boolean reallyExpected) {
            this.hugeModelExpected = reallyExpected;
            return this;
        }

        public Builder scavengerActivationThreshold(int threshold) {
            this.scavengerThreshold = threshold;
            return this;
        }

        public Builder scavengerRetentionDelay(int delay) {
            if (delay < 2) {
                throw new IllegalStateException("Delay < 2 doesn't really makes sense");
            }
            this.retentionDelay = delay;
            return this;
        }

        public VocabularyHolder build() {
            VocabularyHolder holder = null;
            holder = this.cache != null ? new VocabularyHolder(this.cache, true) : new VocabularyHolder();
            holder.minWordFrequency = this.minWordFrequency;
            holder.hugeModelExpected = this.hugeModelExpected;
            holder.scavengerThreshold = this.scavengerThreshold;
            holder.retentionDelay = this.retentionDelay;
            return holder;
        }
    }
}

