Ubuntu 22.04
Radeon RX 6800
google-research에서 공개한 maxim 모델을 구동하려니
Tensorflow 기반의 JAX라는 GPU에 dependent한 연산 라이브러리 환경이 필요했다.
Docker 기반 ROCm에 동작하는 JAX 환경 구축 정리
도커와 ROCm이 설치된 것을 가정
ROCm 개발자들이 도커에서 JAX를 build를 업데이트 하고 command을 안내하였지만
그대로 돌리면 build가 완료되지 않는다.(https://github.com/ROCmSoftwarePlatform/jax/tree/main/build/rocm)
디버깅 내용을 정리
1. git clone
git clone https://github.com/ROCmSoftwarePlatform/jax.git
2. clone한 jax 폴더 내부, build/rocm/build_rocm.sh 를 아래와 같이 수정
set -eux
ROCM_TF_FORK_REPO="https://github.com/ROCmSoftwarePlatform/tensorflow-upstream"
ROCM_TF_FORK_BRANCH="develop-upstream"
rm -rf /tmp/tensorflow-upstream || true
git clone -b ${ROCM_TF_FORK_BRANCH} ${ROCM_TF_FORK_REPO} /tmp/tensorflow-upstream
if [ ! -v TENSORFLOW_ROCM_COMMIT ]; then
echo "The TENSORFLOW_ROCM_COMMIT environment variable is not set, using top of branch"
elif [ ! -z "$TENSORFLOW_ROCM_COMMIT" ]
then
echo "Using tensorflow-rocm at commit: $TENSORFLOW_ROCM_COMMIT"
cd /tmp/tensorflow-upstream
git checkout $TENSORFLOW_ROCM_COMMIT
cd -
fi
# 아래 세 줄 모두 주석처리
#python3 ./build/build.py --enable_rocm --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=org_tensorflow=/tmp/tensorflow-upstream
#pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA)
#pip3 install --force-reinstall . # installs jax
3. jax 폴더에서 docker build 명령어 실행
./build/rocm/ci_build.sh --keep_image bash -c "./build/rocm/build_rocm.sh"
jax_ci.rocm 이라는 docker images가 생성
4. 컨테이너 실행
sudo docker run -it --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --entrypoint /bin/bash jax_ci.rocm:latest
5. 컨터이너 내부에 ROCm jax git 과 ROCm tensorflow upstream git clone
git clone https://github.com/ROCmSoftwarePlatform/jax.git
git clone https://github.com/ROCmSoftwarePlatform/tensorflow-upstream.git
6. jax 폴더 내부에서 jax build 명령어 실행
# /opt 폴더에서 rocm 버전 확인
# clone 한 tensorflow-upstream 경로 확인
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.3.0 \
--bazel_options=--override_repository=org_tensorflow=/tensorflow-upstream
시간이 좀 걸릴 수 있다.
7. 빌드를 마치후 jaxlib, jax 설치
pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA)
pip3 install --force-reinstall . # installs jax
8. build test (optional)
./build/rocm/run_single_gpu
시간이 좀 걸림
9. 이후 tensorflow를 설치하여 사용
pip3 install tensorflow-rocm
끝
참고글
https://github.com/google-research/maxim
GitHub - google-research/maxim: [CVPR 2022 Oral] Official repository for "MAXIM: Multi-Axis MLP for Image Processing". SOTA for
[CVPR 2022 Oral] Official repository for "MAXIM: Multi-Axis MLP for Image Processing". SOTA for denoising, deblurring, deraining, dehazing, and enhancement. - GitHub - google-research/max...
github.com
GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, ve...
github.com
Building from source — JAX documentation
jax.readthedocs.io
https://github.com/ROCmSoftwarePlatform/jax
GitHub - ROCmSoftwarePlatform/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - GitHub - ROCmSoftwarePlatform/jax: Composable transformations of Python+NumPy programs: dif...
github.com
'컴퓨터 > 머신러닝 (Machine Learning)' 카테고리의 다른 글
하드 디스크 병목으로 인한 CPU 사용량 저하 확인 (0) | 2024.05.23 |
---|---|
Yolov9 Jupyter에서 돌려보기 (1) | 2024.05.15 |
Pytorch distributed launch watchdog timeout 에러 해결 (0) | 2022.12.27 |
Super resolution 모델, HAT train 정리 (0) | 2022.12.26 |
AMD GPU MIGraphX docker 사용 정리 (0) | 2022.12.22 |