What needs to be done?
Overview
Track how much probability mass is truncated when distributions are capped during the convolution process. This will help understand the trade-off between k_sigma values and accuracy loss.
Approach
1. Modify _apply_cap_with_renormalization() to return truncated mass
Current signature:
def _apply_cap_with_renormalization(
self, pmf: np.ndarray, max_support: int
) -> np.ndarray:
Proposed signature:
def _apply_cap_with_renormalization(
self, pmf: np.ndarray, max_support: int, return_truncated_mass: bool = False
) -> Union[np.ndarray, Tuple[np.ndarray, float]]:
When return_truncated_mass=True, return (truncated_pmf, truncated_mass).
2. Add truncated mass tracking to DemandPredictor
Add a dictionary to track truncated mass per entity and flow type:
def __init__(self, k_sigma: float = 8.0):
self.k_sigma = k_sigma
self.cache: Dict[str, DemandPrediction] = {}
# Track truncated mass: {(entity_id, flow_type): truncated_mass}
self.truncated_mass: Dict[Tuple[str, str], float] = {}
3. Update call sites to track truncated mass
In predict_flow_total() and _convolve_multiple(), when applying caps:
- Get truncated mass from
_apply_cap_with_renormalization()
- Accumulate it in
self.truncated_mass[(entity_id, flow_type)]
4. Add method to get truncated mass statistics
def get_truncated_mass_stats(self) -> Dict[str, Any]:
"""Get statistics about truncated mass across all predictions."""
if not self.truncated_mass:
return {
'total_truncated_mass': 0.0,
'max_truncated_mass': 0.0,
'num_truncations': 0,
'by_entity': {}
}
total = sum(self.truncated_mass.values())
max_mass = max(self.truncated_mass.values())
# Group by entity_id
by_entity = {}
for (entity_id, flow_type), mass in self.truncated_mass.items():
if entity_id not in by_entity:
by_entity[entity_id] = {}
by_entity[entity_id][flow_type] = mass
return {
'total_truncated_mass': total,
'max_truncated_mass': max_mass,
'num_truncations': len(self.truncated_mass),
'by_entity': by_entity,
'by_flow_type': {
'arrivals': sum(m for (_, ft), m in self.truncated_mass.items() if ft == 'arrivals'),
'departures': sum(m for (_, ft), m in self.truncated_mass.items() if ft == 'departures'),
}
}
5. Clear tracking between iterations
In the test, clear predictor.truncated_mass before each iteration to get fresh measurements.
6. Update test to collect truncated mass
In measure_prediction_time(), after predictions:
# Get truncated mass statistics
truncated_stats = predictor.get_truncated_mass_stats()
# Add to metrics
metrics['truncated_mass'] = {
'total': truncated_stats['total_truncated_mass'],
'max': truncated_stats['max_truncated_mass'],
'num_truncations': truncated_stats['num_truncations'],
'by_flow_type': truncated_stats['by_flow_type'],
}
# Add per-entity truncated mass to pmf_metrics
for pmf_info in pmf_lengths:
entity_id = pmf_info['entity_id']
if entity_id in truncated_stats['by_entity']:
entity_mass = truncated_stats['by_entity'][entity_id]
pmf_info['arrivals_truncated_mass'] = entity_mass.get('arrivals', 0.0)
pmf_info['departures_truncated_mass'] = entity_mass.get('departures', 0.0)
7. Update results summary
Add truncated mass columns to the summary table:
print(f"\n{'k_sigma':<10} {'Mean Time (s)':<15} {'Total Truncated':<15} {'Max Truncated':<15} {'Arrivals Trunc':<15} {'Departures Trunc':<15}")
Implementation Details
Key Considerations
- Backward Compatibility: Use optional parameter
return_truncated_mass=False by default
- Accumulation: Truncated mass should be accumulated (summed) when multiple truncations occur for the same entity/flow
- Clearing: Reset
truncated_mass dict at the start of each prediction iteration
- Flow Types: Track separately for 'arrivals' and 'departures' (net_flow doesn't get truncated)
Where Truncation Occurs
predict_flow_total(): After convolving all flows for an entity
_convolve_multiple(): After convolving multiple distributions (used in hierarchical aggregation)
- Poisson creation: When creating Poisson distributions with caps (but this is handled in
Distribution.from_poisson_with_cap())
Example Output
k_sigma Mean Time (s) Total Truncated Max Truncated Arrivals Trunc Departures Trunc
2.00 0.1234 0.0001 0.0001 0.0001 0.0000
4.00 0.1456 0.0005 0.0002 0.0005 0.0000
6.00 0.1678 0.0012 0.0003 0.0012 0.0000
8.00 0.1890 0.0025 0.0005 0.0025 0.0000
Benefits
- Quantify accuracy loss: See exactly how much probability mass is lost with different k_sigma values
- Identify problematic entities: Find which entities/flows experience the most truncation
- Optimize k_sigma: Choose k_sigma value that balances performance and accuracy
- Debugging: Understand where truncation is happening in the hierarchy
What needs to be done?
Overview
Track how much probability mass is truncated when distributions are capped during the convolution process. This will help understand the trade-off between k_sigma values and accuracy loss.
Approach
1. Modify
_apply_cap_with_renormalization()to return truncated massCurrent signature:
Proposed signature:
When
return_truncated_mass=True, return(truncated_pmf, truncated_mass).2. Add truncated mass tracking to
DemandPredictorAdd a dictionary to track truncated mass per entity and flow type:
3. Update call sites to track truncated mass
In
predict_flow_total()and_convolve_multiple(), when applying caps:_apply_cap_with_renormalization()self.truncated_mass[(entity_id, flow_type)]4. Add method to get truncated mass statistics
5. Clear tracking between iterations
In the test, clear
predictor.truncated_massbefore each iteration to get fresh measurements.6. Update test to collect truncated mass
In
measure_prediction_time(), after predictions:7. Update results summary
Add truncated mass columns to the summary table:
Implementation Details
Key Considerations
return_truncated_mass=Falseby defaulttruncated_massdict at the start of each prediction iterationWhere Truncation Occurs
predict_flow_total(): After convolving all flows for an entity_convolve_multiple(): After convolving multiple distributions (used in hierarchical aggregation)Distribution.from_poisson_with_cap())Example Output
Benefits