/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference.gbp;

import cc.mallet.grmm.inference.gbp.Region;
import cc.mallet.grmm.inference.gbp.RegionGraph;
import cc.mallet.grmm.inference.gbp.RegionGraphGenerator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.UndirectedGrid;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.CollectionUtils;
import cc.mallet.util.MalletLogger;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.logging.Logger;

public class ClusterVariationalRegionGenerator
implements RegionGraphGenerator {
    private static final Logger logger = MalletLogger.getLogger(ClusterVariationalRegionGenerator.class.getName());
    private static final boolean debug = false;
    private BaseRegionComputer regionComputer;

    public ClusterVariationalRegionGenerator() {
        this(new ByFactorRegionComputer());
    }

    public ClusterVariationalRegionGenerator(BaseRegionComputer regionComputer) {
        this.regionComputer = regionComputer;
    }

    public RegionGraph constructRegionGraph(FactorGraph mdl) {
        List baseRegions;
        RegionGraph rg = new RegionGraph();
        int depth = 0;
        List theseRegions = baseRegions = this.regionComputer.computeBaseRegions(mdl);
        while (!theseRegions.isEmpty()) {
            List overlaps = this.computeOverlaps(theseRegions);
            this.addEdgesForOverlaps(rg, theseRegions, overlaps);
            theseRegions = overlaps;
            ++depth;
        }
        rg.computeInferenceCaches();
        logger.info("ClusterVariationalRegionGenerator: Number of regions " + rg.size() + " Number of edges:" + rg.numEdges());
        return rg;
    }

    private List computeOverlaps(List regions) {
        ArrayList<Region> overlaps = new ArrayList<Region>();
        for (Region r1 : regions) {
            for (Region r2 : regions) {
                Collection intersection;
                if (r1 == r2 || (intersection = CollectionUtils.intersection(r1.vars, r2.vars)).isEmpty() || this.anySubsumes(overlaps, intersection)) continue;
                Collection ptlSet = CollectionUtils.intersection(r1.factors, r2.factors);
                Variable[] vars = intersection.toArray(new Variable[intersection.size()]);
                Factor[] ptls = ptlSet.toArray(new Factor[ptlSet.size()]);
                Region r = new Region(vars, ptls);
                overlaps.add(r);
            }
        }
        ListIterator it = overlaps.listIterator();
        while (it.hasNext()) {
            Region region = (Region)it.next();
            List otherRegions = overlaps.subList(it.nextIndex(), overlaps.size());
            if (!this.anySubsumes(otherRegions, region.vars)) continue;
            it.remove();
        }
        return overlaps;
    }

    private boolean anySubsumes(List regions, Collection vars) {
        for (Region region : regions) {
            if (!region.vars.containsAll(vars)) continue;
            return true;
        }
        return false;
    }

    private void addEdgesForOverlaps(RegionGraph rg, List fromList, List toList) {
        for (Region from : fromList) {
            for (Region to : toList) {
                if (!from.vars.containsAll(to.vars)) continue;
                rg.add(from, to);
            }
        }
    }

    public static void removeSubsumedRegions(List regions) {
        ListIterator it = regions.listIterator();
        block0: while (it.hasNext()) {
            Region region = (Region)it.next();
            for (Region r2 : regions) {
                if (r2 == region || r2.vars.size() < region.vars.size() || !r2.vars.containsAll(region.vars)) continue;
                it.remove();
                continue block0;
            }
        }
    }

    public static void addAllFactors(FactorGraph mdl, List regions) {
        for (Region region : regions) {
            Iterator pIt = mdl.factorsIterator();
            while (pIt.hasNext()) {
                Factor ptl = (Factor)pIt.next();
                if (!region.vars.containsAll(ptl.varSet())) continue;
                region.factors.add(ptl);
            }
        }
    }

    public static class Grid2x2RegionComputer
    implements BaseRegionComputer {
        public List computeBaseRegions(FactorGraph mdl) {
            ArrayList<Region> regions = new ArrayList<Region>();
            UndirectedGrid grid = (UndirectedGrid)mdl;
            for (int x = 0; x < grid.getWidth() - 1; ++x) {
                for (int y = 0; y < grid.getHeight() - 1; ++y) {
                    Variable[] vars = new Variable[]{grid.get(x, y), grid.get(x, y + 1), grid.get(x + 1, y + 1), grid.get(x + 1, y)};
                    regions.add(new Region(vars, new Factor[0]));
                }
            }
            ClusterVariationalRegionGenerator.addAllFactors(mdl, regions);
            return regions;
        }
    }

    public static class ByFactorRegionComputer
    implements BaseRegionComputer {
        public List computeBaseRegions(FactorGraph mdl) {
            ArrayList<Region> regions = new ArrayList<Region>(mdl.factors().size());
            Iterator it = mdl.factorsIterator();
            while (it.hasNext()) {
                Factor ptl = (Factor)it.next();
                regions.add(new Region(ptl));
            }
            ClusterVariationalRegionGenerator.removeSubsumedRegions(regions);
            ClusterVariationalRegionGenerator.addAllFactors(mdl, regions);
            return regions;
        }
    }

    public static interface BaseRegionComputer {
        public List computeBaseRegions(FactorGraph var1);
    }
}

