package forge.lda.lda;

import forge.lda.dataset.BagOfWords;
import forge.lda.dataset.Dataset;
import forge.lda.dataset.Vocabularies;
import forge.lda.lda.inference.Inference;
import forge.lda.lda.inference.InferenceFactory;
import forge.lda.lda.inference.InferenceMethod;
import forge.lda.lda.inference.InferenceProperties;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;

/* loaded from: input_file:forge/lda/lda/LDA.class */
public class LDA {
    private Hyperparameters hyperparameters;
    private final int numTopics;
    private Dataset dataset;
    private final Inference inference;
    private InferenceProperties properties;
    private boolean trained;

    public LDA(double d, double d2, int i, Dataset dataset, InferenceMethod inferenceMethod) {
        this(d, d2, i, dataset, inferenceMethod, InferenceProperties.PROPERTIES_FILE_NAME);
    }

    LDA(double d, double d2, int i, Dataset dataset, InferenceMethod inferenceMethod, String str) {
        this.hyperparameters = new Hyperparameters(d, d2, i);
        this.numTopics = i;
        this.dataset = dataset;
        this.inference = InferenceFactory.getInstance(inferenceMethod);
        this.properties = new InferenceProperties();
        this.trained = false;
        this.properties.setSeed(123L);
        this.properties.setNumIteration(100);
    }

    public String getVocab(int i) {
        if (i < 0 || this.dataset.getNumVocabs() < i) {
            throw new IllegalArgumentException();
        }
        return this.dataset.get(i).toString();
    }

    public void run() {
        if (this.properties == null) {
            this.inference.setUp(this);
        } else {
            this.inference.setUp(this, this.properties);
        }
        this.inference.run();
        this.trained = true;
    }

    public double getAlpha(int i) {
        if (i < 0 || this.numTopics < i) {
            throw new ArrayIndexOutOfBoundsException(i);
        }
        return this.hyperparameters.alpha(i);
    }

    public double getSumAlpha() {
        return this.hyperparameters.sumAlpha();
    }

    public double getBeta() {
        return this.hyperparameters.beta();
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public BagOfWords getBow() {
        return this.dataset.getBow();
    }

    public double getTheta(int i, int i2) {
        if (i < 0 || this.dataset.getNumDocs() < i || i2 < 0 || this.numTopics < i2) {
            throw new IllegalArgumentException();
        }
        if (this.trained) {
            return this.inference.getTheta(i, i2);
        }
        throw new IllegalStateException();
    }

    public double getPhi(int i, int i2) {
        if (i < 0 || this.numTopics < i || i2 < 0) {
            throw new IllegalArgumentException();
        }
        if (this.trained) {
            return this.inference.getPhi(i, i2);
        }
        throw new IllegalStateException();
    }

    public Vocabularies getVocabularies() {
        return this.dataset.getVocabularies();
    }

    public List<Pair<String, Double>> getVocabsSortedByPhi(int i) {
        return this.inference.getVocabsSortedByPhi(i);
    }

    public double computePerplexity(Dataset dataset) {
        double d = 0.0d;
        for (int i = 0; i < dataset.getNumDocs(); i++) {
            for (Integer num : dataset.getWords(i)) {
                double d2 = 0.0d;
                for (int i2 = 0; i2 < getNumTopics(); i2++) {
                    d2 += getTheta(i, i2) * getPhi(i2, num.intValue());
                }
                d += Math.log(d2);
            }
        }
        return Math.exp((-d) / dataset.getNumWords());
    }
}
