Chunk함수란?
chunk함수는 tensor를 쪼개는 함수이다. tensor를 몇개로 어떤 dimension으로 쪼갤지 설정해주고 사용하면된다. output = torch.chunk(input, n = (몇개로 쪼갤지 설정), dim = (어떤 차원에 적용할지))
import torch
# type 1
x = torch.chunk(x, n, dim)
# type 2
x = x.chunck(n, dim)
Chunk함수의 시각화
시각화라고 하긴 거창하지만 내 발그림으로 이해해보자. 내 상황을 좀 설명해보면 나는 마이크 여러개에서 얻은 신호중에 첫번째 신호만 꺼내서 처리하고 싶은 상황이다.
음성의 길이는 T이고 64000 sample이 존재한다. 그리고 microphone개수가 6개여서 ch(channel)이 6인 것을 확인할 수 있다. 그리고 batch는 3으로 설정했다. 그러면 총 차원은 [batch_size, ch, T] -> [3, 6, 64000]이 된다.
근데 나는 여기서 모든 batch에서 첫번째 microphone signal만 뽑아오고 싶다.
우리가 바꾸고 뽑고 싶은 차원이 ch 부분인 것을 알 수 있다. 그럼 바꾸고 싶은 차원을 dim이라는 parameter자리에 넣어주면 된다.
내가 바꾸고싶은 차원이 여기서는 1번째 차원인 것을 알 수 있다. 그럼 dim에는 1을 넣어주면 된다. 그리고 첫번째차원을 6등분해야 하나의 음성이 추출되기 때문에 n=6으로 설정하면 된다. 코드로 알아보자.
Python코드
랜덤한 tensor를 생성하고 chunk함수를 적용해보자. 원래는 자르는 갯수만큼의 output이 생기지만 나는 첫번째 output만 필요해서 생략했다.
import torch
"""
M: mini batch
nmic=ch: # of channel (# of mic)
T: nsample
"""
batch_size, nmic, nsample = 3, 6, 64000
x = torch.rand(batch_size, nmic, nsample) # [M, ch, T] = [3, 6, 64000]
mic1, mic2, mic3, mic4, mic5, mic6 = torch.chunk(x, nmic, dim=1) # [M, 1, T] = [3, 1, 64000]
원래는 자르는 갯수만큼의 output이 생기지만 필요한 output만 취할 수 있다.
import torch
"""
M: mini batch
nmic=ch: # of channel (# of mic)
T: nsample
"""
batch_size, nmic, nsample = 3, 6, 64000
x = torch.rand(batch_size, nmic, nsample) # [M, ch, T] = [3, 6, 64000]
mic1,_,_,_,_,_ = torch.chunk(x, nmic, dim=1) # [M, 1, T] = [3, 1, 64000]
그리고 분할하려는 dimension을 꽉 채워서 자르지 않아도 된다.
chunk1, chunk2 = torch.chunk(x, 2, dim=1) # [M, 1, T] = [3, 3, 64000]
Chunk함수 활용
첫번째 dimension을 기준으로 for문을 돌리는 형태가 가능하다.
# feats: dimension [batch, M, feat_dim]
f_axis=[]
for f in feats.chunk(M, dim=1):
f_mean = f.mean(dim=1) #[batch, feat_dim]
f_axis.append(normalize(f_mean, dim=1))
'ML&DL > PyTorch' 카테고리의 다른 글
[PyTorch] 시계열 데이터를 위한 다양한 Normalization기법 (BatchNorm1d, GroupNorm 사용법) (2) | 2020.07.28 |
---|---|
[PyTorch] 시계열 데이터를 위한 1D convolution과 1x1 convolution (11) | 2020.07.08 |
[PyTorch] torch.nn.KLDivLoss() 사용법과 예제 (1) | 2020.07.07 |
[PyTorch] numpy에서 tensor로 변환: Tensor, from_numpy함수의 차이/tensor에서 numpy로 변환: numpy함수 (0) | 2020.06.22 |
[PyTorch] view, reshape, transpose, permute함수의 차이와 contiguous의 의미 (3) | 2020.06.11 |