/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ctakes.temporal.ae.feature;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.ctakes.core.resource.FileLocator;
import org.apache.ctakes.relationextractor.ae.features.RelationFeaturesExtractor;
import org.apache.ctakes.typesystem.type.syntax.WordToken;
import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
import org.apache.ctakes.utils.distsem.WordEmbeddings;
import org.apache.ctakes.utils.distsem.WordVector;
import org.apache.ctakes.utils.distsem.WordVectorReader;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.text.AnnotationFS;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.cleartk.ml.Feature;
import org.cleartk.ml.feature.extractor.CleartkExtractorException;

public class RelationEmbeddingFeatureExtractor
implements RelationFeaturesExtractor<IdentifiedAnnotation, IdentifiedAnnotation> {
    private int numberOfDimensions;
    private WordEmbeddings words = null;

    public RelationEmbeddingFeatureExtractor(String vecFile) throws CleartkExtractorException {
        try {
            this.words = WordVectorReader.getEmbeddings((InputStream)FileLocator.getAsStream((String)vecFile));
        }
        catch (IOException e) {
            e.printStackTrace();
            throw new CleartkExtractorException((Throwable)e);
        }
        this.numberOfDimensions = this.words.getDimensionality();
    }

    public List<Feature> extract(JCas jCas, IdentifiedAnnotation arg1, IdentifiedAnnotation arg2) throws AnalysisEngineProcessException {
        List<Feature> features = new ArrayList<Feature>();
        List preWords = null;
        List afterWords = null;
        List wordsOfArgs1 = null;
        List wordsOfArgs2 = null;
        if (arg1.getBegin() < arg2.getBegin()) {
            preWords = JCasUtil.selectPreceding((JCas)jCas, WordToken.class, (AnnotationFS)arg1, (int)2);
            afterWords = JCasUtil.selectFollowing((JCas)jCas, WordToken.class, (AnnotationFS)arg2, (int)2);
            wordsOfArgs1 = JCasUtil.selectCovered((JCas)jCas, WordToken.class, (AnnotationFS)arg1);
            wordsOfArgs2 = JCasUtil.selectCovered((JCas)jCas, WordToken.class, (AnnotationFS)arg2);
        } else {
            preWords = JCasUtil.selectPreceding((JCas)jCas, WordToken.class, (AnnotationFS)arg2, (int)2);
            afterWords = JCasUtil.selectFollowing((JCas)jCas, WordToken.class, (AnnotationFS)arg1, (int)2);
            wordsOfArgs1 = JCasUtil.selectCovered((JCas)jCas, WordToken.class, (AnnotationFS)arg2);
            wordsOfArgs2 = JCasUtil.selectCovered((JCas)jCas, WordToken.class, (AnnotationFS)arg1);
        }
        List<Double> sum = this.getSumVector(preWords);
        features = this.addFeatures(features, sum, preWords.size(), "pre");
        sum = this.getSumVector(wordsOfArgs1);
        features = this.addFeatures(features, sum, wordsOfArgs1.size(), "arg1");
        ArrayList<WordToken> wordsBetweenArgs = new ArrayList<WordToken>(JCasUtil.selectBetween((JCas)jCas, WordToken.class, (AnnotationFS)arg1, (AnnotationFS)arg2));
        sum = this.getSumVector(wordsBetweenArgs);
        features = this.addFeatures(features, sum, wordsBetweenArgs.size(), "inBetween");
        sum = this.getSumVector(wordsOfArgs2);
        features = this.addFeatures(features, sum, wordsOfArgs2.size(), "arg2");
        sum = this.getSumVector(afterWords);
        features = this.addFeatures(features, sum, afterWords.size(), "after");
        return features;
    }

    private List<Feature> addFeatures(List<Feature> features, List<Double> sum, int size, String field) {
        if (size == 0) {
            for (int dim = 0; dim < this.numberOfDimensions; ++dim) {
                String featureName = String.format(field + "_dim_%d", dim);
                features.add(new Feature(featureName, (Object)sum.get(dim)));
            }
        } else {
            for (int dim = 0; dim < this.numberOfDimensions; ++dim) {
                String featureName = String.format(field + "_dim_%d", dim);
                features.add(new Feature(featureName, (Object)(sum.get(dim) / (double)size)));
            }
        }
        return features;
    }

    private List<Double> getSumVector(List<WordToken> wordsInCheck) {
        List<Double> sum = new ArrayList<Double>(Collections.nCopies(this.numberOfDimensions, 0.0));
        for (WordToken wordToken : wordsInCheck) {
            WordVector wordVector = this.words.containsKey(wordToken.getCoveredText().toLowerCase()) ? this.words.getVector(wordToken.getCoveredText().toLowerCase()) : this.words.getVector("and");
            sum = this.addVectors(sum, wordVector);
        }
        return sum;
    }

    public double computeCosineSimilarity(WordVector vector1, WordVector vector2) {
        double dotProduct = 0.0;
        double norm1 = 0.01;
        double norm2 = 0.01;
        for (int dim = 0; dim < this.numberOfDimensions; ++dim) {
            dotProduct += vector1.getValue(dim) * vector2.getValue(dim);
            norm1 += Math.pow(vector1.getValue(dim), 2.0);
            norm2 += Math.pow(vector2.getValue(dim), 2.0);
        }
        return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
    }

    public double computeCosineSimilarity(List<Double> vector1, List<Double> vector2) {
        double dotProduct = 0.0;
        double norm1 = 0.01;
        double norm2 = 0.01;
        for (int dim = 0; dim < this.numberOfDimensions; ++dim) {
            dotProduct += vector1.get(dim) * vector2.get(dim);
            norm1 += Math.pow(vector1.get(dim), 2.0);
            norm2 += Math.pow(vector2.get(dim), 2.0);
        }
        return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
    }

    public List<Double> addVectors(List<Double> vector1, WordVector vector2) {
        ArrayList<Double> sum = new ArrayList<Double>();
        for (int dim = 0; dim < this.numberOfDimensions; ++dim) {
            sum.add(vector1.get(dim) + vector2.getValue(dim));
        }
        return sum;
    }
}

