Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cosipy/phase_resolved_analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .time_selector import TimeSelector
from .phase_selector import PhaseSelector
from .plot_pulse_profile import PlotPulseProfile
70 changes: 70 additions & 0 deletions cosipy/phase_resolved_analysis/phase_assigner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
import logging
from astropy.io import fits

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class PhaseAssigner:
"""
Reads a pulsar ephemeris and assigns PULSE_PHASE based on MET directly.
"""
def __init__(self, par_file):
self.params = self._parse_par_file(par_file)

# Parse F0 (Frequency) or P0 (Period)
if 'F0' in self.params:
val = float(self.params['F0'])
if val < 1.0:
self.f0 = 1.0 / val
logger.warning(f"F0 < 1.0. Assuming Period. F0={self.f0:.6f} Hz")
else:
self.f0 = val
elif 'P0' in self.params:
self.f0 = 1.0 / float(self.params['P0'])
else:
raise ValueError("PAR file must have F0 or P0")

# We removed T0/Epoch logic as requested.

def _parse_par_file(self, path):
p = {}
with open(path, 'r') as f:
for line in f:
parts = line.strip().split()
if parts and not parts[0].startswith(('#', 'C')):
try: p[parts[0].upper()] = parts[1].replace('D','E')
except: pass
return p

def add_phase_column(self, input_fits, output_fits=None):
"""
Calculates phase = (MET * F0) % 1.0 and adds column.
"""
with fits.open(input_fits) as hdul:
data = hdul[1].data
header = hdul[1].header

# 1. Get Time (MET)
try: times = data['TimeTags']
except KeyError: times = data['TIME']

# 2. Calculate Phase (Simple Folding)
# Phase = (Time * Frequency) % 1.0
phase = (times * self.f0) % 1.0

# 3. Create or Overwrite Column
cols = data.columns
if 'PULSE_PHASE' in cols.names:
logger.info("Overwriting PULSE_PHASE column.")
data['PULSE_PHASE'] = phase
new_hdu = fits.BinTableHDU(data=data, header=header)
else:
logger.info("Creating PULSE_PHASE column.")
col = fits.Column(name='PULSE_PHASE', format='D', array=phase)
new_hdu = fits.BinTableHDU.from_columns(cols + col, header=header)

if output_fits is None: output_fits = input_fits
fits.HDUList([hdul[0], new_hdu]).writeto(output_fits, overwrite=True)
logger.info(f"Saved: {output_fits}")
return output_fits
147 changes: 147 additions & 0 deletions cosipy/phase_resolved_analysis/phase_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import logging
import numpy as np
import itertools
import os
from astropy.io import fits

logger = logging.getLogger(__name__)

# --- ROBUST IMPORTS (Safety Switch) ---
try:
from cosipy.interfaces.event_selection import EventSelectorInterface
from cosipy.interfaces import EventInterface
from cosipy.util.iterables import itertools_batched
except (ImportError, AttributeError):
class EventInterface: pass
class EventSelectorInterface:
def select(self, events): raise NotImplementedError
def itertools_batched(iterable, n):
it = iter(iterable)
while True:
batch = list(itertools.islice(it, n))
if not batch: return
yield batch

class PhaseSelector(EventSelectorInterface):
"""
Selects events based on pulsar phase.
Highly optimized for FITS files via NumPy vectorization.
"""
def __init__(self, ephemeris_file, pstart, pstop, batch_size=10000):
self.ephemeris_file = ephemeris_file
self.pstart = float(pstart)
self.pstop = float(pstop)
self._batch_size = batch_size

def _get_vectorized_mask(self, phases: np.ndarray) -> np.ndarray:
"""Core logic applied across an entire NumPy array instantly."""
pstop_norm = self.pstop % 1.0 if self.pstop > 1.0 else self.pstop

if self.pstart <= pstop_norm:
return (phases >= self.pstart) & (phases <= pstop_norm)
else:
return (phases >= self.pstart) | (phases <= pstop_norm)

def select(self, events):
"""
Maintains pipeline compatibility. Yields booleans.
Optimized to use vectorized masks on batches.
"""
# Fast path for single EventInterface object
if isinstance(events, EventInterface):
phase = getattr(events, 'pulse_phase', -1.0)
if phase is None: return False
return bool(self._get_vectorized_mask(np.array(phase)))

# Fast path if events is already a structured NumPy array
if isinstance(events, np.ndarray) and 'PULSE_PHASE' in events.dtype.names:
return self._get_vectorized_mask(events['PULSE_PHASE'])

# Fallback for generic iterables of objects
for chunk in itertools_batched(events, self._batch_size):
phases = np.array([getattr(e, 'pulse_phase', -1.0) for e in chunk])
mask = self._get_vectorized_mask(phases)
for sel in mask:
yield bool(sel)

def filter_events(self, events, output_fits=None, template_fits=None):
"""
Filters events. Instantaneous execution when passed a FITS file path.
"""
# --- VECTORIZED FAST PATH FOR FITS FILES ---
if isinstance(events, str):
logger.info(f"Auto-loading events from FITS: {events}")
template_fits = events

with fits.open(events) as hdul:
data = hdul[1].data
# Apply mask to the entire 'PULSE_PHASE' column at once
mask = self._get_vectorized_mask(data['PULSE_PHASE'])
# Slice the FITS array instantly
selected_data = data[mask]

if output_fits is not None:
self._save_fits_fast(selected_data, output_fits, template_fits)

return selected_data

# --- SLOW PATH FOR PYTHON OBJECT LISTS ---
mask = list(self.select(events))
selected_events = [e for e, m in zip(events, mask) if m]

if output_fits is not None:
if template_fits is None:
logger.warning("'template_fits' missing. Cannot save.")
else:
self.save_fits(selected_events, output_fits, template_fits)

return selected_events

def _save_fits_fast(self, structured_array, output_filename, template_filename):
"""Saves a NumPy structured array directly to FITS (Orders of magnitude faster)."""
if len(structured_array) == 0:
logger.warning("Warning: No events to save.")
return

logger.info(f"Saving {len(structured_array)} events to {output_filename}...")

try:
with fits.open(template_filename) as hdul:
# Plop the filtered array directly into the new HDU
hdu = fits.BinTableHDU(data=structured_array, header=hdul[1].header)
hdul_new = fits.HDUList([fits.PrimaryHDU(header=hdul[0].header), hdu])
hdul_new.writeto(output_filename, overwrite=True)
logger.info(f"Successfully saved: {output_filename}")
except Exception as e:
logger.error(f"Failed to save FITS: {e}")

def save_fits(self, events, output_filename, template_filename):
"""Legacy fallback for saving lists of Python objects."""
# Auto-route to fast save if an array was passed by mistake
if isinstance(events, np.ndarray):
return self._save_fits_fast(events, output_filename, template_filename)

if not events:
logger.warning("Warning: No events to save.")
return

logger.info(f"Saving {len(events)} events (Legacy Object Mode) to {output_filename}...")
try:
with fits.open(template_filename) as hdul:
columns = hdul[1].columns
rows = [e.row for e in events if hasattr(e, 'row')]

if not rows:
logger.error("Error: Event objects do not contain raw FITS rows.")
return

new_data = fits.FITS_rec.from_columns(columns, nrows=len(rows))
for i, row in enumerate(rows):
new_data[i] = row

hdu = fits.BinTableHDU(data=new_data, header=hdul[1].header)
hdul_new = fits.HDUList([fits.PrimaryHDU(header=hdul[0].header), hdu])
hdul_new.writeto(output_filename, overwrite=True)
logger.info(f"Successfully saved: {output_filename}")
except Exception as e:
logger.error(f"Failed to save FITS: {e}")
109 changes: 109 additions & 0 deletions cosipy/phase_resolved_analysis/plot_pulse_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

class PlotPulseProfile:
"""
Generates a 3-panel figure: Pulse Profile, Phaseogram, and Significance Test.
Optimized for direct FITS table input (vectorized).
"""
def __init__(self, data, n_bins=50, n_time_bins=50):
self.n_bins = n_bins
self.n_time_bins = n_time_bins

# --- VECTORIZED DATA EXTRACTION ---
try:
self.phases = np.array(data['PULSE_PHASE'])
except (KeyError, TypeError):
print("Error: 'PULSE_PHASE' column not found in data.")
self.phases = np.array([])

# Handle various possible time column names
if 'TimeTags' in data.names:
self.times = np.array(data['TimeTags'])
elif 'TIME' in data.names:
self.times = np.array(data['TIME'])
else:
self.times = np.zeros(len(self.phases))

def plot(self, t_start_met=None):
if len(self.phases) == 0:
print("No valid events to plot.")
return

# Relative Time Calculation
if t_start_met is None:
t_start_met = np.min(self.times)

t_elapsed = self.times - t_start_met
duration = np.max(t_elapsed)
if duration <= 0: duration = 1.0

# --- Plotting Setup ---
fig = plt.figure(figsize=(14, 10))
gs = GridSpec(2, 2, width_ratios=[1, 1.2], height_ratios=[1, 1])

ax_prof = fig.add_subplot(gs[0, 0])
ax_htest = fig.add_subplot(gs[1, 0])
ax_phaseogram = fig.add_subplot(gs[:, 1])

# --- Panel 1: Integrated Profile (Top Left) ---
counts, edges = np.histogram(self.phases, bins=self.n_bins, range=(0, 1))
centers = (edges[:-1] + edges[1:]) / 2

# 2-cycle plot for better visualization of peak wrap-around
x_2cycle = np.concatenate([centers, centers + 1])
y_2cycle = np.concatenate([counts, counts])

ax_prof.step(x_2cycle, y_2cycle, where='mid', color='rebeccapurple', lw=2)
ax_prof.set_xlim(0, 2)
ax_prof.set_ylabel("Counts")
ax_prof.set_xlabel("Pulse Phase")
ax_prof.set_title(f"Integrated Profile (N={len(self.phases)})")
ax_prof.grid(alpha=0.3)

# --- Panel 2: Phaseogram (Right) ---
h2d, xedges, yedges = np.histogram2d(
self.phases, t_elapsed,
bins=[self.n_bins, self.n_time_bins],
range=[[0, 1], [0, duration]]
)

im = ax_phaseogram.imshow(h2d.T, origin='lower', aspect='auto',
extent=[0, 1, 0, duration],
cmap='viridis', interpolation='nearest')
ax_phaseogram.set_xlabel("Pulse Phase")
ax_phaseogram.set_ylabel("Time since start (s)")
ax_phaseogram.set_title("Phaseogram")
plt.colorbar(im, ax=ax_phaseogram, label="Counts/bin")

# --- Panel 3: Significance (Bottom Left) ---
if len(t_elapsed) > 1:
sort_idx = np.argsort(t_elapsed)
sorted_phases = self.phases[sort_idx] * 2 * np.pi
sorted_times = t_elapsed[sort_idx]

# Cumulative Z^2_2 statistic (2 harmonics)
ns = np.arange(1, len(sorted_phases) + 1)

# First harmonic (k=1)
cum_cos1 = np.cumsum(np.cos(sorted_phases))
cum_sin1 = np.cumsum(np.sin(sorted_phases))

# Second harmonic (k=2)
cum_cos2 = np.cumsum(np.cos(2 * sorted_phases))
cum_sin2 = np.cumsum(np.sin(2 * sorted_phases))

z2_stats = (2.0 / ns) * (cum_cos1**2 + cum_sin1**2 + cum_cos2**2 + cum_sin2**2)

# Downsample for plotting performance if data is massive
step = max(1, len(z2_stats) // 2000)
ax_htest.plot(sorted_times[::step], z2_stats[::step], '-', color='rebeccapurple', lw=1.5)

ax_htest.set_xlabel("Time since start (s)")
ax_htest.set_ylabel(r"Significance ($Z^2_2$)")
ax_htest.set_title("Detection Significance")
ax_htest.grid(alpha=0.3)

plt.tight_layout()
plt.show()
Loading