/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup;
import org.apache.sysds.runtime.compress.colgroup.APreAgg;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate;

public final class CLALibUtils {
    protected static final Log LOG = LogFactory.getLog((String)CLALibUtils.class.getName());

    public static void combineConstColumns(CompressedMatrixBlock in) {
        ArrayList<AColGroup> e = new ArrayList<AColGroup>();
        ArrayList<AColGroup> c = new ArrayList<AColGroup>();
        ArrayList<AColGroup> o = new ArrayList<AColGroup>();
        for (AColGroup g : in.getColGroups()) {
            if (g instanceof ColGroupEmpty) {
                e.add(g);
                continue;
            }
            if (g instanceof ColGroupConst) {
                c.add(g);
                continue;
            }
            o.add(g);
        }
        if (e.size() < 1 && c.size() < 1) {
            return;
        }
        if (e.size() == 1) {
            o.add((AColGroup)e.get(0));
        } else if (e.size() > 1) {
            o.add(CLALibUtils.combineEmpty(e));
        }
        if (c.size() == 1) {
            o.add((AColGroup)c.get(0));
        } else if (c.size() > 1) {
            o.add(CLALibUtils.combineConst(c));
        }
        in.allocateColGroupList(o);
    }

    protected static boolean shouldPreFilter(List<AColGroup> groups) {
        for (AColGroup g : groups) {
            if (!(g instanceof AMorphingMMColGroup) && !(g instanceof ColGroupConst) && !(g instanceof ColGroupEmpty) && !g.isEmpty()) continue;
            return true;
        }
        return false;
    }

    protected static List<AColGroup> filterGroups(List<AColGroup> groups, double[] constV) {
        if (constV == null) {
            return groups;
        }
        ArrayList<AColGroup> filteredGroups = new ArrayList<AColGroup>();
        for (AColGroup g : groups) {
            if (g instanceof ColGroupEmpty || g.isEmpty()) continue;
            if (g instanceof AMorphingMMColGroup) {
                filteredGroups.add(((AMorphingMMColGroup)g).extractCommon(constV));
                continue;
            }
            if (g instanceof ColGroupConst) {
                ((ColGroupConst)g).addToCommon(constV);
                continue;
            }
            filteredGroups.add(g);
        }
        return CLALibUtils.returnGroupIfFiniteNumbers(groups, filteredGroups, constV);
    }

    protected static void filterGroupsAndSplitPreAgg(List<AColGroup> groups, double[] constV, List<AColGroup> noPreAggGroups, List<APreAgg> preAggGroups) {
        for (AColGroup g : groups) {
            if (g instanceof APreAgg) {
                preAggGroups.add((APreAgg)g);
                continue;
            }
            if (g instanceof AMorphingMMColGroup) {
                AColGroup ga = ((AMorphingMMColGroup)g).extractCommon(constV);
                if (ga instanceof APreAgg) {
                    preAggGroups.add((APreAgg)ga);
                    continue;
                }
                if (ga instanceof ColGroupEmpty) continue;
                throw new DMLCompressionException("I did not think this was a problem");
            }
            if (g instanceof ColGroupEmpty) continue;
            if (g instanceof ColGroupConst) {
                ((ColGroupConst)g).addToCommon(constV);
                continue;
            }
            noPreAggGroups.add(g);
        }
    }

    protected static void splitPreAgg(List<AColGroup> groups, List<AColGroup> noPreAggGroups, List<APreAgg> preAggGroups) {
        for (AColGroup g : groups) {
            if (g instanceof APreAgg) {
                preAggGroups.add((APreAgg)g);
                continue;
            }
            if (g instanceof ColGroupEmpty) continue;
            if (g instanceof ColGroupConst) {
                throw new NotImplementedException();
            }
            noPreAggGroups.add(g);
        }
    }

    private static List<AColGroup> returnGroupIfFiniteNumbers(List<AColGroup> groups, List<AColGroup> filteredGroups, double[] constV) {
        for (double v : constV) {
            if (Double.isFinite(v)) continue;
            throw new NotImplementedException("Not handling if the values are not finite: " + Arrays.toString(constV));
        }
        return filteredGroups;
    }

    private static AColGroup combineEmpty(List<AColGroup> e) {
        return new ColGroupEmpty(CLALibUtils.combineColIndexes(e));
    }

    private static AColGroup combineConst(List<AColGroup> c) {
        IColIndex resCols = CLALibUtils.combineColIndexes(c);
        double[] values = new double[resCols.size()];
        for (AColGroup g : c) {
            ColGroupConst cg = (ColGroupConst)g;
            IColIndex colIdx = cg.getColIndices();
            double[] colVals = cg.getValues();
            for (int i = 0; i < colIdx.size(); ++i) {
                int outId = resCols.findIndex(colIdx.get(i));
                values[outId] = colVals[i];
            }
        }
        return ColGroupConst.create(resCols, values);
    }

    private static IColIndex combineColIndexes(List<AColGroup> gs) {
        return ColIndexFactory.combine(gs);
    }

    protected static double[] getColSum(List<AColGroup> groups, int nCols, int nRows) {
        return AColGroup.colSum(groups, new double[nCols], nRows);
    }

    protected static void addEmptyColumn(List<AColGroup> colGroups, int nCols) {
        for (AColGroup g : colGroups) {
            if (g.getColIndices().size() != nCols) continue;
            return;
        }
        HashSet<Integer> emptyColumns = new HashSet<Integer>(nCols);
        for (int i = 0; i < nCols; ++i) {
            emptyColumns.add(i);
        }
        for (AColGroup g : colGroups) {
            IIterate it = g.getColIndices().iterator();
            while (it.hasNext()) {
                emptyColumns.remove(it.next());
            }
        }
        if (emptyColumns.size() == 0) {
            return;
        }
        int[] emptyColumnsFinal = emptyColumns.stream().mapToInt(Integer::intValue).toArray();
        colGroups.add(new ColGroupEmpty(ColIndexFactory.create(emptyColumnsFinal)));
    }
}

