Skip to content

Commit

Permalink
update to latest accelerate
Browse files Browse the repository at this point in the history
  • Loading branch information
teticio committed Sep 25, 2024
1 parent 15823d7 commit 5d6d837
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ taming
checkpoints
Pipfile
Pipfile.lock
.venv
19 changes: 19 additions & 0 deletions config/accelerate_multi_gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: MULTI_GPU
downcast_bf16: 'no'
dynamo_config: {}
fsdp_config: {}
gpu_ids: all
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
159 changes: 159 additions & 0 deletions requirements-lock.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
absl-py==2.1.0
accelerate==0.34.2
aiofiles==23.2.1
aiohappyeyeballs==2.4.0
aiohttp==3.10.6
aiosignal==1.3.1
altair==5.4.1
annotated-types==0.7.0
anyio==4.6.0
asttokens==2.4.1
attrs==24.2.0
audioread==3.0.1
blinker==1.8.2
cachetools==5.5.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.3.2
click==8.1.7
comm==0.2.2
contourpy==1.3.0
cycler==0.12.1
datasets==2.19.1
debugpy==1.8.6
decorator==5.1.1
diffusers==0.24.0
dill==0.3.8
executing==2.1.0
fastapi==0.115.0
ffmpy==0.4.0
filelock==3.16.1
fonttools==4.54.1
frozenlist==1.4.1
fsspec==2024.3.1
gitdb==4.0.11
GitPython==3.1.43
gradio==4.44.0
gradio_client==1.3.0
grpcio==1.66.1
h11==0.14.0
httpcore==1.0.5
httpx==0.27.2
huggingface-hub==0.25.1
idna==3.10
importlib_metadata==8.5.0
importlib_resources==6.4.5
ipykernel==6.29.5
ipython==8.27.0
jedi==0.19.1
Jinja2==3.1.4
joblib==1.4.2
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter_client==8.6.3
jupyter_core==5.7.2
kiwisolver==1.4.7
lazy_loader==0.4
librosa==0.10.2.post1
llvmlite==0.43.0
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.2
matplotlib-inline==0.1.7
mdurl==0.1.2
mpmath==1.3.0
msgpack==1.1.0
multidict==6.1.0
multiprocess==0.70.16
narwhals==1.8.3
nest-asyncio==1.6.0
networkx==3.3
numba==0.60.0
numpy==2.0.2
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.68
nvidia-nvtx-cu12==12.1.105
orjson==3.10.7
packaging==24.1
pandas==2.2.3
parso==0.8.4
pexpect==4.9.0
pillow==10.4.0
platformdirs==4.3.6
pooch==1.8.2
prompt_toolkit==3.0.48
protobuf==5.28.2
psutil==6.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pyarrow==17.0.0
pyarrow-hotfix==0.6
pycparser==2.22
pydantic==2.9.2
pydantic_core==2.23.4
pydeck==0.9.1
pydub==0.25.1
Pygments==2.18.0
pyparsing==3.1.4
python-dateutil==2.9.0.post0
python-multipart==0.0.10
pytz==2024.2
PyYAML==6.0.2
pyzmq==26.2.0
referencing==0.35.1
regex==2024.9.11
requests==2.31.0
rich==13.8.1
rpds-py==0.20.0
ruff==0.6.7
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
semantic-version==2.10.0
setuptools==75.1.0
shellingham==1.5.4
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
soundfile==0.12.1
soxr==0.5.0.post1
stack-data==0.6.3
starlette==0.38.6
streamlit==1.38.0
sympy==1.13.3
tenacity==8.5.0
tensorboard==2.17.1
tensorboard-data-server==0.7.2
threadpoolctl==3.5.0
tokenizers==0.19.1
toml==0.10.2
tomlkit==0.12.0
torch==2.4.1
torchvision==0.19.1
tornado==6.4.1
tqdm==4.66.5
traitlets==5.14.3
transformers==4.44.2
triton==3.0.0
typer==0.12.5
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
uvicorn==0.30.6
watchdog==4.0.2
wcwidth==0.2.13
websockets==12.0
Werkzeug==3.0.4
xxhash==3.5.0
yarl==1.12.1
zipp==3.20.2
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ accelerate
torchvision
transformers
requests==2.31.0
ipykernel
4 changes: 3 additions & 1 deletion scripts/train_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset, load_from_disk
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, UNet2DConditionModel, UNet2DModel
from diffusers.optimization import get_scheduler
Expand Down Expand Up @@ -40,11 +41,12 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args):
output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
logging_dir = os.path.join(output_dir, args.logging_dir)
config = ProjectConfiguration(project_dir=".", logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
project_config=config,
)

if args.dataset_name is not None:
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ install_requires =
diffusers>=0.12.0
librosa
datasets>=2.9.0
accelerate

0 comments on commit 5d6d837

Please sign in to comment.