본문 바로가기

ML&DL/PyTorch

[PyTorch] Tensor 합치기: cat(), stack()

 

실험이 돌아가는 동안 심심하니까 하는 포스팅. PyTorch에서 tensor를 합치는 2가지 방법이 있는데 cat과 stack이다. 두가지는 현재 차원의 수를 유지하느냐 확장하느냐의 차이가 있다. 그림과 코드를 통해 사용법을 알아보자.

Cat함수란?

cat함수는 concatenate를 해주는 함수이고 concatenate하고자 하는 차원을 증가시킨다 (차원의 갯수는 유지되고 해당 차원이 늘어난다.. 설명이 애매하네 코드 참고). concatenate하고자하는 차원을 지정해주면 그 차원으로 두 tensor의 차원을 더한 값으로 차원이 변경된다. concatenate하고자하는 dimension을 지정해주지 않으면 default=0으로 설정된다. 자주 사용하는 용도로는 network 내에서 다른 두 feature를 fusion할 때와 batch 단위로 생성된 output값을 모아 전체 데이터셋에 대한 output을 저장해야할 때 쓴다.

Cat함수의 시각화

Python 코드

import torch

batch_size, N, K = 3, 10, 256

x = torch.rand(batch_size, N, K) # [M, N, K]
y = torch.rand(batch_size, N, K) # [M, N, K]

output1 = torch.cat([x,y], dim=1) #[M, N+N, K]
output2 = torch.cat([x,y], dim=2) #[M, N, K+K]

Stack함수란?

stack함수는 지정하는 차원으로 확장하여 tensor를 쌓아주는 함수이다. (지정하는 차원에 새로운 차원이 생긴다=차원의 갯수가 증가한다) tensor를 쌓아주는 함수이기 때문에 두 tensor의 차원이 정확히 일치해야 쌓을 수 있다. stack 하고자하는 dimension을 지정해주지 않으면 default=0으로 설정된다.

Stack함수의 시각화

Python 코드

import torch

batch_size, N, K = 3, 10, 256

x = torch.rand(batch_size, N, K) # [M, N, K]
y = torch.rand(batch_size, N, K) # [M, N, K]

output = torch.stack([x,y], dim=1) #[M, 2, N, K]

Cat, Stack 함수의 활용

Cat 함수 활용: Tensor list를 한번에 tensor로 만들기

import torch

#(....중략)

out_list = []
for data in dataloader:
    out = model(data) # [batch size, latent dim]
    out_list.append(out)
output = torch.cat(out_list, 0) # [Total dataset size (batch size * # of batch, latent dim]
# same as --> output = torch.cat(out_list, dim=0) 

# 참고, numpy로 변환
output_np = output.detach().cpu().numpy()

Stack 함수 활용: Tensor list를 한번에 tensor로 만들기

import torch

#(....중략)

out_list = []
for data in dataloader:
    out = model(data) # [batch size, latent dim]
    out_list.append(out)
output = torch.stack(out_list, 0) # [Number of batch, batch size, latent dim]
all_features = []
for (data, label) in tt_dataloader:
	feat = model(data)
    all_features.extend(feat.cpu())
torch.cat(all_features, 0).numpy()

Cat, Stack 함수 없이 tensor 합쳐서 numpy array로 만들기 

features_all = np.zeros((len(test_loader.dataset.targets), L))
ind = 0
with torch.no_grad():
    for inputs, labels in tqdm(test_loader):
        inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device)
        features = model.forward_features(inputs)
        features = features.cpu().detach().numpy()
        print(features.shape)

        features_all[ind:ind + len(features)] = features
        ind = ind + len(features)