torch.frombuffer
- torch.frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False) → Tensor
-
从实现 Python 缓冲协议的对象创建一维
Tensor
。跳过缓冲区中前
offset
字节,并将剩余的原始字节解释为类型为dtype
的一维张量,包含count
个元素。请注意,以下任一条件必须成立:
1.
count
是一个正的非零数字,并且缓冲区中的总字节数大于offset
加上count
乘以每个dtype
的大小(以字节为单位)。2. 当
count
为负时,缓冲区的长度(以字节为单位)减去offset
是dtype
大小(以字节为单位)的倍数。返回的张量和缓冲区使用相同的内存。对张量进行的任何更改都会在缓冲区中显示,反之亦然。此外,返回的张量不能被重新调整大小。
注意
此函数会增加拥有共享内存对象的引用计数。因此,在返回的张量超出作用域之前,这些内存不会被释放。
警告
如果传递给该函数的对象实现了缓冲协议但其数据不在CPU上,则该函数的行为将无法预测,并可能导致段错误。
- 参数
-
buffer (对象) – 一个实现了缓冲区接口的 Python 对象。
- 关键字参数
-
-
dtype (
torch.dtype
) - 返回的张量所需的數據類型。 -
count (int, 可选) – 指定要读取的元素数量。如果设置为负数,则会一直读取到缓冲区结束为止。默认值:-1。
-
offset (int, 可选) – 在缓冲区开始处跳过的字节数。默认值为 0。
-
requires_grad (bool, optional) – 是否启用自动求导记录返回的张量上的操作。默认值:
False
。
-
示例:
>>> import array >>> a = array.array('i', [1, 2, 3]) >>> t = torch.frombuffer(a, dtype=torch.int32) >>> t tensor([ 1, 2, 3]) >>> t[0] = -1 >>> a array([-1, 2, 3]) >>> # Interprets the signed char bytes as 32-bit integers. >>> # Each 4 signed char elements will be interpreted as >>> # 1 signed 32-bit integer. >>> import array >>> a = array.array('b', [-1, 0, 0, 0]) >>> torch.frombuffer(a, dtype=torch.int32) tensor([255], dtype=torch.int32)