squeeze함수
squeeze함수는 차원이 1인 차원을 제거해준다. 따로 차원을 설정하지 않으면 1인 차원을 모두 제거한다. 그리고 차원을 설정해주면 그 차원만 제거한다.
Python 코드
import torch
x = torch.rand(3, 1, 20, 128)
x = x.squeeze() #[3, 1, 20, 128] -> [3, 20, 128]
주의할 점은 생각치도 못하게 batch가 1일 때 batch차원도 없애버리는 불상사가 발생할 수있다. 그래서 validation단계에서 오류가 날 수 있기 때문에 주의해서 사용해야 한다.
import torch
x = torch.rand(1, 1, 20, 128)
x = x.squeeze() # [1, 1, 20, 128] -> [20, 128]
x2 = torch.rand(1, 1, 20, 128)
x2 = x2.squeeze(dim=1) # [1, 1, 20, 128] -> [1, 20, 128]
unsqueeze함수
unsqueeze함수는 squeeze함수의 반대로 1인 차원을 생성하는 함수이다. 그래서 어느 차원에 1인 차원을 생성할 지 꼭 지정해주어야한다.
Python 코드
import torch
x = torch.rand(3, 20, 128)
x = x.unsqueeze(dim=1) #[3, 20, 128] -> [3, 1, 20, 128]
unsqueeze trick: tensor의 None indexing
None indexing은 unsqueeze operation을 대체할 수 있다.
Python 코드
import torch
inputs = torch.rand(3, 100, 100)
print(inputs[:,None].shape) # [3, 100, 100] -> [3, 1, 100, 100]
print(inputs.unsqueeze(1).shape) # [3, 100, 100] -> [3, 1, 100, 100]
- inputs[...,None]: 마지막 차원을 하나 삽입한다. (= inputs.unsqueeze(dim=-1))
- inputs[...,None,:]: 마지막 바로 전 차원을 하나 삽입한다. (= inputs.unsqueeze(dim=-2))
Reference
https://stackoverflow.com/questions/69797614/indexing-a-tensor-with-none-in-pytorch
'ML&DL > PyTorch' 카테고리의 다른 글
[PyTorch] Weight Decay (L2 penalty) (0) | 2020.09.28 |
---|---|
[PyTorch] Gradient clipping (그래디언트 클리핑) (0) | 2020.09.23 |
[PyTorch] Tensor 합치기: cat(), stack() (9) | 2020.09.16 |
[PyTorch] 시계열 데이터를 위한 RNN/LSTM/GRU 사용법과 팁 (8) | 2020.08.05 |
[PyTorch] 리눅스환경에서 특정 GPU만 이용해 Multi GPU로 학습하기 (0) | 2020.07.28 |