Skip to content

Commit b9fd855

Browse files
committed
Sampling: Copy over iterable overrides
If an override was iterable, any modifications to the returned value would alter the reference to the global storage dict. Therefore, copy the structure if it's an iterable so any modification won't alter the original override. Also apply this for the function that checks for forced overrides. Signed-off-by: kingbri <[email protected]>
1 parent 0e9385e commit b9fd855

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

common/sampling.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pathlib
44
import yaml
5+
from copy import deepcopy
56
from loguru import logger
67
from pydantic import AliasChoices, BaseModel, Field
78
from typing import Dict, List, Optional, Union
@@ -376,14 +377,19 @@ def get_all_presets():
376377
def get_default_sampler_value(key, fallback=None):
377378
"""Gets an overridden default sampler value"""
378379

379-
return unwrap(overrides_container.overrides.get(key, {}).get("override"), fallback)
380+
default_value = unwrap(
381+
deepcopy(overrides_container.overrides.get(key, {}).get("override")),
382+
fallback,
383+
)
384+
385+
return default_value
380386

381387

382388
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
383389
"""Forcefully applies overrides if specified by the user"""
384390

385391
for var, value in overrides_container.overrides.items():
386-
override = value.get("override")
392+
override = deepcopy(value.get("override"))
387393
original_value = getattr(params, var, None)
388394

389395
# Force takes precedence over additive

common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ def coalesce(*args):
1515

1616

1717
def prune_dict(input_dict):
18-
"""Trim out instances of None from a dictionary"""
18+
"""Trim out instances of None from a dictionary."""
1919

2020
return {k: v for k, v in input_dict.items() if v is not None}

0 commit comments

Comments
 (0)