package forge.lda.lda.inference.internal;

import forge.lda.dataset.Vocabulary;
import forge.lda.lda.LDA;
import forge.lda.lda.inference.Inference;
import forge.lda.lda.inference.InferenceProperties;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.apache.commons.math3.distribution.IntegerDistribution;

/* loaded from: input_file:forge/lda/lda/inference/internal/CollapsedGibbsSampler.class */
public class CollapsedGibbsSampler implements Inference {
    private LDA lda;
    private Topics topics;
    private Documents documents;
    private int numIteration;
    private static final long DEFAULT_SEED = 0;
    private static final int DEFAULT_NUM_ITERATION = 100;
    private boolean ready = false;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // forge.lda.lda.inference.Inference
    public void setUp(LDA lda, InferenceProperties inferenceProperties) {
        if (inferenceProperties == null) {
            setUp(lda);
            return;
        }
        this.lda = lda;
        initialize(this.lda);
        initializeTopicAssignment(inferenceProperties.seed() != null ? inferenceProperties.seed().longValue() : DEFAULT_SEED);
        this.numIteration = inferenceProperties.numIteration() != null ? inferenceProperties.numIteration().intValue() : DEFAULT_NUM_ITERATION;
        this.ready = true;
    }

    @Override // forge.lda.lda.inference.Inference
    public void setUp(LDA lda) {
        if (lda == null) {
            throw new NullPointerException();
        }
        this.lda = lda;
        initialize(this.lda);
        initializeTopicAssignment(DEFAULT_SEED);
        this.numIteration = DEFAULT_NUM_ITERATION;
        this.ready = true;
    }

    private void initialize(LDA lda) {
        if (!$assertionsDisabled && lda == null) {
            throw new AssertionError();
        }
        this.topics = new Topics(lda);
        this.documents = new Documents(lda);
    }

    public boolean isReady() {
        return this.ready;
    }

    public int getNumIteration() {
        return this.numIteration;
    }

    public void setNumIteration(int i) {
        this.numIteration = i;
    }

    @Override // forge.lda.lda.inference.Inference
    public void run() {
        if (!this.ready) {
            throw new IllegalStateException("instance has not set up yet");
        }
        for (int i = 1; i <= this.numIteration; i++) {
            System.out.println("Iteration " + i + ".");
            runSampling();
        }
    }

    void runSampling() {
        for (Document document : this.documents.getDocuments()) {
            for (int i = 0; i < document.getDocLength(); i++) {
                Topic topic = this.topics.get(document.getTopicID(i));
                document.decrementTopicCount(topic.id());
                Vocabulary vocabulary = document.getVocabulary(i);
                topic.decrementVocabCount(vocabulary.id());
                int sample = getFullConditionalDistribution(this.lda.getNumTopics(), document.id(), vocabulary.id()).sample();
                document.setTopicID(i, sample);
                document.incrementTopicCount(sample);
                this.topics.get(sample).incrementVocabCount(vocabulary.id());
            }
        }
    }

    IntegerDistribution getFullConditionalDistribution(int i, int i2, int i3) {
        int[] array = IntStream.range(0, i).toArray();
        return new EnumeratedIntegerDistribution(array, Arrays.stream(array).mapToDouble(i4 -> {
            return getTheta(i2, i4) * getPhi(i4, i3);
        }).toArray());
    }

    void initializeTopicAssignment(long j) {
        this.documents.initializeTopicAssignment(this.topics, j);
    }

    int getDTCount(int i, int i2) {
        if (!this.ready) {
            throw new IllegalStateException();
        }
        if (i <= 0 || this.lda.getBow().getNumDocs() < i || i2 < 0 || this.lda.getNumTopics() <= i2) {
            throw new IllegalArgumentException();
        }
        return this.documents.getTopicCount(i, i2);
    }

    int getTVCount(int i, int i2) {
        if (!this.ready) {
            throw new IllegalStateException();
        }
        if (i < 0 || this.lda.getNumTopics() <= i || i2 <= 0) {
            throw new IllegalArgumentException();
        }
        return this.topics.get(i).getVocabCount(i2);
    }

    int getTSumCount(int i) {
        if (i < 0 || this.lda.getNumTopics() <= i) {
            throw new IllegalArgumentException();
        }
        return this.topics.get(i).getSumCount();
    }

    @Override // forge.lda.lda.inference.Inference
    public double getTheta(int i, int i2) {
        if (this.ready) {
            return this.documents.getTheta(i, i2, this.lda.getAlpha(i2), this.lda.getSumAlpha());
        }
        throw new IllegalStateException();
    }

    @Override // forge.lda.lda.inference.Inference
    public double getPhi(int i, int i2) {
        if (this.ready) {
            return this.topics.getPhi(i, i2, this.lda.getBeta());
        }
        throw new IllegalStateException();
    }

    @Override // forge.lda.lda.inference.Inference
    public List<Pair<String, Double>> getVocabsSortedByPhi(int i) {
        return this.topics.getVocabsSortedByPhi(i, this.lda.getVocabularies(), this.lda.getBeta());
    }

    static {
        $assertionsDisabled = !CollapsedGibbsSampler.class.desiredAssertionStatus();
    }
}
