本篇介绍Pytorch 的索引与切片。
索引
| 1 | In[3]: a = torch.rand(4,3,28,28) | 
切片
- 顾头不顾尾
| 1 | In[7]: a.shape | 
步长
- 顾头不顾尾 + 步长
- start : end : step
- 对于步长为1的,通常就省略了。
| 1 | a[:,:,0:28,0:28:2].shape # 隔点采样 | 
具体的索引
- .index_select(dim, indices)- dim为维度,indices是索引序号
- 这里的indeces必须是tensor ,不能直接是一个list
 
| 1 | In[17]: a.shape | 
...
- ...表示任意多维度,根据实际的shape来推断。
- 当有 ...出现时,右边的索引理解为最右边
- 为什么会有它,没有它的话,存在这样一种情况 a[0,: ,: ,: ,: ,: ,: ,: ,: ,: ,2] 只对最后一个维度做了限度,这个向量的维度又很高,以前的方式就不太方便了。
| 1 | In[23]: a.shape | 
使用mask来索引
- .masked_select()
- 求掩码位置原来的元素大小
- 缺点:会把数据,默认打平(flatten),
| 1 | In[31]: x = torch.randn(3,4) | 
使用打平(flatten)后的序列
- torch.take(src, torch.tensor([index]))
- 打平后,按照index来取对应位置的元素
| 1 | In[39]: src = torch.tensor([[4,3,5],[6,7,8]]) # 先打平成1维的,共6列 | 

