torch.set_default_tensor_type

torch.set_default_tensor_type(t, /)[源代码]

警告

此函数在 PyTorch 2.1 中已弃用,请使用torch.set_default_dtype()torch.set_default_device() 作为替代。

将默认的 torch.Tensor 类型设置为浮点张量类型 t。此类型也将作为在 torch.tensor() 中进行类型推断时的默认浮点类型。

默认的浮点张量类型初始为 torch.FloatTensor

参数

t (类型字符串) - 表示浮点张量的类型或其名称

示例:

>>> torch.tensor([1.2, 3]).dtype    # initial default for floating point is torch.float32
torch.float32
>>> torch.set_default_tensor_type(torch.DoubleTensor)
>>> torch.tensor([1.2, 3]).dtype    # a new floating point tensor
torch.float64
本页目录