본문 바로가기

ML&DL/PyTorch

[PyTorch] squeeze, unsqueeze함수: 차원 삭제와 차원 삽입

squeeze함수

squeeze함수는 차원이 1인 차원을 제거해준다. 따로 차원을 설정하지 않으면 1인 차원을 모두 제거한다. 그리고 차원을 설정해주면 그 차원만 제거한다.

Python 코드

import torch

x = torch.rand(3, 1, 20, 128)
x = x.squeeze() #[3, 1, 20, 128] -> [3, 20, 128]

 

주의할 점은 생각치도 못하게 batch가 1일 때 batch차원도 없애버리는 불상사가 발생할 수있다. 그래서 validation단계에서 오류가 날 수 있기 때문에 주의해서 사용해야 한다.

 

import torch

x = torch.rand(1, 1, 20, 128)
x = x.squeeze() # [1, 1, 20, 128] -> [20, 128]

x2 = torch.rand(1, 1, 20, 128)
x2 = x2.squeeze(dim=1) # [1, 1, 20, 128] -> [1, 20, 128]

unsqueeze함수

unsqueeze함수는 squeeze함수의 반대로 1인 차원을 생성하는 함수이다. 그래서 어느 차원에 1인 차원을 생성할 지 꼭 지정해주어야한다.

Python 코드

import torch

x = torch.rand(3, 20, 128)
x = x.unsqueeze(dim=1) #[3, 20, 128] -> [3, 1, 20, 128]

unsqueeze trick: tensor의 None indexing 

None indexing은 unsqueeze operation을 대체할 수 있다.

Python 코드

import torch

inputs = torch.rand(3, 100, 100)

print(inputs[:,None].shape) # [3, 100, 100] -> [3, 1, 100, 100]
print(inputs.unsqueeze(1).shape) # [3, 100, 100] -> [3, 1, 100, 100]
  • inputs[...,None]: 마지막 차원을 하나 삽입한다. (= inputs.unsqueeze(dim=-1))
  • inputs[...,None,:]: 마지막 바로 전 차원을 하나 삽입한다. (= inputs.unsqueeze(dim=-2))

Reference

https://stackoverflow.com/questions/69797614/indexing-a-tensor-with-none-in-pytorch