diff --git a/examples/models/core/enc_dec/convert_checkpoint.py b/examples/models/core/enc_dec/convert_checkpoint.py index 577e89e941..83b7051596 100755 --- a/examples/models/core/enc_dec/convert_checkpoint.py +++ b/examples/models/core/enc_dec/convert_checkpoint.py @@ -13,7 +13,7 @@ import safetensors from helper import (convert_weight_to_dtype, fairseq_sin_pos_embedding, fuse_qkv_one_layer, reshape, split) -from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration, +from transformers import (AutoModelForSeq2SeqLM, AutoModelForCausalLM, Blip2ForConditionalGeneration, MBartForConditionalGeneration, Pix2StructForConditionalGeneration, T5ForConditionalGeneration, VisionEncoderDecoderModel) @@ -966,6 +966,8 @@ def get_attn_module_name(component, layer, attn_type): return weights +convert_florence2_weights_to_tllm_safetensors = convert_bart_weights_to_tllm_safetensors # func alias + def parse_pix2struct_config(args, hf_model): # manually set q_scaling to offset attention scaling's effect. @@ -1487,6 +1489,11 @@ def get_model(args): elif args.model_type == "blip2": model = Blip2ForConditionalGeneration.from_pretrained( args.model_dir).language_model + elif args.model_type == "florence2": + model = AutoModelForCausalLM.from_pretrained( + args.model_dir, + trust_remote_code=True, + ).language_model elif args.model_type == "language_adapter": import torch @@ -1522,6 +1529,7 @@ def convert_checkpoint(args): quant_algo = None model_type = args.model_type if args.model_type != "blip2" else "t5" + model_type = model_type if model_type != "florence2" else "bart" encoder_config, decoder_config = globals()[f'parse_{model_type}_config']( args, model) @@ -1705,7 +1713,7 @@ def convert(worker_rank, world_size, args, model_config, convert_args, type=str, default='t5', choices=[ - 't5', 'nmt', 'bart', 'pix2struct', 'blip2', 'language_adapter' + 't5', 'nmt', 'bart', 'florence2', 'pix2struct', 'blip2', 'language_adapter' ], help= 'Multimodal type when this script is used for multimodal conversion.') diff --git a/examples/models/core/enc_dec/helper.py b/examples/models/core/enc_dec/helper.py index ed3628c1bb..6a82a680d3 100755 --- a/examples/models/core/enc_dec/helper.py +++ b/examples/models/core/enc_dec/helper.py @@ -74,7 +74,12 @@ def get_qkv_module_name(model_type): q = "q" k = "k" v = "v" - elif model_type == "bart" or model_type == "nmt" or model_type == "language_adapter": + elif ( + model_type == "bart" + or model_type == "florence2" + or model_type == "nmt" + or model_type == "language_adapter" + ): q = "q_proj" k = "k_proj" v = "v_proj" diff --git a/examples/models/core/multimodal/README.md b/examples/models/core/multimodal/README.md index 94965e832b..b0b5a26eb5 100644 --- a/examples/models/core/multimodal/README.md +++ b/examples/models/core/multimodal/README.md @@ -1132,6 +1132,87 @@ pip install -r requirements-qwen2vl.txt Note: use `--run_profiling` for performance measurement, use `--check_accuracy` for accuracy check. +## Florence-2 + +[Florence-2](https://huggingface.co/microsoft/Florence-2-large) is a powerful vision foundation model designed to perform a wide range of vision and vision-language tasks using simple text prompts. Built on a sequence-to-sequence architecture, it combines the BART language model with the DaViT vision encoder to understand and generate responses for tasks like image captioning, object detection, and segmentation. Trained on the large-scale FLD-5B dataset (5.4 billion annotations across 126 million images), Florence-2 excels in both zero-shot and fine-tuned scenarios, demonstrating strong multi-task learning capabilities. + +1. Download Huggingface weights and convert original checkpoint to TRT-LLM checkpoint format + following example in `examples/models/core/enc_dec/README.md`. + + ```bash + export MODEL_NAME="florence-2-large" + export MODEL_TYPE="florence2" + export INFERENCE_PRECISION="float16" + export TP_SIZE=1 + export PP_SIZE=1 + export WORLD_SIZE=1 + export BATCH_SIZE=32 + export MAX_BEAM_WIDTH=1 + export NUM_VISUAL_FEATURES=577 + + python ../enc_dec/convert_checkpoint.py --model_type ${MODEL_TYPE} \ + --model_dir tmp/hf_models/${MODEL_NAME} \ + --output_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \ + --tp_size ${TP_SIZE} \ + --pp_size ${PP_SIZE} \ + --dtype ${INFERENCE_PRECISION} + ``` + +2. Build TRT-LLM engine from TRT-LLM checkpoint + + ```bash + trtllm-build --checkpoint_dir tmp/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \ + --output_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \ + --kv_cache_type disabled \ + --moe_plugin disable \ + --max_beam_width ${MAX_BEAM_WIDTH} \ + --max_batch_size ${BATCH_SIZE} \ + --max_input_len 1024 \ + --max_prompt_embedding_table_size $((NUM_VISUAL_FEATURES * BATCH_SIZE)) \ + --gemm_plugin ${INFERENCE_PRECISION} \ + --bert_attention_plugin ${INFERENCE_PRECISION} \ + --bert_context_fmha_fp32_acc enable \ + --gpt_attention_plugin ${INFERENCE_PRECISION} \ + --remove_input_padding enable + + trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \ + --output_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \ + --kv_cache_type paged \ + --moe_plugin disable \ + --max_beam_width ${MAX_BEAM_WIDTH} \ + --max_batch_size ${BATCH_SIZE} \ + --max_input_len 1 \ + --max_seq_len 1024 \ + --max_encoder_input_len $((1024 * BATCH_SIZE)) \ + --gemm_plugin ${INFERENCE_PRECISION} \ + --bert_attention_plugin ${INFERENCE_PRECISION} \ + --gpt_attention_plugin ${INFERENCE_PRECISION} \ + --remove_input_padding enable + ``` + +3. Build TensorRT engines for vision encoders + + ```bash + python build_multimodal_engine.py + --model_type blip2 + --model_path tmp/hf_models/${MODEL_NAME} + --output_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/vision + --max_batch_size ${BATCH_SIZE} + ``` + + The built engines are located in `tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/vision`. + + To run the Florence-2 pipeline with batch size > 8, change `--max_batch_size` argument to `build_multimodal_engine.py` accordingly. + +4. Assemble everything into BLIP2 pipeline + + ```bash + python run.py \ + --max_new_tokens 30 \ + --input_text "" \ + --hf_model_dir tmp/hf_models/${MODEL_NAME} \ + --engine_dir tmp/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION} + ``` ## Dataset Evaluation This section explains how to evaluate datasets using our provided script, including supported models and configurations. diff --git a/examples/models/core/multimodal/build.sh b/examples/models/core/multimodal/build.sh new file mode 100644 index 0000000000..0c0e95003f --- /dev/null +++ b/examples/models/core/multimodal/build.sh @@ -0,0 +1,53 @@ +export MODEL_NAME="florence-2-large-ft" +export MODEL_TYPE="florence2" +export INFERENCE_PRECISION="float16" +export TP_SIZE=1 +export PP_SIZE=1 +export WORLD_SIZE=1 +export BATCH_SIZE=32 +export MAX_BEAM_WIDTH=1 +export NUM_VISUAL_FEATURES=577 + +python ../enc_dec/convert_checkpoint.py --model_type ${MODEL_TYPE} \ + --model_dir /workspace/models/hf_models/${MODEL_NAME} \ + --output_dir /workspace/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION} \ + --tp_size ${TP_SIZE} \ + --pp_size ${PP_SIZE} \ + --dtype ${INFERENCE_PRECISION} \ + --workers 1 + +trtllm-build --checkpoint_dir /workspace/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \ + --output_dir /workspace/models/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/encoder \ + --kv_cache_type disabled \ + --moe_plugin disable \ + --max_beam_width ${MAX_BEAM_WIDTH} \ + --max_batch_size ${BATCH_SIZE} \ + --max_input_len 1024 \ + --max_prompt_embedding_table_size $((NUM_VISUAL_FEATURES * BATCH_SIZE)) \ + --gemm_plugin ${INFERENCE_PRECISION} \ + --bert_attention_plugin ${INFERENCE_PRECISION} \ + --bert_context_fmha_fp32_acc enable \ + --gpt_attention_plugin ${INFERENCE_PRECISION} \ + --remove_input_padding enable \ + --workers 1 + +trtllm-build --checkpoint_dir /workspace/models/trt_models/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \ + --output_dir /workspace/models/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/decoder \ + --kv_cache_type paged \ + --moe_plugin disable \ + --max_beam_width ${MAX_BEAM_WIDTH} \ + --max_batch_size ${BATCH_SIZE} \ + --max_input_len 1 \ + --max_seq_len 1024 \ + --max_encoder_input_len $((1024 * BATCH_SIZE)) \ + --gemm_plugin ${INFERENCE_PRECISION} \ + --bert_attention_plugin ${INFERENCE_PRECISION} \ + --gpt_attention_plugin ${INFERENCE_PRECISION} \ + --remove_input_padding enable \ + --workers 1 + +python build_multimodal_engine.py \ + --model_type ${MODEL_TYPE} \ + --model_path /workspace/models/hf_models/${MODEL_NAME} \ + --output_dir /workspace/models/trt_engines/${MODEL_NAME}/${INFERENCE_PRECISION}/vision \ + --max_batch_size ${BATCH_SIZE} \ No newline at end of file diff --git a/tensorrt_llm/tools/multimodal_builder.py b/tensorrt_llm/tools/multimodal_builder.py index 95c4b066d2..06b5a35251 100644 --- a/tensorrt_llm/tools/multimodal_builder.py +++ b/tensorrt_llm/tools/multimodal_builder.py @@ -144,6 +144,8 @@ def build(self): build_qwen2_audio_engine(args) elif args.model_type == "pixtral": build_pixtral_engine(args) + elif args.model_type == "florence2": + build_florence2_engine(args) else: raise RuntimeError(f"Invalid model type {args.model_type}") @@ -1734,3 +1736,101 @@ def forward(self, pixel_values, attention_mask): max_batch_size=args.max_batch_size, engine_name=f"model.engine", dtype=torch.bfloat16) + + +def build_florence2_engine(args): + processor = AutoProcessor.from_pretrained(args.model_path) + + raw_image = Image.new("RGB", [10, 10]) # dummy image + prompt = "" + inputs = processor(raw_image, prompt, return_tensors="pt").to(args.device, torch.float16) + pixel_values = inputs["pixel_values"] + + class Florence2VisionWrapper(torch.nn.Module): + def __init__( + self, + vision_tower, + image_projection, + image_proj_norm, + image_pos_embed, + image_feature_source, + visual_temporal_embed, + ): + super().__init__() + self.vision_tower = vision_tower + self.image_proj_norm = image_proj_norm + self.image_pos_embed = image_pos_embed + self.image_projection = image_projection + self.visual_temporal_embed = visual_temporal_embed + self.image_feature_source = image_feature_source + + def forward(self, pixel_values): + assert len(pixel_values.shape) == 4, f"invalid image shape {pixel_values.shape}" + batch_size, _, _, _ = pixel_values.shape + T = 1 + x = self.vision_tower.forward_features_unpool(pixel_values) + + # image_pos_embed + x = x.view(batch_size * T, -1, x.shape[-1]) + num_tokens = x.shape[-2] + h, w = int(num_tokens**0.5), int(num_tokens**0.5) + assert h * w == num_tokens, "Only support square feature maps for now" + x = x.view(batch_size * T, h, w, x.shape[-1]) + pos_embed = self.image_pos_embed(x) + x = x + pos_embed + x = x.view(batch_size, T * h * w, x.shape[-1]) + + # visual_temporal_embed + visual_temporal_embed = self.visual_temporal_embed( + x.view(batch_size, T, -1, x.shape[-1])[:, :, 0] + ) + x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view( + 1, T, 1, x.shape[-1] + ) + + x_feat_dict = {} + + spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x + + temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) + x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x + + x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] + x_feat_dict["last_frame"] = x + + new_x = [] + for _image_feature_source in self.image_feature_source: + if _image_feature_source not in x_feat_dict: + raise ValueError( + "invalid image feature source: {}".format(_image_feature_source) + ) + new_x.append(x_feat_dict[_image_feature_source]) + + x = torch.cat(new_x, dim=1) + + x = x @ self.image_projection + x = self.image_proj_norm(x) + + return x + + model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16) + + wrapper = Florence2VisionWrapper( + model.vision_tower, + model.image_projection, + model.image_proj_norm, + model.image_pos_embed, + model.image_feature_source, + model.visual_temporal_embed, + ) + wrapper.to(args.device) + + export_onnx(wrapper, pixel_values, f"{args.output_dir}/onnx") + build_trt_engine( + args.model_type, + [pixel_values.shape[1], pixel_values.shape[2], pixel_values.shape[3]], # [3, H, W] + f"{args.output_dir}/onnx", + args.output_dir, + args.max_batch_size, + )