torch.set_default_dtype

torch.set_default_dtype(d, /)[源代码]

将默认的浮点数数据类型设置为 d。支持浮点数类型的输入。其他数据类型会导致 torch 抛出异常。

在初始化 PyTorch 时,默认的浮点数数据类型是 torch.float32。使用 set_default_dtype(torch.float64) 可以实现类似于 NumPy 的类型推断功能。默认的浮点数数据类型用于:

  1. 隐式确定默认的复数数据类型。当默认浮点类型为 float16 时,默认的复数数据类型是 complex32;对于 float32,默认的复数数据类型是 complex64;对于 float64,默认的复数数据类型是 complex128。而对于 bfloat16,则会抛出异常,因为没有与之对应的复数类型。

  2. 根据使用的Python浮点数或复数,推断张量的数据类型。参见下面的例子。

  3. 确定布尔值与整数张量、Python浮点数与复数之间的类型提升规则。

参数

d (torch.dtype) – 设置为默认值的浮点数据类型。

示例

>>> # initial default for floating point is torch.float32
>>> # Python floats are interpreted as float32
>>> torch.tensor([1.2, 3]).dtype
torch.float32
>>> # initial default for floating point is torch.complex64
>>> # Complex Python numbers are interpreted as complex64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> # Python floats are now interpreted as float64
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float64
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex128
>>> torch.set_default_dtype(torch.float16)
>>> # Python floats are now interpreted as float16
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float16
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex32
本页目录