Skip to content

Commit b06ef8f

Browse files
committed
New Function: sample_balanced_dataset() in dataset.py
sample_balanced_dataset( data, sample_size=0.1, # 10% of data (or absolute count like 1000) balance_classes=True, # Enable class balancing balance_strategy="undersample", # Strategy for balancing seed=42 ) Three Balancing Strategies: | Strategy | Description | Use Case | |--------------|-------------------------------------|---------------------------------------------| | undersample | Caps all classes to equal counts | Maximum class balance | | sqrt | Uses √(count) weighting | Gentler balancing, keeps more majority data | | proportional | Maintains ratios with min guarantee | Light balancing with minimum representation | Config Parameters (added to DataConfig): - sample_size: Float (0-1) for percentage or int for absolute count - balance_classes: Enable/disable class balancing - balance_strategy: Choose balancing algorithm CLI Arguments: # Sample 10% with balanced classes python scripts/train.py --config configs/cpu_training.yaml \ --sample-size 0.1 --balance-classes # Sample 500 items with sqrt balancing python scripts/train.py --config configs/cpu_training.yaml \ --sample-size 500 --balance-classes --balance-strategy sqrt Example Results: Original: 6997 samples - {1: 39, 2: 69, ... 7: 1741, ...} 5% balanced: 349 samples - {1: 38, 2: 38, 3: 38, 4: 40, 5: 38, 6: 39, 7: 41, 8: 39, 9: 38} The class weights become nearly equal (~1.0) after balancing, which can help training on imbalanced datasets.
1 parent 50acb8f commit b06ef8f

7 files changed

Lines changed: 333 additions & 5 deletions

File tree

.github/workflows/tests.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ jobs:
3232
python-version: ['3.12']
3333

3434
steps:
35+
- name: Free up disk space
36+
if: runner.os == 'Linux'
37+
run: |
38+
sudo apt-get clean
39+
sudo apt-get autoclean
40+
sudo apt-get autoremove -y
41+
sudo rm -rf /usr/share/dotnet
42+
sudo rm -rf /usr/local/lib/android
43+
rm -rf ~/.cache/pip ~/.cache/pipenv ~/.npm
44+
df -h
3545
- name: Checkout repository
3646
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
3747

beauty_scorer/config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,25 @@ class DataConfig(BaseModel):
129129
default=(0.229, 0.224, 0.225),
130130
description="Normalization std (ImageNet default)",
131131
)
132+
# Sampling configuration
133+
sample_size: float | int | None = Field(
134+
default=None,
135+
description=(
136+
"Dataset sample size. Float (0-1) = fraction of data, "
137+
"int >= 1 = absolute count, None = use all data"
138+
),
139+
)
140+
balance_classes: bool = Field(
141+
default=False,
142+
description="Balance class distribution when sampling",
143+
)
144+
balance_strategy: Literal["undersample", "sqrt", "proportional"] = Field(
145+
default="undersample",
146+
description=(
147+
"Strategy for balancing: 'undersample' caps each class equally, "
148+
"'sqrt' uses square root weighting, 'proportional' maintains ratios"
149+
),
150+
)
132151

133152
@field_validator("image_size", "face_size", mode="before")
134153
@classmethod

beauty_scorer/data/dataset.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,232 @@ def compute_class_weights(
412412
weights = torch.clamp(weights, max=max_weight)
413413

414414
return weights
415+
416+
417+
def sample_balanced_dataset(
418+
data: list[dict],
419+
sample_size: int | float | None = None,
420+
balance_classes: bool = True,
421+
balance_strategy: str = "undersample",
422+
min_samples_per_class: int = 1,
423+
seed: int = 42,
424+
) -> list[dict]:
425+
"""
426+
Sample a subset of the dataset with optional class balancing.
427+
428+
This function is useful for:
429+
- Quick experimentation with smaller datasets
430+
- Reducing training time while maintaining class representation
431+
- Handling class imbalance by undersampling majority classes
432+
- Creating balanced mini-datasets for debugging or prototyping
433+
434+
Args:
435+
data: Full dataset list with 'score' key for each item.
436+
sample_size: Target sample size.
437+
- If float in (0.0, 1.0]: fraction of total data (e.g., 0.1 = 10%)
438+
- If int >= 1: absolute number of samples
439+
- If None: use all data (only balance if balance_classes=True)
440+
balance_classes: If True, attempts to balance class distribution.
441+
Each class will have roughly equal representation, limited by
442+
the smallest class size or the balance_strategy.
443+
balance_strategy: Strategy for balancing classes.
444+
- "undersample": Cap each class to target_per_class samples.
445+
Ensures balanced classes but may lose data from majority classes.
446+
- "sqrt": Use square root of original counts as weights. Reduces
447+
imbalance while preserving more majority class data.
448+
- "proportional": Maintain original distribution ratios but with
449+
guaranteed minimum representation per class.
450+
min_samples_per_class: Minimum samples to keep per class when possible.
451+
Ensures very small classes aren't completely dropped.
452+
seed: Random seed for reproducibility.
453+
454+
Returns:
455+
Sampled dataset list.
456+
457+
Example:
458+
>>> # Sample 10% of data with balanced classes
459+
>>> sampled = sample_balanced_dataset(data, sample_size=0.1, balance_classes=True)
460+
461+
>>> # Sample exactly 1000 items with balanced classes
462+
>>> sampled = sample_balanced_dataset(data, sample_size=1000, balance_classes=True)
463+
464+
>>> # Sample 20% maintaining original class distribution
465+
>>> sampled = sample_balanced_dataset(data, sample_size=0.2, balance_classes=False)
466+
467+
>>> # Just balance classes without reducing total size
468+
>>> sampled = sample_balanced_dataset(data, sample_size=None, balance_classes=True)
469+
470+
>>> # Use sqrt balancing for gentler rebalancing
471+
>>> sampled = sample_balanced_dataset(
472+
... data, sample_size=0.5, balance_classes=True, balance_strategy="sqrt"
473+
... )
474+
"""
475+
if not data:
476+
return []
477+
478+
random.seed(seed)
479+
480+
# Group data by class (score)
481+
class_groups: dict[int, list[dict]] = {}
482+
for item in data:
483+
score = item["score"]
484+
if score not in class_groups:
485+
class_groups[score] = []
486+
class_groups[score].append(item)
487+
488+
num_classes = len(class_groups)
489+
total_data = len(data)
490+
491+
# Determine target total samples
492+
if sample_size is None:
493+
target_total = total_data
494+
elif isinstance(sample_size, float) and 0 < sample_size <= 1.0:
495+
target_total = max(1, int(total_data * sample_size))
496+
elif isinstance(sample_size, (int, float)) and sample_size >= 1:
497+
target_total = max(1, min(int(sample_size), total_data))
498+
else:
499+
raise ValueError(
500+
f"sample_size must be float in (0, 1], int >= 1, or None. Got: {sample_size}"
501+
)
502+
503+
# Log class distribution before sampling
504+
class_counts = {cls: len(items) for cls, items in sorted(class_groups.items())}
505+
logger.debug(f"Original class distribution: {class_counts}")
506+
507+
if not balance_classes:
508+
# Simple random sampling without balancing
509+
if target_total >= total_data:
510+
sampled = data.copy()
511+
else:
512+
sampled = random.sample(data, target_total)
513+
random.shuffle(sampled)
514+
logger.info(
515+
f"Sampled {len(sampled)} items without balancing "
516+
f"({len(sampled)/total_data*100:.1f}% of {total_data})"
517+
)
518+
return sampled
519+
520+
# Balanced sampling
521+
sampled: list[dict] = []
522+
523+
if balance_strategy == "undersample":
524+
# Find the smallest class size for true balancing
525+
min_class_size = min(len(items) for items in class_groups.values())
526+
527+
# Determine target per class based on mode:
528+
# - If sample_size specified: balance within the budget (target_total / num_classes)
529+
# - If sample_size=None (balance only): use min class size for true balancing
530+
if sample_size is None:
531+
# Balance-only mode: undersample all classes to match smallest
532+
target_per_class = max(min_samples_per_class, min_class_size)
533+
# Also cap total to balanced amount
534+
target_total = target_per_class * num_classes
535+
else:
536+
# Size-limited mode: distribute budget evenly
537+
ideal_per_class = target_total // num_classes
538+
target_per_class = max(min_samples_per_class, min(ideal_per_class, min_class_size))
539+
540+
# Sample up to target_per_class from each class
541+
remaining_quota = target_total
542+
sorted_classes = sorted(class_groups.keys())
543+
544+
for cls in sorted_classes:
545+
items = class_groups[cls]
546+
n_available = len(items)
547+
548+
# Take min of target and available
549+
n_samples = min(target_per_class, n_available, remaining_quota)
550+
n_samples = max(n_samples, min(min_samples_per_class, n_available))
551+
552+
if n_samples > 0:
553+
sampled.extend(random.sample(items, n_samples))
554+
remaining_quota -= n_samples
555+
556+
# Second pass: if we have remaining quota (only when sample_size was specified),
557+
# fill from larger classes to hit target
558+
if remaining_quota > 0 and sample_size is not None:
559+
sampled_ids = {item["id"] for item in sampled}
560+
remaining_items = []
561+
for items in class_groups.values():
562+
for item in items:
563+
if item["id"] not in sampled_ids:
564+
remaining_items.append(item)
565+
566+
if remaining_items:
567+
extra = random.sample(remaining_items, min(remaining_quota, len(remaining_items)))
568+
sampled.extend(extra)
569+
570+
elif balance_strategy == "sqrt":
571+
# Use square root of counts to determine sampling weights
572+
# This reduces imbalance while preserving more majority class data
573+
sqrt_counts = {cls: np.sqrt(len(items)) for cls, items in class_groups.items()}
574+
total_sqrt = sum(sqrt_counts.values())
575+
576+
# Calculate target samples per class based on sqrt weights
577+
for cls in sorted(class_groups.keys()):
578+
items = class_groups[cls]
579+
weight = sqrt_counts[cls] / total_sqrt
580+
n_samples = max(
581+
min_samples_per_class,
582+
min(int(target_total * weight), len(items)),
583+
)
584+
sampled.extend(random.sample(items, n_samples))
585+
586+
# Adjust to hit target
587+
if len(sampled) > target_total:
588+
sampled = random.sample(sampled, target_total)
589+
590+
elif balance_strategy == "proportional":
591+
# Maintain original distribution but ensure minimum representation
592+
for cls in sorted(class_groups.keys()):
593+
items = class_groups[cls]
594+
proportion = len(items) / total_data
595+
n_samples = max(
596+
min_samples_per_class,
597+
min(int(target_total * proportion), len(items)),
598+
)
599+
sampled.extend(random.sample(items, n_samples))
600+
601+
# Adjust to hit target
602+
if len(sampled) > target_total:
603+
sampled = random.sample(sampled, target_total)
604+
605+
else:
606+
raise ValueError(
607+
f"Unknown balance_strategy: {balance_strategy}. "
608+
f"Choose from: 'undersample', 'sqrt', 'proportional'"
609+
)
610+
611+
random.shuffle(sampled)
612+
613+
# Log resulting distribution
614+
result_counts: dict[int, int] = {}
615+
for item in sampled:
616+
score = item["score"]
617+
result_counts[score] = result_counts.get(score, 0) + 1
618+
result_counts = dict(sorted(result_counts.items()))
619+
620+
logger.info(
621+
f"Sampled {len(sampled)} items with '{balance_strategy}' balancing "
622+
f"({len(sampled)/total_data*100:.1f}% of {total_data})"
623+
)
624+
logger.debug(f"Balanced class distribution: {result_counts}")
625+
626+
return sampled
627+
628+
629+
def get_class_distribution(data: list[dict]) -> dict[int, int]:
630+
"""
631+
Get the class distribution of a dataset.
632+
633+
Args:
634+
data: Dataset list with 'score' key.
635+
636+
Returns:
637+
Dictionary mapping score to count.
638+
"""
639+
distribution: dict[int, int] = {}
640+
for item in data:
641+
score = item["score"]
642+
distribution[score] = distribution.get(score, 0) + 1
643+
return dict(sorted(distribution.items()))

configs/advanced_model.yaml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ training:
3838
pin_memory: true
3939

4040
data:
41-
image_size: [384, 384] # Higher resolution for ViT
41+
image_size: [448, 448] # Higher resolution for ViT
4242
face_size: [224, 224]
4343
max_photos: 9
4444
augmentation: true
@@ -48,16 +48,20 @@ data:
4848
face_confidence_threshold: 0.5
4949
normalize_mean: [0.485, 0.456, 0.406]
5050
normalize_std: [0.229, 0.224, 0.225]
51+
# Dataset sampling (optional - for quick experiments or class balancing)
52+
# sample_size: 0.1 # Use 10% of data (float 0-1) or absolute count (int)
53+
# balance_classes: true # Balance class distribution when sampling
54+
# balance_strategy: undersample # 'undersample', 'sqrt', or 'proportional'
5155

5256
paths:
53-
dataset_file: datasets/beauty_dataset.csv
54-
dataset_base_path: datasets/images
57+
dataset_file: data/dataset.csv
58+
dataset_base_path: /opt/SP/DATA/beauty_dataset
5559
output_dir: outputs/advanced
5660
checkpoint_path:
5761
export_dir: exports
5862

5963
logging:
60-
log_level: INFO
64+
log_level: DEBUG
6165
log_dir: logs/advanced
6266
use_tensorboard: true
6367
use_wandb: true

configs/cpu_training.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ data:
4848
face_confidence_threshold: 0.5
4949
normalize_mean: [0.485, 0.456, 0.406]
5050
normalize_std: [0.229, 0.224, 0.225]
51+
# Dataset sampling (optional - for quick experiments or class balancing)
52+
# sample_size: 0.1 # Use 10% of data (float 0-1) or absolute count (int)
53+
# balance_classes: true # Balance class distribution when sampling
54+
# balance_strategy: undersample # 'undersample', 'sqrt', or 'proportional'
5155

5256
paths:
5357
dataset_file: data/dataset.csv

configs/default.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,13 @@ data:
4848
face_confidence_threshold: 0.3
4949
normalize_mean: [0.485, 0.456, 0.406]
5050
normalize_std: [0.229, 0.224, 0.225]
51+
# Dataset sampling (optional - for quick experiments or class balancing)
52+
# sample_size: 0.1 # Use 10% of data (float 0-1) or absolute count (int)
53+
# balance_classes: true # Balance class distribution when sampling
54+
# balance_strategy: undersample # 'undersample', 'sqrt', or 'proportional'
5155

5256
paths:
53-
dataset_file: datasets/beauty_dataset.csv
57+
dataset_file: data/dataset.csv
5458
dataset_base_path: /opt/SP/DATA/beauty_dataset
5559
output_dir: outputs
5660
checkpoint_path:

0 commit comments

Comments
 (0)