I have a simple transformer model (decoder only) which is trained on some dataset containing sentences to do next-word prediction. The model captures a probability distribution $P_{\theta}(\mathbf{a})$ over a sentence $\mathbf{a}$ by using the chain rule for probabilities:
$$P_\theta(\mathbf{a}) = P_\theta(a_1,\ a_2,\ \dots ,\ a_N) = \prod_{k}^NP_\theta(a_{k}|a_{<k}),$$
where $a_i$ is the $i$-th word and the $a_{<k}$ notation indicates the sentence up to and not including the $k$-th word. The conditional probability distributions $P_\theta(a_{k}|a_{<k})$ are obtained auto-regressively, i.e. inputting $a_{<k}$ for all $k$ (length of sentence) into the model and obtaining the (conditional) probability distribution output for the next word in the sentence.
Goal:
Extract the probability distribution of a set of $M$ words (sub-sentence) in the sentence of length $N$ given the context of the rest of the sentence: $P(a_{\{M\}} | a_{\{N\}/\{M\}})$, where I indicated "the rest of the sentence" as $a_{\{N\}/\{M\}}$.
Now, for sets $\{M\}$ that are "not interrupted" (containing all words up to a certain point in the sentence), it comes down to simply sampling up to that that word, i.e. $P(a_1, a_0) = P(a_1 | a_0)\cdot P(a_0)$. However, for sets $\{M\}$ that are interrupted this is not the case, e.g. $P(a_0, a_2|a_1)$ for N = 3.
Since the vocabulary is too large, it is not possible to sample all possible combinations $\mathbf{a}$ as the cost is exponential in $N$. For small systems however this is possible and this is what I am interested in. The goal is to obtain these probability distributions for "sub-sentences" and constrain them in a way (given by the context of the problem) by adding a term to the loss-function.
My attempt:
One can approximate the distribution using Monte Carlo. By obtaining a large enough set of samples and simply counting occurrences of the sub-sentences in $\{M\}$. That way one indeed obtains $P(a_{\{M\}} | a_{\{N\}/\{M\}})$.
The problem with this approach however is that this approach is not differentiable since sampling from a model is inherently discrete and that causes issues to calculate gradients through the model for back-propagation. I am aware of differentiable variants like the Gumbel-Softmax but it feels like there should be an easier solution that uses the conditional probability distributions.
Update:
We know that: $$P(a_{\{M\}}|a_{\{N\}/\{M\}}) = \sum_{a_{\{N\}/\{M\}}}P(a_{N}|a_{\{N - 1\}}),$$ which we sample sparsely from the model. This equality holds because by design of the model all conditional distributions are normalized. Since we cannot sample all $a_{\{N - 1\}}$), we are not guaranteed the distribution we obtain is normalized. After normalization this yields an estimation of the subsystem probability distribution.