@@ -10,8 +10,10 @@ ARG PYTORCH_BRANCH="3a585126"
10
10
ARG PYTORCH_VISION_BRANCH="v0.19.1"
11
11
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
12
12
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
13
- ARG FA_BRANCH="b7d29fb"
14
- ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
13
+ ARG FA_BRANCH="1a7f4dfa"
14
+ ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
15
+ ARG AITER_BRANCH="0508c8df"
16
+ ARG AITER_REPO="https://github.com/ROCm/aiter.git"
15
17
16
18
FROM ${BASE_IMAGE} AS base
17
19
@@ -108,11 +110,26 @@ RUN git clone ${FA_REPO}
108
110
RUN cd flash-attention \
109
111
&& git checkout ${FA_BRANCH} \
110
112
&& git submodule update --init \
111
- && MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
113
+ && GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
112
114
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
113
115
&& cp /app/vision/dist/*.whl /app/install \
114
116
&& cp /app/flash-attention/dist/*.whl /app/install
115
117
118
+ FROM base AS build_aiter
119
+ ARG AITER_BRANCH
120
+ ARG AITER_REPO
121
+ COPY requirements-rocm.txt /app
122
+ COPY requirements-common.txt /app
123
+ RUN pip install -r requirements-rocm.txt
124
+ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
125
+ pip install /install/*.whl
126
+ RUN git clone --recursive ${AITER_REPO}
127
+ RUN cd aiter \
128
+ && git checkout ${AITER_BRANCH} \
129
+ && pip install -r requirements.txt \
130
+ && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist
131
+ RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install
132
+
116
133
FROM base AS final
117
134
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
118
135
dpkg -i /install/*deb \
@@ -128,6 +145,8 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
128
145
pip install /install/*.whl
129
146
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
130
147
pip install /install/*.whl
148
+ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
149
+ pip install /install/*.whl
131
150
132
151
ARG BASE_IMAGE
133
152
ARG HIPBLASLT_BRANCH
@@ -155,4 +174,5 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
155
174
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
156
175
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
157
176
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
158
- && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
177
+ && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
178
+ && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
0 commit comments