/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.MarginalProbEstimator;
import cc.mallet.topics.TopicAssignment;
import cc.mallet.topics.TopicInferencer;
import cc.mallet.topics.WorkerRunnable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureSequenceWithBigrams;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.Randoms;
import gnu.trove.TObjectIntHashMap;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.zip.GZIPOutputStream;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ParallelTopicModel
implements Serializable {
    protected ArrayList<TopicAssignment> data = new ArrayList();
    protected Alphabet alphabet;
    protected LabelAlphabet topicAlphabet;
    protected int numTopics;
    protected int topicMask;
    protected int topicBits;
    protected int numTypes;
    protected int totalTokens;
    protected double[] alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    public static final double DEFAULT_BETA = 0.01;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected int[] docLengthCounts;
    protected int[][] topicDocCounts;
    public int numIterations = 1000;
    public int burninPeriod = 200;
    public int saveSampleInterval = 10;
    public int optimizeInterval = 50;
    public int showTopicsInterval = 50;
    public int wordsPerTopic = 7;
    protected int saveStateInterval = 0;
    protected String stateFilename = null;
    protected int saveModelInterval = 0;
    protected String modelFilename = null;
    protected int randomSeed = -1;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood = true;
    int[] typeTotals;
    int maxTypeCount;
    int numThreads = 1;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public ParallelTopicModel(int numberOfTopics) {
        this(numberOfTopics, (double)numberOfTopics, 0.01);
    }

    public ParallelTopicModel(int numberOfTopics, double alphaSum, double beta) {
        this(ParallelTopicModel.newLabelAlphabet(numberOfTopics), alphaSum, beta);
    }

    private static LabelAlphabet newLabelAlphabet(int numTopics) {
        LabelAlphabet ret = new LabelAlphabet();
        for (int i = 0; i < numTopics; ++i) {
            ret.lookupIndex("topic" + i);
        }
        return ret;
    }

    public ParallelTopicModel(LabelAlphabet topicAlphabet, double alphaSum, double beta) {
        this.topicAlphabet = topicAlphabet;
        this.numTopics = topicAlphabet.size();
        if (Integer.bitCount(this.numTopics) == 1) {
            this.topicMask = this.numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = Integer.highestOneBit(this.numTopics) * 2 - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.alphaSum = alphaSum;
        this.alpha = new double[this.numTopics];
        Arrays.fill(this.alpha, alphaSum / (double)this.numTopics);
        this.beta = beta;
        this.tokensPerTopic = new int[this.numTopics];
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        System.err.println("Coded LDA: " + this.numTopics + " topics, " + this.topicBits + " topic bits, " + Integer.toBinaryString(this.topicMask) + " topic mask");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public void setBurninPeriod(int burninPeriod) {
        this.burninPeriod = burninPeriod;
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.randomSeed = seed;
    }

    public void setOptimizeInterval(int interval) {
        this.optimizeInterval = interval;
        if (this.saveSampleInterval > this.optimizeInterval) {
            this.saveSampleInterval = this.optimizeInterval;
        }
    }

    public void setNumThreads(int threads) {
        this.numThreads = threads;
    }

    public void setSaveState(int interval, String filename) {
        this.saveStateInterval = interval;
        this.stateFilename = filename;
    }

    public void setSaveSerializedModel(int interval, String filename) {
        this.saveModelInterval = interval;
        this.modelFilename = filename;
    }

    public void addInstances(InstanceList training) {
        this.alphabet = training.getDataAlphabet();
        this.numTypes = this.alphabet.size();
        this.betaSum = this.beta * (double)this.numTypes;
        this.typeTopicCounts = new int[this.numTypes][];
        this.typeTotals = new int[this.numTypes];
        int doc = 0;
        for (Instance instance : training) {
            ++doc;
            FeatureSequence tokens = (FeatureSequence)instance.getData();
            for (int position = 0; position < tokens.getLength(); ++position) {
                int type;
                int n = type = tokens.getIndexAtPosition(position);
                this.typeTotals[n] = this.typeTotals[n] + 1;
            }
        }
        this.maxTypeCount = 0;
        for (int type = 0; type < this.numTypes; ++type) {
            if (this.typeTotals[type] > this.maxTypeCount) {
                this.maxTypeCount = this.typeTotals[type];
            }
            this.typeTopicCounts[type] = new int[Math.min(this.numTopics, this.typeTotals[type])];
        }
        doc = 0;
        Randoms random = null;
        random = this.randomSeed == -1 ? new Randoms() : new Randoms(this.randomSeed);
        for (Instance instance : training) {
            ++doc;
            FeatureSequence tokens = (FeatureSequence)instance.getData();
            LabelSequence topicSequence = new LabelSequence(this.topicAlphabet, new int[tokens.size()]);
            int[] topics = topicSequence.getFeatures();
            for (int position = 0; position < topics.length; ++position) {
                int topic;
                topics[position] = topic = random.nextInt(this.numTopics);
            }
            TopicAssignment t = new TopicAssignment(instance, topicSequence);
            this.data.add(t);
        }
        this.buildInitialTypeTopicCounts();
        this.initializeHistograms();
    }

    public void buildInitialTypeTopicCounts() {
        Arrays.fill(this.tokensPerTopic, 0);
        for (int type = 0; type < this.numTypes; ++type) {
            int[] topicCounts = this.typeTopicCounts[type];
            for (int position = 0; position < topicCounts.length && topicCounts[position] > 0; ++position) {
                topicCounts[position] = 0;
            }
        }
        for (TopicAssignment document : this.data) {
            FeatureSequence tokens = (FeatureSequence)document.instance.getData();
            LabelSequence topicSequence = document.topicSequence;
            int[] topics = topicSequence.getFeatures();
            for (int position = 0; position < tokens.size(); ++position) {
                int topic;
                int n = topic = topics[position];
                this.tokensPerTopic[n] = this.tokensPerTopic[n] + 1;
                int type = tokens.getIndexAtPosition(position);
                int[] currentTypeTopicCounts = this.typeTopicCounts[type];
                int index = 0;
                int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
                    if (++index == currentTypeTopicCounts.length) {
                        System.out.println("overflow on type " + type);
                    }
                    currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                }
                int currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                if (currentValue == 0) {
                    currentTypeTopicCounts[index] = (1 << this.topicBits) + topic;
                    continue;
                }
                currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + topic;
                while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                    int temp = currentTypeTopicCounts[index];
                    currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                    currentTypeTopicCounts[index - 1] = temp;
                    --index;
                }
            }
        }
    }

    public void sumTypeTopicCounts(WorkerRunnable[] runnables) {
        Arrays.fill(this.tokensPerTopic, 0);
        for (int type = 0; type < this.numTypes; ++type) {
            int[] targetCounts = this.typeTopicCounts[type];
            for (int position = 0; position < targetCounts.length && targetCounts[position] > 0; ++position) {
                targetCounts[position] = 0;
            }
        }
        for (int thread = 0; thread < this.numThreads; ++thread) {
            int[] sourceTotals = runnables[thread].getTokensPerTopic();
            for (int topic = 0; topic < this.numTopics; ++topic) {
                int n = topic;
                this.tokensPerTopic[n] = this.tokensPerTopic[n] + sourceTotals[topic];
            }
            int[][] sourceTypeTopicCounts = runnables[thread].getTypeTopicCounts();
            for (int type = 0; type < this.numTypes; ++type) {
                int[] sourceCounts = sourceTypeTopicCounts[type];
                int[] targetCounts = this.typeTopicCounts[type];
                for (int sourceIndex = 0; sourceIndex < sourceCounts.length && sourceCounts[sourceIndex] > 0; ++sourceIndex) {
                    int topic = sourceCounts[sourceIndex] & this.topicMask;
                    int count = sourceCounts[sourceIndex] >> this.topicBits;
                    int targetIndex = 0;
                    int currentTopic = targetCounts[targetIndex] & this.topicMask;
                    while (targetCounts[targetIndex] > 0 && currentTopic != topic) {
                        if (++targetIndex == targetCounts.length) {
                            System.out.println("overflow in merging on type " + type);
                        }
                        currentTopic = targetCounts[targetIndex] & this.topicMask;
                    }
                    int currentCount = targetCounts[targetIndex] >> this.topicBits;
                    targetCounts[targetIndex] = (currentCount + count << this.topicBits) + topic;
                    while (targetIndex > 0 && targetCounts[targetIndex] > targetCounts[targetIndex - 1]) {
                        int temp = targetCounts[targetIndex];
                        targetCounts[targetIndex] = targetCounts[targetIndex - 1];
                        targetCounts[targetIndex - 1] = temp;
                        --targetIndex;
                    }
                }
            }
        }
    }

    private void initializeHistograms() {
        int maxTokens = 0;
        this.totalTokens = 0;
        for (int doc = 0; doc < this.data.size(); ++doc) {
            FeatureSequence fs = (FeatureSequence)this.data.get((int)doc).instance.getData();
            int seqLen = fs.getLength();
            if (seqLen > maxTokens) {
                maxTokens = seqLen;
            }
            this.totalTokens += seqLen;
        }
        System.err.println("max tokens: " + maxTokens);
        System.err.println("total tokens: " + this.totalTokens);
        this.docLengthCounts = new int[maxTokens + 1];
        this.topicDocCounts = new int[this.numTopics][maxTokens + 1];
    }

    public void optimizeAlpha(WorkerRunnable[] runnables) {
        Arrays.fill(this.docLengthCounts, 0);
        for (int topic = 0; topic < this.topicDocCounts.length; ++topic) {
            Arrays.fill(this.topicDocCounts[topic], 0);
        }
        for (int thread = 0; thread < this.numThreads; ++thread) {
            int[] sourceLengthCounts = runnables[thread].getDocLengthCounts();
            int[][] sourceTopicCounts = runnables[thread].getTopicDocCounts();
            for (int count = 0; count < sourceLengthCounts.length; ++count) {
                if (sourceLengthCounts[count] <= 0) continue;
                int n = count;
                this.docLengthCounts[n] = this.docLengthCounts[n] + sourceLengthCounts[count];
                sourceLengthCounts[count] = 0;
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                for (int count = 0; count < sourceTopicCounts[topic].length; ++count) {
                    if (sourceTopicCounts[topic][count] <= 0) continue;
                    int[] nArray = this.topicDocCounts[topic];
                    int n = count;
                    nArray[n] = nArray[n] + sourceTopicCounts[topic][count];
                    sourceTopicCounts[topic][count] = 0;
                }
            }
        }
        this.alphaSum = Dirichlet.learnParameters(this.alpha, this.topicDocCounts, this.docLengthCounts);
    }

    public void optimizeBeta(WorkerRunnable[] runnables) {
        int[] countHistogram = new int[this.maxTypeCount + 1];
        for (int type = 0; type < this.numTypes; ++type) {
            int[] counts = this.typeTopicCounts[type];
            for (int index = 0; index < counts.length && counts[index] > 0; ++index) {
                int count;
                int n = count = counts[index] >> this.topicBits;
                countHistogram[n] = countHistogram[n] + 1;
            }
        }
        int maxTopicSize = 0;
        for (int topic = 0; topic < this.numTopics; ++topic) {
            if (this.tokensPerTopic[topic] <= maxTopicSize) continue;
            maxTopicSize = this.tokensPerTopic[topic];
        }
        int[] topicSizeHistogram = new int[maxTopicSize + 1];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            int n = this.tokensPerTopic[topic];
            topicSizeHistogram[n] = topicSizeHistogram[n] + 1;
        }
        this.betaSum = Dirichlet.learnSymmetricConcentration(countHistogram, topicSizeHistogram, this.numTypes, this.betaSum);
        this.beta = this.betaSum / (double)this.numTypes;
        System.out.print("[beta: " + this.formatter.format(this.beta) + "] ");
        for (int thread = 0; thread < this.numThreads; ++thread) {
            runnables[thread].resetBeta(this.beta, this.betaSum);
        }
    }

    public void estimate() throws IOException {
        long startTime = System.currentTimeMillis();
        WorkerRunnable[] runnables = new WorkerRunnable[this.numThreads];
        int docsPerThread = this.data.size() / this.numThreads;
        int offset = 0;
        if (this.numThreads > 1) {
            for (int thread = 0; thread < this.numThreads; ++thread) {
                int[] runnableTotals = new int[this.numTopics];
                System.arraycopy(this.tokensPerTopic, 0, runnableTotals, 0, this.numTopics);
                int[][] runnableCounts = new int[this.numTypes][];
                for (int type = 0; type < this.numTypes; ++type) {
                    int[] counts = new int[this.typeTopicCounts[type].length];
                    System.arraycopy(this.typeTopicCounts[type], 0, counts, 0, counts.length);
                    runnableCounts[type] = counts;
                }
                if (thread == this.numThreads - 1) {
                    docsPerThread = this.data.size() - offset;
                }
                Randoms random = null;
                random = this.randomSeed == -1 ? new Randoms() : new Randoms(this.randomSeed);
                runnables[thread] = new WorkerRunnable(this.numTopics, this.alpha, this.alphaSum, this.beta, random, this.data, runnableCounts, runnableTotals, offset, docsPerThread);
                runnables[thread].initializeAlphaStatistics(this.docLengthCounts.length);
                offset += docsPerThread;
            }
        } else {
            Randoms random = null;
            random = this.randomSeed == -1 ? new Randoms() : new Randoms(this.randomSeed);
            runnables[0] = new WorkerRunnable(this.numTopics, this.alpha, this.alphaSum, this.beta, random, this.data, this.typeTopicCounts, this.tokensPerTopic, offset, docsPerThread);
            runnables[0].initializeAlphaStatistics(this.docLengthCounts.length);
            runnables[0].makeOnlyThread();
        }
        ExecutorService executor = Executors.newFixedThreadPool(this.numThreads);
        for (int iteration = 1; iteration <= this.numIterations; ++iteration) {
            long iterationStart = System.currentTimeMillis();
            if (this.showTopicsInterval != 0 && iteration != 0 && iteration % this.showTopicsInterval == 0) {
                System.out.println();
                this.printTopWords(System.out, this.wordsPerTopic, false);
            }
            if (this.saveStateInterval != 0 && iteration % this.saveStateInterval == 0) {
                this.printState(new File(this.stateFilename + '.' + iteration));
            }
            if (this.saveModelInterval != 0 && iteration % this.saveModelInterval == 0) {
                this.write(new File(this.modelFilename + '.' + iteration));
            }
            if (this.numThreads > 1) {
                int thread;
                for (int thread2 = 0; thread2 < this.numThreads; ++thread2) {
                    if (iteration > this.burninPeriod && this.optimizeInterval != 0 && iteration % this.saveSampleInterval == 0) {
                        runnables[thread2].collectAlphaStatistics();
                    }
                    executor.submit(runnables[thread2]);
                }
                try {
                    Thread.sleep(20L);
                }
                catch (InterruptedException e) {
                    // empty catch block
                }
                boolean finished = false;
                while (!finished) {
                    try {
                        Thread.sleep(10L);
                    }
                    catch (InterruptedException e) {
                        // empty catch block
                    }
                    finished = true;
                    for (thread = 0; thread < this.numThreads; ++thread) {
                        finished = finished && runnables[thread].isFinished;
                    }
                }
                this.sumTypeTopicCounts(runnables);
                for (thread = 0; thread < this.numThreads; ++thread) {
                    int[] runnableTotals = runnables[thread].getTokensPerTopic();
                    System.arraycopy(this.tokensPerTopic, 0, runnableTotals, 0, this.numTopics);
                    int[][] runnableCounts = runnables[thread].getTypeTopicCounts();
                    block11: for (int type = 0; type < this.numTypes; ++type) {
                        int[] targetCounts = runnableCounts[type];
                        int[] sourceCounts = this.typeTopicCounts[type];
                        for (int index = 0; index < sourceCounts.length; ++index) {
                            if (sourceCounts[index] != 0) {
                                targetCounts[index] = sourceCounts[index];
                                continue;
                            }
                            if (targetCounts[index] == 0) continue block11;
                            targetCounts[index] = 0;
                        }
                    }
                }
            } else {
                if (iteration > this.burninPeriod && this.optimizeInterval != 0 && iteration % this.saveSampleInterval == 0) {
                    runnables[0].collectAlphaStatistics();
                }
                runnables[0].run();
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            if (elapsedMillis < 1000L) {
                System.out.print(elapsedMillis + "ms ");
            } else {
                System.out.print(elapsedMillis / 1000L + "s ");
            }
            if (iteration > this.burninPeriod && this.optimizeInterval != 0 && iteration % this.optimizeInterval == 0) {
                this.optimizeAlpha(runnables);
                this.optimizeBeta(runnables);
                System.out.print("[O " + (System.currentTimeMillis() - iterationStart) + "] ");
            }
            if (iteration % 10 == 0) {
                System.out.println("<" + iteration + "> ");
                if (this.printLogLikelihood) {
                    System.out.println(this.modelLogLikelihood() / (double)this.totalTokens);
                }
            }
            System.out.flush();
        }
        executor.shutdownNow();
        long seconds = Math.round((double)(System.currentTimeMillis() - startTime) / 1000.0);
        long minutes = seconds / 60L;
        seconds %= 60L;
        long hours = minutes / 60L;
        minutes %= 60L;
        long days = hours / 24L;
        hours %= 24L;
        System.out.print("\nTotal time: ");
        if (days != 0L) {
            System.out.print(days);
            System.out.print(" days ");
        }
        if (hours != 0L) {
            System.out.print(hours);
            System.out.print(" hours ");
        }
        if (minutes != 0L) {
            System.out.print(minutes);
            System.out.print(" minutes ");
        }
        System.out.print(seconds);
        System.out.println(" seconds");
    }

    public void printTopWords(File file, int numWords, boolean useNewLines) throws IOException {
        PrintStream out = new PrintStream(file);
        this.printTopWords(out, numWords, useNewLines);
        out.close();
    }

    public TreeSet[] getSortedWords() {
        TreeSet[] topicSortedWords = new TreeSet[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            topicSortedWords[topic] = new TreeSet();
        }
        for (int type = 0; type < this.numTypes; ++type) {
            int[] topicCounts = this.typeTopicCounts[type];
            for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                int topic = topicCounts[index] & this.topicMask;
                int count = topicCounts[index] >> this.topicBits;
                topicSortedWords[topic].add(new IDSorter(type, count));
            }
        }
        return topicSortedWords;
    }

    public Object[][] getTopWords(int numWords) {
        TreeSet[] topicSortedWords = this.getSortedWords();
        Object[][] result = new Object[this.numTopics][];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            TreeSet sortedWords = topicSortedWords[topic];
            int limit = numWords;
            if (sortedWords.size() < numWords) {
                limit = sortedWords.size();
            }
            result[topic] = new Object[limit];
            Iterator iterator = sortedWords.iterator();
            for (int i = 0; i < limit; ++i) {
                IDSorter info = (IDSorter)iterator.next();
                result[topic][i] = this.alphabet.lookupObject(info.getID());
            }
        }
        return result;
    }

    public void printTopWords(PrintStream out, int numWords, boolean usingNewLines) {
        TreeSet[] topicSortedWords = this.getSortedWords();
        for (int topic = 0; topic < this.numTopics; ++topic) {
            IDSorter info;
            int word;
            TreeSet sortedWords = topicSortedWords[topic];
            Iterator iterator = sortedWords.iterator();
            if (usingNewLines) {
                out.println(topic + "\t" + this.formatter.format(this.alpha[topic]));
                for (word = 1; iterator.hasNext() && word < numWords; ++word) {
                    info = (IDSorter)iterator.next();
                    out.println(this.alphabet.lookupObject(info.getID()) + "\t" + this.formatter.format(info.getWeight()));
                }
                continue;
            }
            out.print(topic + "\t" + this.formatter.format(this.alpha[topic]) + "\t");
            while (iterator.hasNext() && word < numWords) {
                info = (IDSorter)iterator.next();
                out.print(this.alphabet.lookupObject(info.getID()) + " ");
                ++word;
            }
            out.print("\n");
        }
    }

    public void topicXMLReport(PrintWriter out, int numWords) {
        TreeSet[] topicSortedWords = this.getSortedWords();
        out.println("<?xml version='1.0' ?>");
        out.println("<topicModel>");
        for (int topic = 0; topic < this.numTopics; ++topic) {
            out.println("  <topic id='" + topic + "' alpha='" + this.alpha[topic] + "' totalTokens='" + this.tokensPerTopic[topic] + "'>");
            Iterator iterator = topicSortedWords[topic].iterator();
            for (int word = 1; iterator.hasNext() && word < numWords; ++word) {
                IDSorter info = (IDSorter)iterator.next();
                out.println("    <word rank='" + word + "'>" + this.alphabet.lookupObject(info.getID()) + "</word>");
            }
            out.println("  </topic>");
        }
        out.println("</topicModel>");
    }

    public void topicPhraseXMLReport(PrintWriter out, int numWords) {
        int numTopics = this.getNumTopics();
        TObjectIntHashMap[] phrases = new TObjectIntHashMap[numTopics];
        Alphabet alphabet = this.getAlphabet();
        for (int ti = 0; ti < numTopics; ++ti) {
            phrases[ti] = new TObjectIntHashMap();
        }
        for (int di = 0; di < this.getData().size(); ++di) {
            TopicAssignment t = this.getData().get(di);
            Instance instance = t.instance;
            FeatureSequence fvs = (FeatureSequence)instance.getData();
            boolean withBigrams = false;
            if (fvs instanceof FeatureSequenceWithBigrams) {
                withBigrams = true;
            }
            int prevtopic = -1;
            int prevfeature = -1;
            int topic = -1;
            StringBuffer sb = null;
            int feature = -1;
            int doclen = fvs.size();
            for (int pi = 0; pi < doclen; ++pi) {
                feature = fvs.getIndexAtPosition(pi);
                topic = this.getData().get((int)di).topicSequence.getIndexAtPosition(pi);
                if (!(topic != prevtopic || withBigrams && ((FeatureSequenceWithBigrams)fvs).getBiIndexAtPosition(pi) == -1)) {
                    if (sb == null) {
                        sb = new StringBuffer(alphabet.lookupObject(prevfeature).toString() + " " + alphabet.lookupObject(feature));
                        continue;
                    }
                    sb.append(" ");
                    sb.append(alphabet.lookupObject(feature));
                    continue;
                }
                if (sb != null) {
                    String sbs = sb.toString().intern();
                    if (phrases[prevtopic].get((Object)sbs) == 0) {
                        phrases[prevtopic].put((Object)sbs, 0);
                    }
                    phrases[prevtopic].increment((Object)sbs);
                    prevfeature = -1;
                    prevtopic = -1;
                    sb = null;
                    continue;
                }
                prevtopic = topic;
                prevfeature = feature;
            }
        }
        out.println("<?xml version='1.0' ?>");
        out.println("<topics>");
        TreeSet[] topicSortedWords = this.getSortedWords();
        double[] probs = new double[alphabet.size()];
        for (int ti = 0; ti < numTopics; ++ti) {
            out.print("  <topic id=\"" + ti + "\" alpha=\"" + this.alpha[ti] + "\" totalTokens=\"" + this.tokensPerTopic[ti] + "\" ");
            ByteArrayOutputStream bout = new ByteArrayOutputStream();
            PrintStream pout = new PrintStream(bout);
            AugmentableFeatureVector titles = new AugmentableFeatureVector(new Alphabet());
            int word = 1;
            Iterator iterator = topicSortedWords[ti].iterator();
            while (iterator.hasNext() && word < numWords) {
                IDSorter info = (IDSorter)iterator.next();
                pout.println("    <word weight=\"" + info.getWeight() / (double)this.tokensPerTopic[ti] + "\" count=\"" + Math.round(info.getWeight()) + "\">" + alphabet.lookupObject(info.getID()) + "</word>");
                if (++word >= 20) continue;
                titles.add(alphabet.lookupObject(info.getID()), info.getWeight());
            }
            Object[] keys = phrases[ti].keys();
            int[] values = phrases[ti].getValues();
            double[] counts = new double[keys.length];
            for (int i = 0; i < counts.length; ++i) {
                counts[i] = values[i];
            }
            double countssum = MatrixOps.sum(counts);
            Alphabet alph = new Alphabet(keys);
            RankedFeatureVector rfv = new RankedFeatureVector(alph, counts);
            int max = rfv.numLocations() < numWords ? rfv.numLocations() : numWords;
            for (int ri = 0; ri < max; ++ri) {
                int fi = rfv.getIndexAtRank(ri);
                pout.println("    <phrase weight=\"" + counts[fi] / countssum + "\" count=\"" + values[fi] + "\">" + alph.lookupObject(fi) + "</phrase>");
                if (ri >= 20 || values[fi] <= 20) continue;
                titles.add(alph.lookupObject(fi), (double)(100 * values[fi]));
            }
            StringBuffer titlesStringBuffer = new StringBuffer();
            rfv = new RankedFeatureVector(titles.getAlphabet(), titles);
            int numTitles = 10;
            for (int ri = 0; ri < numTitles && ri < rfv.numLocations(); ++ri) {
                if (titlesStringBuffer.indexOf(rfv.getObjectAtRank(ri).toString()) == -1) {
                    titlesStringBuffer.append(rfv.getObjectAtRank(ri));
                    if (ri >= numTitles - 1) continue;
                    titlesStringBuffer.append(", ");
                    continue;
                }
                ++numTitles;
            }
            out.println("titles=\"" + titlesStringBuffer.toString() + "\">");
            out.print(bout.toString());
            out.println("  </topic>");
        }
        out.println("</topics>");
    }

    public void printTypeTopicCounts(File file) throws IOException {
        PrintWriter out = new PrintWriter(new FileWriter(file));
        for (int type = 0; type < this.numTypes; ++type) {
            StringBuilder buffer = new StringBuilder();
            buffer.append(type + " " + this.alphabet.lookupObject(type));
            int[] topicCounts = this.typeTopicCounts[type];
            for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                int topic = topicCounts[index] & this.topicMask;
                int count = topicCounts[index] >> this.topicBits;
                buffer.append(" " + topic + ":" + count);
            }
            out.println(buffer);
        }
        out.close();
    }

    public void printTopicWordWeights(File file) throws IOException {
        PrintWriter out = new PrintWriter(new FileWriter(file));
        this.printTopicWordWeights(out);
        out.close();
    }

    public void printTopicWordWeights(PrintWriter out) throws IOException {
        for (int topic = 0; topic < this.numTopics; ++topic) {
            for (int type = 0; type < this.numTypes; ++type) {
                int[] topicCounts = this.typeTopicCounts[type];
                double weight = this.beta;
                for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                    int currentTopic = topicCounts[index] & this.topicMask;
                    if (currentTopic != topic) continue;
                    weight += (double)(topicCounts[index] >> this.topicBits);
                    break;
                }
                out.println(topic + "\t" + this.alphabet.lookupObject(type) + "\t" + weight);
            }
        }
    }

    public void printDocumentTopics(File file) throws IOException {
        PrintWriter out = new PrintWriter(new FileWriter(file));
        this.printDocumentTopics(out);
        out.close();
    }

    public void printDocumentTopics(PrintWriter out) {
        this.printDocumentTopics(out, 0.0, -1);
    }

    public void printDocumentTopics(PrintWriter out, double threshold, int max) {
        out.print("#doc source topic proportion ...\n");
        int[] topicCounts = new int[this.numTopics];
        Object[] sortedTopics = new IDSorter[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            sortedTopics[topic] = new IDSorter(topic, topic);
        }
        if (max < 0 || max > this.numTopics) {
            max = this.numTopics;
        }
        for (int doc = 0; doc < this.data.size(); ++doc) {
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] currentDocTopics = topicSequence.getFeatures();
            out.print(doc);
            out.print(' ');
            if (this.data.get((int)doc).instance.getSource() != null) {
                out.print(this.data.get((int)doc).instance.getSource());
            } else {
                out.print("null-source");
            }
            out.print(' ');
            int docLen = currentDocTopics.length;
            for (int token = 0; token < docLen; ++token) {
                int n = currentDocTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                ((IDSorter)sortedTopics[topic]).set(topic, (float)topicCounts[topic] / (float)docLen);
            }
            Arrays.sort(sortedTopics);
            for (int i = 0; i < max && !(((IDSorter)sortedTopics[i]).getWeight() < threshold); ++i) {
                out.print(((IDSorter)sortedTopics[i]).getID() + " " + ((IDSorter)sortedTopics[i]).getWeight() + " ");
            }
            out.print(" \n");
            Arrays.fill(topicCounts, 0);
        }
    }

    public void printState(File f) throws IOException {
        PrintStream out = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
        this.printState(out);
        out.close();
    }

    public void printState(PrintStream out) {
        out.println("#doc source pos typeindex type topic");
        out.print("#alpha : ");
        for (int topic = 0; topic < this.numTopics; ++topic) {
            out.print(this.alpha[topic] + " ");
        }
        out.println();
        out.println("#beta : " + this.beta);
        for (int doc = 0; doc < this.data.size(); ++doc) {
            FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            String source = "NA";
            if (this.data.get((int)doc).instance.getSource() != null) {
                source = this.data.get((int)doc).instance.getSource().toString();
            }
            for (int pi = 0; pi < topicSequence.getLength(); ++pi) {
                int type = tokenSequence.getIndexAtPosition(pi);
                int topic = topicSequence.getIndexAtPosition(pi);
                out.print(doc);
                out.print(' ');
                out.print(source);
                out.print(' ');
                out.print(pi);
                out.print(' ');
                out.print(type);
                out.print(' ');
                out.print(this.alphabet.lookupObject(type));
                out.print(' ');
                out.print(topic);
                out.println();
            }
        }
    }

    public double modelLogLikelihood() {
        double logLikelihood = 0.0;
        int[] topicCounts = new int[this.numTopics];
        double[] topicLogGammas = new double[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            topicLogGammas[topic] = Dirichlet.logGammaStirling(this.alpha[topic]);
        }
        for (int doc = 0; doc < this.data.size(); ++doc) {
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] docTopics = topicSequence.getFeatures();
            for (int token = 0; token < docTopics.length; ++token) {
                int n = docTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (topicCounts[topic] <= 0) continue;
                logLikelihood += Dirichlet.logGammaStirling(this.alpha[topic] + (double)topicCounts[topic]) - topicLogGammas[topic];
            }
            logLikelihood -= Dirichlet.logGammaStirling(this.alphaSum + (double)docTopics.length);
            Arrays.fill(topicCounts, 0);
        }
        logLikelihood += (double)this.data.size() * Dirichlet.logGammaStirling(this.alphaSum);
        int nonZeroTypeTopics = 0;
        for (int type = 0; type < this.numTypes; ++type) {
            topicCounts = this.typeTopicCounts[type];
            for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                int topic = topicCounts[index] & this.topicMask;
                int count = topicCounts[index] >> this.topicBits;
                ++nonZeroTypeTopics;
                if (!Double.isNaN(logLikelihood += Dirichlet.logGammaStirling(this.beta + (double)count))) continue;
                System.out.println(count);
                System.exit(1);
            }
        }
        for (int topic = 0; topic < this.numTopics; ++topic) {
            if (!Double.isNaN(logLikelihood -= Dirichlet.logGammaStirling(this.beta * (double)this.numTypes + (double)this.tokensPerTopic[topic]))) continue;
            System.out.println("after topic " + topic + " " + this.tokensPerTopic[topic]);
            System.exit(1);
        }
        if (Double.isNaN(logLikelihood += Dirichlet.logGammaStirling(this.beta * (double)this.numTypes) - Dirichlet.logGammaStirling(this.beta) * (double)nonZeroTypeTopics)) {
            System.out.println("at the end");
            System.exit(1);
        }
        return logLikelihood;
    }

    public TopicInferencer getInferencer() {
        return new TopicInferencer(this.typeTopicCounts, this.tokensPerTopic, this.data.get((int)0).instance.getDataAlphabet(), this.alpha, this.beta, this.betaSum);
    }

    public MarginalProbEstimator getProbEstimator() {
        return new MarginalProbEstimator(this.numTopics, this.alpha, this.alphaSum, this.beta, this.typeTopicCounts, this.tokensPerTopic);
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(0);
        out.writeObject(this.data);
        out.writeObject(this.alphabet);
        out.writeObject(this.topicAlphabet);
        out.writeInt(this.numTopics);
        out.writeInt(this.topicMask);
        out.writeInt(this.topicBits);
        out.writeInt(this.numTypes);
        out.writeObject(this.alpha);
        out.writeDouble(this.alphaSum);
        out.writeDouble(this.beta);
        out.writeDouble(this.betaSum);
        out.writeObject(this.typeTopicCounts);
        out.writeObject(this.tokensPerTopic);
        out.writeObject(this.docLengthCounts);
        out.writeObject(this.topicDocCounts);
        out.writeInt(this.numIterations);
        out.writeInt(this.burninPeriod);
        out.writeInt(this.saveSampleInterval);
        out.writeInt(this.optimizeInterval);
        out.writeInt(this.showTopicsInterval);
        out.writeInt(this.wordsPerTopic);
        out.writeInt(this.saveStateInterval);
        out.writeObject(this.stateFilename);
        out.writeInt(this.saveModelInterval);
        out.writeObject(this.modelFilename);
        out.writeInt(this.randomSeed);
        out.writeObject(this.formatter);
        out.writeBoolean(this.printLogLikelihood);
        out.writeInt(this.numThreads);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        this.data = (ArrayList)in.readObject();
        this.alphabet = (Alphabet)in.readObject();
        this.topicAlphabet = (LabelAlphabet)in.readObject();
        this.numTopics = in.readInt();
        this.topicMask = in.readInt();
        this.topicBits = in.readInt();
        this.numTypes = in.readInt();
        this.alpha = (double[])in.readObject();
        this.alphaSum = in.readDouble();
        this.beta = in.readDouble();
        this.betaSum = in.readDouble();
        this.typeTopicCounts = (int[][])in.readObject();
        this.tokensPerTopic = (int[])in.readObject();
        this.docLengthCounts = (int[])in.readObject();
        this.topicDocCounts = (int[][])in.readObject();
        this.numIterations = in.readInt();
        this.burninPeriod = in.readInt();
        this.saveSampleInterval = in.readInt();
        this.optimizeInterval = in.readInt();
        this.showTopicsInterval = in.readInt();
        this.wordsPerTopic = in.readInt();
        this.saveStateInterval = in.readInt();
        this.stateFilename = (String)in.readObject();
        this.saveModelInterval = in.readInt();
        this.modelFilename = (String)in.readObject();
        this.randomSeed = in.readInt();
        this.formatter = (NumberFormat)in.readObject();
        this.printLogLikelihood = in.readBoolean();
        this.numThreads = in.readInt();
    }

    public void write(File serializedModelFile) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(serializedModelFile));
            oos.writeObject(this);
            oos.close();
        }
        catch (IOException e) {
            System.err.println("Problem serializing ParallelTopicModel to file " + serializedModelFile + ": " + e);
        }
    }

    public static ParallelTopicModel read(File f) throws Exception {
        ParallelTopicModel topicModel = null;
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));
        topicModel = (ParallelTopicModel)ois.readObject();
        ois.close();
        return topicModel;
    }

    public static void main(String[] args) {
        try {
            InstanceList training = InstanceList.load(new File(args[0]));
            int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
            ParallelTopicModel lda = new ParallelTopicModel(numTopics, 50.0, 0.01);
            lda.printLogLikelihood = true;
            lda.setTopicDisplay(50, 7);
            lda.addInstances(training);
            lda.setNumThreads(Integer.parseInt(args[2]));
            lda.estimate();
            System.out.println("printing state");
            lda.printState(new File("state.gz"));
            System.out.println("finished printing");
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }
}

