PyTorch张量索引简明指南

作为数据科学家或软件工程师,你可能经常处理大型数据集和复杂的数学运算,这些运算需要高效且可扩展的计算。PyTorch 是一个流行的开源机器学习库,它通过 GPU 加速提供快速灵活的张量计算。在本文中,我们将深入研究 PyTorch 张量索引,这是一种强大的技术,可让你轻松选择和操作张量的特定元素或子集。

1、什么是 PyTorch 张量索引?

张量索引是根据张量的位置或值选择特定元素或子集的过程。PyTorch 张量索引提供了一组丰富的索引操作,使你可以使用不同的索引方案(例如整数索引、布尔索引和高级索引)选择和修改张量元素。

在 PyTorch 中,张量是一个多维数组,可以存储不同类型和大小的数值数据。可以使用一个或多个索引对张量进行索引,这些索引指定元素沿张量每个维度的位置。索引张量将返回一个包含所选元素的新张量或具有修改元素的原始张量的视图。

1、整数索引

整数索引是张量索引的最基本形式,它允许你使用张量沿每个维度的整数位置来选择特定元素。你可以使用整数索引来选择单个元素或子张量,方法是提供与所需元素的索引相对应的整数列表。

import torch

# create a 2D tensor of size (3, 4)
x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# select the element at position (1, 2)
print(x[1, 2])   # output: tensor(7)

# select the sub-tensor of size (2, 3) starting at position (0, 1)
print(x[0:2, 1:4])   # output: tensor([[2, 3, 4], [6, 7, 8]])

2、布尔索引

布尔索引允许你根据布尔条件选择张量的特定元素。你可以使用布尔索引来选择满足特定条件的元素或屏蔽不满足条件的元素。

# create a 1D tensor of size 5
x = torch.tensor([1, 2, 3, 4, 5])

# select the elements that are greater than 3
print(x[x > 3])   # output: tensor([4, 5])

# mask out the elements that are not even
x[x % 2 != 0] = 0
print(x)   # output: tensor([0, 2, 0, 4, 0])

3、高级索引

高级索引允许你使用整数索引、布尔索引和其他高级索引方案的组合来选择张量的特定元素。你可以使用高级索引从张量的多个维度中选择元素,或创建具有不同形状的新张量。

import torch

# create a 3D tensor of size (2, 3, 4)
x = torch.tensor([[[1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 10, 11, 12]],
                  
                  [[13, 14, 15, 16],
                   [17, 18, 19, 20],
                   [21, 22, 23, 24]]])

# select the elements at positions (0, 1, 2) and (1, 2, 3)
print(x[:, (0, 1, 2), (1, 2, 3)])   # output: tensor([[ 2,  7, 12],
                                    #          [14, 19, 24]])

# create a new tensor by selecting the elements that are greater than 10
y = x[x > 10]
print(y)   # output: tensor([11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])

4、修改张量元素

PyTorch 张量索引不仅允许你选择张量元素,还允许你就地修改它们或使用修改后的元素创建新的张量。你可以使用索引为张量元素分配新值或对选定元素应用数学运算。

# create a 2D tensor of size (3, 3)
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

# assign a new value to the element at position (1, 2)
x[1, 2] = 10
print(x)   # output: tensor([[ 1,  2,  3], [ 4,  5, 10], [ 7,  8,  9]])

# multiply the elements that are greater than 5 by 2
x[x > 5] *= 2
print(x)   # output: tensor([[ 1,  2,  3], [ 4,  5, 20], [14, 16, 18]])

5、结束语

PyTorch 张量索引是一种强大而灵活的技术,它使你能够使用不同的索引方案选择和修改张量的特定元素或子集。无论你处理的是大型数据集还是复杂的数学运算,PyTorch 张量索引都提供了一种简单而有效的方法来操作张量数据。

在本文中,我们介绍了 PyTorch 张量索引的基础知识,包括整数索引、布尔索引、高级索引和修改张量元素。通过掌握这些索引技术,你可以为数据科学和机器学习项目释放 PyTorch 的全部潜力。


原文链接:PyTorch Tensor Indexing: A Guide

BimAnt翻译整理,转载请标明出处