본문 바로가기

컴퓨터/머신러닝 (Machine Learning)

Ubuntu, ROCm, AMD GPU, Docker, Tensorflow, 환경에서 JAX 세팅 정리

Ubuntu 22.04

Radeon RX 6800

 

google-research에서 공개한 maxim 모델을 구동하려니

Tensorflow 기반의 JAX라는 GPU에 dependent한 연산 라이브러리 환경이 필요했다.

Docker 기반 ROCm에 동작하는 JAX 환경 구축 정리

 

도커와 ROCm이 설치된 것을 가정

 

ROCm 개발자들이 도커에서 JAX를 build를 업데이트 하고 command을 안내하였지만

그대로 돌리면 build가 완료되지 않는다.(https://github.com/ROCmSoftwarePlatform/jax/tree/main/build/rocm)

디버깅 내용을 정리

 

 

1. git clone

 

git clone https://github.com/ROCmSoftwarePlatform/jax.git

 

2. clone한 jax 폴더 내부, build/rocm/build_rocm.sh 를 아래와 같이 수정

 

set -eux

ROCM_TF_FORK_REPO="https://github.com/ROCmSoftwarePlatform/tensorflow-upstream"
ROCM_TF_FORK_BRANCH="develop-upstream"
rm -rf /tmp/tensorflow-upstream || true
git clone -b ${ROCM_TF_FORK_BRANCH} ${ROCM_TF_FORK_REPO} /tmp/tensorflow-upstream
if [ ! -v TENSORFLOW_ROCM_COMMIT ]; then
    echo "The TENSORFLOW_ROCM_COMMIT environment variable is not set, using top of branch"
elif [ ! -z "$TENSORFLOW_ROCM_COMMIT" ]
then
      echo "Using tensorflow-rocm at commit: $TENSORFLOW_ROCM_COMMIT"
      cd /tmp/tensorflow-upstream
      git checkout $TENSORFLOW_ROCM_COMMIT
      cd -
fi

# 아래 세 줄 모두 주석처리
#python3 ./build/build.py --enable_rocm --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=org_tensorflow=/tmp/tensorflow-upstream
#pip3 install --force-reinstall dist/*.whl  # installs jaxlib (includes XLA)
#pip3 install --force-reinstall .  # installs jax

 

3. jax 폴더에서 docker build 명령어 실행

 

./build/rocm/ci_build.sh --keep_image bash -c "./build/rocm/build_rocm.sh"

 

jax_ci.rocm 이라는 docker images가 생성

 

4. 컨테이너 실행

 

sudo docker run -it --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --entrypoint /bin/bash jax_ci.rocm:latest

 

5. 컨터이너 내부에 ROCm jax git 과  ROCm tensorflow upstream git clone

 

git clone https://github.com/ROCmSoftwarePlatform/jax.git
git clone https://github.com/ROCmSoftwarePlatform/tensorflow-upstream.git

 

6. jax 폴더 내부에서 jax build 명령어 실행

 

# /opt 폴더에서 rocm 버전 확인
# clone 한 tensorflow-upstream 경로 확인
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.3.0 \
  --bazel_options=--override_repository=org_tensorflow=/tensorflow-upstream

 

시간이 좀 걸릴 수 있다.

 

7. 빌드를 마치후 jaxlib, jax 설치

 

pip3 install --force-reinstall dist/*.whl  # installs jaxlib (includes XLA)
pip3 install --force-reinstall .  # installs jax

 

8. build test (optional)

 

./build/rocm/run_single_gpu

 

시간이 좀 걸림

 

9. 이후 tensorflow를 설치하여 사용

 

pip3 install tensorflow-rocm

 

 

참고글

https://github.com/google-research/maxim

 

GitHub - google-research/maxim: [CVPR 2022 Oral] Official repository for "MAXIM: Multi-Axis MLP for Image Processing". SOTA for

[CVPR 2022 Oral] Official repository for "MAXIM: Multi-Axis MLP for Image Processing". SOTA for denoising, deblurring, deraining, dehazing, and enhancement. - GitHub - google-research/max...

github.com

 

https://github.com/google/jax

 

GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, ve...

github.com

 

https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus

 

Building from source — JAX documentation

 

jax.readthedocs.io

 

https://github.com/ROCmSoftwarePlatform/jax

 

GitHub - ROCmSoftwarePlatform/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - GitHub - ROCmSoftwarePlatform/jax: Composable transformations of Python+NumPy programs: dif...

github.com