torch.multinomial
- torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None)) → LongTensor
-
返回一个张量,其中每一行包含从对应的输入张量
input
的行中的多项式分布(更严格的定义是多元分布,参见torch.distributions.multinomial.Multinomial
获取更多详情)中抽取的num_samples
个索引。注意
input
的各行不必总和为一(在这种情况下,我们会将这些值作为权重使用),但是它们必须是非负数、有限值,并且总和不能为零。索引按样本时间顺序从左到右排列(第一个样本放置在第一列)。
如果
input
是一个向量,那么out
就是一个大小为num_samples
的向量。如果
input
是一个具有 m 行的矩阵,那么out
将是一个形状为 $(m \times \text{num\_samples})$ 的矩阵。如果 replacement 为
True
,则抽样时允许重复。如果不是,这些样本将不会被放回。也就是说,一旦为某一行抽取出一个样本索引,这个索引就再也不能用于同一行。
注意
当不放回抽取时,
num_samples
必须小于input
中非零元素的数量(如果是矩阵,则必须小于每行中非零元素的最小数量)。- 参数
- 关键字参数
-
-
generator (
torch.Generator
, 可选) – 用于样本采集的伪随机数生成器 -
out (Tensor, 可选) – 指定输出张量。
-
示例:
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> torch.multinomial(weights, 2) tensor([1, 2]) >>> torch.multinomial(weights, 5) # ERROR! RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1])