/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ctakes.ytex.kernel;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Random;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.TreeSet;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.ctakes.ytex.kernel.FileUtil;
import org.apache.ctakes.ytex.kernel.FoldGenerator;
import org.apache.ctakes.ytex.kernel.InstanceData;
import org.apache.ctakes.ytex.kernel.KernelContextHolder;
import org.apache.ctakes.ytex.kernel.KernelUtil;
import org.apache.ctakes.ytex.kernel.dao.ClassifierEvaluationDao;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFold;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFoldInstance;

public class FoldGeneratorImpl
implements FoldGenerator {
    private static final Log log = LogFactory.getLog(FoldGeneratorImpl.class);
    ClassifierEvaluationDao classifierEvaluationDao;
    KernelUtil kernelUtil;

    private static List<Set<Long>> createFolds(Map<String, List<Long>> mapClassToInstanceId, int nFolds, int nMinPerClass, Random r) {
        ArrayList<Set<Long>> folds = new ArrayList<Set<Long>>(nFolds);
        HashMap mapLabelFolds = new HashMap();
        for (Map.Entry<String, List<Long>> classToInstanceId : mapClassToInstanceId.entrySet()) {
            List<Long> instanceIds = classToInstanceId.getValue();
            Collections.shuffle(instanceIds, r);
            ArrayList classFolds = new ArrayList(nFolds);
            int blockSize = instanceIds.size() / nFolds;
            for (int i = 0; i < nFolds; ++i) {
                int instanceIdIndex;
                HashSet<Long> foldInstanceIds = new HashSet<Long>(blockSize);
                if (instanceIds.size() <= nMinPerClass) {
                    foldInstanceIds.addAll(instanceIds);
                } else if (blockSize < nMinPerClass) {
                    double fraction = (double)nMinPerClass / (double)instanceIds.size();
                    instanceIdIndex = (int)(r.nextDouble() * (double)instanceIds.size());
                    while (foldInstanceIds.size() < nMinPerClass) {
                        if (instanceIdIndex >= instanceIds.size()) {
                            instanceIdIndex = 0;
                        }
                        if (r.nextDouble() <= fraction) {
                            long instanceId = instanceIds.get(instanceIdIndex);
                            foldInstanceIds.add(instanceId);
                        }
                        ++instanceIdIndex;
                    }
                } else {
                    int nStart = i * blockSize;
                    int nEnd = i == nFolds - 1 ? instanceIds.size() : nStart + blockSize;
                    for (instanceIdIndex = nStart; instanceIdIndex < nEnd; ++instanceIdIndex) {
                        foldInstanceIds.add(instanceIds.get(instanceIdIndex));
                    }
                }
                classFolds.add(foldInstanceIds);
            }
            mapLabelFolds.put(classToInstanceId.getKey(), classFolds);
        }
        for (int i = 0; i < nFolds; ++i) {
            HashSet foldInstanceIds = new HashSet();
            for (List labelFold : mapLabelFolds.values()) {
                foldInstanceIds.addAll((Collection)labelFold.get(i));
            }
            folds.add(foldInstanceIds);
        }
        return folds;
    }

    public static void main(String[] args) throws ParseException, IOException {
        Options options = new Options();
        OptionBuilder.withArgName((String)"prop");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription((String)"property file with query to retrieve instance id - label - class triples");
        options.addOption(OptionBuilder.create((String)"prop"));
        try {
            if (args.length == 0) {
                FoldGeneratorImpl.printHelp(options);
            } else {
                GnuParser parser = new GnuParser();
                CommandLine line = parser.parse(options, args);
                String propFile = line.getOptionValue("prop");
                Properties props = FileUtil.loadProperties(propFile, true);
                String corpusName = props.getProperty("org.apache.ctakes.ytex.corpusName");
                String splitName = props.getProperty("org.apache.ctakes.ytex.splitName");
                String query = props.getProperty("instanceClassQuery");
                int folds = Integer.parseInt(props.getProperty("folds", "2"));
                int runs = Integer.parseInt(props.getProperty("runs", "5"));
                int minPerClass = Integer.parseInt(props.getProperty("minPerClass", "1"));
                Integer rand = props.containsKey("rand") ? Integer.valueOf(Integer.parseInt(props.getProperty("rand"))) : null;
                boolean argsOk = true;
                if (corpusName == null) {
                    log.error((Object)"missing parameter: org.apache.ctakes.ytex.corpusName");
                    argsOk = false;
                }
                if (query == null) {
                    log.error((Object)"missing parameter: instanceClassQuery");
                    argsOk = false;
                }
                if (!argsOk) {
                    FoldGeneratorImpl.printHelp(options);
                    System.exit(1);
                } else {
                    ((FoldGenerator)KernelContextHolder.getApplicationContext().getBean(FoldGenerator.class)).generateRuns(corpusName, splitName, query, folds, minPerClass, rand, runs);
                }
            }
        }
        catch (ParseException pe) {
            FoldGeneratorImpl.printHelp(options);
        }
    }

    private static void printHelp(Options options) {
        HelpFormatter formatter = new HelpFormatter();
        formatter.printHelp("java org.apache.ctakes.ytex.kernel.FoldGeneratorImpl splits training data into mxn training/test sets for mxn-fold cross validation", options);
    }

    public void generateFolds(Set<String> labels, InstanceData instances, String corpusName, String splitName, int run, String query, int nFolds, int nMinPerClass, Random r) {
        for (String label : instances.getLabelToInstanceMap().keySet()) {
            SortedMap runMap = (SortedMap)instances.getLabelToInstanceMap().get(label);
            SortedMap foldMap = (SortedMap)runMap.values().iterator().next();
            SortedMap trainMap = (SortedMap)foldMap.values().iterator().next();
            SortedMap mapInstanceIdToClass = (SortedMap)trainMap.values().iterator().next();
            List<Set<Long>> folds = this.createFolds(nFolds, nMinPerClass, r, mapInstanceIdToClass);
            this.insertFolds(folds, corpusName, splitName, label, run);
        }
    }

    private List<Set<Long>> createFolds(int nFolds, int nMinPerClass, Random r, SortedMap<Long, String> mapInstanceIdToClass) {
        TreeMap<String, List<Long>> mapClassToInstanceId = new TreeMap<String, List<Long>>();
        for (Map.Entry<Long, String> instance : mapInstanceIdToClass.entrySet()) {
            String className = instance.getValue();
            long instanceId = instance.getKey();
            ArrayList<Long> classInstanceIds = (ArrayList<Long>)mapClassToInstanceId.get(className);
            if (classInstanceIds == null) {
                classInstanceIds = new ArrayList<Long>();
                mapClassToInstanceId.put(className, classInstanceIds);
            }
            classInstanceIds.add(instanceId);
        }
        List<Set<Long>> folds = FoldGeneratorImpl.createFolds(mapClassToInstanceId, nFolds, nMinPerClass, r);
        return folds;
    }

    @Override
    public void generateRuns(String corpusName, String splitName, String query, int nFolds, int nMinPerClass, Integer nSeed, int nRuns) {
        Random r = new Random(nSeed != null ? (long)nSeed.intValue() : System.currentTimeMillis());
        TreeSet<String> labels = new TreeSet<String>();
        InstanceData instances = this.kernelUtil.loadInstances(query);
        this.getClassifierEvaluationDao().deleteCrossValidationFoldByName(corpusName, splitName);
        for (int run = 1; run <= nRuns; ++run) {
            this.generateFolds(labels, instances, corpusName, splitName, run, query, nFolds, nMinPerClass, r);
        }
    }

    public ClassifierEvaluationDao getClassifierEvaluationDao() {
        return this.classifierEvaluationDao;
    }

    public KernelUtil getKernelUtil() {
        return this.kernelUtil;
    }

    private void insertFolds(List<Set<Long>> folds, String corpusName, String splitName, String label, int run) {
        for (int foldNum = 1; foldNum <= folds.size(); ++foldNum) {
            HashSet<CrossValidationFoldInstance> instanceIds = new HashSet<CrossValidationFoldInstance>();
            for (int trainFoldNum = 1; trainFoldNum <= folds.size(); ++trainFoldNum) {
                for (long instanceId : folds.get(trainFoldNum - 1)) {
                    instanceIds.add(new CrossValidationFoldInstance(instanceId, trainFoldNum != foldNum));
                }
            }
            this.classifierEvaluationDao.saveFold(new CrossValidationFold(corpusName, splitName, label, run, foldNum, instanceIds));
        }
    }

    public void setClassifierEvaluationDao(ClassifierEvaluationDao classifierEvaluationDao) {
        this.classifierEvaluationDao = classifierEvaluationDao;
    }

    public void setKernelUtil(KernelUtil kernelUtil) {
        this.kernelUtil = kernelUtil;
    }

    @Override
    public SortedMap<String, SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>> generateRuns(SortedMap<String, SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>> labelToInstanceMap, int nFolds, int nMinPerClass, Integer nSeed, int nRuns) {
        TreeMap<String, SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>> labelToInstanceFoldMap = new TreeMap<String, SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>>();
        Random r = new Random(nSeed != null ? (long)nSeed.intValue() : System.currentTimeMillis());
        for (Map.Entry<String, SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>> labelRun : labelToInstanceMap.entrySet()) {
            String label = labelRun.getKey();
            SortedMap instanceClassMap = (SortedMap)((SortedMap)((SortedMap)labelRun.getValue().get(0)).get(0)).get(true);
            TreeMap runMap = new TreeMap();
            labelToInstanceFoldMap.put(label, runMap);
            for (int run = 1; run <= nRuns; ++run) {
                List<Set<Long>> folds = this.createFolds(nFolds, nMinPerClass, r, instanceClassMap);
                TreeMap foldMap = new TreeMap();
                runMap.put(run, foldMap);
                for (int trainFoldNum = 1; trainFoldNum <= folds.size(); ++trainFoldNum) {
                    TreeMap trainTestMap = new TreeMap();
                    foldMap.put(trainFoldNum, trainTestMap);
                    trainTestMap.put(true, new TreeMap());
                    trainTestMap.put(false, new TreeMap());
                    Set<Long> testIds = folds.get(trainFoldNum - 1);
                    for (Map.Entry instanceClass : instanceClassMap.entrySet()) {
                        long instanceId = (Long)instanceClass.getKey();
                        String clazz = (String)instanceClass.getValue();
                        ((SortedMap)trainTestMap.get(!testIds.contains(instanceId))).put(instanceId, clazz);
                    }
                }
            }
        }
        return labelToInstanceFoldMap;
    }
}

