/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.classify.RankMaxEnt;
import cc.mallet.optimize.ConjugateGradient;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labels;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Logger;

public class RankMaxEntTrainer
extends MaxEntTrainer {
    private static Logger logger = MalletLogger.getLogger(RankMaxEntTrainer.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(RankMaxEntTrainer.class.getName() + "-pl");
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;

    public RankMaxEntTrainer() {
    }

    public RankMaxEntTrainer(double gaussianPriorVariance) {
        super(gaussianPriorVariance);
    }

    public Optimizable.ByGradientValue getMaximizableTrainer(InstanceList ilist) {
        if (ilist == null) {
            return new MaximizableTrainer();
        }
        return new MaximizableTrainer(ilist, null);
    }

    public MaxEnt train(InstanceList trainingSet) {
        logger.fine("trainingSet.size() = " + trainingSet.size());
        MaximizableTrainer mt = new MaximizableTrainer(trainingSet, (RankMaxEnt)this.initialClassifier);
        LimitedMemoryBFGS maximizer = new LimitedMemoryBFGS(mt);
        for (int i = 0; i < this.numIterations; ++i) {
            boolean converged;
            try {
                converged = maximizer.optimize(1);
            }
            catch (IllegalArgumentException e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                converged = true;
            }
            if (converged) break;
        }
        if (this.numIterations == Integer.MAX_VALUE) {
            this.optimizer = new ConjugateGradient(mt);
            try {
                this.optimizer.optimize();
            }
            catch (IllegalArgumentException e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
            }
        }
        progressLogger.info("\n");
        return mt.getClassifier();
    }

    public String toString() {
        return "RankMaxEntTrainer,numIterations=" + this.numIterations + ",gaussianPriorVariance=" + this.gaussianPriorVariance;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.defaultWriteObject();
        out.writeInt(1);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        int version = in.readInt();
    }

    private class MaximizableTrainer
    implements Optimizable.ByGradientValue {
        double[] parameters;
        double[] constraints;
        double[] cachedGradient;
        RankMaxEnt theClassifier;
        InstanceList trainingList;
        double cachedValue;
        boolean cachedValueStale;
        boolean cachedGradientStale;
        int numLabels;
        int numFeatures;
        int defaultFeatureIndex;
        FeatureSelection featureSelection;
        FeatureSelection[] perLabelFeatureSelection;

        public MaximizableTrainer() {
        }

        public MaximizableTrainer(InstanceList ilist, RankMaxEnt initialClassifier) {
            this.trainingList = ilist;
            Alphabet fd = ilist.getDataAlphabet();
            LabelAlphabet ld = (LabelAlphabet)ilist.getTargetAlphabet();
            this.numLabels = 2;
            this.numFeatures = fd.size() + 1;
            this.defaultFeatureIndex = this.numFeatures - 1;
            this.parameters = new double[this.numLabels * this.numFeatures];
            this.constraints = new double[this.numLabels * this.numFeatures];
            this.cachedGradient = new double[this.numLabels * this.numFeatures];
            Arrays.fill(this.parameters, 0.0);
            Arrays.fill(this.constraints, 0.0);
            Arrays.fill(this.cachedGradient, 0.0);
            this.featureSelection = ilist.getFeatureSelection();
            this.perLabelFeatureSelection = ilist.getPerLabelFeatureSelection();
            if (this.featureSelection != null) {
                this.featureSelection.add(this.defaultFeatureIndex);
            }
            if (this.perLabelFeatureSelection != null) {
                for (int i = 0; i < this.perLabelFeatureSelection.length; ++i) {
                    this.perLabelFeatureSelection[i].add(this.defaultFeatureIndex);
                }
            }
            assert (this.featureSelection == null || this.perLabelFeatureSelection == null);
            if (initialClassifier != null) {
                this.theClassifier = initialClassifier;
                this.parameters = this.theClassifier.parameters;
                this.featureSelection = this.theClassifier.featureSelection;
                this.perLabelFeatureSelection = this.theClassifier.perClassFeatureSelection;
                this.defaultFeatureIndex = this.theClassifier.defaultFeatureIndex;
                assert (initialClassifier.getInstancePipe() == ilist.getPipe());
            } else if (this.theClassifier == null) {
                this.theClassifier = new RankMaxEnt(ilist.getPipe(), this.parameters, this.featureSelection, this.perLabelFeatureSelection);
            }
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            Iterator iter = this.trainingList.iterator();
            logger.fine("Number of instances in training list = " + this.trainingList.size());
            while (iter.hasNext()) {
                Instance instance = (Instance)iter.next();
                double instanceWeight = this.trainingList.getInstanceWeight(instance);
                FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
                Object target = instance.getTarget();
                Label label = null;
                label = target instanceof Labels ? ((Labels)target).get(0) : (Label)target;
                int positiveIndex = Integer.valueOf(label.getBestLabel().getEntry().toString());
                if (positiveIndex == -1) {
                    logger.warning("True label is -1. Skipping...");
                    continue;
                }
                FeatureVector fv = fvs.get(positiveIndex);
                Alphabet fdict = fv.getAlphabet();
                assert (fv.getAlphabet() == fd);
                MatrixOps.rowPlusEquals(this.constraints, this.numFeatures, 0, fv, instanceWeight);
                assert (!Double.isNaN(instanceWeight)) : "instanceWeight is NaN";
                boolean hasNaN = false;
                for (int i = 0; i < fv.numLocations(); ++i) {
                    if (!Double.isNaN(fv.valueAtLocation(i))) continue;
                    logger.info("NaN for feature " + fdict.lookupObject(fv.indexAtLocation(i)).toString());
                    hasNaN = true;
                }
                if (hasNaN) {
                    logger.info("NaN in instance: " + instance.getName());
                }
                int n = 0 * this.numFeatures + this.defaultFeatureIndex;
                this.constraints[n] = this.constraints[n] + 1.0 * instanceWeight;
            }
        }

        public RankMaxEnt getClassifier() {
            return this.theClassifier;
        }

        public double getParameter(int index) {
            return this.parameters[index];
        }

        public void setParameter(int index, double v) {
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            this.parameters[index] = v;
        }

        public int getNumParameters() {
            return this.parameters.length;
        }

        public void getParameters(double[] buff) {
            if (buff == null || buff.length != this.parameters.length) {
                buff = new double[this.parameters.length];
            }
            System.arraycopy(this.parameters, 0, buff, 0, this.parameters.length);
        }

        public void setParameters(double[] buff) {
            assert (buff != null);
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            if (buff.length != this.parameters.length) {
                this.parameters = new double[buff.length];
            }
            System.arraycopy(buff, 0, this.parameters, 0, buff.length);
        }

        public double getValue() {
            if (this.cachedValueStale) {
                this.cachedValue = 0.0;
                this.cachedGradientStale = true;
                MatrixOps.setAll(this.cachedGradient, 0.0);
                double value = 0.0;
                Iterator iter = this.trainingList.iterator();
                int ii = 0;
                while (iter.hasNext()) {
                    ++ii;
                    Instance instance = (Instance)iter.next();
                    FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
                    double[] scores = new double[fvs.size()];
                    double instanceWeight = this.trainingList.getInstanceWeight(instance);
                    Object target = instance.getTarget();
                    int li = -1;
                    if (target instanceof Label) {
                        li = Integer.valueOf(((Label)target).toString());
                        if (li == -1) continue;
                        assert (li >= 0 && li < fvs.size());
                        this.theClassifier.getClassificationScores(instance, scores);
                    } else if (target instanceof Labels) {
                        Labels labels = (Labels)target;
                        int[] bestPositions = new int[labels.size()];
                        for (int pi = 0; pi < labels.size(); ++pi) {
                            bestPositions[pi] = Integer.valueOf(labels.get(pi).toString());
                        }
                        li = bestPositions[0];
                        this.theClassifier.getClassificationScoresForTies(instance, scores, bestPositions);
                    }
                    value = -(instanceWeight * Math.log(scores[li]));
                    if (Double.isNaN(value)) {
                        logger.fine("MaxEntTrainer: Instance " + instance.getName() + "has NaN value. log(scores)= " + Math.log(scores[li]) + " scores = " + scores[li] + " has instance weight = " + instanceWeight);
                    }
                    if (Double.isInfinite(value)) {
                        logger.warning("Instance " + instance.getSource() + " has infinite value; skipping value and gradient");
                        this.cachedValue -= value;
                        this.cachedValueStale = false;
                        return -value;
                    }
                    this.cachedValue += value;
                    double positiveScore = scores[li];
                    for (int si = 0; si < fvs.size(); ++si) {
                        if (scores[si] == 0.0) continue;
                        assert (!Double.isInfinite(scores[si]));
                        FeatureVector cfv = fvs.get(si);
                        MatrixOps.rowPlusEquals(this.cachedGradient, this.numFeatures, 0, cfv, -instanceWeight * scores[si]);
                        int n = this.numFeatures * 0 + this.defaultFeatureIndex;
                        this.cachedGradient[n] = this.cachedGradient[n] + -instanceWeight * scores[si];
                    }
                }
                for (int li = 0; li < this.numLabels; ++li) {
                    for (int fi = 0; fi < this.numFeatures; ++fi) {
                        double param = this.parameters[li * this.numFeatures + fi];
                        this.cachedValue += param * param / (2.0 * RankMaxEntTrainer.this.gaussianPriorVariance);
                    }
                }
                this.cachedValue *= -1.0;
                this.cachedValueStale = false;
                progressLogger.info("Value (loglikelihood) = " + this.cachedValue);
            }
            return this.cachedValue;
        }

        public void getValueGradient(double[] buffer) {
            if (this.cachedGradientStale) {
                if (this.cachedValueStale) {
                    this.getValue();
                }
                MatrixOps.plusEquals(this.cachedGradient, this.constraints);
                MatrixOps.plusEquals(this.cachedGradient, this.parameters, -1.0 / RankMaxEntTrainer.this.gaussianPriorVariance);
                MatrixOps.substitute(this.cachedGradient, Double.NEGATIVE_INFINITY, 0.0);
                if (this.perLabelFeatureSelection == null) {
                    for (int labelIndex = 0; labelIndex < this.numLabels; ++labelIndex) {
                        MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, labelIndex, 0.0, this.featureSelection, false);
                    }
                } else {
                    for (int labelIndex = 0; labelIndex < this.numLabels; ++labelIndex) {
                        MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, labelIndex, 0.0, this.perLabelFeatureSelection[labelIndex], false);
                    }
                }
                this.cachedGradientStale = false;
            }
            assert (buffer != null && buffer.length == this.parameters.length);
            System.arraycopy(this.cachedGradient, 0, buffer, 0, this.cachedGradient.length);
        }
    }
}

