Skip to content

[Bug]: Model Handlers not handling inference args correctly #37093

@damccorm

Description

@damccorm

What happened?

I noticed that while vllm expects inference args to be available to it, in practice this is not possible because it doesn't override the base ModelHandler validate_inference_args function, which doesn't allow any inference args. See:

https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/inference/base.py#L215

Given this, I did an audit of all our model handlers and found that most are not handling this correctly. Here are the results:

Currently handling inference args correctly by overriding validate_inference_args:

  • vertex_ai_inference.py - passed through here among other places. Never dropped silently. validate_inference_args is overriden to silently pass
  • tensorflow_inference.py - passed through here among other places. Never dropped silently. validate_inference_args is overriden to silently pass in multiple places
  • pytorch_inference.py - passed through here among other places. Never dropped silently. validate_inference_args is overriden to silently pass in multiple places

Currently trying to consume inference args, but not overriding validate_inference_args:

  • gemini_inference.py - passed through here and here. Never dropped silently. validate_inference_args not overridden
  • huggingface_inference.py - passed through here and here among other places. Never dropped silently. validate_inference_args not overridden
  • onnx_inference.py - passed through here. Intentionally added to here. Never dropped silently. validate_inference_args not overridden
  • vllm_inference.py - passed through here among other places. Never dropped silently. validate_inference_args not overridden
  • xgboost_inference.py - passed through here among other places. Never dropped silently. validate_inference_args not overridden

Currently not handling inference args, but should:

  • sklearn_inference.py - passed through here among other places. Silently dropped here and here (probably wrongly) validate_inference_args not overridden.

Currently not handling inference args correctly:

  • tensorrt_inference.py - silently dropped here. This may be correct, inference args are hard to map to meaning here.

The underlying rationale for disallowing inference args (Because most frameworks do not need extra arguments in their predict() call, the default behavior is to error out if inference_args are present.) clearly does not actually hold. Given that, we should:

  1. Change the base ModelHandler behavior to allow inference args
  2. Add the validation to model handlers that need it. At this time only tensorrt_inference.py needs it, and sklearn.py needs to be updated to correctly consume the inference args.
  3. Mention this in CHANGES.md since it is a behavior change (though not breaking)

Issue Priority

Priority: 2 (default / most bugs should be filed as P2)

Issue Components

  • Component: Python SDK
  • Component: Java SDK
  • Component: Go SDK
  • Component: Typescript SDK
  • Component: IO connector
  • Component: Beam YAML
  • Component: Beam examples
  • Component: Beam playground
  • Component: Beam katas
  • Component: Website
  • Component: Infrastructure
  • Component: Spark Runner
  • Component: Flink Runner
  • Component: Samza Runner
  • Component: Twister2 Runner
  • Component: Hazelcast Jet Runner
  • Component: Google Cloud Dataflow Runner

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions