Skip to content

Commit 869fe6c

Browse files
committed
Add DeepSpeed Example with Pytorch Operator
Signed-off-by: Syulin7 <[email protected]>
1 parent c64a5a6 commit 869fe6c

File tree

6 files changed

+931
-0
lines changed

6 files changed

+931
-0
lines changed

.github/workflows/publish-example-images.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,7 @@ jobs:
6969
platforms: linux/amd64,linux/arm64
7070
dockerfile: examples/pytorch/mnist/Dockerfile-mpi
7171
context: examples/pytorch/mnist
72+
- component-name: pytorch-deepspeed-demo
73+
platforms: linux/amd64
74+
dockerfile: examples/pytorch/deepspeed-demo/Dockerfile
75+
context: examples/pytorch/deepspeed-demo
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
FROM deepspeed/deepspeed:v072_torch112_cu117
2+
3+
RUN apt update
4+
RUN apt install -y ninja-build
5+
6+
WORKDIR /
7+
COPY requirements.txt .
8+
COPY train_bert_ds.py .
9+
10+
RUN pip install -r requirements.txt
11+
RUN mkdir -p /root/deepspeed_data
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
## Training a Masked Language Model with PyTorch and DeepSpeed
2+
3+
This folder contains an example of training a Masked Language Model with PyTorch and DeepSpeed.
4+
5+
The python script used to train BERT with PyTorch and DeepSpeed. For more information, please refer to the [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/blob/master/training/HelloDeepSpeed/README.md).
6+
7+
DeepSpeed can be deployed by different launchers such as torchrun, the deepspeed launcher, or Accelerate.
8+
See [deepspeed](https://huggingface.co/docs/transformers/main/en/deepspeed?deploy=multi-GPU&pass-config=path+to+file&multinode=torchrun#deployment).
9+
10+
This guide will show you how to deploy DeepSpeed with the `torchrun` launcher.
11+
The simplest way to quickly reproduce the following is to switch to the DeepSpeedExamples commit:
12+
```shell
13+
git clone https://github.com/microsoft/DeepSpeedExamples.git
14+
cd DeepSpeedExamples
15+
git checkout efacebb
16+
```
17+
18+
The script train_bert_ds.py is located in the DeepSpeedExamples/HelloDeepSpeed/ directory.
19+
Since the script is not launched using the deepspeed launcher, it needs to read the local_rank from the environment.
20+
The following content has been added at line 670:
21+
```
22+
local_rank = int(os.getenv('LOCAL_RANK', '-1'))
23+
```
24+
25+
### Build Image
26+
27+
The default image name and tag is `kubeflow/pytorch-deepspeed-demo:latest`.
28+
29+
```shell
30+
docker build -f Dockerfile -t kubeflow/pytorch-deepspeed-demo:latest ./
31+
```
32+
33+
### Create the PyTorchJob with DeepSpeed example
34+
35+
```shell
36+
kubectl create -f pytorch_deepspeed_demo.yaml
37+
```
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
apiVersion: "kubeflow.org/v1"
2+
kind: PyTorchJob
3+
metadata:
4+
name: pytorch-deepspeed-demo
5+
spec:
6+
pytorchReplicaSpecs:
7+
Master:
8+
replicas: 1
9+
restartPolicy: OnFailure
10+
template:
11+
spec:
12+
containers:
13+
- name: pytorch
14+
image: kubeflow/pytorch-deepspeed-demo:latest
15+
command:
16+
- torchrun
17+
- --nnodes=2
18+
- --nproc_per_node=1
19+
- /train_bert_ds.py
20+
- --checkpoint_dir
21+
- /root/deepspeed_data
22+
resources:
23+
limits:
24+
nvidia.com/gpu: 1
25+
Worker:
26+
replicas: 1
27+
restartPolicy: OnFailure
28+
template:
29+
spec:
30+
containers:
31+
- name: pytorch
32+
image: kubeflow/pytorch-deepspeed-demo:latest
33+
command:
34+
- torchrun
35+
- --nnodes=2
36+
- --nproc_per_node=1
37+
- /train_bert_ds.py
38+
- --checkpoint_dir
39+
- /root/deepspeed_data
40+
resources:
41+
limits:
42+
nvidia.com/gpu: 1
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
datasets==1.13.3
2+
transformers==4.5.1
3+
fire==0.4.0
4+
pytz==2021.1
5+
loguru==0.5.3
6+
sh==1.14.2
7+
pytest==6.2.5
8+
tqdm==4.62.3

0 commit comments

Comments
 (0)