torch.frombuffer

torch.frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False) Tensor

从实现 Python 缓冲协议的对象创建一维 Tensor

跳过缓冲区中前 offset 字节,并将剩余的原始字节解释为类型为dtype的一维张量,包含count个元素。

请注意,以下任一条件必须成立:

1. count 是一个正的非零数字,并且缓冲区中的总字节数大于 offset 加上 count 乘以每个 dtype 的大小(以字节为单位)。

2. 当 count 为负时,缓冲区的长度(以字节为单位)减去 offsetdtype 大小(以字节为单位)的倍数。

返回的张量和缓冲区使用相同的内存。对张量进行的任何更改都会在缓冲区中显示,反之亦然。此外,返回的张量不能被重新调整大小。

注意

此函数会增加拥有共享内存对象的引用计数。因此,在返回的张量超出作用域之前,这些内存不会被释放。

警告

如果传递给该函数的对象实现了缓冲协议但其数据不在CPU上,则该函数的行为将无法预测,并可能导致段错误。

警告

此函数不会尝试推断dtype(因此它是不可选的)。传递与其源不同的dtype可能会导致意外行为。

参数

buffer (对象) – 一个实现了缓冲区接口的 Python 对象。

关键字参数
  • dtype (torch.dtype) - 返回的张量所需的數據類型。

  • count (int, 可选) – 指定要读取的元素数量。如果设置为负数,则会一直读取到缓冲区结束为止。默认值:-1。

  • offset (int, 可选) – 在缓冲区开始处跳过的字节数。默认值为 0。

  • requires_grad (bool, optional) – 是否启用自动求导记录返回的张量上的操作。默认值:False

示例:

>>> import array
>>> a = array.array('i', [1, 2, 3])
>>> t = torch.frombuffer(a, dtype=torch.int32)
>>> t
tensor([ 1,  2,  3])
>>> t[0] = -1
>>> a
array([-1,  2,  3])

>>> # Interprets the signed char bytes as 32-bit integers.
>>> # Each 4 signed char elements will be interpreted as
>>> # 1 signed 32-bit integer.
>>> import array
>>> a = array.array('b', [-1, 0, 0, 0])
>>> torch.frombuffer(a, dtype=torch.int32)
tensor([255], dtype=torch.int32)
本页目录