본문 바로가기

ML&DL/PyTorch

[PyTorch] Tensor 자르기/분리하기: chunk함수

Chunk함수란?

 

chunk함수는 tensor를 쪼개는 함수이다. tensor를 몇개로 어떤 dimension으로 쪼갤지 설정해주고 사용하면된다. output = torch.chunk(input, n = (몇개로 쪼갤지 설정), dim = (어떤 차원에 적용할지))

 

 

import torch

# type 1
x = torch.chunk(x, n, dim)

# type 2
x = x.chunck(n, dim)

Chunk함수의 시각화

시각화라고 하긴 거창하지만 내 발그림으로 이해해보자. 내 상황을 좀 설명해보면 나는 마이크 여러개에서 얻은 신호중에 첫번째 신호만 꺼내서 처리하고 싶은 상황이다.

음성의 길이는 T이고 64000 sample이 존재한다. 그리고 microphone개수가 6개여서 ch(channel)이 6인 것을 확인할 수 있다. 그리고 batch는 3으로 설정했다. 그러면 총 차원은 [batch_size, ch, T] -> [3, 6, 64000]이 된다.

 

 

 

 

근데 나는 여기서 모든 batch에서 첫번째 microphone signal만 뽑아오고 싶다.

 

 

 

우리가 바꾸고 뽑고 싶은 차원이 ch 부분인 것을 알 수 있다. 그럼 바꾸고 싶은 차원을 dim이라는 parameter자리에 넣어주면 된다.

 

 

내가 바꾸고싶은 차원이 여기서는 1번째 차원인 것을 알 수 있다. 그럼 dim에는 1을 넣어주면 된다. 그리고 첫번째차원을 6등분해야 하나의 음성이 추출되기 때문에 n=6으로 설정하면 된다. 코드로 알아보자.

Python코드

랜덤한 tensor를 생성하고 chunk함수를 적용해보자. 원래는 자르는 갯수만큼의 output이 생기지만 나는 첫번째 output만 필요해서 생략했다.

 

import torch

"""
M: mini batch
nmic=ch: # of channel (# of mic)
T: nsample
"""

batch_size, nmic, nsample = 3, 6, 64000

x = torch.rand(batch_size, nmic, nsample) # [M, ch, T] = [3, 6, 64000]
mic1, mic2, mic3, mic4, mic5, mic6 = torch.chunk(x, nmic, dim=1) # [M, 1, T] = [3, 1, 64000]

 

원래는 자르는 갯수만큼의 output이 생기지만 필요한 output만 취할 수 있다.

 

import torch

"""
M: mini batch
nmic=ch: # of channel (# of mic)
T: nsample
"""

batch_size, nmic, nsample = 3, 6, 64000

x = torch.rand(batch_size, nmic, nsample) # [M, ch, T] = [3, 6, 64000]
mic1,_,_,_,_,_ = torch.chunk(x, nmic, dim=1) # [M, 1, T] = [3, 1, 64000]

 

그리고 분할하려는 dimension을 꽉 채워서 자르지 않아도 된다.

 

chunk1, chunk2 = torch.chunk(x, 2, dim=1) # [M, 1, T] = [3, 3, 64000]

Chunk함수 활용

첫번째 dimension을 기준으로 for문을 돌리는 형태가 가능하다.

 

# feats: dimension [batch, M, feat_dim]

f_axis=[]
for f in feats.chunk(M, dim=1):
	f_mean = f.mean(dim=1) #[batch, feat_dim]
    f_axis.append(normalize(f_mean, dim=1))