Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[
# pick the first item of the list as best guess (it's almost always a list of length 1 anyway)
distribution_name = pkg_name if pkg_name in distributions else distributions[0]
package_version = importlib.metadata.version(distribution_name)
except importlib.metadata.PackageNotFoundError:
except (importlib.metadata.PackageNotFoundError, KeyError):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what is the issue here that requires this new KeyError?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When i run pytest -rA tests/models/video_llama_3/test_modeling_video_llama_3.py::VideoLlama3ModelTest::test_generate_with_quant_cache, it will fail and return error:

pkg_name = 'optimum.quanto', return_version = True

    def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
        """Check if `pkg_name` exist, and optionally try to get its version"""
        spec = importlib.util.find_spec(pkg_name)
        package_exists = spec is not None
        package_version = "N/A"
        if package_exists and return_version:
            try:
                # importlib.metadata works with the distribution package, which may be different from the import
                # name (e.g. `PIL` is the import name, but `pillow` is the distribution name)
>               distributions = PACKAGE_DISTRIBUTION_MAPPING[pkg_name]
E               KeyError: 'optimum.quanto'

src/transformers/utils/import_utils.py:56: KeyError

Here is to fix this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaixuanliu it shows 'optimum': ['optimum-quanto', 'optimum-onnx', 'optimum'], in PACKAGE_DISTRIBUTION_MAPPING, so we may need another logic to search optimum-quanto distribution in optimum package, rather than using your current logic, because the package actually there, just we find it in a wrong way in this case. @ydshieh , WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check ,thank you for explaining the issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, when we catch the keyError exception, it will load optimum-quanto in package = importlib.import_module(pkg_name), where pkg_name is optimum.quanto, and it can be imported properly.

# If we cannot find the metadata (because of editable install for example), try to import directly.
# Note that this branch will almost never be run, so we do not import packages for nothing here
package = importlib.import_module(pkg_name)
Expand Down
28 changes: 23 additions & 5 deletions tests/models/video_llama_3/test_modeling_video_llama_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
is_torch_available,
)
from transformers.testing_utils import (
Expectations,
backend_empty_cache,
require_flash_attn,
require_torch,
Expand Down Expand Up @@ -831,7 +832,14 @@ def test_small_model_integration_test(self):
torch.testing.assert_close(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=1e-4, rtol=1e-4)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False, repetition_penalty=None)
EXPECTED_DECODED_TEXT = "user\n\nDescribe the image.\nassistant\nThe image captures a vibrant nighttime scene on a bustling city street. A woman in a striking red dress"
# fmt: off
EXPECTED_DECODED_TEXT = Expectations(
{
("cuda", None): "user\n\nDescribe the image.\nassistant\nThe image captures a vibrant nighttime scene on a bustling city street. A woman in a striking red dress",
("xpu", None): "user\n\nDescribe the image.\nassistant\nThe image captures a vibrant night scene in a bustling Japanese city. A woman in a striking red dress",
}
).get_expectation()
# fmt: on

self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
Expand Down Expand Up @@ -874,11 +882,21 @@ def test_small_model_integration_test_batch_wo_image(self):

# it should not matter whether two images are the same size or not
output = model.generate(**inputs, max_new_tokens=20, do_sample=False, repetition_penalty=None)
# fmt: off
EXPECTED_DECODED_TEXT = Expectations(
{
("cuda", None): [
"user\n\nDescribe the image.\nassistant\nThe image captures a vibrant nighttime scene on a bustling city street. A woman in a striking red dress",
"user\nWhat is relativity?\nassistant\nRelativity is a scientific theory that describes the relationship between space and time. It was first proposed by",
],
("xpu", None): [
"user\n\nDescribe the image.\nassistant\nThe image captures a vibrant night scene in a bustling Japanese city. A woman in a striking red dress",
"user\nWhat is relativity?\nassistant\nRelativity is a scientific theory that describes the relationship between space and time. It was first proposed by",
],
}
).get_expectation()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above

# fmt: on

EXPECTED_DECODED_TEXT = [
"user\n\nDescribe the image.\nassistant\nThe image captures a vibrant nighttime scene on a bustling city street. A woman in a striking red dress",
"user\nWhat is relativity?\nassistant\nRelativity is a scientific theory that describes the relationship between space and time. It was first proposed by",
] # fmt: skip
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
Expand Down