본문 바로가기
Archive/ML & DL

(극복Proj-1)[Pytorch] DataLoader란?

by 다람이도토리 2021. 12. 23.

* 당분간 DL 기초 구현 극복 Project를 진행하고자 합니다.

Pytorch에서 잘 이해하지 못했던 기능들과 RNN, LSTM의 구현 및 텍스트 문제에의 적용의 극복을 통해 다음 단계로 넘어가기 위한 준비를 진행하고자 합니다.

왜, 기본 예제 이외에 추가로 무언가 하려고 하면 RNN/LSTM에서 학습이 잘 안되고 있었는지 이러한 것들을 공부하고 맨 처음 딥러닝을 접할때 어려움을 느낀 부분들을 정리하려고 합니다.
-------------------------------------------------------------------------------------------------------------------------

[1] Pytorch - DataLoader 이해하기

[참고자료]
https://tutorials.pytorch.kr/beginner/basics/data_tutorial.html#dataloader

Pytorch DataLoader 왜 사용할까?

데이터를 1개, 1개 학습을 시키는 방법도 가능하지만 Pytorch를 활용하면 Mini-Batch 단위의 학습이 가능하다. 또한 데이터를 무작위로 섞어줄 수 있다는 장점이 있다. 또한, 데이터셋을 불러오고 관리하는 과정을 간결하게 표현 가능하기에  DataLoader의 활용은 중요할 것이다.

Fashion MNIST를 통한 DataLoader의 활용법

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

transform을 통해 쉽게 미리 텐서화를 시킬 수 있다. 이제 이를 원래대로라면, 하나하나 다 쪼개고 머ㅜ하고 라벨 붙이고..장난 아니었을 것인데, 이를 쉽게 처리 가능하다.

from torch.utils.data import DataLoader

# 하나의 data_loader 각각에는 feature과 label이 들어있다.
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

미니배치 형태로 구성되며 하나의 배치 안에는 64개의 데이터가 존재한다.
data_loader에서 iteration 시키면서 데이터를 사용하게 된다고 생각하면 된다.

# 이미지와 정답(label)을 표시합니다.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[63].squeeze()
label = train_labels[63]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

 

 

그런데 이를 위한 DataSet은 어떻게 구성되는가?

MNIST는 깔끔하게 가능했다 치고, 전혀 다른 Data에서는 기본적으로 Custom Dataset을 만들어야 한다. 
즉, 이를 정의하여 DataLoader에 넘겨주어여 하는 것이다.

class CustomDataset(torch.utils.data.Dataset): 
  def __init__(self):
  # 데이터셋에서 X, Y를 기본 정의 및 전처리한다.

  def __len__(self):
  # 총 샘플의 수가 몇개인지 적어준다.

  def __getitem__(self, idx): 
  # 1개의 샘플 수를 가져온다.

Dataset을 만들때는 반드시, 위의 3가지 함수를 구현해서 진행해야 한다.

Fashion-MNIST의 case를 살펴보자.

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'label'])
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)
        
    # 인덱스를 입력받아, 이에 해당하는 입력/출력 데이터를 tensor 형태로 return 시킨다.
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label