/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ctakes.temporal.ae.feature;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.ctakes.constituency.parser.treekernel.TreeExtractor;
import org.apache.ctakes.constituency.parser.util.AnnotationTreeUtils;
import org.apache.ctakes.core.resource.FileLocator;
import org.apache.ctakes.relationextractor.ae.features.RelationFeaturesExtractor;
import org.apache.ctakes.typesystem.type.syntax.TopTreebankNode;
import org.apache.ctakes.typesystem.type.syntax.TreebankNode;
import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
import org.apache.ctakes.utils.distsem.WordEmbeddings;
import org.apache.ctakes.utils.distsem.WordVector;
import org.apache.ctakes.utils.distsem.WordVectorReader;
import org.apache.ctakes.utils.tree.SimpleTree;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.ml.Feature;
import org.cleartk.ml.feature.extractor.CleartkExtractorException;

public class RelationSyntacticEmbeddingFeatureExtractor
implements RelationFeaturesExtractor<IdentifiedAnnotation, IdentifiedAnnotation> {
    private int numberOfDimensions;
    private WordEmbeddings paths = null;

    public RelationSyntacticEmbeddingFeatureExtractor(String vecFile) throws CleartkExtractorException {
        try {
            this.paths = WordVectorReader.getEmbeddings((InputStream)FileLocator.getAsStream((String)vecFile));
        }
        catch (IOException e) {
            e.printStackTrace();
            throw new CleartkExtractorException((Throwable)e);
        }
        this.numberOfDimensions = this.paths.getDimensionality();
    }

    public List<Feature> extract(JCas jCas, IdentifiedAnnotation arg1, IdentifiedAnnotation arg2) throws AnalysisEngineProcessException {
        ArrayList<Feature> features = new ArrayList<Feature>();
        TopTreebankNode root = AnnotationTreeUtils.getTreeCopy((JCas)jCas, (TopTreebankNode)AnnotationTreeUtils.getAnnotationTree((JCas)jCas, (Annotation)arg1));
        if (root == null) {
            return features;
        }
        TreebankNode t1 = AnnotationTreeUtils.annotationNode((JCas)jCas, (Annotation)arg1);
        TreebankNode t2 = AnnotationTreeUtils.annotationNode((JCas)jCas, (Annotation)arg2);
        ArrayList<String> pathsBetweenArgs = new ArrayList<String>();
        if (t1.getBegin() <= t2.getBegin() && t1.getEnd() >= t2.getEnd()) {
            pathsBetweenArgs.add(this.getPathBetweenNodes(t2, t1, ""));
        } else if (t2.getBegin() <= t1.getBegin() && t2.getEnd() >= t1.getEnd()) {
            pathsBetweenArgs.add(this.getPathBetweenNodes(t1, t2, ""));
        } else {
            TreebankNode lca = TreeExtractor.getLCA((TreebankNode)t1, (TreebankNode)t2);
            pathsBetweenArgs.add(this.getPathBetweenNodes(t1, lca, ""));
            pathsBetweenArgs.add(this.getPathBetweenNodes(t2, lca, ""));
        }
        SimpleTree tree = TreeExtractor.getSimpleClone((TreebankNode)t1);
        pathsBetweenArgs.addAll(this.traverseTreeForDPath(tree));
        tree = TreeExtractor.getSimpleClone((TreebankNode)t2);
        pathsBetweenArgs.addAll(this.traverseTreeForDPath(tree));
        List<Double> sum = new ArrayList<Double>(Collections.nCopies(this.numberOfDimensions, 0.0));
        for (String path : pathsBetweenArgs) {
            WordVector wordVector;
            if (this.paths.containsKey(path)) {
                wordVector = this.paths.getVector(path);
            } else {
                String trimmedPath;
                while (!this.paths.containsKey(path) && (trimmedPath = RelationSyntacticEmbeddingFeatureExtractor.removeTail(path)) != null) {
                    path = trimmedPath;
                }
                wordVector = this.paths.containsKey(path) ? this.paths.getVector(path) : this.paths.getVector("S");
            }
            sum = this.addVectors(sum, wordVector);
        }
        for (int dim = 0; dim < this.numberOfDimensions; ++dim) {
            String featureName = String.format("syntactic_average_dim_%d", dim);
            features.add(new Feature(featureName, (Object)(sum.get(dim) / (double)pathsBetweenArgs.size())));
        }
        return features;
    }

    private static String removeTail(String path) {
        int dashIdx = path.lastIndexOf("-");
        if (dashIdx > 0) {
            path = path.substring(0, dashIdx);
            return path;
        }
        return null;
    }

    private String getPathBetweenNodes(TreebankNode child, TreebankNode ancestor, String path) {
        TreebankNode father = child.getParent();
        path = "".equals(path) ? child.getNodeType() : child.getNodeType() + "-" + path;
        if (father == null) {
            return path;
        }
        if (father == ancestor) {
            path = father.getNodeType() + "-" + path;
            return path;
        }
        return this.getPathBetweenNodes(father, ancestor, path);
    }

    private String getPathToRoot(TreebankNode child, String path) {
        TreebankNode father = child.getParent();
        path = "".equals(path) ? child.getNodeType() : child.getNodeType() + "-" + path;
        if (father == null) {
            return path;
        }
        return this.getPathToRoot(father, path);
    }

    private List<String> traverseTreeForDPath(SimpleTree tree) {
        ArrayList<String> features = new ArrayList<String>();
        String rootStr = tree.cat;
        features.add(rootStr);
        if (tree.children.size() == 1 && ((SimpleTree)tree.children.get((int)0)).children.size() == 0) {
            features.add(rootStr + "-" + ((SimpleTree)tree.children.get((int)0)).cat);
        } else {
            for (SimpleTree subtree : tree.children) {
                features.addAll(this.traverseTreeForDPath(subtree));
                for (String str : this.getSubTreeStrings(subtree)) {
                    features.add(rootStr + "-" + str);
                }
            }
        }
        return features;
    }

    private List<String> getSubTreeStrings(SimpleTree subtree) {
        ArrayList<String> subTreeStrings = new ArrayList<String>();
        subTreeStrings.add(subtree.cat);
        if (subtree.children.size() == 1 && ((SimpleTree)subtree.children.get((int)0)).children.size() == 0) {
            subTreeStrings.add(subtree.cat + "-" + ((SimpleTree)subtree.children.get((int)0)).cat);
        } else {
            for (SimpleTree subsubTree : subtree.children) {
                for (String str : this.getSubTreeStrings(subsubTree)) {
                    subTreeStrings.add(subtree.cat + "-" + str);
                }
            }
        }
        return subTreeStrings;
    }

    public double computeCosineSimilarity(WordVector vector1, WordVector vector2) {
        double dotProduct = 0.0;
        double norm1 = 0.01;
        double norm2 = 0.01;
        for (int dim = 0; dim < this.numberOfDimensions; ++dim) {
            dotProduct += vector1.getValue(dim) * vector2.getValue(dim);
            norm1 += Math.pow(vector1.getValue(dim), 2.0);
            norm2 += Math.pow(vector2.getValue(dim), 2.0);
        }
        return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
    }

    public double computeCosineSimilarity(List<Double> vector1, List<Double> vector2) {
        double dotProduct = 0.0;
        double norm1 = 0.01;
        double norm2 = 0.01;
        for (int dim = 0; dim < this.numberOfDimensions; ++dim) {
            dotProduct += vector1.get(dim) * vector2.get(dim);
            norm1 += Math.pow(vector1.get(dim), 2.0);
            norm2 += Math.pow(vector2.get(dim), 2.0);
        }
        return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
    }

    public List<Double> addVectors(List<Double> vector1, WordVector vector2) {
        ArrayList<Double> sum = new ArrayList<Double>();
        for (int dim = 0; dim < this.numberOfDimensions; ++dim) {
            sum.add(vector1.get(dim) + vector2.getValue(dim));
        }
        return sum;
    }
}

