|
| 1 | +"""Scheduled task execution via cron expressions and heartbeat intervals. |
| 2 | +
|
| 3 | +Inspired by OpenClaw cron+heartbeat pattern and Hermes /cron command. |
| 4 | +
|
| 5 | +Usage: |
| 6 | + scheduler = CronScheduler() |
| 7 | +
|
| 8 | + scheduler.add_job(CronJob( |
| 9 | + name="daily-report", |
| 10 | + cron="0 7 * * *", |
| 11 | + persona_id="analyst-agent", |
| 12 | + message="Generate daily metrics report", |
| 13 | + )) |
| 14 | +
|
| 15 | + scheduler.add_heartbeat(HeartbeatJob( |
| 16 | + name="sensor-check", |
| 17 | + interval_seconds=300, |
| 18 | + persona_id="agronomist", |
| 19 | + message="Check all sensor readings", |
| 20 | + )) |
| 21 | +
|
| 22 | + await scheduler.start() # Runs until stopped |
| 23 | +""" |
| 24 | + |
| 25 | +from __future__ import annotations |
| 26 | + |
| 27 | +import asyncio |
| 28 | +import logging |
| 29 | +from collections.abc import Awaitable, Callable |
| 30 | +from dataclasses import dataclass, field |
| 31 | +from datetime import datetime, UTC |
| 32 | +from enum import Enum |
| 33 | + |
| 34 | +logger = logging.getLogger(__name__) |
| 35 | + |
| 36 | + |
| 37 | +class JobStatus(str, Enum): |
| 38 | + PENDING = "pending" |
| 39 | + ACTIVE = "active" |
| 40 | + PAUSED = "paused" |
| 41 | + COMPLETED = "completed" |
| 42 | + |
| 43 | + |
| 44 | +@dataclass |
| 45 | +class CronJob: |
| 46 | + name: str |
| 47 | + cron: str |
| 48 | + persona_id: str |
| 49 | + message: str |
| 50 | + session_mode: str = "isolated" |
| 51 | + enabled: bool = True |
| 52 | + max_runs: int | None = None |
| 53 | + run_count: int = 0 |
| 54 | + status: JobStatus = JobStatus.PENDING |
| 55 | + last_run: datetime | None = None |
| 56 | + |
| 57 | + |
| 58 | +@dataclass |
| 59 | +class HeartbeatJob: |
| 60 | + name: str |
| 61 | + interval_seconds: int |
| 62 | + persona_id: str |
| 63 | + message: str |
| 64 | + enabled: bool = True |
| 65 | + status: JobStatus = JobStatus.PENDING |
| 66 | + last_run: datetime | None = None |
| 67 | + |
| 68 | + |
| 69 | +JobCallback = Callable[[str, str, str], Awaitable[None]] |
| 70 | + |
| 71 | + |
| 72 | +def _parse_cron_field(field_str: str, min_val: int, max_val: int) -> set[int]: |
| 73 | + """Parse a single cron field into a set of matching integers.""" |
| 74 | + values: set[int] = set() |
| 75 | + for part in field_str.split(","): |
| 76 | + if "/" in part: |
| 77 | + base, step_str = part.split("/", 1) |
| 78 | + step = int(step_str) |
| 79 | + start = min_val if base in ("*", "") else int(base) |
| 80 | + values.update(range(start, max_val + 1, step)) |
| 81 | + elif part == "*": |
| 82 | + values.update(range(min_val, max_val + 1)) |
| 83 | + elif "-" in part: |
| 84 | + lo, hi = part.split("-", 1) |
| 85 | + values.update(range(int(lo), int(hi) + 1)) |
| 86 | + else: |
| 87 | + values.add(int(part)) |
| 88 | + return values |
| 89 | + |
| 90 | + |
| 91 | +def cron_matches(cron_expr: str, dt: datetime) -> bool: |
| 92 | + """Check if a datetime matches a cron expression (minute hour dom month dow).""" |
| 93 | + parts = cron_expr.strip().split() |
| 94 | + if len(parts) != 5: |
| 95 | + raise ValueError(f"Invalid cron expression: {cron_expr!r} (need 5 fields)") |
| 96 | + |
| 97 | + minute_f, hour_f, dom_f, month_f, dow_f = parts |
| 98 | + return ( |
| 99 | + dt.minute in _parse_cron_field(minute_f, 0, 59) |
| 100 | + and dt.hour in _parse_cron_field(hour_f, 0, 23) |
| 101 | + and dt.day in _parse_cron_field(dom_f, 1, 31) |
| 102 | + and dt.month in _parse_cron_field(month_f, 1, 12) |
| 103 | + and dt.weekday() in _parse_cron_field(dow_f, 0, 6) |
| 104 | + ) |
| 105 | + |
| 106 | + |
| 107 | +class CronScheduler: |
| 108 | + """Manages cron jobs and heartbeat intervals.""" |
| 109 | + |
| 110 | + def __init__(self, callback: JobCallback | None = None) -> None: |
| 111 | + self._cron_jobs: dict[str, CronJob] = {} |
| 112 | + self._heartbeat_jobs: dict[str, HeartbeatJob] = {} |
| 113 | + self._callback = callback |
| 114 | + self._running = False |
| 115 | + self._task: asyncio.Task[None] | None = None |
| 116 | + |
| 117 | + def add_job(self, job: CronJob) -> None: |
| 118 | + self._cron_jobs[job.name] = job |
| 119 | + |
| 120 | + def remove_job(self, name: str) -> None: |
| 121 | + self._cron_jobs.pop(name, None) |
| 122 | + |
| 123 | + def add_heartbeat(self, job: HeartbeatJob) -> None: |
| 124 | + self._heartbeat_jobs[job.name] = job |
| 125 | + |
| 126 | + def remove_heartbeat(self, name: str) -> None: |
| 127 | + self._heartbeat_jobs.pop(name, None) |
| 128 | + |
| 129 | + def list_jobs(self) -> list[CronJob]: |
| 130 | + return list(self._cron_jobs.values()) |
| 131 | + |
| 132 | + def list_heartbeats(self) -> list[HeartbeatJob]: |
| 133 | + return list(self._heartbeat_jobs.values()) |
| 134 | + |
| 135 | + def get_job(self, name: str) -> CronJob | None: |
| 136 | + return self._cron_jobs.get(name) |
| 137 | + |
| 138 | + def get_heartbeat(self, name: str) -> HeartbeatJob | None: |
| 139 | + return self._heartbeat_jobs.get(name) |
| 140 | + |
| 141 | + @property |
| 142 | + def is_running(self) -> bool: |
| 143 | + return self._running |
| 144 | + |
| 145 | + async def start(self) -> None: |
| 146 | + self._running = True |
| 147 | + for job in self._cron_jobs.values(): |
| 148 | + job.status = JobStatus.ACTIVE |
| 149 | + for hb in self._heartbeat_jobs.values(): |
| 150 | + hb.status = JobStatus.ACTIVE |
| 151 | + logger.info( |
| 152 | + "CronScheduler started: %d cron jobs, %d heartbeats", |
| 153 | + len(self._cron_jobs), len(self._heartbeat_jobs), |
| 154 | + ) |
| 155 | + |
| 156 | + async def stop(self) -> None: |
| 157 | + self._running = False |
| 158 | + if self._task is not None: |
| 159 | + self._task.cancel() |
| 160 | + self._task = None |
| 161 | + for job in self._cron_jobs.values(): |
| 162 | + if job.status == JobStatus.ACTIVE: |
| 163 | + job.status = JobStatus.PAUSED |
| 164 | + for hb in self._heartbeat_jobs.values(): |
| 165 | + if hb.status == JobStatus.ACTIVE: |
| 166 | + hb.status = JobStatus.PAUSED |
| 167 | + logger.info("CronScheduler stopped") |
| 168 | + |
| 169 | + async def tick(self, now: datetime | None = None) -> list[str]: |
| 170 | + """Check and fire due jobs. Returns names of fired jobs. |
| 171 | +
|
| 172 | + Pass `now` explicitly for deterministic testing. |
| 173 | + """ |
| 174 | + if not self._running: |
| 175 | + return [] |
| 176 | + |
| 177 | + now = now or datetime.now(UTC) |
| 178 | + fired: list[str] = [] |
| 179 | + |
| 180 | + for job in self._cron_jobs.values(): |
| 181 | + if not job.enabled or job.status != JobStatus.ACTIVE: |
| 182 | + continue |
| 183 | + if job.max_runs is not None and job.run_count >= job.max_runs: |
| 184 | + job.status = JobStatus.COMPLETED |
| 185 | + continue |
| 186 | + if cron_matches(job.cron, now): |
| 187 | + await self._fire_job(job.persona_id, job.message, job.name) |
| 188 | + job.run_count += 1 |
| 189 | + job.last_run = now |
| 190 | + fired.append(job.name) |
| 191 | + |
| 192 | + for hb in self._heartbeat_jobs.values(): |
| 193 | + if not hb.enabled or hb.status != JobStatus.ACTIVE: |
| 194 | + continue |
| 195 | + if hb.last_run is None or ( |
| 196 | + (now - hb.last_run).total_seconds() >= hb.interval_seconds |
| 197 | + ): |
| 198 | + await self._fire_job(hb.persona_id, hb.message, hb.name) |
| 199 | + hb.last_run = now |
| 200 | + fired.append(hb.name) |
| 201 | + |
| 202 | + return fired |
| 203 | + |
| 204 | + async def _fire_job( |
| 205 | + self, persona_id: str, message: str, job_name: str, |
| 206 | + ) -> None: |
| 207 | + if self._callback is not None: |
| 208 | + try: |
| 209 | + await self._callback(persona_id, message, job_name) |
| 210 | + except Exception: |
| 211 | + logger.exception("Job '%s' callback failed", job_name) |
| 212 | + else: |
| 213 | + logger.info("Job '%s' fired: persona=%s", job_name, persona_id) |
0 commit comments