@@ -412,3 +412,232 @@ def compute_class_weights(
412412 weights = torch .clamp (weights , max = max_weight )
413413
414414 return weights
415+
416+
417+ def sample_balanced_dataset (
418+ data : list [dict ],
419+ sample_size : int | float | None = None ,
420+ balance_classes : bool = True ,
421+ balance_strategy : str = "undersample" ,
422+ min_samples_per_class : int = 1 ,
423+ seed : int = 42 ,
424+ ) -> list [dict ]:
425+ """
426+ Sample a subset of the dataset with optional class balancing.
427+
428+ This function is useful for:
429+ - Quick experimentation with smaller datasets
430+ - Reducing training time while maintaining class representation
431+ - Handling class imbalance by undersampling majority classes
432+ - Creating balanced mini-datasets for debugging or prototyping
433+
434+ Args:
435+ data: Full dataset list with 'score' key for each item.
436+ sample_size: Target sample size.
437+ - If float in (0.0, 1.0]: fraction of total data (e.g., 0.1 = 10%)
438+ - If int >= 1: absolute number of samples
439+ - If None: use all data (only balance if balance_classes=True)
440+ balance_classes: If True, attempts to balance class distribution.
441+ Each class will have roughly equal representation, limited by
442+ the smallest class size or the balance_strategy.
443+ balance_strategy: Strategy for balancing classes.
444+ - "undersample": Cap each class to target_per_class samples.
445+ Ensures balanced classes but may lose data from majority classes.
446+ - "sqrt": Use square root of original counts as weights. Reduces
447+ imbalance while preserving more majority class data.
448+ - "proportional": Maintain original distribution ratios but with
449+ guaranteed minimum representation per class.
450+ min_samples_per_class: Minimum samples to keep per class when possible.
451+ Ensures very small classes aren't completely dropped.
452+ seed: Random seed for reproducibility.
453+
454+ Returns:
455+ Sampled dataset list.
456+
457+ Example:
458+ >>> # Sample 10% of data with balanced classes
459+ >>> sampled = sample_balanced_dataset(data, sample_size=0.1, balance_classes=True)
460+
461+ >>> # Sample exactly 1000 items with balanced classes
462+ >>> sampled = sample_balanced_dataset(data, sample_size=1000, balance_classes=True)
463+
464+ >>> # Sample 20% maintaining original class distribution
465+ >>> sampled = sample_balanced_dataset(data, sample_size=0.2, balance_classes=False)
466+
467+ >>> # Just balance classes without reducing total size
468+ >>> sampled = sample_balanced_dataset(data, sample_size=None, balance_classes=True)
469+
470+ >>> # Use sqrt balancing for gentler rebalancing
471+ >>> sampled = sample_balanced_dataset(
472+ ... data, sample_size=0.5, balance_classes=True, balance_strategy="sqrt"
473+ ... )
474+ """
475+ if not data :
476+ return []
477+
478+ random .seed (seed )
479+
480+ # Group data by class (score)
481+ class_groups : dict [int , list [dict ]] = {}
482+ for item in data :
483+ score = item ["score" ]
484+ if score not in class_groups :
485+ class_groups [score ] = []
486+ class_groups [score ].append (item )
487+
488+ num_classes = len (class_groups )
489+ total_data = len (data )
490+
491+ # Determine target total samples
492+ if sample_size is None :
493+ target_total = total_data
494+ elif isinstance (sample_size , float ) and 0 < sample_size <= 1.0 :
495+ target_total = max (1 , int (total_data * sample_size ))
496+ elif isinstance (sample_size , (int , float )) and sample_size >= 1 :
497+ target_total = max (1 , min (int (sample_size ), total_data ))
498+ else :
499+ raise ValueError (
500+ f"sample_size must be float in (0, 1], int >= 1, or None. Got: { sample_size } "
501+ )
502+
503+ # Log class distribution before sampling
504+ class_counts = {cls : len (items ) for cls , items in sorted (class_groups .items ())}
505+ logger .debug (f"Original class distribution: { class_counts } " )
506+
507+ if not balance_classes :
508+ # Simple random sampling without balancing
509+ if target_total >= total_data :
510+ sampled = data .copy ()
511+ else :
512+ sampled = random .sample (data , target_total )
513+ random .shuffle (sampled )
514+ logger .info (
515+ f"Sampled { len (sampled )} items without balancing "
516+ f"({ len (sampled )/ total_data * 100 :.1f} % of { total_data } )"
517+ )
518+ return sampled
519+
520+ # Balanced sampling
521+ sampled : list [dict ] = []
522+
523+ if balance_strategy == "undersample" :
524+ # Find the smallest class size for true balancing
525+ min_class_size = min (len (items ) for items in class_groups .values ())
526+
527+ # Determine target per class based on mode:
528+ # - If sample_size specified: balance within the budget (target_total / num_classes)
529+ # - If sample_size=None (balance only): use min class size for true balancing
530+ if sample_size is None :
531+ # Balance-only mode: undersample all classes to match smallest
532+ target_per_class = max (min_samples_per_class , min_class_size )
533+ # Also cap total to balanced amount
534+ target_total = target_per_class * num_classes
535+ else :
536+ # Size-limited mode: distribute budget evenly
537+ ideal_per_class = target_total // num_classes
538+ target_per_class = max (min_samples_per_class , min (ideal_per_class , min_class_size ))
539+
540+ # Sample up to target_per_class from each class
541+ remaining_quota = target_total
542+ sorted_classes = sorted (class_groups .keys ())
543+
544+ for cls in sorted_classes :
545+ items = class_groups [cls ]
546+ n_available = len (items )
547+
548+ # Take min of target and available
549+ n_samples = min (target_per_class , n_available , remaining_quota )
550+ n_samples = max (n_samples , min (min_samples_per_class , n_available ))
551+
552+ if n_samples > 0 :
553+ sampled .extend (random .sample (items , n_samples ))
554+ remaining_quota -= n_samples
555+
556+ # Second pass: if we have remaining quota (only when sample_size was specified),
557+ # fill from larger classes to hit target
558+ if remaining_quota > 0 and sample_size is not None :
559+ sampled_ids = {item ["id" ] for item in sampled }
560+ remaining_items = []
561+ for items in class_groups .values ():
562+ for item in items :
563+ if item ["id" ] not in sampled_ids :
564+ remaining_items .append (item )
565+
566+ if remaining_items :
567+ extra = random .sample (remaining_items , min (remaining_quota , len (remaining_items )))
568+ sampled .extend (extra )
569+
570+ elif balance_strategy == "sqrt" :
571+ # Use square root of counts to determine sampling weights
572+ # This reduces imbalance while preserving more majority class data
573+ sqrt_counts = {cls : np .sqrt (len (items )) for cls , items in class_groups .items ()}
574+ total_sqrt = sum (sqrt_counts .values ())
575+
576+ # Calculate target samples per class based on sqrt weights
577+ for cls in sorted (class_groups .keys ()):
578+ items = class_groups [cls ]
579+ weight = sqrt_counts [cls ] / total_sqrt
580+ n_samples = max (
581+ min_samples_per_class ,
582+ min (int (target_total * weight ), len (items )),
583+ )
584+ sampled .extend (random .sample (items , n_samples ))
585+
586+ # Adjust to hit target
587+ if len (sampled ) > target_total :
588+ sampled = random .sample (sampled , target_total )
589+
590+ elif balance_strategy == "proportional" :
591+ # Maintain original distribution but ensure minimum representation
592+ for cls in sorted (class_groups .keys ()):
593+ items = class_groups [cls ]
594+ proportion = len (items ) / total_data
595+ n_samples = max (
596+ min_samples_per_class ,
597+ min (int (target_total * proportion ), len (items )),
598+ )
599+ sampled .extend (random .sample (items , n_samples ))
600+
601+ # Adjust to hit target
602+ if len (sampled ) > target_total :
603+ sampled = random .sample (sampled , target_total )
604+
605+ else :
606+ raise ValueError (
607+ f"Unknown balance_strategy: { balance_strategy } . "
608+ f"Choose from: 'undersample', 'sqrt', 'proportional'"
609+ )
610+
611+ random .shuffle (sampled )
612+
613+ # Log resulting distribution
614+ result_counts : dict [int , int ] = {}
615+ for item in sampled :
616+ score = item ["score" ]
617+ result_counts [score ] = result_counts .get (score , 0 ) + 1
618+ result_counts = dict (sorted (result_counts .items ()))
619+
620+ logger .info (
621+ f"Sampled { len (sampled )} items with '{ balance_strategy } ' balancing "
622+ f"({ len (sampled )/ total_data * 100 :.1f} % of { total_data } )"
623+ )
624+ logger .debug (f"Balanced class distribution: { result_counts } " )
625+
626+ return sampled
627+
628+
629+ def get_class_distribution (data : list [dict ]) -> dict [int , int ]:
630+ """
631+ Get the class distribution of a dataset.
632+
633+ Args:
634+ data: Dataset list with 'score' key.
635+
636+ Returns:
637+ Dictionary mapping score to count.
638+ """
639+ distribution : dict [int , int ] = {}
640+ for item in data :
641+ score = item ["score" ]
642+ distribution [score ] = distribution .get (score , 0 ) + 1
643+ return dict (sorted (distribution .items ()))
0 commit comments