本篇介绍Pytorch的两个高阶操作: where与gather。
where
使用

图片来自 Pytorch官方文档.
- 就像python中的三元运算一样,如果条件满足,选input的元素,不满足选other的元素。
| 1 | In[49]: cond = torch.rand(2,2) # 制作一个选择器 | 
为什么会有where
- 以前我们通常在for循环下以 c[0] = a[0] 来进行整段的复制,但是这些都是在cpu上完成的,想要用gpu并行,就必须摆脱这种方式。
- 使用where语句,在gpu中,高度并行。
- cond 可以由cpu生成也可以有gpu生成。
gather
| 1 | torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor | 
- 收集/查表 
- input 理解为一张表;dim决定哪个维度查找;所查的索引 
- 这样做一来可以用gpu加速,二来可以达到从 relative gather 到 global gather。 

例子:
- 以手写体数字为例,我们假设输出[4,10] 
- 每张图片取可能性最大的前3 
| 1 | In[3]: prob=torch.randn(4,10) | 

