본문 바로가기
A.I./구현

PyTorch Dataloader에서 원본 파일명을 알고 싶다.

by 채소장사 2020. 11. 25.

 이미지 분류를 위한 PyTorch의 사용 시에, 입력 데이터를 준비하는 가장 흔한 방법은 다음과 같다.

  1. 각 레이블 별로 이미지를 서로 다른 디렉토리에 저장한다. 
  2. torchvision.datasets.ImageFolder( )를 이용하여, 폴더 구조로부터 데이터셋을 생성한다.
    이 과정에서 필요한 transform을 수행하고, 텐서로 변환하며 정규화를 수행할 수 있다.
  3. torch.utils.data.DataLoader( )를 사용하여 생성된 데이터셋으로부터 데이터를 로드할 수 있다.
    이 때, 원하는 크기의 배치 단위로 데이터를 로드하거나, 순서가 무작위로 섞이도록(shuffle) 할 수 있다.

 이렇게 처리된 입력 데이터는 변환된 텐서와 레이블이 조합된 형태로 메모리에 로드되어 있다. 그런데 훈련된 모델에 의해 예측된 결과에 따라서, 원본 이미지를 이동시키거나 삭제해야 되는 추가 작업을 해야할 필요가 있다면 현재 사용하는 입력 데이터의 경로나 파일명을 알고 싶을 수 있다. 

 이 글은 DataLoader에서 shuffle의 사용 여부에 따라 원본 파일명을 찾아내는 방법을 간단하게 정리한 포스트이다.

 

1. DataLoader에서 shuffle을 사용하지 않을 때

 DataLoader에서 데이터의 순서를 무작위로 섞지 않고 사용한다면, 추가적인 작업을 하지 않아도 입력 이미지의 원래 경로 및 파일명을 아는 것이 가능하다.

dataset = torchvision.datasets.ImageFolder(<folder_path>, transforms=<transforms>)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=<batch_size>, shuffle=False)

 위의 코드 블록처럼, <folder_path>의 경로에 정리된 folder structure의 이미지들이 있고, 사용자가 지정한 <transforms>를 적용한 데이터셋에서 <batch_size>크기로 shuffle하지 않는 데이터로더를 만든 경우를 생각한다.

 이렇게 생성된 데이터로더에서 dataloader.dataset.samples는 데이터셋의 모든 파일에 대한 (파일경로, 레이블)의 튜플들로 이뤄진 리스트를 가리킨다. 편의를 위해서 파일 경로들로만 이뤄진 리스트를 뽑아서 사용해 보았다.

allFiles, _ = map(list, zip(*dataloader.dataset.samples))

for i, (inputs, labels) in enumerate(dataloader):
	inputs = inputs.to(device)
    labels = labels.to(device)
    
    for j in range(inputs.size()[0]):
    	print("현재 파일에 대한 파일경로")
        print(allFiles[ i * <batch_size> + j ])

 

2. DataLoader에서 shuffle을 사용하고 싶을 때

 앞의 방법에서 계속 데이터로더가 shuffle=False 이어야한다고 강조했던 까닭은 데이터로더를 통해 데이터 셔플을 수행할 때, DataLoader.dataset.samples에 담긴 리스트폴더 내의 파일 순서대로 유지되어서 셔플되지 않기 때문이다.

 따라서 셔플을 수행한 데이터로더에서도 원본 파일 경로를 알고 싶다면, 데이터로더의 샘플이 파일 경로를 가지고 있도록 변형한 Custom Dataset을 구성할 필요가 있다. Custom Dataset을 구성한다는 뜻은 자신이 사용할 의도에 맞게 PyTorch에서 사용할 데이터셋을 직접 구성한다는 의미이므로, 원본 파일명도 받아올 방법이 있음을 쉽게 생각해볼 수 있다.

 Custom Image Dataset을 만들기 위해서는 우선 추상 클래스인 torch.utils.data.Dataset을 상속해야 한다. Custom Image Dataset을 위해 구현할 subclass는 주어진 인덱스에 대한 데이터 샘플을 가져오는 방법을 정의하는 __getitem__()반드시 override해야한다. 또 데이터셋의 크기를 반환하는 __len__()을 선택적으로 override할 수 있다.

 이 포스트는 간단한 Custom Dataset 구현을 위해 아래의 글을 참고하였다.

PyTorch 공식 예제 : Writing Custon Datasets, DataLoaders and Transforms
한글 번역 : 사용자 정의 Dataset, DataLoader, Transforms 작성하기

 다만 참고 포스트에서 사용한 scikit-image 라이브러리의 io 메서드 대신에 PIL의 Image를 사용하였다. 이는 사용 상의 큰 차이가 있는 것은 아니지만 내부적으로 PIL을 사용하는 PyTorch의 구현을 생각해서 바꿔보았다. 또 같은 이유로 scikit-image의 transform를 사용하고 이를 위한 transform클래스를 생성하는 것보다 torchvison.transforms를 사용하는 것이 자연스러워 transforms 모듈의 사용으로 변경하였다.

class CustomDataset(torch.utils.data.Dataset):
	def __init__(self, root_dir, transforms=None):
    	self.root_dir = root_dir
        self.classes = os.listdir(self.root_dir)
        self.transforms = transforms
        self.data = []
        self.labels = []
        
        for idx, cls in enumerate(self.classes):
        	cls_dir = os.path.join(self.root_dir, cls)
            for img in glob(os.path.join(cls_dir, '*.jpg')):
            	self.data.append(img)
                self.labels.append(idx)
                
    def __getitem__(self, idx):
    	img_path, label = self.data[idx], self.labels[idx]
        img = PIL.Image.open(img_path)
        
        if self.transforms:
        	img = self.transforms(img)
        
        sample = {'image':img, 'label':label, 'filename':img_path}
        
        return sample
    
    def __len__(self):
    	return len(self.data)

 이제 Custom Dataset 클래스로부터 데이터셋을 생성하고, 이로부터 데이터로더를 만든다면 기존의 task는 그대로 진행하면서 필요에 따라 입력 이미지의 원본 파일명을 알 수 있다.

custom_dataset = CustomDataset(root_dir=<folder_path>, transforms=<transforms>)
dataloader = torch.utils.data.DataLoader(custom_dataset, batch_size=<batch_size>,
										 shuffle = True)

for i, (sample) in enumerate(dataloader):
	inputs = sample['image'].to(device)
    labels = sample['label'].to(device)
    names = sample['filename']
    
    for j in range(inputs.size()[0]):
    	print("현재 파일에 대한 파일 경로")
        print(names[j])

 

댓글