diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 760f092f0e2b..24b54ad3d801 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -441,28 +441,6 @@ }, } -NO_FEATURE_EXTRACTOR_TASKS = set() -NO_IMAGE_PROCESSOR_TASKS = set() -NO_TOKENIZER_TASKS = set() - -# Those model configs are special, they are generic over their task, meaning -# any tokenizer/feature_extractor might be use for a given model so we cannot -# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to -# see if the model defines such objects or not. -MULTI_MODEL_AUDIO_CONFIGS = {"SpeechEncoderDecoderConfig"} -MULTI_MODEL_VISION_CONFIGS = {"VisionEncoderDecoderConfig", "VisionTextDualEncoderConfig"} -for task, values in SUPPORTED_TASKS.items(): - if values["type"] == "text": - NO_FEATURE_EXTRACTOR_TASKS.add(task) - NO_IMAGE_PROCESSOR_TASKS.add(task) - elif values["type"] in {"image", "video"}: - NO_TOKENIZER_TASKS.add(task) - elif values["type"] in {"audio"}: - NO_TOKENIZER_TASKS.add(task) - NO_IMAGE_PROCESSOR_TASKS.add(task) - elif values["type"] != "multimodal": - raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}") - PIPELINE_REGISTRY = PipelineRegistry(supported_tasks=SUPPORTED_TASKS, task_aliases=TASK_ALIASES) @@ -1029,205 +1007,169 @@ def pipeline( **model_kwargs, ) - model_config = model.config hub_kwargs["_commit_hash"] = model.config._commit_hash - load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None - load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None - load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None - load_processor = type(model_config) in PROCESSOR_MAPPING or processor is not None - - # Check that pipeline class required loading - load_tokenizer = load_tokenizer and pipeline_class._load_tokenizer - load_feature_extractor = load_feature_extractor and pipeline_class._load_feature_extractor - load_image_processor = load_image_processor and pipeline_class._load_image_processor - load_processor = load_processor and pipeline_class._load_processor - - # If `model` (instance of `PretrainedModel` instead of `str`) is passed (and/or same for config), while - # `image_processor` or `feature_extractor` is `None`, the loading will fail. This happens particularly for some - # vision tasks when calling `pipeline()` with `model` and only one of the `image_processor` and `feature_extractor`. - # TODO: we need to make `NO_IMAGE_PROCESSOR_TASKS` and `NO_FEATURE_EXTRACTOR_TASKS` more robust to avoid such issue. - # This block is only temporarily to make CI green. - if load_image_processor and load_feature_extractor: - load_feature_extractor = False - - if ( - tokenizer is None - and not load_tokenizer - and normalized_task not in NO_TOKENIZER_TASKS - # Using class name to avoid importing the real class. - and ( - model_config.__class__.__name__ in MULTI_MODEL_AUDIO_CONFIGS - or model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS - ) - ): - # This is a special category of models, that are fusions of multiple models - # so the model_config might not define a tokenizer, but it seems to be - # necessary for the task, so we're force-trying to load it. - load_tokenizer = True - if ( - image_processor is None - and not load_image_processor - and normalized_task not in NO_IMAGE_PROCESSOR_TASKS - # Using class name to avoid importing the real class. - and model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS - ): - # This is a special category of models, that are fusions of multiple models - # so the model_config might not define a tokenizer, but it seems to be - # necessary for the task, so we're force-trying to load it. - load_image_processor = True - if ( - feature_extractor is None - and not load_feature_extractor - and normalized_task not in NO_FEATURE_EXTRACTOR_TASKS - # Using class name to avoid importing the real class. - and model_config.__class__.__name__ in MULTI_MODEL_AUDIO_CONFIGS - ): - # This is a special category of models, that are fusions of multiple models - # so the model_config might not define a tokenizer, but it seems to be - # necessary for the task, so we're force-trying to load it. - load_feature_extractor = True - - if task in NO_TOKENIZER_TASKS: - # These will never require a tokenizer. - # the model on the other hand might have a tokenizer, but - # the files could be missing from the hub, instead of failing - # on such repos, we just force to not load it. - load_tokenizer = False - - if task in NO_FEATURE_EXTRACTOR_TASKS: - load_feature_extractor = False - if task in NO_IMAGE_PROCESSOR_TASKS: - load_image_processor = False - - if load_tokenizer: - # Try to infer tokenizer from model or config name (if provided as str) - if tokenizer is None: - if isinstance(model_name, str): - tokenizer = model_name - elif isinstance(config, str): - tokenizer = config - else: - # Impossible to guess what is the right tokenizer here - raise Exception( - "Impossible to guess which tokenizer to use. " - "Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer." + # Check which preprocessing classes the pipeline uses + # None values indicate optional classes that the pipeline can run without, we don't raise errors if loading fails + load_tokenizer = pipeline_class._load_tokenizer + load_feature_extractor = pipeline_class._load_feature_extractor + load_image_processor = pipeline_class._load_image_processor + load_processor = pipeline_class._load_processor + + if load_tokenizer or load_tokenizer is None: + try: + # Try to infer tokenizer from model or config name (if provided as str) + if tokenizer is None: + if isinstance(model_name, str): + tokenizer = model_name + elif isinstance(config, str): + tokenizer = config + else: + # Impossible to guess what is the right tokenizer here + raise Exception( + "Impossible to guess which tokenizer to use. " + "Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer." + ) + + # Instantiate tokenizer if needed + if isinstance(tokenizer, (str, tuple)): + if isinstance(tokenizer, tuple): + # For tuple we have (tokenizer name, {kwargs}) + use_fast = tokenizer[1].pop("use_fast", use_fast) + tokenizer_identifier = tokenizer[0] + tokenizer_kwargs = tokenizer[1] + else: + tokenizer_identifier = tokenizer + tokenizer_kwargs = model_kwargs.copy() + tokenizer_kwargs.pop("torch_dtype", None) + + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_identifier, use_fast=use_fast, _from_pipeline=task, **hub_kwargs, **tokenizer_kwargs ) - - # Instantiate tokenizer if needed - if isinstance(tokenizer, (str, tuple)): - if isinstance(tokenizer, tuple): - # For tuple we have (tokenizer name, {kwargs}) - use_fast = tokenizer[1].pop("use_fast", use_fast) - tokenizer_identifier = tokenizer[0] - tokenizer_kwargs = tokenizer[1] + except Exception as e: + if load_tokenizer: + raise e else: - tokenizer_identifier = tokenizer - tokenizer_kwargs = model_kwargs.copy() - tokenizer_kwargs.pop("torch_dtype", None) - - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_identifier, use_fast=use_fast, _from_pipeline=task, **hub_kwargs, **tokenizer_kwargs - ) - - if load_image_processor: - # Try to infer image processor from model or config name (if provided as str) - if image_processor is None: - if isinstance(model_name, str): - image_processor = model_name - elif isinstance(config, str): - image_processor = config - # Backward compatibility, as `feature_extractor` used to be the name - # for `ImageProcessor`. - elif feature_extractor is not None and isinstance(feature_extractor, BaseImageProcessor): - image_processor = feature_extractor - else: - # Impossible to guess what is the right image_processor here - raise Exception( - "Impossible to guess which image processor to use. " - "Please provide a PreTrainedImageProcessor class or a path/identifier " - "to a pretrained image processor." + tokenizer = None + + if load_image_processor or load_image_processor is None: + try: + # Try to infer image processor from model or config name (if provided as str) + if image_processor is None: + if isinstance(model_name, str): + image_processor = model_name + elif isinstance(config, str): + image_processor = config + # Backward compatibility, as `feature_extractor` used to be the name + # for `ImageProcessor`. + elif feature_extractor is not None and isinstance(feature_extractor, BaseImageProcessor): + image_processor = feature_extractor + else: + # Impossible to guess what is the right image_processor here + raise Exception( + "Impossible to guess which image processor to use. " + "Please provide a PreTrainedImageProcessor class or a path/identifier " + "to a pretrained image processor." + ) + + # Instantiate image_processor if needed + if isinstance(image_processor, (str, tuple)): + image_processor = AutoImageProcessor.from_pretrained( + image_processor, _from_pipeline=task, **hub_kwargs, **model_kwargs ) - - # Instantiate image_processor if needed - if isinstance(image_processor, (str, tuple)): - image_processor = AutoImageProcessor.from_pretrained( - image_processor, _from_pipeline=task, **hub_kwargs, **model_kwargs - ) - - if load_feature_extractor: - # Try to infer feature extractor from model or config name (if provided as str) - if feature_extractor is None: - if isinstance(model_name, str): - feature_extractor = model_name - elif isinstance(config, str): - feature_extractor = config + except Exception as e: + if load_image_processor: + raise e else: - # Impossible to guess what is the right feature_extractor here - raise Exception( - "Impossible to guess which feature extractor to use. " - "Please provide a PreTrainedFeatureExtractor class or a path/identifier " - "to a pretrained feature extractor." + image_processor = None + + if load_feature_extractor or load_feature_extractor is None: + try: + # Try to infer feature extractor from model or config name (if provided as str) + if feature_extractor is None: + if isinstance(model_name, str): + feature_extractor = model_name + elif isinstance(config, str): + feature_extractor = config + else: + # Impossible to guess what is the right feature_extractor here + raise Exception( + "Impossible to guess which feature extractor to use. " + "Please provide a PreTrainedFeatureExtractor class or a path/identifier " + "to a pretrained feature extractor." + ) + + # Instantiate feature_extractor if needed + if isinstance(feature_extractor, (str, tuple)): + feature_extractor = AutoFeatureExtractor.from_pretrained( + feature_extractor, _from_pipeline=task, **hub_kwargs, **model_kwargs ) - # Instantiate feature_extractor if needed - if isinstance(feature_extractor, (str, tuple)): - feature_extractor = AutoFeatureExtractor.from_pretrained( - feature_extractor, _from_pipeline=task, **hub_kwargs, **model_kwargs - ) - - if ( - feature_extractor._processor_class - and feature_extractor._processor_class.endswith("WithLM") - and isinstance(model_name, str) - ): - try: - import kenlm # to trigger `ImportError` if not installed - from pyctcdecode import BeamSearchDecoderCTC - - if os.path.isdir(model_name) or os.path.isfile(model_name): - decoder = BeamSearchDecoderCTC.load_from_dir(model_name) - else: - language_model_glob = os.path.join( - BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*" + if ( + feature_extractor._processor_class + and feature_extractor._processor_class.endswith("WithLM") + and isinstance(model_name, str) + ): + try: + import kenlm # to trigger `ImportError` if not installed + from pyctcdecode import BeamSearchDecoderCTC + + if os.path.isdir(model_name) or os.path.isfile(model_name): + decoder = BeamSearchDecoderCTC.load_from_dir(model_name) + else: + language_model_glob = os.path.join( + BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*" + ) + alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME + allow_patterns = [language_model_glob, alphabet_filename] + decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_patterns=allow_patterns) + + kwargs["decoder"] = decoder + except ImportError as e: + logger.warning( + f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Error: {e}" ) - alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME - allow_patterns = [language_model_glob, alphabet_filename] - decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_patterns=allow_patterns) - - kwargs["decoder"] = decoder - except ImportError as e: - logger.warning(f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Error: {e}") - if not is_kenlm_available(): - logger.warning("Try to install `kenlm`: `pip install kenlm") - - if not is_pyctcdecode_available(): - logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode") - - if load_processor: - # Try to infer processor from model or config name (if provided as str) - if processor is None: - if isinstance(model_name, str): - processor = model_name - elif isinstance(config, str): - processor = config + if not is_kenlm_available(): + logger.warning("Try to install `kenlm`: `pip install kenlm") + + if not is_pyctcdecode_available(): + logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode") + except Exception as e: + if load_feature_extractor: + raise e else: - # Impossible to guess what is the right processor here - raise Exception( - "Impossible to guess which processor to use. " - "Please provide a processor instance or a path/identifier " - "to a processor." - ) - - # Instantiate processor if needed - if isinstance(processor, (str, tuple)): - processor = AutoProcessor.from_pretrained(processor, _from_pipeline=task, **hub_kwargs, **model_kwargs) - if not isinstance(processor, ProcessorMixin): - raise TypeError( - "Processor was loaded, but it is not an instance of `ProcessorMixin`. " - f"Got type `{type(processor)}` instead. Please check that you specified " - "correct pipeline task for the model and model has processor implemented and saved." - ) + feature_extractor = None + + if load_processor or load_processor is None: + try: + # Try to infer processor from model or config name (if provided as str) + if processor is None: + if isinstance(model_name, str): + processor = model_name + elif isinstance(config, str): + processor = config + else: + # Impossible to guess what is the right processor here + raise Exception( + "Impossible to guess which processor to use. " + "Please provide a processor instance or a path/identifier " + "to a processor." + ) + + # Instantiate processor if needed + if isinstance(processor, (str, tuple)): + processor = AutoProcessor.from_pretrained(processor, _from_pipeline=task, **hub_kwargs, **model_kwargs) + if not isinstance(processor, ProcessorMixin): + raise TypeError( + "Processor was loaded, but it is not an instance of `ProcessorMixin`. " + f"Got type `{type(processor)}` instead. Please check that you specified " + "correct pipeline task for the model and model has processor implemented and saved." + ) + except Exception as e: + if load_processor: + raise e + else: + processor = None if task == "translation" and model.config.task_specific_params: for key in model.config.task_specific_params: diff --git a/src/transformers/pipelines/audio_classification.py b/src/transformers/pipelines/audio_classification.py index 1715d247e520..d7552d9419e1 100644 --- a/src/transformers/pipelines/audio_classification.py +++ b/src/transformers/pipelines/audio_classification.py @@ -90,6 +90,11 @@ class AudioClassificationPipeline(Pipeline): [huggingface.co/models](https://huggingface.co/models?filter=audio-classification). """ + _load_processor = False + _load_image_processor = False + _load_feature_extractor = True + _load_tokenizer = False + def __init__(self, *args, **kwargs): # Only set default top_k if explicitly provided if "top_k" in kwargs and kwargs["top_k"] is None: diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index a24493767ef6..4dccc43e22e3 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -178,6 +178,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): """ _pipeline_calls_generate = True + _load_processor = False + _load_image_processor = False + _load_feature_extractor = True + _load_tokenizer = True # Make sure the docstring is updated when the default generation config is changed _default_generation_config = GenerationConfig( max_new_tokens=256, diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index cc32871c12de..5af0092b6f76 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -898,21 +898,15 @@ class Pipeline(_ScikitCompat, PushToHubMixin): constructor argument. If set to `True`, the output will be stored in the pickle format. """ - # Historically we have pipelines working with `tokenizer`, `feature_extractor`, and `image_processor` - # as separate processing components. While we have `processor` class that combines them, some pipelines - # might still operate with these components separately. - # With the addition of `processor` to `pipeline`, we want to avoid: - # - loading `processor` for pipelines that still work with `image_processor` and `tokenizer` separately; - # - loading `image_processor`/`tokenizer` as a separate component while we operate only with `processor`, - # because `processor` will load required sub-components by itself. - # Below flags allow granular control over loading components and set to be backward compatible with current - # pipelines logic. You may override these flags when creating your pipeline. For example, for - # `zero-shot-object-detection` pipeline which operates with `processor` you should set `_load_processor=True` - # and all the rest flags to `False` to avoid unnecessary loading of the components. - _load_processor = False - _load_image_processor = True - _load_feature_extractor = True - _load_tokenizer = True + # These flags should be overridden for downstream pipelines. They indicate which preprocessing classes are + # used by each pipeline. The possible values are: + # - True (the class is mandatory, raise an error if it's not present in the repo) + # - None (the class is optional; it should be loaded if present in the repo but the pipeline can work without it) + # - False (the class is never used by the pipeline and should not be loaded even if present) + _load_processor = None + _load_image_processor = None + _load_feature_extractor = None + _load_tokenizer = None # Pipelines that call `generate` have shared logic, e.g. preparing the generation config. _pipeline_calls_generate = False diff --git a/src/transformers/pipelines/depth_estimation.py b/src/transformers/pipelines/depth_estimation.py index 50de4c1ca6df..cd9df1317105 100644 --- a/src/transformers/pipelines/depth_estimation.py +++ b/src/transformers/pipelines/depth_estimation.py @@ -47,6 +47,11 @@ class DepthEstimationPipeline(Pipeline): See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=depth-estimation). """ + _load_processor = False + _load_image_processor = True + _load_feature_extractor = False + _load_tokenizer = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) requires_backends(self, "vision") diff --git a/src/transformers/pipelines/document_question_answering.py b/src/transformers/pipelines/document_question_answering.py index 1aa4ab3a500f..1d3c5f2f7353 100644 --- a/src/transformers/pipelines/document_question_answering.py +++ b/src/transformers/pipelines/document_question_answering.py @@ -135,6 +135,10 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): """ _pipeline_calls_generate = True + _load_processor = False + _load_image_processor = None + _load_feature_extractor = None + _load_tokenizer = False # Make sure the docstring is updated when the default generation config is changed _default_generation_config = GenerationConfig( max_new_tokens=256, diff --git a/src/transformers/pipelines/feature_extraction.py b/src/transformers/pipelines/feature_extraction.py index 83ea91ee6ae5..9c8005d05f22 100644 --- a/src/transformers/pipelines/feature_extraction.py +++ b/src/transformers/pipelines/feature_extraction.py @@ -37,6 +37,11 @@ class FeatureExtractionPipeline(Pipeline): [huggingface.co/models](https://huggingface.co/models). """ + _load_processor = False + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = True + def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs): if tokenize_kwargs is None: tokenize_kwargs = {} diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py index c4076fbaf19f..f5dbe71dadff 100644 --- a/src/transformers/pipelines/fill_mask.py +++ b/src/transformers/pipelines/fill_mask.py @@ -32,6 +32,11 @@ Additional dictionary of keyword arguments passed along to the tokenizer.""", ) class FillMaskPipeline(Pipeline): + _load_processor = False + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = True + """ Masked language modeling prediction pipeline using any `ModelWithLMHead`. See the [masked language modeling examples](../task_summary#masked-language-modeling) for more information. diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index f1156054c2d7..95d525776c9a 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -99,6 +99,10 @@ class ImageClassificationPipeline(Pipeline): """ function_to_apply: ClassificationFunction = ClassificationFunction.NONE + _load_processor = False + _load_image_processor = True + _load_feature_extractor = False + _load_tokenizer = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/pipelines/image_feature_extraction.py b/src/transformers/pipelines/image_feature_extraction.py index 1333bf6fb8f2..7193a59249a6 100644 --- a/src/transformers/pipelines/image_feature_extraction.py +++ b/src/transformers/pipelines/image_feature_extraction.py @@ -45,6 +45,11 @@ class ImageFeatureExtractionPipeline(Pipeline): [huggingface.co/models](https://huggingface.co/models). """ + _load_processor = False + _load_image_processor = True + _load_feature_extractor = False + _load_tokenizer = False + def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, pool=None, **kwargs): preprocess_params = {} if image_processor_kwargs is None else image_processor_kwargs diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py index 8947dcf8cba3..3d113e776810 100644 --- a/src/transformers/pipelines/image_segmentation.py +++ b/src/transformers/pipelines/image_segmentation.py @@ -60,6 +60,11 @@ class ImageSegmentationPipeline(Pipeline): [huggingface.co/models](https://huggingface.co/models?filter=image-segmentation). """ + _load_processor = False + _load_image_processor = True + _load_feature_extractor = False + _load_tokenizer = None # Oneformer uses it but no-one else does + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/pipelines/image_to_image.py b/src/transformers/pipelines/image_to_image.py index a16c2b4d6ba1..52fc99bd05f1 100644 --- a/src/transformers/pipelines/image_to_image.py +++ b/src/transformers/pipelines/image_to_image.py @@ -67,6 +67,11 @@ class ImageToImagePipeline(Pipeline): See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=image-to-image). """ + _load_processor = False + _load_image_processor = True + _load_feature_extractor = False + _load_tokenizer = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) requires_backends(self, "vision") diff --git a/src/transformers/pipelines/image_to_text.py b/src/transformers/pipelines/image_to_text.py index 1ad25a7c26dc..86adf402cce6 100644 --- a/src/transformers/pipelines/image_to_text.py +++ b/src/transformers/pipelines/image_to_text.py @@ -72,6 +72,10 @@ class ImageToTextPipeline(Pipeline): """ _pipeline_calls_generate = True + _load_processor = False + _load_image_processor = True + _load_feature_extractor = False + _load_tokenizer = True # Make sure the docstring is updated when the default generation config is changed _default_generation_config = GenerationConfig( max_new_tokens=256, diff --git a/src/transformers/pipelines/mask_generation.py b/src/transformers/pipelines/mask_generation.py index 31b168a6f664..76d3a5be0e97 100644 --- a/src/transformers/pipelines/mask_generation.py +++ b/src/transformers/pipelines/mask_generation.py @@ -84,6 +84,11 @@ class MaskGenerationPipeline(ChunkPipeline): See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation). """ + _load_processor = False + _load_image_processor = True + _load_feature_extractor = False + _load_tokenizer = False + def __init__(self, **kwargs): super().__init__(**kwargs) requires_backends(self, "vision") diff --git a/src/transformers/pipelines/object_detection.py b/src/transformers/pipelines/object_detection.py index 868b9ad2ea77..284c9e19079e 100644 --- a/src/transformers/pipelines/object_detection.py +++ b/src/transformers/pipelines/object_detection.py @@ -48,6 +48,11 @@ class ObjectDetectionPipeline(Pipeline): See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=object-detection). """ + _load_processor = False + _load_image_processor = True + _load_feature_extractor = False + _load_tokenizer = None + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index 12a0d09a00ea..2fe92d747e81 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -156,6 +156,11 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler): supplied arguments. """ + _load_processor = False + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = True + def normalize(self, item): if isinstance(item, SquadExample): return item diff --git a/src/transformers/pipelines/table_question_answering.py b/src/transformers/pipelines/table_question_answering.py index 4e941c4b6edb..54a65ad77f70 100644 --- a/src/transformers/pipelines/table_question_answering.py +++ b/src/transformers/pipelines/table_question_answering.py @@ -122,6 +122,10 @@ class TableQuestionAnsweringPipeline(Pipeline): default_input_names = "table,query" _pipeline_calls_generate = True + _load_processor = False + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = True # Make sure the docstring is updated when the default generation config is changed _default_generation_config = GenerationConfig( max_new_tokens=256, diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index 8800a91bbb49..c9333b760da6 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -68,6 +68,10 @@ class Text2TextGenerationPipeline(Pipeline): ```""" _pipeline_calls_generate = True + _load_processor = False + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = True # Make sure the docstring is updated when the default generation config is changed (in all pipelines in this file) _default_generation_config = GenerationConfig( max_new_tokens=256, diff --git a/src/transformers/pipelines/text_classification.py b/src/transformers/pipelines/text_classification.py index 4ea0fc8b658d..367d867d0e80 100644 --- a/src/transformers/pipelines/text_classification.py +++ b/src/transformers/pipelines/text_classification.py @@ -78,6 +78,11 @@ class TextClassificationPipeline(Pipeline): [huggingface.co/models](https://huggingface.co/models?filter=text-classification). """ + _load_processor = False + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = True + return_all_scores = False function_to_apply = ClassificationFunction.NONE diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 3de9a87c0f24..d7eb3b54f8f4 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -105,6 +105,11 @@ class TextGenerationPipeline(Pipeline): """ _pipeline_calls_generate = True + _load_processor = False + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = True + # Make sure the docstring is updated when the default generation config is changed _default_generation_config = GenerationConfig( max_new_tokens=256, diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 79b0e9b35f3a..2c591914f0be 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -81,6 +81,11 @@ class TextToAudioPipeline(Pipeline): """ _pipeline_calls_generate = True + _load_processor = False + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = True + # Make sure the docstring is updated when the default generation config is changed _default_generation_config = GenerationConfig( max_new_tokens=256, diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index e976ee9b1874..de6107c56b31 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -133,6 +133,11 @@ class TokenClassificationPipeline(ChunkPipeline): default_input_names = "sequences" + _load_processor = False + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = True + def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs): super().__init__(*args, **kwargs) self.check_model_type( diff --git a/src/transformers/pipelines/video_classification.py b/src/transformers/pipelines/video_classification.py index 761e21512875..0d601ca8f1e5 100644 --- a/src/transformers/pipelines/video_classification.py +++ b/src/transformers/pipelines/video_classification.py @@ -51,6 +51,11 @@ class VideoClassificationPipeline(Pipeline): [huggingface.co/models](https://huggingface.co/models?filter=video-classification). """ + _load_processor = True + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) requires_backends(self, "av") diff --git a/src/transformers/pipelines/visual_question_answering.py b/src/transformers/pipelines/visual_question_answering.py index 1df9986ce4db..9be1b97ccd57 100644 --- a/src/transformers/pipelines/visual_question_answering.py +++ b/src/transformers/pipelines/visual_question_answering.py @@ -57,6 +57,11 @@ class VisualQuestionAnsweringPipeline(Pipeline): [huggingface.co/models](https://huggingface.co/models?filter=visual-question-answering). """ + _load_processor = False + _load_image_processor = True + _load_feature_extractor = False + _load_tokenizer = True + _pipeline_calls_generate = True # Make sure the docstring is updated when the default generation config is changed _default_generation_config = GenerationConfig( diff --git a/src/transformers/pipelines/zero_shot_audio_classification.py b/src/transformers/pipelines/zero_shot_audio_classification.py index f554dfb5fc0d..ed988f70b120 100644 --- a/src/transformers/pipelines/zero_shot_audio_classification.py +++ b/src/transformers/pipelines/zero_shot_audio_classification.py @@ -60,6 +60,11 @@ class ZeroShotAudioClassificationPipeline(Pipeline): [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-audio-classification). """ + _load_processor = False + _load_image_processor = False + _load_feature_extractor = True + _load_tokenizer = True + def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index c23f1c544cb7..b571a7896b72 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -82,6 +82,11 @@ class ZeroShotClassificationPipeline(ChunkPipeline): of available models on [huggingface.co/models](https://huggingface.co/models?search=nli). """ + _load_processor = False + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = True + def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs): self._args_parser = args_parser super().__init__(*args, **kwargs) diff --git a/src/transformers/pipelines/zero_shot_image_classification.py b/src/transformers/pipelines/zero_shot_image_classification.py index 2629185ba314..cb3c06fd238f 100644 --- a/src/transformers/pipelines/zero_shot_image_classification.py +++ b/src/transformers/pipelines/zero_shot_image_classification.py @@ -64,6 +64,11 @@ class ZeroShotImageClassificationPipeline(Pipeline): [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-image-classification). """ + _load_processor = False + _load_image_processor = True + _load_feature_extractor = False + _load_tokenizer = True + def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/pipelines/zero_shot_object_detection.py b/src/transformers/pipelines/zero_shot_object_detection.py index 024b90356319..46426fed3923 100644 --- a/src/transformers/pipelines/zero_shot_object_detection.py +++ b/src/transformers/pipelines/zero_shot_object_detection.py @@ -53,6 +53,11 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline): [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-object-detection). """ + _load_processor = False + _load_image_processor = True + _load_feature_extractor = False + _load_tokenizer = True + def __init__(self, **kwargs): super().__init__(**kwargs)