Skip to content

Jaxlib error due to convolutions #2576

Answered by cgarciae
rdilip asked this question in Q&A
Discussion options

You must be logged in to vote

This is tricky to debug and should probably go to the JAX repo, but if it helps this is what I used recently:

FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04

RUN rm -f /etc/apt/sources.list.d/cuda.list && rm -f /etc/apt/sources.list.d/nvidia-ml.list
RUN apt-get update
RUN apt-get install -y python3 python3-pip python-is-python3 python3-tk git
RUN pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# test runtime
RUN python -c "import jax; print(jax.devices())"

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@rdilip
Comment options

@cgarciae
Comment options

Answer selected by rdilip
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants