余弦相似度

classtorch.nn.CosineSimilarity(dim=1, eps=1e-08)[源代码]

沿 计算并返回 $x_1$$x_2$之间的余弦相似度。

$\text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}.$
参数
  • dim (int, 可选) – 用于计算余弦相似度的维度。默认值为1。

  • eps (float, 可选) – 一个很小的数值,用于防止除以零的情况。默认值:1e-8

形状:
  • 输入1: $(\ast_1, D, \ast_2)$,其中D位于dim位置

  • 输入2: $(\ast_1, D, \ast_2)$,与x1具有相同数量的维度,并在维度dim上与x1大小匹配。

    并且可以在其他维度上与 x1 进行广播。

  • 结果: $(\ast_1, \ast_2)$

示例:
>>> input1 = torch.randn(100, 128)
>>> input2 = torch.randn(100, 128)
>>> cos = nn.CosineSimilarity(dim=1, eps=1e-6)
>>> output = cos(input1, input2)
本页目录