1

I am trying to understand the top_p parameter in langchain (nucleus sampling) but I can't seem to grasp it. Based on this we sort the probabilities and select a subset that exceeds p and concurrently has the fewer members possible.
For example for:

t1 =0.05
t2 = 0.5
t3 = 0.3
t4 = 0.15

and top_p=0.75 we would select t2 and t3, right?

If this is the case what happens if top_p=0.001?
We just need one token and any one of these is enough.
Do we select the biggest one (t2)? (based on my experience this makes sense, since i tested top_p=0.001 on an LLM and the output was coherent, so since we select only one token if it was a random token with probability >0.001 the output should be garbage).

1 Answers1

2

If top_p=0.75 we would select t2 and t3, right? → YES.

If top_p=0.001? → We would select only t2.

This is the original definition:

The key idea is to use the shape of the probability distribution to determine the set of tokens to be sampled from. Given a distribution $P(x | x_{1:i-1})$, we define its top-$p$ vocabulary $V^{(p)} \subset V$ as the smallest set such that \begin{equation} \sum_{x \in V^{(p)}} P(x | x_{1:i-1}) \geq p. \end{equation}

I would add something about the lines "the result of the sum shall be maximal among all possible combinations".

In practical terms, we can take a look at the original implementation:

sorted_probs, sorted_indices = torch.sort(samp_probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
sorted_samp_probs = sorted_probs.clone()
sorted_samp_probs[sorted_indices_to_remove] = 0

...

sorted_next_indices = sorted_samp_probs.multinomial(1).view(-1, 1) next_tokens = sorted_indices.gather(1, sorted_next_indices) next_logprobs = sorted_samp_probs.gather(1, sorted_next_indices).log()

There, we can see that the first thing they do is sorting and then they compute the cumulative probability distribution to find the cutting point.

noe
  • 28,203
  • 1
  • 49
  • 83