/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.classification.algorithm;

import com.carrotsearch.hppc.IntArrayList;
import com.googlecode.clearnlp.classification.algorithm.AbstractAlgorithm;
import com.googlecode.clearnlp.classification.train.AbstractTrainSpace;
import com.googlecode.clearnlp.util.UTArray;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

public class AdaGradLR
extends AbstractAlgorithm {
    protected int n_iter;
    protected Random r_rand;
    protected double d_alpha;
    protected double d_rho;

    public AdaGradLR(int iter, double alpha, double rho, Random rand) {
        this.n_iter = iter;
        this.r_rand = rand;
        this.d_alpha = alpha;
        this.d_rho = rho;
    }

    @Override
    public double[] getWeight(AbstractTrainSpace space, int numThreads) {
        double[] weights = new double[space.getFeatureSize() * space.getLabelSize()];
        this.updateWeight(space, weights);
        return weights;
    }

    public void updateWeight(AbstractTrainSpace space) {
        this.updateWeight(space, space.getModel().getWeights());
    }

    public void updateWeight(AbstractTrainSpace space, double[] weights) {
        int D = space.getFeatureSize();
        int L = space.getLabelSize();
        int N = space.getInstanceSize();
        double[] gs = new double[D * L];
        IntArrayList ys = space.getYs();
        ArrayList<int[]> xs = space.getXs();
        ArrayList<double[]> vs = space.getVs();
        double[] vi = null;
        for (int i = 0; i < this.n_iter; ++i) {
            int[] indices = this.getShuffledIndices(N);
            Arrays.fill(gs, 0.0);
            for (int j = 0; j < N; ++j) {
                int yi = ys.get(indices[j]);
                int[] xi = xs.get(indices[j]);
                if (space.hasWeight()) {
                    vi = vs.get(indices[j]);
                }
                double[] grad = this.getGradients(L, yi, xi, vi, weights);
                this.updateCounts(L, gs, grad, xi, vi);
                this.updateWeights(L, gs, grad, xi, vi, weights);
            }
        }
    }

    protected int[] getShuffledIndices(int N) {
        int i;
        int[] indices = new int[N];
        for (i = 0; i < N; ++i) {
            indices[i] = i;
        }
        for (i = 0; i < N; ++i) {
            int j = i + this.r_rand.nextInt(N - i);
            UTArray.swap(indices, i, j);
        }
        return indices;
    }

    protected double[] getGradients(int L, int y, int[] x, double[] v, double[] weights) {
        double[] scores = this.getScores(L, x, v, weights);
        AdaGradLR.normalize(scores);
        int i = 0;
        while (i < L) {
            int n = i++;
            scores[n] = scores[n] * -1.0;
        }
        int n = y;
        scores[n] = scores[n] + 1.0;
        return scores;
    }

    private double[] getScores(int L, int[] x, double[] v, double[] weights) {
        double[] scores = new double[L];
        int len = x.length;
        if (v != null) {
            for (int i = 0; i < len; ++i) {
                for (int label = 0; label < L; ++label) {
                    int n = label;
                    scores[n] = scores[n] + weights[this.getWeightIndex(L, label, x[i])] * v[i];
                }
            }
        } else {
            for (int i = 0; i < len; ++i) {
                for (int label = 0; label < L; ++label) {
                    int n = label;
                    scores[n] = scores[n] + weights[this.getWeightIndex(L, label, x[i])];
                }
            }
        }
        return scores;
    }

    protected void updateCounts(int L, double[] gs, double[] grad, int[] x, double[] v) {
        int label;
        int len = x.length;
        double[] g = new double[L];
        for (label = 0; label < L; ++label) {
            g[label] = grad[label] * grad[label];
        }
        if (v != null) {
            for (int i = 0; i < len; ++i) {
                double d = v[i] * v[i];
                for (label = 0; label < L; ++label) {
                    int n = this.getWeightIndex(L, label, x[i]);
                    gs[n] = gs[n] + d * g[label];
                }
            }
        } else {
            for (int i = 0; i < len; ++i) {
                for (label = 0; label < L; ++label) {
                    int n = this.getWeightIndex(L, label, x[i]);
                    gs[n] = gs[n] + g[label];
                }
            }
        }
    }

    protected void updateWeights(int L, double[] gs, double[] grad, int[] x, double[] v, double[] weights) {
        int len = x.length;
        if (v != null) {
            for (int i = 0; i < len; ++i) {
                for (int label = 0; label < L; ++label) {
                    int n = this.getWeightIndex(L, label, x[i]);
                    weights[n] = weights[n] + this.getUpdate(L, gs, label, x[i]) * grad[label] * v[i];
                }
            }
        } else {
            for (int i = 0; i < len; ++i) {
                for (int label = 0; label < L; ++label) {
                    int n = this.getWeightIndex(L, label, x[i]);
                    weights[n] = weights[n] + this.getUpdate(L, gs, label, x[i]) * grad[label];
                }
            }
        }
    }

    protected double getUpdate(int L, double[] gs, int y, int x) {
        return this.d_alpha / (this.d_rho + Math.sqrt(gs[this.getWeightIndex(L, y, x)]));
    }
}

