torch.multinomial

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None))LongTensor

返回一个张量,其中每一行包含从对应的输入张量input的行中的多项式分布(更严格的定义是多元分布,参见torch.distributions.multinomial.Multinomial 获取更多详情)中抽取的num_samples个索引。

注意

input 的各行不必总和为一(在这种情况下,我们会将这些值作为权重使用),但是它们必须是非负数、有限值,并且总和不能为零。

索引按样本时间顺序从左到右排列(第一个样本放置在第一列)。

如果 input 是一个向量,那么 out 就是一个大小为 num_samples 的向量。

如果 input 是一个具有 m 行的矩阵,那么 out 将是一个形状为 $(m \times \text{num\_samples})$ 的矩阵。

如果 replacement 为 True,则抽样时允许重复。

如果不是,这些样本将不会被放回。也就是说,一旦为某一行抽取出一个样本索引,这个索引就再也不能用于同一行。

注意

当不放回抽取时,num_samples 必须小于 input 中非零元素的数量(如果是矩阵,则必须小于每行中非零元素的最小数量)。

参数
  • input (Tensor) – 包含概率值的输入张量

  • num_samples (int) – 抽取的样本数目

  • replacement (bool, 可选) – 是否进行有放回的抽取

关键字参数
  • generator (torch.Generator, 可选) – 用于样本采集的伪随机数生成器

  • out (Tensor, 可选) – 指定输出张量。

示例:

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 5) # ERROR!
RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])
本页目录