diff --git a/.gitignore b/.gitignore index 0fb21d1..85ac48c 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ taming checkpoints Pipfile Pipfile.lock +.venv diff --git a/config/accelerate_multi_gpu.yaml b/config/accelerate_multi_gpu.yaml new file mode 100644 index 0000000..a55899c --- /dev/null +++ b/config/accelerate_multi_gpu.yaml @@ -0,0 +1,19 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: MULTI_GPU +downcast_bf16: 'no' +dynamo_config: {} +fsdp_config: {} +gpu_ids: all +machine_rank: 0 +main_training_function: main +megatron_lm_config: {} +mixed_precision: 'no' +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/requirements-lock.txt b/requirements-lock.txt new file mode 100644 index 0000000..f908298 --- /dev/null +++ b/requirements-lock.txt @@ -0,0 +1,159 @@ +absl-py==2.1.0 +accelerate==0.34.2 +aiofiles==23.2.1 +aiohappyeyeballs==2.4.0 +aiohttp==3.10.6 +aiosignal==1.3.1 +altair==5.4.1 +annotated-types==0.7.0 +anyio==4.6.0 +asttokens==2.4.1 +attrs==24.2.0 +audioread==3.0.1 +blinker==1.8.2 +cachetools==5.5.0 +certifi==2024.8.30 +cffi==1.17.1 +charset-normalizer==3.3.2 +click==8.1.7 +comm==0.2.2 +contourpy==1.3.0 +cycler==0.12.1 +datasets==2.19.1 +debugpy==1.8.6 +decorator==5.1.1 +diffusers==0.24.0 +dill==0.3.8 +executing==2.1.0 +fastapi==0.115.0 +ffmpy==0.4.0 +filelock==3.16.1 +fonttools==4.54.1 +frozenlist==1.4.1 +fsspec==2024.3.1 +gitdb==4.0.11 +GitPython==3.1.43 +gradio==4.44.0 +gradio_client==1.3.0 +grpcio==1.66.1 +h11==0.14.0 +httpcore==1.0.5 +httpx==0.27.2 +huggingface-hub==0.25.1 +idna==3.10 +importlib_metadata==8.5.0 +importlib_resources==6.4.5 +ipykernel==6.29.5 +ipython==8.27.0 +jedi==0.19.1 +Jinja2==3.1.4 +joblib==1.4.2 +jsonschema==4.23.0 +jsonschema-specifications==2023.12.1 +jupyter_client==8.6.3 +jupyter_core==5.7.2 +kiwisolver==1.4.7 +lazy_loader==0.4 +librosa==0.10.2.post1 +llvmlite==0.43.0 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.9.2 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +mpmath==1.3.0 +msgpack==1.1.0 +multidict==6.1.0 +multiprocess==0.70.16 +narwhals==1.8.3 +nest-asyncio==1.6.0 +networkx==3.3 +numba==0.60.0 +numpy==2.0.2 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvtx-cu12==12.1.105 +orjson==3.10.7 +packaging==24.1 +pandas==2.2.3 +parso==0.8.4 +pexpect==4.9.0 +pillow==10.4.0 +platformdirs==4.3.6 +pooch==1.8.2 +prompt_toolkit==3.0.48 +protobuf==5.28.2 +psutil==6.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pyarrow==17.0.0 +pyarrow-hotfix==0.6 +pycparser==2.22 +pydantic==2.9.2 +pydantic_core==2.23.4 +pydeck==0.9.1 +pydub==0.25.1 +Pygments==2.18.0 +pyparsing==3.1.4 +python-dateutil==2.9.0.post0 +python-multipart==0.0.10 +pytz==2024.2 +PyYAML==6.0.2 +pyzmq==26.2.0 +referencing==0.35.1 +regex==2024.9.11 +requests==2.31.0 +rich==13.8.1 +rpds-py==0.20.0 +ruff==0.6.7 +safetensors==0.4.5 +scikit-learn==1.5.2 +scipy==1.14.1 +semantic-version==2.10.0 +setuptools==75.1.0 +shellingham==1.5.4 +six==1.16.0 +smmap==5.0.1 +sniffio==1.3.1 +soundfile==0.12.1 +soxr==0.5.0.post1 +stack-data==0.6.3 +starlette==0.38.6 +streamlit==1.38.0 +sympy==1.13.3 +tenacity==8.5.0 +tensorboard==2.17.1 +tensorboard-data-server==0.7.2 +threadpoolctl==3.5.0 +tokenizers==0.19.1 +toml==0.10.2 +tomlkit==0.12.0 +torch==2.4.1 +torchvision==0.19.1 +tornado==6.4.1 +tqdm==4.66.5 +traitlets==5.14.3 +transformers==4.44.2 +triton==3.0.0 +typer==0.12.5 +typing_extensions==4.12.2 +tzdata==2024.2 +urllib3==2.2.3 +uvicorn==0.30.6 +watchdog==4.0.2 +wcwidth==0.2.13 +websockets==12.0 +Werkzeug==3.0.4 +xxhash==3.5.0 +yarl==1.12.1 +zipp==3.20.2 diff --git a/requirements.txt b/requirements.txt index 76690a2..1858231 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ accelerate torchvision transformers requests==2.31.0 +ipykernel diff --git a/scripts/train_unet.py b/scripts/train_unet.py index 319add7..8dba4a9 100644 --- a/scripts/train_unet.py +++ b/scripts/train_unet.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from accelerate import Accelerator from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration from datasets import load_dataset, load_from_disk from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, UNet2DConditionModel, UNet2DModel from diffusers.optimization import get_scheduler @@ -40,11 +41,12 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: def main(args): output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir logging_dir = os.path.join(output_dir, args.logging_dir) + config = ProjectConfiguration(project_dir=".", logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with="tensorboard", - logging_dir=logging_dir, + project_config=config, ) if args.dataset_name is not None: diff --git a/setup.cfg b/setup.cfg index f45d8f2..9df908d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,3 +18,4 @@ install_requires = diffusers>=0.12.0 librosa datasets>=2.9.0 + accelerate