본문 바로가기

컴퓨터/머신러닝 (Machine Learning)

Super Resolution EDT 사용하기

Super resolution 모델 중 하나인 EDT를 사용한 내용을 정리한다.

 

1. 깃 repository 다운로드 및 package 설치

 

EDT 모델은 Python 3.7 버전 이상, pytorch 1.4 버전 이상에서 돌아간다. (적절하게 설치한다)

git clone https://github.com/fenglinglwb/EDT.git

를 입력하여 git 다운로드 후,

다운받은 EDT 폴더에서 requirement 에 있는 package 들을 설치한다.

pip install -r requirements.txt

 

2. test_sample.py 코드 수정

 

EDT 모델은 test_sample.py 코드를 돌려서 사용가능한데,

git에 업로드된 코드에 오류가 있어 수정을 해야한다.

nano test_sample.py

를 입력하여 코드를 뜯어서

134번째 줄에 img_name = f.split('/')[-1] 로 입력된 부분을

img_name = lq_path.split('/')[-1] 로 수정한다. 그리고 저장.

 

 

3. pre-trained 모델 저장

 

공식 깃헙 (GitHub - fenglinglwb/EDT: On Efficient Transformer-Based Image Pre-training for Low-Level Vision) 의

3번 설명에 잘 나와있는데, EDT를 활용할 각 task (Super-resolution, Denoising, Deraining)에 따라,

 

 

one drive 링크에 가서 모델을 다운 받는다.

이후 EDT 폴더 안에 pretrained 폴더를 하나 만든 뒤 다운받은 모델을 넣어준다.

 

4. 모델 돌리기

 

모델을 돌리기 위한 코드의 기본 포멧은 아래와 같다.

python test_sample.py --config config_path --model model_path --input input_folder --output output_folder

각 파라미터들을 설명하자면,

--config 는 3번에서 저장한 pre-trained 모델과 같은 이름의 .py 파일 경로

--model 에는 3번에서 저장한 pre-trained 모델의 경로

--input EDT를 돌릴 데이터가 저장된 (사진이 담긴) 폴더의 경로

--output EDT 작업이 처리된 데이터가 저장될 폴더의 경로

 

공식 깃헙에서 제공하는 예제 코드는 아래와 같다.

python test_sample.py --config configs/SRx2_EDTT_Div2kFlickr2K.py --model pretrained/SRx2_EDTT_Div2kFlickr2K.pth --input test_sets/SR/Set5/LR/x2

 

5. Jupyter 에서 EDT 작업 loop 돌리기

 

여러 폴더에 구분된 사진들을 EDT에 폴더별로 처리하는 코드이다.

folder_list = ['folder_1', 'folder_2', 'folder_3']

for folder in folder_list:
    config_path = 'configs/SRx2_EDTB_Div2kFlickr2K__SRx2_ImageNetFull.py'
    model_path = 'pretrained/SRx2_EDTB_Div2kFlickr2K__SRx2_ImageNetFull.pth'
    input_path = '/Data/data/data_gsun/inference_SR/EDT/' + folder
    output_path = '/Data/data/data_gsun/results_SR/EDT/' + folder
    
    !python3 test_sample.py --config {config_path} --model {model_path} --input {input_path} --output {output_path}

jupyter에서 bash 명령어를 실행하기 위해서는 !를 붙여준다.

파이썬 argument는 bracket ({) 으로 연결한다.