-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Description
What happened?
I noticed that while vllm expects inference args to be available to it, in practice this is not possible because it doesn't override the base ModelHandler validate_inference_args function, which doesn't allow any inference args. See:
https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/inference/base.py#L215
Given this, I did an audit of all our model handlers and found that most are not handling this correctly. Here are the results:
Currently handling inference args correctly by overriding validate_inference_args:
- vertex_ai_inference.py - passed through here among other places. Never dropped silently. validate_inference_args is overriden to silently pass
- tensorflow_inference.py - passed through here among other places. Never dropped silently. validate_inference_args is overriden to silently pass in multiple places
- pytorch_inference.py - passed through here among other places. Never dropped silently. validate_inference_args is overriden to silently pass in multiple places
Currently trying to consume inference args, but not overriding validate_inference_args:
- gemini_inference.py - passed through here and here. Never dropped silently. validate_inference_args not overridden
- huggingface_inference.py - passed through here and here among other places. Never dropped silently. validate_inference_args not overridden
- onnx_inference.py - passed through here. Intentionally added to here. Never dropped silently. validate_inference_args not overridden
- vllm_inference.py - passed through here among other places. Never dropped silently. validate_inference_args not overridden
- xgboost_inference.py - passed through here among other places. Never dropped silently. validate_inference_args not overridden
Currently not handling inference args, but should:
- sklearn_inference.py - passed through here among other places. Silently dropped here and here (probably wrongly) validate_inference_args not overridden.
Currently not handling inference args correctly:
- tensorrt_inference.py - silently dropped here. This may be correct, inference args are hard to map to meaning here.
The underlying rationale for disallowing inference args (Because most frameworks do not need extra arguments in their predict() call, the default behavior is to error out if inference_args are present.) clearly does not actually hold. Given that, we should:
- Change the base ModelHandler behavior to allow inference args
- Add the validation to model handlers that need it. At this time only tensorrt_inference.py needs it, and sklearn.py needs to be updated to correctly consume the inference args.
- Mention this in CHANGES.md since it is a behavior change (though not breaking)
Issue Priority
Priority: 2 (default / most bugs should be filed as P2)
Issue Components
- Component: Python SDK
- Component: Java SDK
- Component: Go SDK
- Component: Typescript SDK
- Component: IO connector
- Component: Beam YAML
- Component: Beam examples
- Component: Beam playground
- Component: Beam katas
- Component: Website
- Component: Infrastructure
- Component: Spark Runner
- Component: Flink Runner
- Component: Samza Runner
- Component: Twister2 Runner
- Component: Hazelcast Jet Runner
- Component: Google Cloud Dataflow Runner