Skip to content

Track Truncated Mass in k_sigma Sensitivity Testing #116

@zmek

Description

@zmek

What needs to be done?

Overview

Track how much probability mass is truncated when distributions are capped during the convolution process. This will help understand the trade-off between k_sigma values and accuracy loss.

Approach

1. Modify _apply_cap_with_renormalization() to return truncated mass

Current signature:

def _apply_cap_with_renormalization(
    self, pmf: np.ndarray, max_support: int
) -> np.ndarray:

Proposed signature:

def _apply_cap_with_renormalization(
    self, pmf: np.ndarray, max_support: int, return_truncated_mass: bool = False
) -> Union[np.ndarray, Tuple[np.ndarray, float]]:

When return_truncated_mass=True, return (truncated_pmf, truncated_mass).

2. Add truncated mass tracking to DemandPredictor

Add a dictionary to track truncated mass per entity and flow type:

def __init__(self, k_sigma: float = 8.0):
    self.k_sigma = k_sigma
    self.cache: Dict[str, DemandPrediction] = {}
    # Track truncated mass: {(entity_id, flow_type): truncated_mass}
    self.truncated_mass: Dict[Tuple[str, str], float] = {}

3. Update call sites to track truncated mass

In predict_flow_total() and _convolve_multiple(), when applying caps:

  • Get truncated mass from _apply_cap_with_renormalization()
  • Accumulate it in self.truncated_mass[(entity_id, flow_type)]

4. Add method to get truncated mass statistics

def get_truncated_mass_stats(self) -> Dict[str, Any]:
    """Get statistics about truncated mass across all predictions."""
    if not self.truncated_mass:
        return {
            'total_truncated_mass': 0.0,
            'max_truncated_mass': 0.0,
            'num_truncations': 0,
            'by_entity': {}
        }
    
    total = sum(self.truncated_mass.values())
    max_mass = max(self.truncated_mass.values())
    
    # Group by entity_id
    by_entity = {}
    for (entity_id, flow_type), mass in self.truncated_mass.items():
        if entity_id not in by_entity:
            by_entity[entity_id] = {}
        by_entity[entity_id][flow_type] = mass
    
    return {
        'total_truncated_mass': total,
        'max_truncated_mass': max_mass,
        'num_truncations': len(self.truncated_mass),
        'by_entity': by_entity,
        'by_flow_type': {
            'arrivals': sum(m for (_, ft), m in self.truncated_mass.items() if ft == 'arrivals'),
            'departures': sum(m for (_, ft), m in self.truncated_mass.items() if ft == 'departures'),
        }
    }

5. Clear tracking between iterations

In the test, clear predictor.truncated_mass before each iteration to get fresh measurements.

6. Update test to collect truncated mass

In measure_prediction_time(), after predictions:

# Get truncated mass statistics
truncated_stats = predictor.get_truncated_mass_stats()

# Add to metrics
metrics['truncated_mass'] = {
    'total': truncated_stats['total_truncated_mass'],
    'max': truncated_stats['max_truncated_mass'],
    'num_truncations': truncated_stats['num_truncations'],
    'by_flow_type': truncated_stats['by_flow_type'],
}

# Add per-entity truncated mass to pmf_metrics
for pmf_info in pmf_lengths:
    entity_id = pmf_info['entity_id']
    if entity_id in truncated_stats['by_entity']:
        entity_mass = truncated_stats['by_entity'][entity_id]
        pmf_info['arrivals_truncated_mass'] = entity_mass.get('arrivals', 0.0)
        pmf_info['departures_truncated_mass'] = entity_mass.get('departures', 0.0)

7. Update results summary

Add truncated mass columns to the summary table:

print(f"\n{'k_sigma':<10} {'Mean Time (s)':<15} {'Total Truncated':<15} {'Max Truncated':<15} {'Arrivals Trunc':<15} {'Departures Trunc':<15}")

Implementation Details

Key Considerations

  1. Backward Compatibility: Use optional parameter return_truncated_mass=False by default
  2. Accumulation: Truncated mass should be accumulated (summed) when multiple truncations occur for the same entity/flow
  3. Clearing: Reset truncated_mass dict at the start of each prediction iteration
  4. Flow Types: Track separately for 'arrivals' and 'departures' (net_flow doesn't get truncated)

Where Truncation Occurs

  1. predict_flow_total(): After convolving all flows for an entity
  2. _convolve_multiple(): After convolving multiple distributions (used in hierarchical aggregation)
  3. Poisson creation: When creating Poisson distributions with caps (but this is handled in Distribution.from_poisson_with_cap())

Example Output

k_sigma    Mean Time (s)  Total Truncated Max Truncated  Arrivals Trunc Departures Trunc
2.00       0.1234         0.0001          0.0001         0.0001         0.0000
4.00       0.1456         0.0005          0.0002         0.0005         0.0000
6.00       0.1678         0.0012          0.0003         0.0012         0.0000
8.00       0.1890         0.0025          0.0005         0.0025         0.0000

Benefits

  1. Quantify accuracy loss: See exactly how much probability mass is lost with different k_sigma values
  2. Identify problematic entities: Find which entities/flows experience the most truncation
  3. Optimize k_sigma: Choose k_sigma value that balances performance and accuracy
  4. Debugging: Understand where truncation is happening in the hierarchy

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions