/*
 * Decompiled with CFR 0.152.
 */
package net.myrrix.online.som;

import com.google.common.base.Preconditions;
import java.util.Collections;
import java.util.Comparator;
import net.myrrix.common.LangUtils;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.math.SimpleVectorMath;
import net.myrrix.common.random.RandomManager;
import net.myrrix.common.random.RandomUtils;
import net.myrrix.online.som.Node;
import org.apache.commons.math3.distribution.PascalDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class SelfOrganizingMaps {
    private static final Logger log = LoggerFactory.getLogger(SelfOrganizingMaps.class);
    public static final double DEFAULT_MIN_DECAY = 1.0E-5;
    public static final double DEFAULT_INIT_LEARNING_RATE = 0.5;
    private final double minDecay;
    private final double initLearningRate;

    public SelfOrganizingMaps() {
        this(1.0E-5, 0.5);
    }

    public SelfOrganizingMaps(double minDecay, double initLearningRate) {
        Preconditions.checkArgument(minDecay > 0.0, "Min decay must be positive: {}", minDecay);
        Preconditions.checkArgument(initLearningRate > 0.0 && initLearningRate <= 1.0, "Learning rate should be in (0,1]: {}", initLearningRate);
        this.minDecay = minDecay;
        this.initLearningRate = initLearningRate;
    }

    public Node[][] buildSelfOrganizedMap(FastByIDMap<float[]> vectors, int maxMapSize) {
        return this.buildSelfOrganizedMap(vectors, maxMapSize, Double.NaN);
    }

    public Node[][] buildSelfOrganizedMap(FastByIDMap<float[]> vectors, int maxMapSize, double samplingRate) {
        Preconditions.checkNotNull(vectors);
        Preconditions.checkArgument(!vectors.isEmpty());
        Preconditions.checkArgument(maxMapSize > 0);
        Preconditions.checkArgument(Double.isNaN(samplingRate) || samplingRate > 0.0 && samplingRate <= 1.0);
        if (Double.isNaN(samplingRate)) {
            double expectedNodeSize = (double)vectors.size() / (double)(maxMapSize * maxMapSize);
            samplingRate = expectedNodeSize > 1.0 ? 1.0 / expectedNodeSize : 1.0;
        }
        log.debug("Sampling rate: {}", (Object)samplingRate);
        int mapSize = FastMath.min(maxMapSize, (int)FastMath.sqrt((double)vectors.size() * samplingRate));
        Node[][] map = SelfOrganizingMaps.buildInitialMap(vectors, mapSize);
        this.sketchMapParallel(vectors, samplingRate, map);
        Node[][] arr$ = map;
        int len$ = arr$.length;
        for (int i$ = 0; i$ < len$; ++i$) {
            Node[] mapRow;
            for (Node node : mapRow = arr$[i$]) {
                node.clearAssignedIDs();
            }
        }
        SelfOrganizingMaps.assignVectorsParallel(vectors, samplingRate, map);
        SelfOrganizingMaps.sortMembers(map);
        int numFeatures = vectors.entrySet().iterator().next().getValue().length;
        SelfOrganizingMaps.buildProjections(numFeatures, map);
        return map;
    }

    private void sketchMapParallel(FastByIDMap<float[]> vectors, double samplingRate, Node[][] map) {
        int mapSize = map.length;
        double sigma = (double)vectors.size() * samplingRate / Math.log(mapSize);
        int t = 0;
        for (FastByIDMap.MapEntry<float[]> entry : vectors.entrySet()) {
            float[] V = entry.getValue();
            double decayFactor = FastMath.exp((double)(-t) / sigma);
            ++t;
            if (decayFactor < this.minDecay) break;
            int[] bmuCoordinates = SelfOrganizingMaps.findBestMatchingUnit(V, map);
            if (bmuCoordinates == null) continue;
            this.updateNeighborhood(map, V, bmuCoordinates[0], bmuCoordinates[1], decayFactor);
        }
    }

    private static void assignVectorsParallel(FastByIDMap<float[]> vectors, double samplingRate, Node[][] map) {
        boolean doSample = samplingRate < 1.0;
        RandomGenerator random = RandomManager.getRandom();
        for (FastByIDMap.MapEntry<float[]> entry : vectors.entrySet()) {
            float[] V;
            int[] bmuCoordinates;
            if (doSample && random.nextDouble() > samplingRate || (bmuCoordinates = SelfOrganizingMaps.findBestMatchingUnit(V = entry.getValue(), map)) == null) continue;
            Node node = map[bmuCoordinates[0]][bmuCoordinates[1]];
            float[] center = node.getCenter();
            double currentScore = SimpleVectorMath.dot(V, center) / (SimpleVectorMath.norm(center) * SimpleVectorMath.norm(V));
            Pair<Double, Long> newAssignedID = new Pair<Double, Long>(currentScore, entry.getKey());
            node.addAssignedID(newAssignedID);
        }
    }

    private static Node[][] buildInitialMap(FastByIDMap<float[]> vectors, int mapSize) {
        Node[][] map;
        double p = (double)mapSize * (double)mapSize / (double)vectors.size();
        PascalDistribution pascalDistribution = p >= 1.0 ? null : new PascalDistribution(RandomManager.getRandom(), 1, p);
        LongPrimitiveIterator keyIterator = vectors.keySetIterator();
        for (Node[] mapRow : map = new Node[mapSize][mapSize]) {
            for (int j = 0; j < mapSize; ++j) {
                if (pascalDistribution != null) {
                    keyIterator.skip(pascalDistribution.sample());
                }
                while (!keyIterator.hasNext()) {
                    keyIterator = vectors.keySetIterator();
                    Preconditions.checkState(keyIterator.hasNext());
                    if (pascalDistribution == null) continue;
                    keyIterator.skip(pascalDistribution.sample());
                }
                float[] sampledVector = vectors.get(keyIterator.nextLong());
                mapRow[j] = new Node(sampledVector);
            }
        }
        return map;
    }

    private static int[] findBestMatchingUnit(float[] vector, Node[][] map) {
        int[] nArray;
        int mapSize = map.length;
        double vectorNorm = SimpleVectorMath.norm(vector);
        double bestScore = Double.NEGATIVE_INFINITY;
        int bestI = -1;
        int bestJ = -1;
        for (int i = 0; i < mapSize; ++i) {
            Node[] mapRow = map[i];
            for (int j = 0; j < mapSize; ++j) {
                float[] center = mapRow[j].getCenter();
                double currentScore = SimpleVectorMath.dot(vector, center) / (SimpleVectorMath.norm(center) * vectorNorm);
                if (!LangUtils.isFinite(currentScore) || !(currentScore > bestScore)) continue;
                bestScore = currentScore;
                bestI = i;
                bestJ = j;
            }
        }
        if (bestI == -1 || bestJ == -1) {
            nArray = null;
        } else {
            int[] nArray2 = new int[2];
            nArray2[0] = bestI;
            nArray = nArray2;
            nArray2[1] = bestJ;
        }
        return nArray;
    }

    private void updateNeighborhood(Node[][] map, float[] V, int bmuI, int bmuJ, double decayFactor) {
        int mapSize = map.length;
        double neighborhoodRadius = (double)mapSize * decayFactor;
        int minI = FastMath.max(0, (int)FastMath.floor((double)bmuI - neighborhoodRadius));
        int maxI = FastMath.min(mapSize, (int)FastMath.ceil((double)bmuI + neighborhoodRadius));
        int minJ = FastMath.max(0, (int)FastMath.floor((double)bmuJ - neighborhoodRadius));
        int maxJ = FastMath.min(mapSize, (int)FastMath.ceil((double)bmuJ + neighborhoodRadius));
        for (int i = minI; i < maxI; ++i) {
            Node[] mapRow = map[i];
            for (int j = minJ; j < maxJ; ++j) {
                double learningRate = this.initLearningRate * decayFactor;
                double currentDistance = SelfOrganizingMaps.distance(i, j, bmuI, bmuJ);
                double theta = FastMath.exp(-(currentDistance * currentDistance) / (2.0 * neighborhoodRadius * neighborhoodRadius));
                double learningTheta = learningRate * theta;
                float[] center = mapRow[j].getCenter();
                int length = center.length;
                for (int k = 0; k < length; ++k) {
                    int n = k;
                    center[n] = center[n] + (float)(learningTheta * (double)(V[k] - center[k]));
                }
            }
        }
    }

    private static void sortMembers(Node[][] map) {
        Node[][] arr$ = map;
        int len$ = arr$.length;
        for (int i$ = 0; i$ < len$; ++i$) {
            Node[] mapRow;
            for (Node node : mapRow = arr$[i$]) {
                Collections.sort(node.getAssignedIDs(), new Comparator<Pair<Double, Long>>(){

                    @Override
                    public int compare(Pair<Double, Long> a, Pair<Double, Long> b) {
                        if (a.getFirst() > b.getFirst()) {
                            return -1;
                        }
                        if (a.getFirst() < b.getFirst()) {
                            return 1;
                        }
                        return 0;
                    }
                });
            }
        }
    }

    private static void buildProjections(int numFeatures, Node[][] map) {
        int mapSize = map.length;
        float[] mean = new float[numFeatures];
        for (Node[] mapRow : map) {
            for (int j = 0; j < mapSize; ++j) {
                SelfOrganizingMaps.add(mapRow[j].getCenter(), mean);
            }
        }
        SelfOrganizingMaps.divide(mean, mapSize * mapSize);
        RandomGenerator random = RandomManager.getRandom();
        float[] rBasis = RandomUtils.randomUnitVector(numFeatures, random);
        float[] gBasis = RandomUtils.randomUnitVector(numFeatures, random);
        float[] bBasis = RandomUtils.randomUnitVector(numFeatures, random);
        for (Node[] mapRow : map) {
            for (int j = 0; j < mapSize; ++j) {
                float[] W = (float[])mapRow[j].getCenter().clone();
                SelfOrganizingMaps.subtract(mean, W);
                double norm = SimpleVectorMath.norm(W);
                float[] projection3D = mapRow[j].getProjection3D();
                projection3D[0] = (float)((1.0 + SimpleVectorMath.dot(W, rBasis) / norm) / 2.0);
                projection3D[1] = (float)((1.0 + SimpleVectorMath.dot(W, gBasis) / norm) / 2.0);
                projection3D[2] = (float)((1.0 + SimpleVectorMath.dot(W, bBasis) / norm) / 2.0);
            }
        }
    }

    private static void add(float[] from, float[] to) {
        int length = from.length;
        for (int i = 0; i < length; ++i) {
            int n = i;
            to[n] = to[n] + from[i];
        }
    }

    private static void subtract(float[] toSubtract, float[] from) {
        int length = toSubtract.length;
        for (int i = 0; i < length; ++i) {
            int n = i;
            from[n] = from[n] - toSubtract[i];
        }
    }

    private static void divide(float[] x, float by) {
        int length = x.length;
        int i = 0;
        while (i < length) {
            int n = i++;
            x[n] = x[n] / by;
        }
    }

    private static double distance(int i1, int j1, int i2, int j2) {
        int diff1 = i1 - i2;
        int diff2 = j1 - j2;
        return FastMath.sqrt(diff1 * diff1 + diff2 * diff2);
    }
}

