Skip to content

feat: add support for florence2 #4383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions examples/models/core/enc_dec/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import safetensors
from helper import (convert_weight_to_dtype, fairseq_sin_pos_embedding,
fuse_qkv_one_layer, reshape, split)
from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration,
from transformers import (AutoModelForSeq2SeqLM, AutoModelForCausalLM, Blip2ForConditionalGeneration,
MBartForConditionalGeneration,
Pix2StructForConditionalGeneration,
T5ForConditionalGeneration, VisionEncoderDecoderModel)
Expand Down Expand Up @@ -966,6 +966,8 @@ def get_attn_module_name(component, layer, attn_type):

return weights

convert_florence2_weights_to_tllm_safetensors = convert_bart_weights_to_tllm_safetensors # func alias


def parse_pix2struct_config(args, hf_model):
# manually set q_scaling to offset attention scaling's effect.
Expand Down Expand Up @@ -1487,6 +1489,11 @@ def get_model(args):
elif args.model_type == "blip2":
model = Blip2ForConditionalGeneration.from_pretrained(
args.model_dir).language_model
elif args.model_type == "florence2":
model = AutoModelForCausalLM.from_pretrained(
args.model_dir,
trust_remote_code=True,
).language_model
elif args.model_type == "language_adapter":
import torch

Expand Down Expand Up @@ -1522,6 +1529,7 @@ def convert_checkpoint(args):
quant_algo = None

model_type = args.model_type if args.model_type != "blip2" else "t5"
model_type = model_type if model_type != "florence2" else "bart"
encoder_config, decoder_config = globals()[f'parse_{model_type}_config'](
args, model)

Expand Down Expand Up @@ -1705,7 +1713,7 @@ def convert(worker_rank, world_size, args, model_config, convert_args,
type=str,
default='t5',
choices=[
't5', 'nmt', 'bart', 'pix2struct', 'blip2', 'language_adapter'
't5', 'nmt', 'bart', 'florence2', 'pix2struct', 'blip2', 'language_adapter'
],
help=
'Multimodal type when this script is used for multimodal conversion.')
Expand Down
7 changes: 6 additions & 1 deletion examples/models/core/enc_dec/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ def get_qkv_module_name(model_type):
q = "q"
k = "k"
v = "v"
elif model_type == "bart" or model_type == "nmt" or model_type == "language_adapter":
elif (
model_type == "bart"
or model_type == "florence2"
or model_type == "nmt"
or model_type == "language_adapter"
):
q = "q_proj"
k = "k_proj"
v = "v_proj"
Expand Down
81 changes: 81 additions & 0 deletions examples/models/core/multimodal/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,87 @@ pip install -r requirements-qwen2vl.txt

Note: use `--run_profiling` for performance measurement, use `--check_accuracy` for accuracy check.

## Florence-2

[Florence-2](https://huggingface.co/microsoft/Florence-2-large) is a powerful vision foundation model designed to perform a wide range of vision and vision-language tasks using simple text prompts. Built on a sequence-to-sequence architecture, it combines the BART language model with the DaViT vision encoder to understand and generate responses for tasks like image captioning, object detection, and segmentation. Trained on the large-scale FLD-5B dataset (5.4 billion annotations across 126 million images), Florence-2 excels in both zero-shot and fine-tuned scenarios, demonstrating strong multi-task learning capabilities.

1. Download Huggingface weights and convert original checkpoint to TRT-LLM checkpoint format
following example in `examples/models/core/enc_dec/README.md`.

```bash
export MODEL_NAME="florence-2-large"
export MODEL_TYPE="florence2"
export INFERENCE_PRECISION="float16"
export TP_SIZE=1
export PP_SIZE=1
export WORLD_SIZE=1
export BATCH_SIZE=32
export MAX_BEAM_WIDTH=1
export NUM_VISUAL_FEATURES=577

python ../enc_dec/convert_checkpoint.py --model_type ${MODEL_TYPE} \
--model_dir tmp/hf_models/${MODEL_NAME} \
--output_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \
--tp_size ${TP_SIZE} \
--pp_size ${PP_SIZE} \
--dtype ${INFERENCE_PRECISION}
```

2. Build TRT-LLM engine from TRT-LLM checkpoint

```bash
trtllm-build --checkpoint_dir tmp/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \
--output_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \
--kv_cache_type disabled \
--moe_plugin disable \
--max_beam_width ${MAX_BEAM_WIDTH} \
--max_batch_size ${BATCH_SIZE} \
--max_input_len 1024 \
--max_prompt_embedding_table_size $((NUM_VISUAL_FEATURES * BATCH_SIZE)) \
--gemm_plugin ${INFERENCE_PRECISION} \
--bert_attention_plugin ${INFERENCE_PRECISION} \
--bert_context_fmha_fp32_acc enable \
--gpt_attention_plugin ${INFERENCE_PRECISION} \
--remove_input_padding enable

trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \
--output_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \
--kv_cache_type paged \
--moe_plugin disable \
--max_beam_width ${MAX_BEAM_WIDTH} \
--max_batch_size ${BATCH_SIZE} \
--max_input_len 1 \
--max_seq_len 1024 \
--max_encoder_input_len $((1024 * BATCH_SIZE)) \
--gemm_plugin ${INFERENCE_PRECISION} \
--bert_attention_plugin ${INFERENCE_PRECISION} \
--gpt_attention_plugin ${INFERENCE_PRECISION} \
--remove_input_padding enable
```

3. Build TensorRT engines for vision encoders

```bash
python build_multimodal_engine.py
--model_type blip2
--model_path tmp/hf_models/${MODEL_NAME}
--output_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/vision
--max_batch_size ${BATCH_SIZE}
```

The built engines are located in `tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/vision`.

To run the Florence-2 pipeline with batch size > 8, change `--max_batch_size` argument to `build_multimodal_engine.py` accordingly.

4. Assemble everything into BLIP2 pipeline

```bash
python run.py \
--max_new_tokens 30 \
--input_text "<OD>" \
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
--engine_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}
```
## Dataset Evaluation

This section explains how to evaluate datasets using our provided script, including supported models and configurations.
Expand Down
53 changes: 53 additions & 0 deletions examples/models/core/multimodal/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
export MODEL_NAME="florence-2-large-ft"
export MODEL_TYPE="florence2"
export INFERENCE_PRECISION="float16"
export TP_SIZE=1
export PP_SIZE=1
export WORLD_SIZE=1
export BATCH_SIZE=32
export MAX_BEAM_WIDTH=1
export NUM_VISUAL_FEATURES=577

python ../enc_dec/convert_checkpoint.py --model_type ${MODEL_TYPE} \
--model_dir /workspace/models/hf_models/${MODEL_NAME} \
--output_dir /workspace/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \
--tp_size ${TP_SIZE} \
--pp_size ${PP_SIZE} \
--dtype ${INFERENCE_PRECISION} \
--workers 1

trtllm-build --checkpoint_dir /workspace/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \
--output_dir /workspace/models/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \
--kv_cache_type disabled \
--moe_plugin disable \
--max_beam_width ${MAX_BEAM_WIDTH} \
--max_batch_size ${BATCH_SIZE} \
--max_input_len 1024 \
--max_prompt_embedding_table_size $((NUM_VISUAL_FEATURES * BATCH_SIZE)) \
--gemm_plugin ${INFERENCE_PRECISION} \
--bert_attention_plugin ${INFERENCE_PRECISION} \
--bert_context_fmha_fp32_acc enable \
--gpt_attention_plugin ${INFERENCE_PRECISION} \
--remove_input_padding enable \
--workers 1

trtllm-build --checkpoint_dir /workspace/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \
--output_dir /workspace/models/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \
--kv_cache_type paged \
--moe_plugin disable \
--max_beam_width ${MAX_BEAM_WIDTH} \
--max_batch_size ${BATCH_SIZE} \
--max_input_len 1 \
--max_seq_len 1024 \
--max_encoder_input_len $((1024 * BATCH_SIZE)) \
--gemm_plugin ${INFERENCE_PRECISION} \
--bert_attention_plugin ${INFERENCE_PRECISION} \
--gpt_attention_plugin ${INFERENCE_PRECISION} \
--remove_input_padding enable \
--workers 1

python build_multimodal_engine.py \
--model_type ${MODEL_TYPE} \
--model_path /workspace/models/hf_models/${MODEL_NAME} \
--output_dir /workspace/models/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/vision \
--max_batch_size ${BATCH_SIZE}
100 changes: 100 additions & 0 deletions tensorrt_llm/tools/multimodal_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def build(self):
build_qwen2_audio_engine(args)
elif args.model_type == "pixtral":
build_pixtral_engine(args)
elif args.model_type == "florence2":
build_florence2_engine(args)
else:
raise RuntimeError(f"Invalid model type {args.model_type}")

Expand Down Expand Up @@ -1734,3 +1736,101 @@ def forward(self, pixel_values, attention_mask):
max_batch_size=args.max_batch_size,
engine_name=f"model.engine",
dtype=torch.bfloat16)


def build_florence2_engine(args):
processor = AutoProcessor.from_pretrained(args.model_path)

raw_image = Image.new("RGB", [10, 10]) # dummy image
prompt = "<OD>"
inputs = processor(raw_image, prompt, return_tensors="pt").to(args.device, torch.float16)
pixel_values = inputs["pixel_values"]

class Florence2VisionWrapper(torch.nn.Module):
def __init__(
self,
vision_tower,
image_projection,
image_proj_norm,
image_pos_embed,
image_feature_source,
visual_temporal_embed,
):
super().__init__()
self.vision_tower = vision_tower
self.image_proj_norm = image_proj_norm
self.image_pos_embed = image_pos_embed
self.image_projection = image_projection
self.visual_temporal_embed = visual_temporal_embed
self.image_feature_source = image_feature_source

def forward(self, pixel_values):
assert len(pixel_values.shape) == 4, f"invalid image shape {pixel_values.shape}"
batch_size, _, _, _ = pixel_values.shape
T = 1
x = self.vision_tower.forward_features_unpool(pixel_values)

# image_pos_embed
x = x.view(batch_size * T, -1, x.shape[-1])
num_tokens = x.shape[-2]
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
assert h * w == num_tokens, "Only support square feature maps for now"
x = x.view(batch_size * T, h, w, x.shape[-1])
pos_embed = self.image_pos_embed(x)
x = x + pos_embed
x = x.view(batch_size, T * h * w, x.shape[-1])

# visual_temporal_embed
visual_temporal_embed = self.visual_temporal_embed(
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]
)
x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(
1, T, 1, x.shape[-1]
)

x_feat_dict = {}

spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x

temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x

x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
x_feat_dict["last_frame"] = x

new_x = []
for _image_feature_source in self.image_feature_source:
if _image_feature_source not in x_feat_dict:
raise ValueError(
"invalid image feature source: {}".format(_image_feature_source)
)
new_x.append(x_feat_dict[_image_feature_source])

x = torch.cat(new_x, dim=1)

x = x @ self.image_projection
x = self.image_proj_norm(x)

return x

model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16)

wrapper = Florence2VisionWrapper(
model.vision_tower,
model.image_projection,
model.image_proj_norm,
model.image_pos_embed,
model.image_feature_source,
model.visual_temporal_embed,
)
wrapper.to(args.device)

export_onnx(wrapper, pixel_values, f"{args.output_dir}/onnx")
build_trt_engine(
args.model_type,
[pixel_values.shape[1], pixel_values.shape[2], pixel_values.shape[3]], # [3, H, W]
f"{args.output_dir}/onnx",
args.output_dir,
args.max_batch_size,
)