Skip to content

Commit

Permalink
merged
Browse files Browse the repository at this point in the history
  • Loading branch information
PVirie committed Jan 11, 2025
2 parents e604095 + b9bf4bd commit 7dfa095
Show file tree
Hide file tree
Showing 16 changed files with 927 additions and 578 deletions.
40 changes: 36 additions & 4 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
],
"connect": {
"host": "127.0.0.1",
"port": 5678
"port": 43690
}
},
{
Expand Down Expand Up @@ -53,7 +53,7 @@
],
"connect": {
"host": "127.0.0.1",
"port": 5678
"port": 43690
}
},
{
Expand Down Expand Up @@ -81,7 +81,39 @@
],
"connect": {
"host": "127.0.0.1",
"port": 5678
"port": 43690
}
},
{
"name": "torch-gpu",
"type": "debugpy",
"request": "attach",
"preLaunchTask": "Current file: torch-gpu",
"pathMappings": [
{
"localRoot": "${workspaceFolder}/core",
"remoteRoot": "/app/core"
},
{
"localRoot": "${workspaceFolder}/jax_onehot",
"remoteRoot": "/app/jax_onehot"
},
{
"localRoot": "${workspaceFolder}/llm",
"remoteRoot": "/app/llm"
},
{
"localRoot": "${workspaceFolder}/humn",
"remoteRoot": "/app/humn"
},
{
"localRoot": "${workspaceFolder}/tasks",
"remoteRoot": "/app/tasks"
}
],
"connect": {
"host": "127.0.0.1",
"port": 43690
}
},
{
Expand Down Expand Up @@ -113,7 +145,7 @@
],
"connect": {
"host": "127.0.0.1",
"port": 5678
"port": 43690
}
}
]
Expand Down
38 changes: 34 additions & 4 deletions .vscode/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
{
"label": "Current file: clear (internal)",
"type": "shell",
"command": "docker compose -f docker_compose.yaml --profile jax-cpu run --rm -d --build --service-ports jax-cpu-service python3 -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${relativeFile} --clear",
"command": "docker compose -f docker_compose.yaml --profile jax-cpu run --rm -d --build --service-ports jax-cpu-service python3 -m debugpy --listen 0.0.0.0:43690 --wait-for-client ${relativeFile} --clear",
"dependsOn": ["docker-compose-stop: cpu"],
"isBackground": true,
"problemMatcher": [
Expand All @@ -34,7 +34,7 @@
{
"label": "Current file: cpu (internal)",
"type": "shell",
"command": "docker compose -f docker_compose.yaml --profile jax-cpu run --rm -d --build --service-ports jax-cpu-service python3 -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${relativeFile}",
"command": "docker compose -f docker_compose.yaml --profile jax-cpu run --rm -d --build --service-ports jax-cpu-service python3 -m debugpy --listen 0.0.0.0:43690 --wait-for-client ${relativeFile}",
"dependsOn": ["docker-compose-stop: cpu"],
"isBackground": true,
"problemMatcher": [
Expand Down Expand Up @@ -64,7 +64,7 @@
{
"label": "Current file: gpu (internal)",
"type": "shell",
"command": "docker compose -f docker_compose.yaml --profile jax-gpu run --rm -d --build --service-ports jax-gpu-service python3 -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${relativeFile}",
"command": "docker compose -f docker_compose.yaml --profile jax-gpu run --rm -d --build --service-ports jax-gpu-service python3 -m debugpy --listen 0.0.0.0:43690 --wait-for-client ${relativeFile}",
"dependsOn": ["docker-compose-stop: gpu"],
"isBackground": true,
"problemMatcher": [
Expand Down Expand Up @@ -94,7 +94,7 @@
{
"label": "Current file: torch-cpu (internal)",
"type": "shell",
"command": "docker compose -f docker_compose.yaml --profile torch-cpu run --rm -d --build --service-ports torch-cpu-service python3 -m debugpy --listen 0.0.0.0:5678 --wait-for-client ${relativeFile}",
"command": "docker compose -f docker_compose.yaml --profile torch-cpu run --rm -d --build --service-ports torch-cpu-service python3 -m debugpy --listen 0.0.0.0:43690 --wait-for-client ${relativeFile}",
"dependsOn": ["docker-compose-stop: torch-cpu"],
"isBackground": true,
"problemMatcher": [
Expand All @@ -114,6 +114,36 @@
"command": "sleep 3",
"dependsOn": ["Current file: torch-cpu (internal)"],
"isBackground": false
},
{
"label": "docker-compose-stop: torch-gpu",
"type": "shell",
"command": "docker compose -f docker_compose.yaml down",
"isBackground": true
},
{
"label": "Current file: torch-gpu (internal)",
"type": "shell",
"command": "docker compose -f docker_compose.yaml --profile torch-gpu run --rm -d --build --service-ports torch-gpu-service python3 -m debugpy --listen 0.0.0.0:43690 --wait-for-client ${relativeFile}",
"dependsOn": ["docker-compose-stop: torch-gpu"],
"isBackground": true,
"problemMatcher": [
{
"pattern": [{ "regexp": ".", "file": 1, "location": 2, "message": 3 }],
"background": {
"activeOnStart": true,
"beginsPattern": "^(Building py-service)$",
"endsPattern": "^(Creating|Recreating|Starting) (py-container) ... (done)$"
}
}
]
},
{
"label": "Current file: torch-gpu",
"type": "shell",
"command": "sleep 3",
"dependsOn": ["Current file: torch-gpu (internal)"],
"isBackground": false
}
]
}
20 changes: 0 additions & 20 deletions Dockerfile

This file was deleted.

9 changes: 5 additions & 4 deletions Dockerfile.nogpu → Dockerfile.jax-cpu
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
FROM python:3.10
FROM ubuntu:24.04

RUN apt update && apt install python3-pip git python3-venv -y
RUN apt install -y libgl1-mesa-glx libosmesa6
RUN apt update
RUN apt install -y python3 python3-pip git python3-venv
RUN apt install -y libosmesa6-dev

# create virtual environment, the correct way https://pythonspeed.com/articles/activate-virtualenv-dockerfile/
ENV VIRTUAL_ENV=/app/venv
Expand All @@ -16,4 +17,4 @@ RUN pip3 install --upgrade jax[cpu]==0.4.31 flax clu ott-jax
COPY ./requirements.txt /app/requirements.txt
RUN pip3 install --no-cache-dir -r requirements.txt

CMD ["python3", "-m", "debugpy", "--listen", "0.0.0.0:5678", "tasks/benchmark.py"]
CMD ["python3", "-m", "debugpy", "--listen", "0.0.0.0:43690", "tasks/benchmark.py"]
20 changes: 20 additions & 0 deletions Dockerfile.jax-gpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
FROM ubuntu:24.04

RUN apt update
RUN apt install -y python3 python3-pip git python3-venv
RUN apt install -y libosmesa6-dev

# create virtual environment, the correct way https://pythonspeed.com/articles/activate-virtualenv-dockerfile/
ENV VIRTUAL_ENV=/app/venv
RUN python3 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"

WORKDIR /app

# install jax with cuda support
RUN pip3 install --upgrade pip
RUN pip3 install --upgrade jax[cuda12]==0.4.38 flax ott-jax
COPY ./requirements.txt /app/requirements.txt
RUN pip3 install --no-cache-dir --upgrade -r requirements.txt

CMD ["python3", "-m", "debugpy", "--listen", "0.0.0.0:43690", "tasks/benchmark.py"]
7 changes: 4 additions & 3 deletions Dockerfile.pure_python → Dockerfile.pure-python
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
FROM python:3.10
FROM ubuntu:24.04

RUN apt update && apt install python3-pip git python3-venv -y
RUN apt update
RUN apt install -y python3 python3-pip git python3-venv

# create virtual environment, the correct way https://pythonspeed.com/articles/activate-virtualenv-dockerfile/
ENV VIRTUAL_ENV=/app/venv
Expand All @@ -14,4 +15,4 @@ RUN pip3 install --upgrade pip
COPY ./requirements.txt /app/requirements.txt
RUN pip3 install --no-cache-dir -r requirements.txt

CMD ["python3", "-m", "debugpy", "--listen", "0.0.0.0:5678", "tasks/benchmark.py"]
CMD ["python3", "-m", "debugpy", "--listen", "0.0.0.0:43690", "tasks/benchmark.py"]
7 changes: 4 additions & 3 deletions Dockerfile.nogpu_torch → Dockerfile.torch-cpu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
FROM python:3.10
FROM ubuntu:24.04

RUN apt update && apt install python3-pip git python3-venv -y
RUN apt update
RUN apt install -y python3 python3-pip git python3-venv

# create virtual environment, the correct way https://pythonspeed.com/articles/activate-virtualenv-dockerfile/
ENV VIRTUAL_ENV=/app/venv
Expand All @@ -16,4 +17,4 @@ RUN pip3 install openai sentencepiece transformers==4.44.2
COPY ./requirements.txt /app/requirements.txt
RUN pip3 install --no-cache-dir -r requirements.txt

CMD ["python3", "-m", "debugpy", "--listen", "0.0.0.0:5678", "tasks/benchmark.py"]
CMD ["python3", "-m", "debugpy", "--listen", "0.0.0.0:43690", "tasks/benchmark.py"]
20 changes: 20 additions & 0 deletions Dockerfile.torch-gpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
FROM ubuntu:24.04

RUN apt update
RUN apt install -y python3 python3-pip git python3-venv

# create virtual environment, the correct way https://pythonspeed.com/articles/activate-virtualenv-dockerfile/
ENV VIRTUAL_ENV=/app/venv
RUN python3 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"

WORKDIR /app

# install torch with cuda support
RUN pip3 install --upgrade pip
RUN pip3 install torch torchvision torchaudio
RUN pip3 install openai sentencepiece transformers==4.44.2
COPY ./requirements.txt /app/requirements.txt
RUN pip3 install --no-cache-dir -r requirements.txt

CMD ["python3", "-m", "debugpy", "--listen", "0.0.0.0:43690", "tasks/benchmark.py"]
25 changes: 14 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,21 @@ If you want to purely install the python code, you can follow the steps in the d
- launch `jax-cpu` for jax in cpu environment.
- launch `jax-gpu` for jax in gpu environment.
- launch `torch-cpu` for torch in cpu environment
- launch `torch-gpu` for torch in gpu environment.
- Running on Windows
- The relative path in Windows that passes to docker has invalid path separators. _Always use POSIX path separators_ when passing `{path to file}` parameter when running `run_manual.sh` script. Or simply create a new configuration in `.vscode/launch.json` with the hard coded configuration you wish to run with the POSIX path separators.
| Experiment | Task | Description | Valid configurations (pick one) | File (--flags) | Required env vars |
| ------------------ | ------------------ | -------------------------------------------------------------------------------------- | ------------------------------- | ------------------------------- | ---------------------------- |
| **Simple graph** | Train and test | Train the model to learn simple graph tasks. | `jax-gpu`, `jax-cpu` | `tasks/simple.py` | - |
| | Clear weight | Clear the weight in the model. (Or simply delete the weight direction in `./artifacts` | `jax-gpu`, `jax-cpu` | `tasks/simple.py --clear` | - |
| **RL: cart pole** | Train and test | Train the model to learn to control the cart pole. | `jax-gpu`, `jax-cpu` | `tasks/rl_cart_pole.py` | - |
| | Clear weight | Clear the weight in the model. (Or simply delete the weight direction in `./artifacts` | `jax-gpu`, `jax-cpu` | `tasks/rl_cart_pole.py --clear` | - |
| **Language model** | Prepare | Prepare data for the the language model hierarchical guide model. | `torch-cpu` | `tasks/lm_data_prepare.py` | `HF_TOKEN`, `OPENAI_API_KEY` |
| | Train hierarchy | Train the language model hierarchical guide model. | `jax-gpu`, `jax-cpu` | `tasks/lm_guide_train.py` | - |
| | Generate hierarchy | Generate the language model hierarchical guide model. | `jax-gpu`, `jax-cpu` | `tasks/lm_guide_inference.py` | - |
| | Interpret | Given the hierarchy guide, print out the text generation. | `torch-cpu` | `tasks/lm_data_interpret` | `HF_TOKEN` |
| Experiment | Task | Description | Valid configurations (pick one) | File (--flags) | Required env vars |
| ------------------ | ------------------ | -------------------------------------------------------------------------------------- | ---------------------------------------------- | ------------------------------- | ---------------------------- |
| **Benchmark** | Benchmark devices | Run the benchmark to compare the performance of the devices. | `jax-gpu`, `jax-cpu`, `torch-cpu`, `torch-gpu` | `tasks/benchmark.py` | - |
| **Simple graph** | Train and test | Train the model to learn simple graph tasks. | `jax-gpu`, `jax-cpu` | `tasks/simple.py` | - |
| | Clear weight | Clear the weight in the model. (Or simply delete the weight direction in `./artifacts` | `jax-gpu`, `jax-cpu` | `tasks/simple.py --clear` | - |
| **RL: cart pole** | Train and test | Train the model to learn to control the cart pole. | `jax-gpu`, `jax-cpu` | `tasks/rl_cart_pole.py` | - |
| | Clear weight | Clear the weight in the model. (Or simply delete the weight direction in `./artifacts` | `jax-gpu`, `jax-cpu` | `tasks/rl_cart_pole.py --clear` | - |
| **Language model** | Prepare | Prepare data for the the language model hierarchical guide model. | `torch-cpu` | `tasks/lm_data_prepare.py` | `HF_TOKEN`, `OPENAI_API_KEY` |
| | Train hierarchy | Train the language model hierarchical guide model. | `jax-gpu`, `jax-cpu` | `tasks/lm_guide_train.py` | - |
| | Generate hierarchy | Generate the language model hierarchical guide model. | `jax-gpu`, `jax-cpu` | `tasks/lm_guide_inference.py` | - |
| | Interpret | Given the hierarchy guide, print out the text generation. | `torch-cpu` | `tasks/lm_data_interpret` | `HF_TOKEN` |
## To do
Expand Down Expand Up @@ -101,5 +103,6 @@ If you want to purely install the python code, you can follow the steps in the d
### Code
- [ ] Interruptible training
- [x] Interruptible training
- [x] Torch GPU
- [ ] Use flax nnx
Loading

0 comments on commit 7dfa095

Please sign in to comment.