From 27a6e469a8daf1f0bf39fd357bca8c7913fe90c2 Mon Sep 17 00:00:00 2001 From: ducviet00 Date: Thu, 15 May 2025 17:20:23 +0700 Subject: [PATCH 1/3] add builder --- examples/models/core/multimodal/README.md | 82 +++++++++++++++++++ tensorrt_llm/tools/multimodal_builder.py | 99 +++++++++++++++++++++++ 2 files changed, 181 insertions(+) diff --git a/examples/models/core/multimodal/README.md b/examples/models/core/multimodal/README.md index 94965e832b..ed8612d5bb 100644 --- a/examples/models/core/multimodal/README.md +++ b/examples/models/core/multimodal/README.md @@ -1132,6 +1132,88 @@ pip install -r requirements-qwen2vl.txt Note: use `--run_profiling` for performance measurement, use `--check_accuracy` for accuracy check. +## Florence-2 + +[Florence-2](https://huggingface.co/microsoft/Florence-2-large) is a powerful vision foundation model designed to perform a wide range of vision and vision-language tasks using simple text prompts. Built on a sequence-to-sequence architecture, it combines the BART language model with the DaViT vision encoder to understand and generate responses for tasks like image captioning, object detection, and segmentation. Trained on the large-scale FLD-5B dataset (5.4 billion annotations across 126 million images), Florence-2 excels in both zero-shot and fine-tuned scenarios, demonstrating strong multi-task learning capabilities. + +1. Download Huggingface weights and convert original checkpoint to TRT-LLM checkpoint format + following example in `examples/models/core/enc_dec/README.md`. + + ```bash + export MODEL_NAME="florence-2-large" + export MODEL_TYPE="florence2" + export INFERENCE_PRECISION="float16" + export TP_SIZE=1 + export PP_SIZE=1 + export WORLD_SIZE=1 + export BATCH_SIZE=32 + export MAX_BEAM_WIDTH=1 + export NUM_VISUAL_FEATURES=577 + + python convert_checkpoint.py --model_type ${MODEL_TYPE} \ + --model_dir tmp/hf_models/${MODEL_NAME} \ + --output_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \ + --tp_size ${TP_SIZE} \ + --pp_size ${PP_SIZE} \ + --use_prompt_tuning \ + --dtype ${INFERENCE_PRECISION} + ``` + +2. Build TRT-LLM engine from TRT-LLM checkpoint + + ```bash + trtllm-build --checkpoint_dir tmp/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \ + --output_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \ + --kv_cache_type disabled \ + --moe_plugin disable \ + --max_beam_width ${MAX_BEAM_WIDTH} \ + --max_batch_size ${BATCH_SIZE} \ + --max_input_len 1024 \ + --max_prompt_embedding_table_size $((NUM_VISUAL_FEATURES * BATCH_SIZE)) \ + --gemm_plugin ${INFERENCE_PRECISION} \ + --bert_attention_plugin ${INFERENCE_PRECISION} \ + --bert_context_fmha_fp32_acc enable \ + --gpt_attention_plugin ${INFERENCE_PRECISION} \ + --remove_input_padding enable + + trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \ + --output_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \ + --kv_cache_type paged \ + --moe_plugin disable \ + --max_beam_width ${MAX_BEAM_WIDTH} \ + --max_batch_size ${BATCH_SIZE} \ + --max_input_len 1 \ + --max_seq_len 1024 \ + --max_encoder_input_len $((1024 * BATCH_SIZE)) \ + --gemm_plugin ${INFERENCE_PRECISION} \ + --bert_attention_plugin ${INFERENCE_PRECISION} \ + --gpt_attention_plugin ${INFERENCE_PRECISION} \ + --remove_input_padding enable + ``` + +3. Build TensorRT engines for vision encoders + + ```bash + python build_multimodal_engine.py + --model_type blip2 + --model_path tmp/hf_models/${MODEL_NAME} + --output_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/vision + --max_batch_size ${BATCH_SIZE} + ``` + + The built engines are located in `tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/vision`. + + To run the Florence-2 pipeline with batch size > 8, change `--max_batch_size` argument to `build_multimodal_engine.py` accordingly. + +4. Assemble everything into BLIP2 pipeline + + ```bash + python run.py \ + --max_new_tokens 30 \ + --input_text "" \ + --hf_model_dir tmp/hf_models/${MODEL_NAME} \ + --engine_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION} + ``` ## Dataset Evaluation This section explains how to evaluate datasets using our provided script, including supported models and configurations. diff --git a/tensorrt_llm/tools/multimodal_builder.py b/tensorrt_llm/tools/multimodal_builder.py index 95c4b066d2..af42ca00cf 100644 --- a/tensorrt_llm/tools/multimodal_builder.py +++ b/tensorrt_llm/tools/multimodal_builder.py @@ -144,6 +144,8 @@ def build(self): build_qwen2_audio_engine(args) elif args.model_type == "pixtral": build_pixtral_engine(args) + elif args.model_type == "florence2": + build_florence2_engine(args) else: raise RuntimeError(f"Invalid model type {args.model_type}") @@ -1734,3 +1736,100 @@ def forward(self, pixel_values, attention_mask): max_batch_size=args.max_batch_size, engine_name=f"model.engine", dtype=torch.bfloat16) + + +def build_florence2_engine(args): + processor = AutoProcessor.from_pretrained(args.model_path) + + raw_image = Image.new("RGB", [10, 10]) # dummy image + prompt = "" + inputs = processor(raw_image, prompt, return_tensors="pt").to(args.device, torch.float16) + pixel_values = inputs["pixel_values"] + + class Florence2VisionWrapper(torch.nn.Module): + def __init__( + self, + vision_tower, + image_projection, + image_proj_norm, + image_pos_embed, + image_feature_source, + visual_temporal_embed, + ): + super().__init__() + self.vision_tower = vision_tower + self.image_proj_norm = image_proj_norm + self.image_pos_embed = image_pos_embed + self.image_projection = image_projection + self.visual_temporal_embed = visual_temporal_embed + self.image_feature_source = image_feature_source + + def forward(self, pixel_values): + batch_size, _, _, _ = pixel_values.shape + T = 1 + x = self.vision_tower.forward_features_unpool(pixel_values) + + # image_pos_embed + x = x.view(batch_size * T, -1, x.shape[-1]) + num_tokens = x.shape[-2] + h, w = int(num_tokens**0.5), int(num_tokens**0.5) + assert h * w == num_tokens, "Only support square feature maps for now" + x = x.view(batch_size * T, h, w, x.shape[-1]) + pos_embed = self.image_pos_embed(x) + x = x + pos_embed + x = x.view(batch_size, T * h * w, x.shape[-1]) + + # visual_temporal_embed + visual_temporal_embed = self.visual_temporal_embed( + x.view(batch_size, T, -1, x.shape[-1])[:, :, 0] + ) + x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view( + 1, T, 1, x.shape[-1] + ) + + x_feat_dict = {} + + spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x + + temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) + x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x + + x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] + x_feat_dict["last_frame"] = x + + new_x = [] + for _image_feature_source in self.image_feature_source: + if _image_feature_source not in x_feat_dict: + raise ValueError( + "invalid image feature source: {}".format(_image_feature_source) + ) + new_x.append(x_feat_dict[_image_feature_source]) + + x = torch.cat(new_x, dim=1) + + x = x @ self.image_projection + x = self.image_proj_norm(x) + + return x + + model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16) + + wrapper = Florence2VisionWrapper( + model.vision_tower, + model.image_projection, + model.image_proj_norm, + model.image_pos_embed, + model.image_feature_source, + model.visual_temporal_embed, + ) + wrapper.to(args.device) + + export_onnx(wrapper, pixel_values, f"{args.output_dir}/onnx") + build_trt_engine( + args.model_type, + [pixel_values.shape[1], pixel_values.shape[2], pixel_values.shape[3]], # [3, H, W] + f"{args.output_dir}/onnx", + args.output_dir, + args.max_batch_size, + ) From cd1ee3e71e04fe882a1abd1fe480b2e1dd4bb572 Mon Sep 17 00:00:00 2001 From: ducviet00 Date: Thu, 15 May 2025 23:01:50 +0700 Subject: [PATCH 2/3] add florence2 language model conversion --- .../models/core/enc_dec/convert_checkpoint.py | 12 ++++- examples/models/core/multimodal/build.sh | 51 +++++++++++++++++++ tensorrt_llm/tools/multimodal_builder.py | 1 + 3 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 examples/models/core/multimodal/build.sh diff --git a/examples/models/core/enc_dec/convert_checkpoint.py b/examples/models/core/enc_dec/convert_checkpoint.py index 577e89e941..a4e7ebcd9a 100755 --- a/examples/models/core/enc_dec/convert_checkpoint.py +++ b/examples/models/core/enc_dec/convert_checkpoint.py @@ -13,7 +13,7 @@ import safetensors from helper import (convert_weight_to_dtype, fairseq_sin_pos_embedding, fuse_qkv_one_layer, reshape, split) -from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration, +from transformers import (AutoModelForSeq2SeqLM, AutoModelForCausalLM, Blip2ForConditionalGeneration, MBartForConditionalGeneration, Pix2StructForConditionalGeneration, T5ForConditionalGeneration, VisionEncoderDecoderModel) @@ -966,6 +966,8 @@ def get_attn_module_name(component, layer, attn_type): return weights +convert_florence2_weights_to_tllm_safetensors = convert_bart_weights_to_tllm_safetensors # func alias + def parse_pix2struct_config(args, hf_model): # manually set q_scaling to offset attention scaling's effect. @@ -1487,6 +1489,11 @@ def get_model(args): elif args.model_type == "blip2": model = Blip2ForConditionalGeneration.from_pretrained( args.model_dir).language_model + elif args.model_type == "blip2": + model = AutoModelForCausalLM.from_pretrained( + args.model_dir, + trust_remote_code=True, + ).language_model elif args.model_type == "language_adapter": import torch @@ -1522,6 +1529,7 @@ def convert_checkpoint(args): quant_algo = None model_type = args.model_type if args.model_type != "blip2" else "t5" + model_type = model_type if model_type != "florence2" else "bart" encoder_config, decoder_config = globals()[f'parse_{model_type}_config']( args, model) @@ -1705,7 +1713,7 @@ def convert(worker_rank, world_size, args, model_config, convert_args, type=str, default='t5', choices=[ - 't5', 'nmt', 'bart', 'pix2struct', 'blip2', 'language_adapter' + 't5', 'nmt', 'bart', 'florence2', 'pix2struct', 'blip2', 'language_adapter' ], help= 'Multimodal type when this script is used for multimodal conversion.') diff --git a/examples/models/core/multimodal/build.sh b/examples/models/core/multimodal/build.sh new file mode 100644 index 0000000000..056022d73d --- /dev/null +++ b/examples/models/core/multimodal/build.sh @@ -0,0 +1,51 @@ +export MODEL_NAME="florence-2-large" +export MODEL_TYPE="florence2" +export INFERENCE_PRECISION="float16" +export TP_SIZE=1 +export PP_SIZE=1 +export WORLD_SIZE=1 +export BATCH_SIZE=32 +export MAX_BEAM_WIDTH=1 +export NUM_VISUAL_FEATURES=577 + +python convert_checkpoint.py --model_type ${MODEL_TYPE} \ + --model_dir /workspace/models/hf_models/${MODEL_NAME} \ + --output_dir /workspace/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \ + --tp_size ${TP_SIZE} \ + --pp_size ${PP_SIZE} \ + --use_prompt_tuning \ + --dtype ${INFERENCE_PRECISION} + +trtllm-build --checkpoint_dir /workspace/models/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \ + --output_dir /workspace/models/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \ + --kv_cache_type disabled \ + --moe_plugin disable \ + --max_beam_width ${MAX_BEAM_WIDTH} \ + --max_batch_size ${BATCH_SIZE} \ + --max_input_len 1024 \ + --max_prompt_embedding_table_size $((NUM_VISUAL_FEATURES * BATCH_SIZE)) \ + --gemm_plugin ${INFERENCE_PRECISION} \ + --bert_attention_plugin ${INFERENCE_PRECISION} \ + --bert_context_fmha_fp32_acc enable \ + --gpt_attention_plugin ${INFERENCE_PRECISION} \ + --remove_input_padding enable + +trtllm-build --checkpoint_dir /workspace/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \ + --output_dir /workspace/models/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \ + --kv_cache_type paged \ + --moe_plugin disable \ + --max_beam_width ${MAX_BEAM_WIDTH} \ + --max_batch_size ${BATCH_SIZE} \ + --max_input_len 1 \ + --max_seq_len 1024 \ + --max_encoder_input_len $((1024 * BATCH_SIZE)) \ + --gemm_plugin ${INFERENCE_PRECISION} \ + --bert_attention_plugin ${INFERENCE_PRECISION} \ + --gpt_attention_plugin ${INFERENCE_PRECISION} \ + --remove_input_padding enable + +python build_multimodal_engine.py \ + --model_type ${MODEL_TYPE} \ + --model_path /workspace/models/hf_models/${MODEL_NAME} \ + --output_dir /workspace/models/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/vision \ + --max_batch_size ${BATCH_SIZE} \ No newline at end of file diff --git a/tensorrt_llm/tools/multimodal_builder.py b/tensorrt_llm/tools/multimodal_builder.py index af42ca00cf..06b5a35251 100644 --- a/tensorrt_llm/tools/multimodal_builder.py +++ b/tensorrt_llm/tools/multimodal_builder.py @@ -1765,6 +1765,7 @@ def __init__( self.image_feature_source = image_feature_source def forward(self, pixel_values): + assert len(pixel_values.shape) == 4, f"invalid image shape {pixel_values.shape}" batch_size, _, _, _ = pixel_values.shape T = 1 x = self.vision_tower.forward_features_unpool(pixel_values) From 3cf38b464204b16f48f4712eb577997a6485de8e Mon Sep 17 00:00:00 2001 From: ducviet00 Date: Fri, 16 May 2025 00:27:52 +0700 Subject: [PATCH 3/3] chore: add florence2 in enc_dec --- .../models/core/enc_dec/convert_checkpoint.py | 2 +- examples/models/core/enc_dec/helper.py | 7 ++++++- examples/models/core/multimodal/README.md | 3 +-- examples/models/core/multimodal/build.sh | 16 +++++++++------- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/models/core/enc_dec/convert_checkpoint.py b/examples/models/core/enc_dec/convert_checkpoint.py index a4e7ebcd9a..83b7051596 100755 --- a/examples/models/core/enc_dec/convert_checkpoint.py +++ b/examples/models/core/enc_dec/convert_checkpoint.py @@ -1489,7 +1489,7 @@ def get_model(args): elif args.model_type == "blip2": model = Blip2ForConditionalGeneration.from_pretrained( args.model_dir).language_model - elif args.model_type == "blip2": + elif args.model_type == "florence2": model = AutoModelForCausalLM.from_pretrained( args.model_dir, trust_remote_code=True, diff --git a/examples/models/core/enc_dec/helper.py b/examples/models/core/enc_dec/helper.py index ed3628c1bb..6a82a680d3 100755 --- a/examples/models/core/enc_dec/helper.py +++ b/examples/models/core/enc_dec/helper.py @@ -74,7 +74,12 @@ def get_qkv_module_name(model_type): q = "q" k = "k" v = "v" - elif model_type == "bart" or model_type == "nmt" or model_type == "language_adapter": + elif ( + model_type == "bart" + or model_type == "florence2" + or model_type == "nmt" + or model_type == "language_adapter" + ): q = "q_proj" k = "k_proj" v = "v_proj" diff --git a/examples/models/core/multimodal/README.md b/examples/models/core/multimodal/README.md index ed8612d5bb..b0b5a26eb5 100644 --- a/examples/models/core/multimodal/README.md +++ b/examples/models/core/multimodal/README.md @@ -1150,12 +1150,11 @@ pip install -r requirements-qwen2vl.txt export MAX_BEAM_WIDTH=1 export NUM_VISUAL_FEATURES=577 - python convert_checkpoint.py --model_type ${MODEL_TYPE} \ + python ../enc_dec/convert_checkpoint.py --model_type ${MODEL_TYPE} \ --model_dir tmp/hf_models/${MODEL_NAME} \ --output_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \ --tp_size ${TP_SIZE} \ --pp_size ${PP_SIZE} \ - --use_prompt_tuning \ --dtype ${INFERENCE_PRECISION} ``` diff --git a/examples/models/core/multimodal/build.sh b/examples/models/core/multimodal/build.sh index 056022d73d..0c0e95003f 100644 --- a/examples/models/core/multimodal/build.sh +++ b/examples/models/core/multimodal/build.sh @@ -1,4 +1,4 @@ -export MODEL_NAME="florence-2-large" +export MODEL_NAME="florence-2-large-ft" export MODEL_TYPE="florence2" export INFERENCE_PRECISION="float16" export TP_SIZE=1 @@ -8,15 +8,15 @@ export BATCH_SIZE=32 export MAX_BEAM_WIDTH=1 export NUM_VISUAL_FEATURES=577 -python convert_checkpoint.py --model_type ${MODEL_TYPE} \ +python ../enc_dec/convert_checkpoint.py --model_type ${MODEL_TYPE} \ --model_dir /workspace/models/hf_models/${MODEL_NAME} \ --output_dir /workspace/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \ --tp_size ${TP_SIZE} \ --pp_size ${PP_SIZE} \ - --use_prompt_tuning \ - --dtype ${INFERENCE_PRECISION} + --dtype ${INFERENCE_PRECISION} \ + --workers 1 -trtllm-build --checkpoint_dir /workspace/models/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \ +trtllm-build --checkpoint_dir /workspace/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \ --output_dir /workspace/models/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \ --kv_cache_type disabled \ --moe_plugin disable \ @@ -28,7 +28,8 @@ trtllm-build --checkpoint_dir /workspace/models/models/trt_models/${MODEL_NAME}/ --bert_attention_plugin ${INFERENCE_PRECISION} \ --bert_context_fmha_fp32_acc enable \ --gpt_attention_plugin ${INFERENCE_PRECISION} \ - --remove_input_padding enable + --remove_input_padding enable \ + --workers 1 trtllm-build --checkpoint_dir /workspace/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \ --output_dir /workspace/models/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \ @@ -42,7 +43,8 @@ trtllm-build --checkpoint_dir /workspace/models/trt_models/${MODEL_NAME}/${INFER --gemm_plugin ${INFERENCE_PRECISION} \ --bert_attention_plugin ${INFERENCE_PRECISION} \ --gpt_attention_plugin ${INFERENCE_PRECISION} \ - --remove_input_padding enable + --remove_input_padding enable \ + --workers 1 python build_multimodal_engine.py \ --model_type ${MODEL_TYPE} \