Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve throughput of computing embeddings with BetterTransformer #15

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

oliverholworthy
Copy link
Member

Improve throughput of computing embeddings with BetterTransformer in SentenceTransformerModel.

This improved throughput by about 6x for me processing the FIQA BEIR dataset (57k documents)

@oliverholworthy oliverholworthy added the enhancement New feature or request label Oct 24, 2023
@oliverholworthy oliverholworthy self-assigned this Oct 24, 2023
@edknv
Copy link
Contributor

edknv commented Oct 24, 2023

I tested this out (in the multi-gpu case), and I'm not seeing any improvements with BetterTransformer. I'm wondering if it's because of the warnings I'm seeing in the logs:

The BetterTransformer implementation does not support padding during training, as the fused kernels do not support attention masks. Beware that passing padded batched data during training may result in unexpected outputs. Please refer to https://huggingface.co/docs/optimum/bettertransformer/overview for more details.

I'm running

if __name__ == "__main__":
    torch_mem = 40
    model_name = "sentence-transformers/all-MiniLM-L6-v2"
    dataset = "quora"

    start = time.time()

    model = cf.SentenceTransformerModel(model_name, max_mem_gb=torch_mem)

    with cf.Distributed(rmm_pool_size=f"{torch_mem}GB", n_workers=2):

        cf.embed(
            dataset,
            model=model,
            vector_search=False,
            sorted_data_loader=True,
            overwrite=True,
        )

    print("total time", time.time() - start)

With BetterTransformer (this PR):

Embedding quora item (10 parts)...
The BetterTransformer implementation does not support padding during training, as the fused kernels do not support attention masks. Beware that passing padded batched data during training may result in unexpected outputs. Please refer to https://huggingface.co/docs/optimum/bettertransformer/overview for more details.
                                                                                                                                                                       The BetterTransformer implementation does not support padding during training, as the fused kernels do not support attention masks. Beware that passing padded batched data during training may result in unexpected outputs. Please refer to https://huggingface.co/docs/optimum/bettertransformer/overview for more details.
GPU: 1, Part: 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:04<00:00, 12607.24it/s]
GPU: 0, Part: 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:04<00:00, 11390.25it/s]
GPU: 1, Part: 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:03<00:00, 14565.94it/s]
GPU: 0, Part: 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:04<00:00, 12687.12it/s]
GPU: 1, Part: 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:04<00:00, 12502.03it/s]
GPU: 0, Part: 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:03<00:00, 13335.50it/s]
GPU: 1, Part: 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:03<00:00, 15409.16it/s]
GPU: 0, Part: 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:03<00:00, 14536.23it/s]
GPU: 1, Part: 8: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:03<00:00, 15349.98it/s]
GPU: 0, Part: 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52294/52294 [00:03<00:00, 15842.22it/s]
total time 37.14634895324707

Without BetterTransformer (main):

Embedding quora item (10 parts)...
GPU: 1, Part: 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:02<00:00, 21401.21it/s]
GPU: 0, Part: 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:02<00:00, 20236.92it/s]
GPU: 1, Part: 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:02<00:00, 21444.10it/s]
GPU: 0, Part: 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:02<00:00, 19240.12it/s]
GPU: 1, Part: 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:02<00:00, 21451.30it/s]
GPU: 0, Part: 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:02<00:00, 20436.06it/s]
GPU: 1, Part: 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:02<00:00, 20392.74it/s]
GPU: 0, Part: 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:02<00:00, 20567.16it/s]
GPU: 0, Part: 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52294/52294 [00:02<00:00, 20359.12it/s]
GPU: 1, Part: 8: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 52293/52293 [00:03<00:00, 16523.75it/s]
total time 30.128920316696167

@oliverholworthy
Copy link
Member Author

not seeing any improvements with BetterTransformer

that could be because the Flash Attention kernel is not available in your environment.

Try runnning the following which explictly activates the flash attention kernel. If you get the error RuntimeError: No available kernel , you might need to update your environment to a more recent version of torch.

import torch
from transformers import AutoModel, AutoTokenizer
from optimum.bettertransformer import BetterTransformer

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2").to("cuda")
# convert the model to BetterTransformer
model = BetterTransformer.transform(model.to(torch.float16))

input_text = "Example sentence"
inputs = {k: v.cuda() for k, v in tokenizer(input_text, return_tensors="pt").items()}

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    outputs = model(**inputs)

@edknv
Copy link
Contributor

edknv commented Oct 24, 2023

Try runnning the following which explictly activates the flash attention kernel. If you get the error RuntimeError: No available kernel , you might need to update your environment to a more recent version of torch.

I didn't get any errors so I assume the flash attention kernel is available. I'm running everything in nvcr.io/nvidia/pytorch:23.09-py3.

@oliverholworthy
Copy link
Member Author

Hmm, I've been trying things out with intfloat/e5-large-unsupervised which is a bigger model. I wonder why this doesn't make much difference for smaller models like sentence-transformers/all-MiniLM-L6-v2..

@edknv
Copy link
Contributor

edknv commented Oct 25, 2023

Hmm, I've been trying things out with intfloat/e5-large-unsupervised which is a bigger model. I wonder why this doesn't make much difference for smaller models like sentence-transformers/all-MiniLM-L6-v2..

Yeah, it looks like the model size has something to do with it. I do see ~2x improvements with large models like e5-large-unsupervised, but with e5-small-unsupervised, it's (slightly) worse than the baseline. Just guessing but maybe, smaller models -> fewer attention layers -> minimal boost vs. larger models -> more attention layers -> bigger boost? And maybe with the minimal, smaller models, the overhead of model conversion, etc.. eats up any performance gain.

@marcromeyn
Copy link
Contributor

This confirms my suspicion that it depends on various factors what serving mechanism is the most efficient for a particular model. I wonder if we could implement some functionality that's a bit similar to how we tried to estimate memory consumption of a model by calling it with various batch/seq-len combinations. We could take the model, some synthetic data and record latency of all the batch-prediction techniques we offer & pick the fastest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants