Skip to content

Commit

Permalink
replace oneshot_device with device_map
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Feb 19, 2025
1 parent 4a34d0c commit cd893e5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 30 deletions.
14 changes: 11 additions & 3 deletions src/llmcompressor/args/model_arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass, field
from typing import Optional
from typing import Dict, Optional, Union

import torch


@dataclass
Expand Down Expand Up @@ -80,8 +82,11 @@ class ModelArguments:
metadata={"help": "Whether to compress sparse models during save"},
)
oneshot_device: Optional[str] = field(
default="cuda:0",
metadata={"help": "Device to run oneshot calibration on"},
default=None,
metadata={
"help": "This field is deprecated, please use `device_map` instead. "
"Device to run oneshot calibration on"
},
)
model_revision: str = field(
default="main",
Expand All @@ -90,3 +95,6 @@ class ModelArguments:
"(can be a branch name, tag name or commit id)"
},
)
device_map: Dict[str, Union[int, str, torch.device]] = field(
default="auto", metadata={}
)
17 changes: 0 additions & 17 deletions src/llmcompressor/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"initialize_recipe",
"save_model_and_recipe",
"copy_python_files_from_model_cache",
"fallback_to_cpu",
"parse_dtype",
"get_session_model",
"get_completed_stages",
Expand Down Expand Up @@ -87,22 +86,6 @@ def save_model_and_recipe(
copy_python_files_from_model_cache(model, save_path)


def fallback_to_cpu(device: str) -> str:
"""
Takes in a device string and forces it to cpu if cuda is not available
:param device: device id to check
:return: device modified for CUDA status
"""
if "cuda" in device and not torch.cuda.is_available():
logger.warning(
f"Requested {device} but CUDA is not available, falling back to CPU"
)
return "cpu"

return device


def parse_dtype(dtype_arg: Union[str, torch.dtype]) -> torch.dtype:
"""
:param dtype_arg: dtype or string to parse
Expand Down
22 changes: 12 additions & 10 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
)
from llmcompressor.core import pre_initialize_structure, reset_session
from llmcompressor.pytorch.model_load.helpers import (
fallback_to_cpu,
get_session_model,
initialize_recipe,
parse_dtype,
Expand Down Expand Up @@ -150,10 +149,18 @@ def parse_args(**kwargs):
# raise depreciation warnings
if data_args.remove_columns is not None:
warnings.warn(
"`remove_columns` argument is depreciated. When tokenizing datasets, all "
"`remove_columns` argument is deprecated. When tokenizing datasets, all "
"columns which are invalid inputs the tokenizer will be removed",
DeprecationWarning,
)
if model_args.oneshot_device is not None:
warnings.warn(
"`oneshot_device` argument is deprecated. Please use `device_map` instead",
DeprecationWarning,
)
device_map_default = ModelArguments.__dataclass_fields__["device_map"].default
if model_args.device_map == device_map_default:
model_args.device_map = model_args.oneshot_device

# silently assign tokenizer to processor
if model_args.tokenizer:
Expand Down Expand Up @@ -233,25 +240,20 @@ def initialize_model_from_path(
else model_args.model_name_or_path
)

# Fallback to CPU if GPU requested and not available
model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)

# Trainer handles device assignment for FSDP and training, don't do mapping here
# if running oneshot outside of FSDP, apply user device settings

fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"

device_map = model_args.oneshot_device
if not fsdp_enabled and training_args is not None and training_args.do_train:
device_map = "auto"
logger.log("Detected FSDP training, setting device map to `auto`")
model_args.device_map = "auto"

model_kwargs = {
"config": config,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"torch_dtype": parse_dtype(model_args.precision),
"device_map": device_map,
"device_map": model_args.device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}

Expand Down

0 comments on commit cd893e5

Please sign in to comment.