/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark.functions;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.WeightedCell;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

public abstract class ExtractGroup
implements Serializable {
    private static final long serialVersionUID = -7059358143841229966L;
    protected long _blen = -1L;
    protected long _ngroups = -1L;
    protected Operator _op = null;

    public ExtractGroup(long blen, long ngroups, Operator op) {
        this._blen = blen;
        this._ngroups = ngroups;
        this._op = op;
    }

    protected Iterable<Tuple2<MatrixIndexes, WeightedCell>> execute(MatrixIndexes ix, MatrixBlock group, MatrixBlock target) throws Exception {
        if (group.getNumRows() != target.getNumRows()) {
            throw new Exception("The blocksize for group and target blocks are mismatched: " + group.getNumRows() + " != " + target.getNumRows());
        }
        ArrayList<Tuple2<MatrixIndexes, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<MatrixIndexes, WeightedCell>>();
        long coloff = (ix.getColumnIndex() - 1L) * this._blen;
        if (this._op instanceof AggregateOperator && this._ngroups > 0L && OptimizerUtils.isValidCPDimensions(this._ngroups, target.getNumColumns())) {
            MatrixBlock tmp = group.groupedAggOperations(target, null, new MatrixBlock(), (int)this._ngroups, this._op);
            for (int i = 0; i < tmp.getNumRows(); ++i) {
                for (int j = 0; j < tmp.getNumColumns(); ++j) {
                    double tmpval = tmp.quickGetValue(i, j);
                    if (tmpval == 0.0) continue;
                    WeightedCell weightedCell = new WeightedCell();
                    weightedCell.setValue(tmpval);
                    weightedCell.setWeight(1.0);
                    MatrixIndexes ixout = new MatrixIndexes(i + 1, coloff + (long)j + 1L);
                    groupValuePairs.add((Tuple2<MatrixIndexes, WeightedCell>)new Tuple2((Object)ixout, (Object)weightedCell));
                }
            }
        } else {
            for (int i = 0; i < group.getNumRows(); ++i) {
                long groupVal = UtilFunctions.toLong(group.quickGetValue(i, 0));
                if (groupVal < 1L) {
                    throw new Exception("Expected group values to be greater than equal to 1 but found " + groupVal);
                }
                for (int j = 0; j < target.getNumColumns(); ++j) {
                    WeightedCell weightedCell = new WeightedCell();
                    weightedCell.setValue(target.quickGetValue(i, j));
                    weightedCell.setWeight(1.0);
                    MatrixIndexes ixout = new MatrixIndexes(groupVal, coloff + (long)j + 1L);
                    groupValuePairs.add((Tuple2<MatrixIndexes, WeightedCell>)new Tuple2((Object)ixout, (Object)weightedCell));
                }
            }
        }
        return groupValuePairs;
    }

    public static class ExtractGroupBroadcast
    extends ExtractGroup
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, WeightedCell> {
        private static final long serialVersionUID = 5709955602290131093L;
        private PartitionedBroadcast<MatrixBlock> _pbm = null;

        public ExtractGroupBroadcast(PartitionedBroadcast<MatrixBlock> pbm, long blen, long ngroups, Operator op) {
            super(blen, ngroups, op);
            this._pbm = pbm;
        }

        public Iterator<Tuple2<MatrixIndexes, WeightedCell>> call(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception {
            MatrixIndexes ix = (MatrixIndexes)arg._1;
            MatrixBlock group = this._pbm.getBlock((int)ix.getRowIndex(), 1);
            MatrixBlock target = (MatrixBlock)arg._2;
            return this.execute(ix, group, target).iterator();
        }
    }

    public static class ExtractGroupJoin
    extends ExtractGroup
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>>, MatrixIndexes, WeightedCell> {
        private static final long serialVersionUID = 8890978615936560266L;

        public ExtractGroupJoin(long blen, long ngroups, Operator op) {
            super(blen, ngroups, op);
        }

        public Iterator<Tuple2<MatrixIndexes, WeightedCell>> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg) throws Exception {
            MatrixIndexes ix = (MatrixIndexes)arg._1;
            MatrixBlock group = (MatrixBlock)((Tuple2)arg._2)._1;
            MatrixBlock target = (MatrixBlock)((Tuple2)arg._2)._2;
            return this.execute(ix, group, target).iterator();
        }
    }
}

