파이썬3 노트

torchtext로 전처리하기 (1) name, dirname, urls, cls.download()

Jonchann 2019. 4. 19. 19:53

나는 주로 pytorch로 구현을 하기 때문에 전처리를 torchtext로 같이 많이 한다.

사용하면서 배운 점이나 알게된 점, 기억해야 하는 점을 몇 개에 걸쳐서 적을 것이다.




torchtext에는 데이터를 불러오기 위한 방법이 2 가지 있다.

하나는 내가 직접 path를 지정해주는 것, 두 번째는 파일을 다운로드 할 수 있도록 하는 것.

따라서 내 코드를 다른 사람이 사용하거나 내 pc가 아닌 pc로 돌릴 때 내가 갖고 있는 데이터를 같은 path대로 준비하지 않아도 사용할 수 있다는 것이다.


그러기 위해서는 name, dirname, urls라는 변수를 class 초반에 지정해 줄 필요가 있다.

예를 들면, 아래와 같다.

import torch
from torchtext.data import Dataset


class TorchText(Dataset):
    name = ''
    dirname = ''
    urls = [
         'url_1',
         'url_2',
    ]
    
    def __init__(self, datum, fields, filter_pred=None):
        ...
        super(TorchText, self).__init__(datum, fields, filter_pred)


그 후 splits라는 함수(pytorch의 nn.Module의 forward같은)에서 cls.download(path)를 이용해 urls 속 path들을 다운로드하도록 적으면 된다. 이 때 name으로 지정한 이름의 폴더 속으로 데이터가 다운로드 된다.


변수 이름이 꼭 name, dirname, urls여야 하는 이유는 torchtext의 download함수 코드를 보면 알 수 있다.

(-> https://github.com/pytorch/text/blob/master/torchtext/data/dataset.py)

먼저 name과 urls이다.

    @classmethod
    def download(cls, root, check=None):
        """Download and unzip an online archive (.zip, .gz, or .tgz).

        Arguments:
            root (str): Folder to download data to.
            check (str or None): Folder whose existence indicates
                that the dataset has already been downloaded, or
                None to check the existence of root/{cls.name}.

        Returns:
            str: Path to extracted dataset.
        """
        path = os.path.join(root, cls.name)
        check = path if check is None else check
        if not os.path.isdir(check):
            for url in cls.urls:
                if isinstance(url, tuple):
                    url, filename = url
                else:
                    filename = os.path.basename(url)
                zpath = os.path.join(path, filename)
                if not os.path.isfile(zpath):
                    if not os.path.exists(os.path.dirname(zpath)):
                        os.makedirs(os.path.dirname(zpath))
                   ...

        return os.path.join(path, cls.dirname)

path = 부분에서 cls.name을 root와 묶어서 폴더 path를 생성하기 때문에 우리도 변수의 이름을 name이라 해야 한다. dirname의 경우에는 내가 제대로 이해하고 있다면 name이 지정한 폴더의 상위 폴더를 지칭하고 있는 것이기 때문에 더이상 폴더를 만들고 싶지 않다면 ''로 놔두고 splits()함수 속에 지정하는 path를 상위폴더로 사용하면 된다.

이와 마찬가지로 cls.urls에서 url을 가져오도록 하고 있기 때문에 우리는 다운로드 해야 하는 데이터 path를 class안에 반드시 urls라는 리스트로 묶어두어야 한다.