torch.multinomial#
- torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor#
Returns a tensor where each row contains
num_samplesindices sampled from the multinomial (a stricter definition would be multivariate, refer totorch.distributions.multinomial.Multinomialfor more details) probability distribution located in the corresponding row of tensorinput.Note
The rows of
inputdo not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum.Indices are ordered from left to right according to when each was sampled (first samples are placed in first column).
If
inputis a vector,outis a vector of sizenum_samples.If
inputis a matrix with m rows,outis an matrix of shape .If replacement is
True, samples are drawn with replacement.If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.
Note
When drawn without replacement,
num_samplesmust be lower than number of non-zero elements ininput(or the min number of non-zero elements in each row ofinputif it is a matrix).- Parameters
- Keyword Arguments
generator (
torch.Generator, optional) – a pseudorandom number generator for samplingout (Tensor, optional) – the output tensor.
Example:
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> torch.multinomial(weights, 2) tensor([1, 2]) >>> torch.multinomial(weights, 5) # ERROR! RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1])