diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..4776211 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2ea55c1..660e901 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,6 +59,10 @@ jobs: run: | & $env:VENV_PYTHON -m ruff check . + - name: Run source security lint + run: | + & $env:VENV_PYTHON -m ruff check src --select S + - name: Check Ruff formatting run: | & $env:VENV_PYTHON -m ruff format --check . @@ -71,6 +75,11 @@ jobs: run: | & $env:VENV_PYTHON -m pyright + - name: Run dependency security audit + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.12' + run: | + & $env:VENV_PYTHON -m pip_audit --skip-editable + - name: Run tests run: | & $env:VENV_PYTHON -m pytest -q diff --git a/.gitignore b/.gitignore index de9b256..ebc3123 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ __pycache__/ node_modules/ playwright-report/ test-results/ +session.log* *.log .DS_Store Thumbs.db diff --git a/CHANGELOG.md b/CHANGELOG.md index 902720b..d47b4b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,21 +4,188 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +## [1.0.1] - 2026-05-14 + +### Security + +- The local editor server now requires a per-session API token for `/api/*` + requests, rejects non-JSON POST bodies, blocks untrusted `Host`/`Origin` + headers, and refuses non-loopback binds unless `allow_remote=True` is set + explicitly. +- Editor API routes now reject excessive tensor-network payloads and oversized + template parameters before running expensive validation, rendering, + contraction analysis, code generation, or subnetwork operations. +- CI now runs a dependency vulnerability audit with `pip-audit` as part of the + development dependency set. +- Bundled PrismJS assets were updated to 1.30.0, and the editor server now + emits a nonce-based Content Security Policy plus additional browser defense + headers. +- Live Python import prompts and docs now state that live import should only be + used with trusted local Python files, because it executes code in a + subprocess with the active Python environment. +- CI now runs Ruff's Bandit security rules against `src`, and Dependabot tracks + Python and GitHub Actions dependency updates. +- Added `SECURITY.md` with private reporting guidance, a maintainer disclosure + checklist, and a PrismJS advisory draft for releases that bundled PrismJS + 1.29.0. + +### Changed + +- CLI help now includes a top-level command argument quick reference and + descriptions for previously unlabeled command options and positionals. +- Academic SVG/PNG/PDF exports now inherit the active editor theme for figure + text/background colors; light themes use white PDF backgrounds, while SVG and + PNG exports can preserve transparent backgrounds. + +## [1.0.0] - 2026-05-02 + +### Added + +- The editor now supports an explicit UI launch mode across the CLI and Python + API: browser by default, `pywebview` with the optional `desktop` extra, or a + server-only mode that prints the local URL without opening a window. + +### Changed + +- PyPI trove classifiers and README now state **Production/Stable** readiness + (replacing the previous Beta development-status marker). +- Publishing polish: the README no longer advertises a removed `png` extra, and + `MANIFEST.in` no longer carries redundant exclusions for non-package + directories, which keeps `python -m build` quieter. +- `pywebview` editor launches now open their native window maximized by + default, so the desktop mode starts with the same roomy workspace users + usually expect from the browser flow. +- `pywebview` exports now open a native `Save As` dialog and write the selected + file from Python, so desktop-mode JSON, Python, and academic exports no + longer disappear into the embedded browser backend's implicit download + folder. +- `pywebview` export actions now detect the native save API lazily at export + time instead of only during page startup, so the desktop `Save As` dialog + still appears even when the webview bridge finishes attaching just after the + editor UI initializes. +- `pywebview` export actions now detect text and binary save capabilities + independently, so desktop exports still use the native `Save As` dialog even + when an embedded backend exposes only one of the two save methods at first. +- The editor bootstrap now starts immediately when the document is already in + `interactive` or `complete`, which fixes `pywebview` windows that could show + the shell markup without wiring toolbar actions, canvas interactions, or the + template bootstrap if `DOMContentLoaded` had already fired. +- Windows `pywebview` launches now reuse the packaged + [`favicon.ico`](src/tensor_network_editor/app/static/favicon.ico) for the + native window icon instead of inheriting the default Python executable icon. +- `pywebview` desktop launches now treat the native window-icon hook as best + effort, so backends that do not expose a `before_show` event still open + correctly instead of crashing during startup. +- `pywebview` desktop launches now also tolerate backends with partial window + event hooks, so missing `closed` callbacks no longer crash the editor during + startup. +- Local `EditorServer` startup now waits until a real loopback asset request can + be served before reporting readiness, which stabilizes rapid restart cycles + in tests and makes `_on_server_ready` URLs immediately usable. +- `EditorServer.stop()` is now safe even if it runs before `start()`, so early + cleanup paths no longer risk hanging while waiting for a serve loop that + never began. +- Repeated `EditorServer` startups now reuse the shared static-asset cache + without forcing an immediate full rescan of the asset tree every time, which + trims bursty local startup overhead while still refreshing changed assets + shortly afterward. +- Editor undo/redo snapshots now keep benchmark-mode session history lighter by + stripping inactive scheme view snapshots and ephemeral compare-modal state, + while the active scheme still restores its exact contraction layouts. +- Test cleanup scripts now remove `session.log*` artifacts, and the repository + ignores those rotating session logs explicitly. +- Shared HTTP test helpers now give bundled editor assets more time to load, + which reduces intermittent timeout failures when the local test server is + under load. +- Removed a few unused internal helpers from logging, periodic-mode utilities, + rendering, and einsum code generation, and deduplicated built-in template + defaults so the catalog now keeps each template's default parameters in one + shared definition. +- Periodic code generation now routes linear, grid, and tree modes through one + shared internal dispatcher, and the grid/tree array helpers reuse shared cell + preparation utilities instead of repeating the same setup in each backend. +- Static rendering helpers and the `/api/render` route now share more of their + internal option parsing, validation, and response assembly logic instead of + repeating the same flow per export format. +- Internal built-in template builders are now split by family with shared + construction primitives, while the existing template catalog and public + template APIs keep the same behavior and registration order. +- Large static renders now reuse connected-component geometry and connected + index direction lookups instead of recomputing the same layout heuristics for + every free index, which substantially reduces hot-path SVG/PNG/TikZ latency. +- CLI `edit` now exposes `--ui {browser,pywebview,server}` while keeping + `--no-browser` as a compatibility alias for the server-only mode. + +## [0.5.0] - 2026-04-30 + ### Changed - The browser editor's `Info` help panel now mentions the full current export set (`PNG`, `SVG`, `PDF`, `TikZ/LaTeX`, `Graphviz/DOT`, and `Mermaid`) and clarifies that recommended startup flows can use built-in templates, session templates, and reusable subnetwork fragments. +- Browser-editor `For`-mode code generation now keeps the commented + `TNE_SPEC_B64` round-trip metadata at the end of the generated Python source + instead of the beginning, and the editor only includes it when the new + `Metadata` checkbox is enabled in the `Code` panel. +- Browser-editor `For` mode no longer disables template settings or + selection-based `Extract`, `To Library`, and `To Template` actions just for + being in a periodic editor view; those actions now stay available for normal + tensors and only reject virtual boundary cells such as `next`, `previous`, + grid side cells, or tree parent/child placeholders. ### Added - Static exports now include a `Mermaid` flowchart renderer for documentation workflows, with matching support in the Python API, CLI `render` subcommand, and browser editor export menu. +- The editor `Reflow` popover now offers simple horizontal and vertical + alignment controls plus a 90° clockwise rotation action that also rotates the + selected tensor ports to keep their orientation consistent. +- Static geometric exports now choose shape-aware directions for free indices in + linear, circular, and 2D-grid layouts, use a stable local fallback for + irregular layouts, and draw dangling stubs with a length of two tensor radii. ### Fixed +- Linear-periodic `For`-mode validation now rejects carry plans that the code + generator cannot realize, so the editor no longer reports some + multi-boundary manual schemes as valid during analysis only to fail later + when generating Python code. +- Linear-periodic `For`-mode carry code generation now keeps non-interface + labels from the previous payload distinct from the current cell's local + labels, so valid manual schemes with repeated index names across cells no + longer collapse accidentally during periodic carry simulation. +- Linear-periodic `For`-mode `tensorkrowch` carry generation now keeps local + open edges on stable current-cell edge objects while exporting repeated + carry interfaces from the materialized result node, so periodic helpers no + longer hand later loop iterations a stale leaf edge. +- Linear-periodic `For`-mode `tensorkrowch` carry helpers now reattach + intermediate contraction results before later manual steps reuse them, so + valid periodic plans no longer lose the shared inter-cell edge during + multi-step cell contractions. +- Linear-periodic `For`-mode `tensorkrowch` carry helpers now materialize + shared edges with `reattach_edges(override=True)` instead of relying on + `network.reset()`, so repeated periodic iterations keep their inter-cell + bond visible to later `contract_between` steps. +- Normal `tensorkrowch` manual code generation no longer injects + `reattach_edges(...)` between ordinary contraction steps, so non-`For` + exports keep the simpler node structure that the standard visualizer already + handles correctly. +- Contraction-scene tensor layering now follows the current visible operands + instead of only the base spec tensor list, so selecting or dragging derived + result tensors in `single`/`contract` keeps their free ports visible above + overlapping front tensors. +- Static exports now keep free-index directions aligned with the network's real + on-canvas orientation, so vertical, diagonal, and rotated-grid layouts no + longer get reinterpreted as axis-aligned during `SVG`, `PNG`, `PDF`, and + `TikZ/LaTeX` rendering. +- `Ctrl/Cmd+Enter` now closes the editor with info reliably from contraction + planner views, preview states, `For` mode, and benchmark mode by registering + the global shortcut listener in capture mode. +- Browser-editor academic exports now invalidate the serialized-spec cache + after layout moves and rotations, so exported figures follow the current + canvas geometry instead of occasionally reusing stale pre-reflow positions. - Mermaid export now renders free indices as labeled dangling-edge terminals instead of boxed open-index nodes, so the flowchart output reads more like a tensor-network leg. diff --git a/CITATION.cff b/CITATION.cff index 9e7717d..177e7b1 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,12 +5,12 @@ type: software authors: - family-names: "Mata Ali" given-names: "Alejandro" -version: "0.4.0" -date-released: "2026-04-25" +version: "1.0.1" +date-released: "2026-05-14" repository-code: "https://github.com/DOKOS-TAYOS/Tensor-Network-Editor" url: "https://github.com/DOKOS-TAYOS/Tensor-Network-Editor" license: "MIT" -abstract: "A local Python package and browser editor for drawing tensor networks, saving versioned JSON designs, and generating readable Python code for tensor-network backends." +abstract: "A production-ready local Python package and browser editor for drawing tensor networks, saving versioned JSON designs, and generating readable Python code for tensor-network backends." keywords: - "tensor networks" - "scientific computing" diff --git a/MANIFEST.in b/MANIFEST.in index 426ccfa..dd3e848 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,5 +3,3 @@ include THIRD_PARTY_LICENSES include README.md recursive-include src/tensor_network_editor/app/static *.css *.html *.ico *.js include src/tensor_network_editor/py.typed -recursive-exclude docs/images * -recursive-exclude tests * diff --git a/README.md b/README.md index 2ebb101..fa882f9 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ [![Python 3.11+](https://img.shields.io/badge/python-3.11%2B-blue)](https://github.com/DOKOS-TAYOS/Tensor-Network-Editor) [![Windows%20%7C%20Linux](https://img.shields.io/badge/platform-Windows%20%7C%20Linux-0A7BBB)](https://github.com/DOKOS-TAYOS/Tensor-Network-Editor/actions/workflows/ci.yml) [![MIT License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE) +[![Stability](https://img.shields.io/badge/stability-production--ready-brightgreen)](https://pypi.org/project/tensor-network-editor/) `tensor-network-editor` is a local Python package for drawing tensor networks, saving them as versioned JSON, and generating readable Python code for several @@ -89,10 +90,10 @@ offline use, and generated code you can inspect. ## Why This Project -- Draw tensor-network diagrams in a local browser session. +- Draw tensor-network diagrams in a local browser or `pywebview` desktop session. - Save and reload backend-independent JSON designs. -- Recover the previous local browser session from a project draft if the tab is - closed before you save. +- Recover the previous local editor session from a project draft if the window + or tab is closed before you save. - Generate code for `tensornetwork`, `quimb`, `tensorkrowch`, `einsum_numpy`, and `einsum_torch`. - Render designs to static SVG, TikZ/LaTeX, Graphviz/DOT, or Mermaid from Python, the @@ -135,9 +136,9 @@ offline use, and generated code you can inspect. - Get structural analysis with FLOP and MAC cost summaries. - Use the package from the CLI or directly from Python. -The editor opens in your browser, but the server runs locally on your own -machine. No Node runtime or cloud service is needed for normal use. A future -desktop wrapper such as `pywebview` may sit on top of this local flow, but the +The editor server runs locally on your own machine. By default it opens in your +browser, and you can also ask for a native `pywebview` window with the optional +`desktop` extra. No Node runtime or cloud service is needed for normal use. The browser-served editor remains the core interface and compatibility target. ## Minimal Installation @@ -176,7 +177,20 @@ tensor-network-editor edit ``` This command starts a local server and waits until you press `Done` or -`Cancel` in the browser session. +`Cancel` in the editor session. + +Open the same local editor in a native `pywebview` window: + +```bash +python -m pip install "tensor-network-editor[desktop]" +tensor-network-editor edit --ui pywebview +``` + +Start only the local server and open the printed URL yourself: + +```bash +tensor-network-editor edit --ui server +``` Pick a color theme when you launch the editor: @@ -205,8 +219,7 @@ tensor-network-editor doctor my_network.json tensor-network-editor doctor my_network.json --format json ``` -Render one saved design as SVG, PDF, TikZ/LaTeX, Graphviz/DOT, Mermaid, or with the -optional `png` extra, PNG: +Render one saved design as SVG, PDF, TikZ/LaTeX, Graphviz/DOT, Mermaid, or PNG: ```bash tensor-network-editor render my_network.json --format svg --output figure.svg @@ -369,6 +382,9 @@ whole load immediately. scripts or imports already resolvable from the active `.venv`. If a Python file depends on sibling modules or path-sensitive imports, prefer the Python API or CLI with the real file path. +- Only use live import with local Python files you trust. Live import executes + the file in a subprocess with the active Python environment, so trusted code + can still read or write local files. - Tensor values in the visual editor support portable built-in initializers, dtype choices, JSON-friendly complex scalars, and external `.npy`, `.npz`, and `.pt` data references. Symbolic expressions are not supported yet. @@ -386,6 +402,7 @@ whole load immediately. - Source code: [github.com/DOKOS-TAYOS/Tensor-Network-Editor](https://github.com/DOKOS-TAYOS/Tensor-Network-Editor) - Changelog: [CHANGELOG.md](CHANGELOG.md) +- Security policy: [SECURITY.md](SECURITY.md) - Example script: [examples/basic_usage.py](examples/basic_usage.py) - Issue tracker: [github.com/DOKOS-TAYOS/Tensor-Network-Editor/issues](https://github.com/DOKOS-TAYOS/Tensor-Network-Editor/issues) - License: [LICENSE](LICENSE) diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..b28987b --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,76 @@ +# Security Policy + +## Reporting a Vulnerability + +Please use GitHub private vulnerability reporting for security issues in this +repository when it is available: + +https://github.com/DOKOS-TAYOS/Tensor-Network-Editor/security/advisories/new + +Do not open a public issue with exploit details, proof-of-concept payloads, or +private environment information. If private reporting is unavailable, open a +public issue asking for a preferred security contact without including the +technical details. + +Useful reports include: + +- affected version or commit +- operating system and Python version +- whether the browser editor, CLI, or Python API is involved +- concise reproduction steps +- expected impact, if known + +## Maintainer Disclosure Checklist + +For a confirmed issue: + +1. Prepare the fix privately or in a normal pull request when the details are + already public. +2. Publish the patched release before publishing the advisory, unless users + need immediate mitigation guidance. +3. Create or update a GitHub Security Advisory with affected versions, patched + versions, severity, impact, workarounds, and references. +4. Mention the fix in `CHANGELOG.md` and the release notes. +5. Consider yanking affected PyPI releases only when discouraging new installs + of those exact releases is safer than leaving them available. Prefer a clear + yank reason that points users to the patched release. + +In short: publish the patched release before publishing the advisory when users +do not need immediate mitigation guidance. + +## PrismJS Advisory Draft + +Use this when publishing the bundled PrismJS update as a repository advisory +for `tensor-network-editor`. This is not a new PrismJS vulnerability; it is a +vendored dependency advisory that points to the upstream issue. + +- Title: Bundled PrismJS before 1.30.0 in the browser-based editor +- Related upstream advisory: CVE-2024-53382 / GHSA-x7hr-w5r2-h6wg +- Affected package: `tensor-network-editor` +- Affected versions: releases that bundle PrismJS 1.29.0 in + `src/tensor_network_editor/app/static/vendor/` +- Patched version: the first release that bundles PrismJS 1.30.0 or later +- Severity: Moderate, matching the upstream PrismJS advisory unless new project + evidence shows a different impact + +Suggested impact text: + +```text +Tensor Network Editor bundled PrismJS 1.29.0 for syntax highlighting in the +browser-based editor. PrismJS versions before 1.30.0 are affected by +CVE-2024-53382 / GHSA-x7hr-w5r2-h6wg. + +Installing or importing the Python package alone does not execute PrismJS. The +affected code path is the browser-based editor. Risk is higher if the local +editor is exposed beyond localhost, or if untrusted HTML-like content can reach +the editor UI. +``` + +Suggested recommendation text: + +```text +Upgrade to the patched Tensor Network Editor release. If you cannot upgrade +immediately, avoid exposing the local editor outside trusted loopback/local +workflows and avoid opening untrusted designs or Python-derived content in the +browser editor. +``` diff --git a/THIRD_PARTY_LICENSES b/THIRD_PARTY_LICENSES index 090d73b..04688bf 100644 --- a/THIRD_PARTY_LICENSES +++ b/THIRD_PARTY_LICENSES @@ -8,6 +8,8 @@ Runtime pip-installed dependencies are not bundled into this source distribution or wheel. Required installs (`matplotlib`, `opt_einsum`) and optional extras such as `numpy`, `torch`, `quimb`, `tensornetwork`, `tensorkrowch`, or `pywebview` remain covered by their own licenses and notices. +Development tools installed through the `dev` extra are also not bundled and +remain covered by their own upstream licenses and notices. Runtime dependency notice ------------------------- @@ -28,6 +30,21 @@ Matplotlib is a required runtime dependency for academic SVG/PNG/PDF rendering. vendored into this repository. Its own distribution carries the authoritative license text and notices. +Development dependency notice +----------------------------- + +pip-audit is a development and CI dependency used to scan Python dependencies +for known vulnerabilities. + +- Package: pip-audit +- Version range used by this project: `>=2.7` +- Project: https://pypi.org/project/pip-audit/ +- License: Apache Software License +- Notice handling: + pip-audit is installed as an external development dependency rather than + being vendored into this repository. Its own distribution carries the + authoritative license text and notices. + 1. Cytoscape.js - Bundled file: `src/tensor_network_editor/app/static/vendor/cytoscape.min.js` @@ -45,7 +62,7 @@ Matplotlib is a required runtime dependency for academic SVG/PNG/PDF rendering. - Bundled files: `src/tensor_network_editor/app/static/vendor/prism-core.min.js` `src/tensor_network_editor/app/static/vendor/prism-python.min.js` - - Version: 1.29.0 + - Version: 1.30.0 - Project: https://prismjs.com/ - Upstream repository: https://github.com/PrismJS/prism - Copyright: diff --git a/docs/api.md b/docs/api.md index 4d30ca1..04f630e 100644 --- a/docs/api.md +++ b/docs/api.md @@ -26,6 +26,7 @@ The package exposes the main functions and models at the top level: from tensor_network_editor import ( EngineName, EditorThemeName, + EditorUiMode, NetworkBuilder, NetworkSpec, PythonLoadOptions, @@ -69,7 +70,7 @@ Useful public modules: | Module | Use it for | | --- | --- | -| `tensor_network_editor.editor` | `EditorLaunchOptions`, `EditorThemeName`, and `open_editor(...)` | +| `tensor_network_editor.editor` | `EditorLaunchOptions`, `EditorThemeName`, `EditorUiMode`, and `open_editor(...)` | | `tensor_network_editor.builder` | fluent `NetworkBuilder`, `TensorHandle`, and `IndexHandle` helpers | | `tensor_network_editor.io` | JSON/Python loading, saving, `serialize_spec(...)`, and `SCHEMA_VERSION` | | `tensor_network_editor.models` | data classes, result models, enums, and periodic-mode types | @@ -88,8 +89,7 @@ user-facing API. ## Launch the Editor -Use `open_editor(...)` when you want a local browser editing session from -Python. +Use `open_editor(...)` when you want a local editing session from Python. ```python from tensor_network_editor import EngineName @@ -101,7 +101,7 @@ def main() -> None: options=EditorLaunchOptions( default_engine=EngineName.EINSUM_NUMPY, theme="light", - open_browser=True, + ui_mode="browser", ), ) @@ -122,7 +122,8 @@ Main parameters: - `options.default_collection_format`: initial tensor collection layout - `options.theme`: initial color theme, one of `dark`, `light`, `contrast`, `colorblind`, or `shiny` -- `options.open_browser`: open the browser automatically +- `options.ui_mode`: choose `browser`, `pywebview`, or `server` +- `options.open_browser`: legacy browser/server compatibility flag - `options.host`: local host address, default `127.0.0.1` - `options.port`: local port, default `0` so the OS chooses one - `options.print_code`: print generated code after confirmation @@ -136,6 +137,12 @@ Main parameters: leave it unset to use the project-local default under `.tensor-network-editor/drafts/` +Install the optional desktop extra before using `options.ui_mode="pywebview"`: + +```bash +python -m pip install "tensor-network-editor[desktop]" +``` + Return value: - `None` when the user cancels @@ -266,6 +273,9 @@ Important details: active Python interpreter, supports live `quimb` and `tensornetwork` objects, and accepts `object_name="..."` when several compatible globals exist +- Only use live import with local Python files you trust. Live import executes + the file in a subprocess with the active Python environment, so trusted code + can still read or write local files. - `PythonLoadOptions.reconstruction_level="simple"` rebuilds only the portable network structure: tensors, inferable connections, and portable tensor-data payloads - `PythonLoadOptions.reconstruction_level="best_available"` is currently only supported diff --git a/docs/cli.md b/docs/cli.md index eae10be..93f1d6d 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -26,7 +26,7 @@ reusable-subnetwork catalogs. ## Launch the Editor -Start the local browser editor: +Start the local editor: ```bash tensor-network-editor edit @@ -42,6 +42,8 @@ Useful options: tensor-network-editor edit --load my_network.json tensor-network-editor edit --engine quimb tensor-network-editor edit --theme light +tensor-network-editor edit --ui pywebview +tensor-network-editor edit --ui server tensor-network-editor edit --save-code generated_network.py tensor-network-editor edit --print-code tensor-network-editor edit --no-browser @@ -53,8 +55,19 @@ You can combine them: tensor-network-editor edit --load my_network.json --engine quimb --save-code generated_network.py ``` -Use `--no-browser` when you want to start the local server but open the printed -URL manually. +By default, `edit` opens the local URL in your browser. + +Use `--ui pywebview` when you want the same local editor inside a native +desktop window and you have installed the optional desktop extra: + +```bash +python -m pip install "tensor-network-editor[desktop]" +tensor-network-editor edit --ui pywebview +``` + +Use `--ui server` when you want to start the local server but open the printed +URL manually. `--no-browser` remains as a compatibility alias for the same +server-only mode. Use `--theme` to choose the editor colors at startup. Available themes are `dark`, `light`, `contrast`, `colorblind`, and `shiny`; `dark` is the default. @@ -185,7 +198,7 @@ from tensor_network_editor.editor import EditorLaunchOptions, open_editor open_editor( options=EditorLaunchOptions( - open_browser=False, + ui_mode="server", log_file_path="tne-editor.log", log_file_max_bytes=10_485_760, log_file_backup_count=5, @@ -215,6 +228,10 @@ If `--python-import-mode live` is used on generated source and the live import fails because the backend package is missing, the loader falls back to the static generated-source parser and reports the fallback as a warning. +Only use live import with local Python files you trust. Live import executes +the file in a subprocess with the active Python environment, so trusted code can +still read or write local files. + ## Headless Commands Headless commands work without opening the visual editor: diff --git a/docs/extended_guide.md b/docs/extended_guide.md index 8d0f6d2..0466d33 100644 --- a/docs/extended_guide.md +++ b/docs/extended_guide.md @@ -53,9 +53,9 @@ the design and target another backend later. The editor itself runs locally. The package starts a Python HTTP server on your machine, opens a browser tab by default, and waits until you press `Done` or -`Cancel`. Normal use does not require Node.js or a cloud service. A future -desktop wrapper such as `pywebview` can sit on top of the same local server, -but the browser-served editor remains the primary supported surface. +`Cancel`. You can also ask for a native `pywebview` window with the optional +desktop extra. Normal use does not require Node.js or a cloud service, and the +browser-served editor remains the primary supported surface. ## Choosing The Right Tool @@ -178,6 +178,13 @@ Use `--no-browser` when automatic browser opening is blocked, when you work over SSH, or when you want to copy the printed local URL into a browser manually. +Open the same local editor in a native desktop window: + +```bash +python -m pip install "tensor-network-editor[desktop]" +tensor-network-editor edit --ui pywebview +``` + From Python: ```python @@ -190,7 +197,7 @@ def main() -> None: options=EditorLaunchOptions( default_engine=EngineName.EINSUM_NUMPY, theme="light", - open_browser=True, + ui_mode="browser", ) ) if result is None: @@ -818,6 +825,10 @@ Generated exports provide the richest round-trip. External static profiles and live imports are intentionally conservative and do not recover editor layout, groups, notes, or manual contraction plans. +Only use live import with local Python files you trust. Live import executes +the file in a subprocess with the active Python environment, so trusted code can +still read or write local files. + If live import is requested for generated source and the backend import fails because the backend package is missing, the loader can fall back to the static generated-source parser and report a warning. diff --git a/docs/getting-started.md b/docs/getting-started.md index c5f2f60..ed96613 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -52,13 +52,22 @@ What happens: - the command waits until you press `Done` or `Cancel` - `Done` returns the final design and generated code for the selected engine +If you prefer a native desktop window instead of the browser, install the +optional desktop extra and run: + +```bash +python -m pip install "tensor-network-editor[desktop]" +tensor-network-editor edit --ui pywebview +``` + If your environment cannot open a browser automatically, use: ```bash -tensor-network-editor edit --no-browser +tensor-network-editor edit --ui server ``` -Then open the local URL printed in the terminal. +Then open the local URL printed in the terminal. `--no-browser` still works as +a compatibility alias for the same server-only mode. You can also choose the editor colors when the session starts: diff --git a/docs/installation.md b/docs/installation.md index 68d4913..dc40b28 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -80,8 +80,10 @@ python -m pip install "tensor-network-editor[desktop]" ``` The `desktop` extra installs `pywebview` for environments that want a desktop -webview dependency available. The standard documented workflow is still the -local browser editor. +webview dependency available. After installing it, you can launch the editor in +its own native window with `tensor-network-editor edit --ui pywebview` or +`EditorLaunchOptions(ui_mode="pywebview")`. The standard default workflow is +still the local browser editor. You can combine extras: @@ -227,6 +229,13 @@ tensor-network-editor edit If the browser does not open, see [troubleshooting.md#the-browser-did-not-open-automatically](troubleshooting.md#the-browser-did-not-open-automatically). +If you prefer a native desktop window instead of the browser, install the +desktop extra and launch: + +```bash +tensor-network-editor edit --ui pywebview +``` + ## Cleanup Scripts The repository includes cleanup scripts for generated local artifacts: diff --git a/docs/superpowers/plans/2026-04-29-mermaid-export.md b/docs/superpowers/plans/2026-04-29-mermaid-export.md deleted file mode 100644 index 0091557..0000000 --- a/docs/superpowers/plans/2026-04-29-mermaid-export.md +++ /dev/null @@ -1,321 +0,0 @@ -# Mermaid Export Implementation Plan - -> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Add a Mermaid export format that users can generate from Python, the CLI, and the browser editor for documentation-friendly tensor-network diagrams. - -**Architecture:** Extend the existing static rendering family with a new text renderer in `rendering.py`, then thread the new format through the `/api/render` backend route, the CLI `render` subcommand, and the browser export UI. Reuse `DotRenderOptions` for label toggles, keep the Mermaid output structure-oriented rather than geometry-oriented, and degrade gracefully for notes and groups. - -**Tech Stack:** Python 3.12, typed package exports, argparse CLI, browser editor JavaScript, pytest, pyright, ruff - ---- - -### Task 1: Add the core Mermaid renderer with TDD - -**Files:** -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\rendering.py` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\__init__.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_rendering.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_api.py` - -- [ ] **Step 1: Write the failing renderer tests** - -Add focused tests to `tests/test_rendering.py` for: - -```python -def test_render_spec_mermaid_returns_flowchart_for_normal_network() -> None: - mermaid = render_spec_mermaid(build_sample_spec()) - - assert mermaid.startswith("flowchart LR\n") - assert 'tensor_tensor_a["A"]' in mermaid - assert 'tensor_tensor_b["B"]' in mermaid - assert 'tensor_tensor_a <-->|"bond_x / x=3"| tensor_tensor_b' in mermaid - - -def test_render_spec_mermaid_can_hide_tensor_index_and_bond_labels() -> None: - mermaid = render_spec_mermaid( - build_sample_spec(), - options=DotRenderOptions( - show_tensor_labels=False, - show_index_labels=False, - show_edge_labels=False, - ), - ) - - assert 'tensor_tensor_a["tensor_a"]' in mermaid - assert 'tensor_tensor_b["tensor_b"]' in mermaid - assert 'bond_x' not in mermaid - assert 'x=3' not in mermaid - - -def test_render_spec_mermaid_includes_hyperedges_groups_and_notes() -> None: - spec = build_three_tensor_hyperedge_spec() - mermaid = render_spec_mermaid(spec) - - assert "subgraph group_demo [Demo Group]" in mermaid - assert 'hyperedge_h["shared_h"]' in mermaid - assert '%% Note: Check the contraction order' in mermaid - - -def test_render_spec_mermaid_writes_output_path(tmp_path: Path) -> None: - output_path = tmp_path / "network.mmd" - - mermaid = render_spec_mermaid(build_sample_spec(), output_path=output_path) - - assert output_path.read_text(encoding="utf-8") == mermaid -``` - -Add a public API coverage check in `tests/test_api.py` similar to the existing static render exports: - -```python -assert tensor_network_editor.render_spec_mermaid is render_spec_mermaid -``` - -- [ ] **Step 2: Run the renderer tests to verify they fail** - -Run: `.\.venv\Scripts\python -m pytest tests\test_rendering.py -k mermaid` - -Expected: FAIL because `render_spec_mermaid` does not exist yet. - -- [ ] **Step 3: Write the minimal Mermaid renderer** - -Implement in `src/tensor_network_editor/rendering.py`: - -```python -def render_spec_mermaid( - spec: NetworkSpec, - *, - options: DotRenderOptions | None = None, - output_path: StrPath | None = None, -) -> str: - ... -``` - -Add a small renderer class and helpers that: -- emit `flowchart LR` -- create safe Mermaid ids from existing stable ids -- render tensors, pairwise edges, open indices, hyperedge hubs, groups, and note comments -- reuse existing DOT label logic where possible -- write UTF-8 text when `output_path` is provided - -Then export it from `__all__` and the lazy exports in `src/tensor_network_editor/__init__.py`. - -- [ ] **Step 4: Run the renderer tests to verify they pass** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_rendering.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_api.py -k render_spec_mermaid` - -Expected: PASS - -- [ ] **Step 5: Commit** - -```bash -git add src/tensor_network_editor/rendering.py src/tensor_network_editor/__init__.py tests/test_rendering.py tests/test_api.py -git commit -m "Add Mermaid renderer" -``` - -### Task 2: Integrate Mermaid into the backend route and CLI with TDD - -**Files:** -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\cli.py` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\app\routes.py` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\internal\cli\_cli_handlers.py` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\internal\cli\_cli_parser.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_cli.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_app_routes.py` - -- [ ] **Step 1: Write the failing integration tests** - -Add route coverage in `tests/test_app_routes.py`: - -```python -def test_render_route_returns_mermaid_export(editor_server: EditorServer) -> None: - spec = build_sample_spec() - serialized_spec = {"schema_version": SCHEMA_VERSION, "network": spec.to_dict()} - - payload = request_json( - f"{editor_server.base_url}/api/render", - method="POST", - payload={"format": "mermaid", "spec": serialized_spec}, - ) - - assert payload["format"] == "mermaid" - assert payload["content_type"] == "text/plain;charset=utf-8" - assert payload["text"].startswith("flowchart LR\n") -``` - -Add CLI coverage in `tests/test_cli.py`: - -```python -def test_render_subcommand_writes_mermaid_output(sample_spec: NetworkSpec) -> None: - with ( - patch("tensor_network_editor.cli.load_spec", return_value=sample_spec), - patch("tensor_network_editor.cli.render_spec_mermaid", return_value="flowchart LR\n"), - ): - exit_code = main( - ["render", "saved-network.json", "--format", "mermaid", "--output", "graph.mmd"] - ) - - assert exit_code == 0 - - -def test_render_subcommand_prints_mermaid_when_no_output( - sample_spec: NetworkSpec, - capsys: pytest.CaptureFixture[str], -) -> None: - with ( - patch("tensor_network_editor.cli.load_spec", return_value=sample_spec), - patch("tensor_network_editor.cli.render_spec_mermaid", return_value="flowchart LR\n"), - ): - exit_code = main(["render", "saved-network.json", "--format", "mermaid"]) - - assert exit_code == 0 - assert capsys.readouterr().out == "flowchart LR\n\n" -``` - -- [ ] **Step 2: Run the integration tests to verify they fail** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_app_routes.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_cli.py -k mermaid` - -Expected: FAIL because the route and CLI do not accept `mermaid` yet. - -- [ ] **Step 3: Wire Mermaid into the route and CLI** - -Add `render_spec_mermaid` imports and format branches in: -- `src/tensor_network_editor/cli.py` -- `src/tensor_network_editor/internal/cli/_cli_handlers.py` -- `src/tensor_network_editor/internal/cli/_cli_parser.py` -- `src/tensor_network_editor/app/routes.py` - -Use: - -```python -content_type = "text/plain;charset=utf-8" -``` - -Use `.mmd` as the expected file extension in user-facing messages and downloads. - -- [ ] **Step 4: Run the integration tests to verify they pass** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_app_routes.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_cli.py -k mermaid` - -Expected: PASS - -- [ ] **Step 5: Commit** - -```bash -git add src/tensor_network_editor/cli.py src/tensor_network_editor/app/routes.py src/tensor_network_editor/internal/cli/_cli_handlers.py src/tensor_network_editor/internal/cli/_cli_parser.py tests/test_cli.py tests/test_app_routes.py -git commit -m "Integrate Mermaid export into CLI and API" -``` - -### Task 3: Add Mermaid to the browser editor and docs with TDD - -**Files:** -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\app\static\index.html` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\app\static\js\core\dom.js` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\app\static\js\shell\editorShellBindings.js` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\app\static\js\session\sessionEditorFlows.js` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\README.md` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\CHANGELOG.md` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_app_assets.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_frontend_architecture.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_frontend_runtime.py` - -- [ ] **Step 1: Write the failing editor and docs tests** - -Add asset checks in `tests/test_app_assets.py` for: -- `id="export-mermaid-menu-item"` -- `` -- Mermaid listed in help text when export formats are enumerated -- DOM wiring for `exportMermaidMenuItem` - -Add runtime coverage in `tests/test_frontend_runtime.py` similar to the existing academic export flow: - -```javascript -await flows.downloadExportAs("mermaid"); -``` - -Then assert: -- one `renderSpec` call with `payload.format === "mermaid"` -- one text download for `draft_demo.mmd` -- `contentType === "text/plain;charset=utf-8"` - -- [ ] **Step 2: Run the editor tests to verify they fail** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_app_assets.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_frontend_runtime.py -k mermaid` - -Expected: FAIL because the editor does not expose Mermaid yet. - -- [ ] **Step 3: Implement the editor export wiring** - -Update: -- `index.html` to add a Mermaid export menu item and selector option -- `dom.js` to expose `exportMermaidMenuItem` -- `editorShellBindings.js` to bind `downloadExportAs("mermaid")` -- `sessionEditorFlows.js` to add Mermaid to `exportDetails`, use `.mmd`, and route it through text download - -Also update `README.md` and `CHANGELOG.md` to mention Mermaid export support. - -- [ ] **Step 4: Run the editor tests to verify they pass** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_app_assets.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_frontend_runtime.py -k mermaid` - -Expected: PASS - -- [ ] **Step 5: Commit** - -```bash -git add src/tensor_network_editor/app/static/index.html src/tensor_network_editor/app/static/js/core/dom.js src/tensor_network_editor/app/static/js/shell/editorShellBindings.js src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js README.md CHANGELOG.md tests/test_app_assets.py tests/test_frontend_architecture.py tests/test_frontend_runtime.py -git commit -m "Add Mermaid export to browser editor" -``` - -### Task 4: Final verification and cleanup - -**Files:** -- Verify only - -- [ ] **Step 1: Run focused pytest coverage** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_rendering.py` -- `.\.venv\Scripts\python -m pytest tests\test_api.py -k render_spec_mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_cli.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_app_routes.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_app_assets.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_frontend_runtime.py -k mermaid` - -Expected: PASS - -- [ ] **Step 2: Run Python quality checks** - -Run: -- `.\.venv\Scripts\python -m ruff check . --fix` -- `.\.venv\Scripts\python -m ruff format .` - -Expected: PASS - -- [ ] **Step 3: Run type checking and note pre-existing failures separately** - -Run: -- `.\.venv\Scripts\python -m pyright` - -Expected: either PASS or only the already-known unrelated failures in: -- `tests/test_app_routes.py` -- `tests/test_app_server.py` -- `tests/test_session.py` - -- [ ] **Step 4: Commit any final cleanup** - -```bash -git add README.md CHANGELOG.md src tests -git commit -m "Polish Mermaid export integration" -``` diff --git a/docs/superpowers/specs/2026-04-29-mermaid-export-design.md b/docs/superpowers/specs/2026-04-29-mermaid-export-design.md deleted file mode 100644 index bae4812..0000000 --- a/docs/superpowers/specs/2026-04-29-mermaid-export-design.md +++ /dev/null @@ -1,174 +0,0 @@ -# Mermaid Export Design - -## Summary - -Add a new static export format, `mermaid`, for tensor-network diagrams. -The goal is to produce a portable text representation that users can paste -directly into GitHub, Markdown documents, or Mermaid-enabled documentation -tools. - -This export is documentation-oriented. It should preserve graph structure and -labels, not editor geometry or visual styling. - -## Goals - -- Add a first-class `mermaid` export alongside `svg`, `png`, `pdf`, `tikz`, - and `dot`. -- Keep the API, CLI, and browser editor export flows consistent. -- Reuse the existing export label toggles for tensor names, index names, and - bond names. -- Generate Mermaid that is robust and easy to paste into Markdown. - -## Non-Goals - -- Preserve canvas positions, exact layout, colors, or node sizes. -- Reproduce the editor appearance inside Mermaid. -- Add a separate `markdown` export format in v1. -- Fail the export because Mermaid cannot express some detail exactly. - -## Recommended Approach - -Implement a new renderer function: - -- `render_spec_mermaid(spec: NetworkSpec, *, options: DotRenderOptions | None = None, output_path: StrPath | None = None) -> str` - -`DotRenderOptions` is the best fit for v1 because Mermaid is also a -text-oriented graph export and needs the same label visibility controls as -`dot`. - -The renderer should emit a complete Mermaid diagram using: - -```text -flowchart LR -``` - -This direction matches the current left-to-right mental model already used in -`dot`. - -## Representation Rules - -### Tensors - -- Each tensor becomes one Mermaid node. -- If `show_tensor_labels` is true, use the tensor name as the visible label. -- If `show_tensor_labels` is false, keep the node but use a minimal fallback - label based on the tensor id so the graph remains valid and readable. - -### Pairwise edges - -- Each standard edge becomes one Mermaid connection between the two tensor - nodes. -- The edge label should follow the current `dot` behavior: - - show bond name and index label when both are enabled - - show only bond name when only bond names are enabled - - show only index name and dimension when only index names are enabled - - show no label when both are disabled - -### Open indices - -- Each open index becomes a terminal Mermaid node connected to its tensor. -- The node label should reuse the same label logic already used by `dot` for - open indices. - -### Hyperedges - -- Each hyperedge becomes a synthetic hub node connected to all endpoint - tensors. -- The hub label should use the hyperedge name when bond labels are enabled. -- Endpoint edge labels should reuse the current `dot` hyperedge endpoint label - logic when index labels are enabled. - -### Groups - -- Each group should become a Mermaid `subgraph`. -- The renderer should place the member tensor nodes inside that `subgraph`. -- If Mermaid cannot faithfully reflect complex overlap or crossing semantics, - the export should still succeed with a simple `subgraph` structure. - -### Notes - -- Notes should not become positioned visual nodes in v1. -- Export each note as a Mermaid comment line: - -```text -%% Note: Check the contraction order -``` - -This keeps note content available for documentation without forcing awkward -diagram geometry. - -## Escaping and Identifiers - -- Mermaid node ids must be generated from safe internal identifiers, not from - raw labels. -- Visible labels must be escaped conservatively so quotes, brackets, newlines, - and punctuation do not break the diagram. -- Reuse existing conservative escaping ideas from `dot` and `tikz`, but keep - Mermaid-specific syntax rules separate in small helper functions. - -## API and Integration - -### Python API - -- Export `render_spec_mermaid` from `tensor_network_editor.rendering`. -- Re-export it from `tensor_network_editor.__init__`. - -### CLI - -- Extend `render --format` with `mermaid`. -- Print to stdout when `--output` is omitted. -- Use `.mmd` as the recommended output extension. -- Label the success message as `Mermaid`. - -### Browser editor - -- Add `Mermaid` to the export menu and the export format selector. -- Route it through the same `/api/render` flow as `tikz` and `dot`. -- Download it as text with the `.mmd` extension. - -### Backend route - -- Extend `/api/render` to accept `format == "mermaid"`. -- Return: - - `format: "mermaid"` - - `text: ` - - `content_type: "text/plain;charset=utf-8"` - -## Error Handling - -- If the spec is valid, Mermaid export should succeed. -- Unsupported visual details must degrade gracefully instead of raising. -- Rendering should only fail for the same categories already used elsewhere, - such as invalid payloads or invalid specs. - -## Testing - -Add focused tests for: - -- `render_spec_mermaid` basic output for a normal network -- label toggle behavior for tensor, index, and bond labels -- open indices and hyperedges -- group and note emission -- escaping of special characters -- API route `/api/render` with `format="mermaid"` -- CLI `render --format mermaid` -- editor menu and export selector wiring -- frontend download flow and output filename extension - -## Documentation Updates - -- Add Mermaid export to `README.md`. -- Add Mermaid export to the editor help text if that text enumerates supported - export formats. -- Add a short `CHANGELOG.md` entry when implementation lands. - -## Rollout Notes - -The first version should stay intentionally simple: - -- structure first -- labels second -- visual fidelity out of scope - -This keeps the renderer predictable, testable, and useful for documentation -without turning Mermaid into a second layout engine. diff --git a/docs/user-guide.md b/docs/user-guide.md index 003616e..35c8748 100644 --- a/docs/user-guide.md +++ b/docs/user-guide.md @@ -158,11 +158,11 @@ Python: ```python from tensor_network_editor.editor import EditorLaunchOptions, open_editor -open_editor(options=EditorLaunchOptions(theme="contrast")) +open_editor(options=EditorLaunchOptions(theme="contrast", ui_mode="browser")) ``` Available themes are `dark`, `light`, `contrast`, `colorblind`, and `shiny`. -The choice only affects the browser editor appearance; saved network JSON and +The choice only affects the editor appearance; saved network JSON and recoverable drafts keep the same model data. ## Templates diff --git a/pyproject.toml b/pyproject.toml index 6ae9a8c..3e5073f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "tensor-network-editor" dynamic = ["version"] -description = "Local visual editor for tensor networks: versioned JSON diagrams and Python code for einsum and optional backends." +description = "Production-ready local visual editor for tensor networks: versioned JSON diagrams and Python code for einsum and optional backends." readme = "README.md" requires-python = ">=3.11" license = "MIT" @@ -38,7 +38,7 @@ keywords = [ "visualization", ] classifiers = [ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", "Intended Audience :: Science/Research", "Operating System :: OS Independent", @@ -66,6 +66,7 @@ torch = ["torch>=2.0"] dev = [ "build>=1.2", "mypy>=1.10", + "pip-audit>=2.7", "pyright>=1.1", "pytest>=8.2", "ruff>=0.6", diff --git a/scripts/clean.bat b/scripts/clean.bat index 31fc285..7b46220 100644 --- a/scripts/clean.bat +++ b/scripts/clean.bat @@ -22,6 +22,7 @@ call :remove_glob_dirs_warn ".\pytest-cache-files-*" call :remove_glob_files ".\.coverage" call :remove_glob_files ".\.coverage.*" call :remove_glob_files ".\coverage.xml" +call :remove_glob_files ".\session.log*" call :remove_dir "__pycache__" call :remove_named_dirs ".\src" "__pycache__" diff --git a/scripts/clean.sh b/scripts/clean.sh index 7d2fffa..0ce26b6 100644 --- a/scripts/clean.sh +++ b/scripts/clean.sh @@ -98,6 +98,7 @@ remove_glob_dirs_warn "./pytest-cache-files-*" remove_file_pattern "./.coverage" remove_file_pattern "./.coverage.*" remove_file_pattern "./coverage.xml" +remove_file_pattern "./session.log*" remove_dir "__pycache__" remove_named_dirs "./src" "__pycache__" diff --git a/src/tensor_network_editor/__init__.py b/src/tensor_network_editor/__init__.py index a06978e..6ec5d09 100644 --- a/src/tensor_network_editor/__init__.py +++ b/src/tensor_network_editor/__init__.py @@ -17,7 +17,7 @@ from .analysis import analyze_contraction, analyze_spec from .builder import IndexHandle, NetworkBuilder, TensorHandle from .canonicalization import canonicalize_spec - from .editor import EditorLaunchOptions, EditorThemeName, open_editor + from .editor import EditorLaunchOptions, EditorThemeName, EditorUiMode, open_editor from .internal.diffing._diffing import diff_specs, semantic_diff_specs from .io import PythonLoadOptions, load_python_spec, load_spec, save_spec from .linting import lint_spec @@ -76,6 +76,7 @@ "EdgeSpec", "EditorLaunchOptions", "EditorThemeName", + "EditorUiMode", "EditorResult", "EngineName", "DotRenderOptions", @@ -130,6 +131,7 @@ "EdgeSpec": ".models", "EditorLaunchOptions": ".editor", "EditorThemeName": ".editor", + "EditorUiMode": ".editor", "EditorResult": ".models", "EngineName": ".models", "DotRenderOptions": ".rendering", diff --git a/src/tensor_network_editor/_public_codegen.py b/src/tensor_network_editor/_public_codegen.py index 21f52ff..66f566f 100644 --- a/src/tensor_network_editor/_public_codegen.py +++ b/src/tensor_network_editor/_public_codegen.py @@ -27,6 +27,7 @@ def generate_code( *, engine: EngineIdentifier, collection_format: TensorCollectionFormat = TensorCollectionFormat.LIST, + include_roundtrip_metadata: bool = True, output_path: StrPath | None = None, print_code: bool = False, external_data_base_path: StrPath | None = None, @@ -43,7 +44,10 @@ def generate_code( external_data_base_path=external_data_base_path, ) result = _generate_code( - codegen_spec, engine, collection_format=collection_format + codegen_spec, + engine, + collection_format=collection_format, + include_roundtrip_metadata=include_roundtrip_metadata, ) if print_code: log_branch(LOGGER, "Printing generated code to stdout") diff --git a/src/tensor_network_editor/_version.py b/src/tensor_network_editor/_version.py index 3a58b17..efd8da9 100644 --- a/src/tensor_network_editor/_version.py +++ b/src/tensor_network_editor/_version.py @@ -4,4 +4,4 @@ from typing import Final -__version__: Final[str] = "0.4.0" +__version__: Final[str] = "1.0.1" diff --git a/src/tensor_network_editor/app/_analysis_services.py b/src/tensor_network_editor/app/_analysis_services.py index 0822cf4..ad42cf4 100644 --- a/src/tensor_network_editor/app/_analysis_services.py +++ b/src/tensor_network_editor/app/_analysis_services.py @@ -15,6 +15,7 @@ from ..internal.analysis._contraction_analysis import _analyze_validated_contraction from ..internal.analysis._contraction_analysis_types import ContractionAnalysisResult from ..models import NetworkSpec, ValidationIssue +from ._limits import enforce_spec_api_limits LOGGER = logging.getLogger(__name__) @@ -32,6 +33,7 @@ def analyze_serialized_contraction( emit_start=False, ) as success_context: spec = deserialize_spec_fn(serialized_spec) + enforce_spec_api_limits(spec) issues = validate_spec_fn(spec) if issues: log_branch( diff --git a/src/tensor_network_editor/app/_limits.py b/src/tensor_network_editor/app/_limits.py new file mode 100644 index 0000000..208298f --- /dev/null +++ b/src/tensor_network_editor/app/_limits.py @@ -0,0 +1,180 @@ +"""Complexity limits for local editor API payloads.""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass + +from ..internal.models._model_periodic import LinearPeriodicCellSpec +from ..internal.templates._template_catalog import TemplateParameters +from ..models import NetworkSpec, TensorSpec + +MAX_API_TENSORS = 512 +MAX_API_INDICES = 4096 +MAX_API_CONNECTIONS = 4096 +MAX_API_TENSOR_RANK = 64 +MAX_API_INDEX_DIMENSION = 1_000_000 +MAX_API_TENSOR_CARDINALITY = 10_000_000 +MAX_API_TEMPLATE_LINEAR_GRAPH_SIZE = 512 +MAX_API_TEMPLATE_GRID_SIDE_LENGTH = 32 +MAX_API_TEMPLATE_TREE_DEPTH = 10 +MAX_API_TEMPLATE_DIMENSION = 4096 +_GRID_TEMPLATE_NAMES = frozenset({"peps_2x2", "pepo"}) +_TREE_TEMPLATE_NAMES = frozenset({"mera", "ttn"}) + + +@dataclass(slots=True) +class _SpecComplexity: + """Accumulated size information for one editor API payload.""" + + tensor_count: int = 0 + index_count: int = 0 + connection_count: int = 0 + + +def enforce_spec_api_limits(spec: NetworkSpec) -> None: + """Reject a network spec that is too expensive for the local HTTP API.""" + complexity = _SpecComplexity() + for tensors, edge_count in _iter_spec_parts(spec): + complexity.tensor_count += len(tensors) + complexity.connection_count += edge_count + for tensor in tensors: + _enforce_tensor_api_limits(tensor) + complexity.index_count += len(tensor.indices) + + complexity.connection_count += sum( + len(hyperedge.endpoints) for hyperedge in spec.hyperedges + ) + _enforce_count_limit( + name="tensors", + count=complexity.tensor_count, + limit=MAX_API_TENSORS, + ) + _enforce_count_limit( + name="indices", + count=complexity.index_count, + limit=MAX_API_INDICES, + ) + _enforce_count_limit( + name="connections", + count=complexity.connection_count, + limit=MAX_API_CONNECTIONS, + ) + + +def enforce_template_api_limits( + template_name: str, + parameters: TemplateParameters | None, +) -> None: + """Reject built-in template parameters that would create huge payloads.""" + if parameters is None: + return + graph_size_limit = _template_graph_size_limit(template_name) + if parameters.graph_size is not None and parameters.graph_size > graph_size_limit: + raise ValueError( + "Template parameter 'graph_size' " + f"is {parameters.graph_size}, above the API limit of {graph_size_limit}." + ) + _enforce_optional_template_dimension( + parameters.bond_dimension, + field_name="bond_dimension", + ) + _enforce_optional_template_dimension( + parameters.physical_dimension, + field_name="physical_dimension", + ) + + +def _iter_spec_parts(spec: NetworkSpec) -> Iterator[tuple[list[TensorSpec], int]]: + """Yield tensor and edge collections stored in a spec payload.""" + yield spec.tensors, len(spec.edges) + if spec.linear_periodic_chain is not None: + for cell in ( + spec.linear_periodic_chain.initial_cell, + spec.linear_periodic_chain.periodic_cell, + spec.linear_periodic_chain.final_cell, + ): + yield from _iter_cell_parts(cell) + if spec.grid_periodic_grid is not None: + for cell in ( + spec.grid_periodic_grid.top_left_cell, + spec.grid_periodic_grid.top_cell, + spec.grid_periodic_grid.top_right_cell, + spec.grid_periodic_grid.left_cell, + spec.grid_periodic_grid.center_cell, + spec.grid_periodic_grid.right_cell, + spec.grid_periodic_grid.bottom_left_cell, + spec.grid_periodic_grid.bottom_cell, + spec.grid_periodic_grid.bottom_right_cell, + ): + yield from _iter_cell_parts(cell) + if spec.tree_periodic_tree is not None: + for cell in ( + spec.tree_periodic_tree.root_cell, + spec.tree_periodic_tree.branch_cell, + spec.tree_periodic_tree.leaf_cell, + ): + yield from _iter_cell_parts(cell) + + +def _iter_cell_parts( + cell: LinearPeriodicCellSpec, +) -> Iterator[tuple[list[TensorSpec], int]]: + """Yield tensor and edge collections stored in one periodic cell.""" + yield cell.tensors, len(cell.edges) + + +def _enforce_tensor_api_limits(tensor: TensorSpec) -> None: + """Reject one tensor whose local shape is too expensive.""" + rank = len(tensor.indices) + if rank > MAX_API_TENSOR_RANK: + raise ValueError( + f"Tensor '{tensor.name}' has rank {rank}, " + f"above the API limit of {MAX_API_TENSOR_RANK}." + ) + cardinality = 1 + for index in tensor.indices: + if index.dimension > MAX_API_INDEX_DIMENSION: + raise ValueError( + f"Index '{index.name}' on tensor '{tensor.name}' has dimension " + f"{index.dimension}, above the API limit of {MAX_API_INDEX_DIMENSION}." + ) + if index.dimension > 0: + cardinality *= index.dimension + if cardinality > MAX_API_TENSOR_CARDINALITY: + raise ValueError( + f"Tensor '{tensor.name}' spans {cardinality} elements, " + f"above the API limit of {MAX_API_TENSOR_CARDINALITY}." + ) + + +def _enforce_count_limit(*, name: str, count: int, limit: int) -> None: + """Reject one aggregate count when it exceeds its API limit.""" + if count <= limit: + return + raise ValueError( + f"Network contains {count} {name}, above the API limit of {limit}." + ) + + +def _enforce_optional_template_dimension( + value: int | None, + *, + field_name: str, +) -> None: + """Reject template dimensions that would produce very large tensors.""" + if value is None or value <= MAX_API_TEMPLATE_DIMENSION: + return + raise ValueError( + f"Template parameter '{field_name}' is {value}, " + f"above the API limit of {MAX_API_TEMPLATE_DIMENSION}." + ) + + +def _template_graph_size_limit(template_name: str) -> int: + """Return the graph-size limit appropriate for one template family.""" + if template_name in _GRID_TEMPLATE_NAMES: + return MAX_API_TEMPLATE_GRID_SIDE_LENGTH + if template_name in _TREE_TEMPLATE_NAMES: + return MAX_API_TEMPLATE_TREE_DEPTH + return MAX_API_TEMPLATE_LINEAR_GRAPH_SIZE diff --git a/src/tensor_network_editor/app/_protocol.py b/src/tensor_network_editor/app/_protocol.py index 9e281a9..8cb4ef6 100644 --- a/src/tensor_network_editor/app/_protocol.py +++ b/src/tensor_network_editor/app/_protocol.py @@ -43,6 +43,7 @@ class CodegenRequest: serialized_spec: JsonDict engine: EngineIdentifier collection_format: TensorCollectionFormat + include_roundtrip_metadata: bool @dataclass(slots=True, frozen=True) @@ -269,6 +270,9 @@ def parse_codegen_request( serialized_spec=require_serialized_spec(payload), engine=resolve_engine(payload, default_engine), collection_format=resolve_collection_format(payload, default_collection_format), + include_roundtrip_metadata=require_boolean( + payload, "include_roundtrip_metadata", default=True + ), ) @@ -408,6 +412,16 @@ def bad_request_response(message: str) -> JsonResponse: return HTTPStatus.BAD_REQUEST, {"ok": False, "message": message} +def forbidden_response(message: str) -> JsonResponse: + """Return a standard forbidden JSON response.""" + return HTTPStatus.FORBIDDEN, {"ok": False, "message": message} + + +def unsupported_media_type_response(message: str) -> JsonResponse: + """Return a standard unsupported-media-type JSON response.""" + return HTTPStatus.UNSUPPORTED_MEDIA_TYPE, {"ok": False, "message": message} + + def not_found_response() -> JsonResponse: """Return a standard not-found JSON response.""" return HTTPStatus.NOT_FOUND, {"ok": False, "message": "Not found."} diff --git a/src/tensor_network_editor/app/_session_requests.py b/src/tensor_network_editor/app/_session_requests.py index f1a6ac4..cd20cd0 100644 --- a/src/tensor_network_editor/app/_session_requests.py +++ b/src/tensor_network_editor/app/_session_requests.py @@ -17,6 +17,7 @@ EngineIdentifier, TensorCollectionFormat, ) +from ._limits import enforce_spec_api_limits if TYPE_CHECKING: from .session import EditorSession @@ -30,6 +31,7 @@ def generate_session_request( serialized_spec: Mapping[str, object], engine: EngineIdentifier, collection_format: TensorCollectionFormat | None = None, + include_roundtrip_metadata: bool = True, ) -> CodegenResult: """Generate preview code for one editor request.""" with log_operation( @@ -38,6 +40,7 @@ def generate_session_request( context={"engine": engine_name_to_text(engine)}, ): spec = deserialize_spec(serialized_spec) + enforce_spec_api_limits(spec) log_branch( LOGGER, "Deserialized preview spec", @@ -47,6 +50,7 @@ def generate_session_request( spec, engine, collection_format=_resolve_collection_format(session, collection_format), + include_roundtrip_metadata=include_roundtrip_metadata, validate=False, ) @@ -56,6 +60,7 @@ def complete_session_request( serialized_spec: Mapping[str, object], engine: EngineIdentifier, collection_format: TensorCollectionFormat | None = None, + include_roundtrip_metadata: bool = True, ) -> EditorResult: """Finalize a session request and optionally print or save generated code.""" with log_operation( @@ -66,6 +71,7 @@ def complete_session_request( context={"engine": engine_name_to_text(engine)}, ): spec = deserialize_spec(serialized_spec) + enforce_spec_api_limits(spec) log_branch( LOGGER, "Deserialized completion spec", @@ -75,6 +81,7 @@ def complete_session_request( spec, engine, collection_format=_resolve_collection_format(session, collection_format), + include_roundtrip_metadata=include_roundtrip_metadata, validate=False, ) if session.print_code: diff --git a/src/tensor_network_editor/app/_subnetwork_library_services.py b/src/tensor_network_editor/app/_subnetwork_library_services.py index ecc1716..09a3e08 100644 --- a/src/tensor_network_editor/app/_subnetwork_library_services.py +++ b/src/tensor_network_editor/app/_subnetwork_library_services.py @@ -14,6 +14,7 @@ ) from ..models import CanvasPosition, NetworkSpec from ._bootstrap_payloads import build_subnetwork_catalog_payload +from ._limits import enforce_spec_api_limits from ._protocol import JsonDict if TYPE_CHECKING: @@ -44,7 +45,9 @@ def save_serialized_subnetwork_to_library( LOGGER, "Reusable subnetwork save", context=context ) as success_context: spec = deserialize_spec(serialized_spec, validate=False) + enforce_spec_api_limits(spec) saved_spec = extract_subnetwork_spec(spec, tensor_ids=tensor_ids) + enforce_spec_api_limits(saved_spec) session.save_project_subnetwork( subnetwork_name, saved_spec, @@ -138,9 +141,11 @@ def prepare_saved_subnetwork_for_insertion( context=context, ) as success_context: spec = session.build_saved_subnetwork(subnetwork_name) + enforce_spec_api_limits(spec) prepared_spec = prepare_subnetwork_for_insertion( spec, target_center=target_center, ) + enforce_spec_api_limits(prepared_spec) success_context.update(summarize_spec_counts(prepared_spec)) return prepared_spec diff --git a/src/tensor_network_editor/app/_subnetwork_services.py b/src/tensor_network_editor/app/_subnetwork_services.py index 5893417..5007258 100644 --- a/src/tensor_network_editor/app/_subnetwork_services.py +++ b/src/tensor_network_editor/app/_subnetwork_services.py @@ -12,6 +12,7 @@ prepare_subnetwork_for_insertion, ) from ..models import CanvasPosition, NetworkSpec +from ._limits import enforce_spec_api_limits LOGGER = logging.getLogger(__name__) @@ -28,7 +29,9 @@ def extract_serialized_subnetwork( context={"tensor_id_count": len(tensor_ids)}, ): spec = deserialize_spec(serialized_spec, validate=False) + enforce_spec_api_limits(spec) extracted_spec = extract_subnetwork_spec(spec, tensor_ids=tensor_ids) + enforce_spec_api_limits(extracted_spec) log_branch( LOGGER, "Extracted transient reusable subnetwork", @@ -45,10 +48,12 @@ def prepare_serialized_subnetwork_for_insertion( """Deserialize one payload and prepare it for editor insertion.""" with log_operation(LOGGER, "Transient subnetwork insertion preparation"): spec = deserialize_spec(serialized_spec, validate=False) + enforce_spec_api_limits(spec) prepared_spec = prepare_subnetwork_for_insertion( spec, target_center=target_center, ) + enforce_spec_api_limits(prepared_spec) log_branch( LOGGER, "Prepared transient subnetwork for insertion", diff --git a/src/tensor_network_editor/app/_template_services.py b/src/tensor_network_editor/app/_template_services.py index 924d27c..56b5187 100644 --- a/src/tensor_network_editor/app/_template_services.py +++ b/src/tensor_network_editor/app/_template_services.py @@ -16,6 +16,7 @@ parse_template_parameters, ) from ._bootstrap_payloads import build_template_catalog_payload +from ._limits import enforce_spec_api_limits, enforce_template_api_limits from ._protocol import JsonDict if TYPE_CHECKING: @@ -36,6 +37,7 @@ def build_template_from_payload( if session.has_project_template(template_name): log_branch(LOGGER, "Loading template from the project catalog") spec = session.build_project_template(template_name) + enforce_spec_api_limits(spec) success_context.update(summarize_spec_counts(spec)) success_context["status"] = "project" return spec @@ -43,7 +45,9 @@ def build_template_from_payload( template_name, raw_parameters, ) + enforce_template_api_limits(template_name, parameters) spec = build_template_spec(template_name, parameters) + enforce_spec_api_limits(spec) success_context.update(summarize_spec_counts(spec)) success_context["status"] = "global" return spec @@ -68,7 +72,9 @@ def promote_serialized_subnetwork_to_template( LOGGER, "Template promotion", context=context ) as success_context: spec = deserialize_spec(serialized_spec, validate=False) + enforce_spec_api_limits(spec) promoted_spec = extract_subnetwork_spec(spec, tensor_ids=tensor_ids) + enforce_spec_api_limits(promoted_spec) promoted_spec.name = session.build_project_template_display_name(template_name) session.save_project_template( template_name, diff --git a/src/tensor_network_editor/app/routes.py b/src/tensor_network_editor/app/routes.py index 69cfbd5..ece031c 100644 --- a/src/tensor_network_editor/app/routes.py +++ b/src/tensor_network_editor/app/routes.py @@ -7,8 +7,9 @@ from collections.abc import Callable from dataclasses import dataclass from http import HTTPStatus -from typing import Literal, cast +from typing import Literal, TypedDict, cast +from .._themes import DEFAULT_EDITOR_THEME, EditorThemeName, normalize_editor_theme from ..errors import ( CodeGenerationError, PackageIOError, @@ -40,6 +41,7 @@ from ..types import JSONValue from ..validation import validate_spec from ._drafts import clear_project_draft, load_project_draft, save_project_draft +from ._limits import enforce_spec_api_limits from ._protocol import ( JsonDict, JsonResponse, @@ -91,6 +93,20 @@ _MAX_FRONTEND_CLIENT_LOG_EVENTS = 200 _MAX_FRONTEND_CLIENT_LOG_MESSAGE_LENGTH = 400 _MAX_FRONTEND_CLIENT_LOG_CONTEXT_VALUE_LENGTH = 200 +_RenderFormat = Literal["tikz", "dot", "mermaid", "svg", "png", "pdf"] + + +class _ImageExportThemeOverride(TypedDict, total=False): + """Theme override fields supported by image render options.""" + + background: str + edge_stroke: str + group_stroke: str + hyperedge_stroke: str + index_fill: str + muted_text_fill: str + note_fill: str + text_fill: str @dataclass(slots=True, frozen=True) @@ -102,6 +118,67 @@ class _FrontendClientLogEvent: context: dict[str, object] +@dataclass(slots=True, frozen=True) +class _RenderLabelOptions: + """Shared label-visibility flags for academic render routes.""" + + show_tensor_labels: bool + show_index_labels: bool + show_edge_labels: bool + + +_IMAGE_EXPORT_THEME_OVERRIDES: dict[EditorThemeName, _ImageExportThemeOverride] = { + "dark": { + "background": "#0b0d12", + "edge_stroke": "#7e8aa3", + "index_fill": "#d7ae68", + "group_stroke": "#8f7cf7", + "note_fill": "#252b34", + "text_fill": "#f2f5f8", + "muted_text_fill": "#c6d3e6", + }, + "light": { + "background": "#ffffff", + "edge_stroke": "#64748b", + "index_fill": "#b45309", + "group_stroke": "#6d28d9", + "note_fill": "#ffffff", + "text_fill": "#172033", + "muted_text_fill": "#475569", + }, + "contrast": { + "background": "#000000", + "edge_stroke": "#ffffff", + "index_fill": "#ffff00", + "hyperedge_stroke": "#ff5f5f", + "group_stroke": "#ff00ff", + "note_fill": "#101010", + "text_fill": "#ffffff", + "muted_text_fill": "#ffffff", + }, + "colorblind": { + "background": "#ffffff", + "edge_stroke": "#5b5b5b", + "index_fill": "#e69f00", + "hyperedge_stroke": "#d55e00", + "group_stroke": "#cc79a7", + "note_fill": "#ffffff", + "text_fill": "#202124", + "muted_text_fill": "#5b5b5b", + }, + "shiny": { + "background": "#070915", + "edge_stroke": "#94a3b8", + "index_fill": "#facc15", + "hyperedge_stroke": "#fb7185", + "group_stroke": "#e879f9", + "note_fill": "#11152c", + "text_fill": "#f8fafc", + "muted_text_fill": "#c4b5fd", + }, +} + + def _route_context( session: EditorSession | None, route: str, @@ -222,6 +299,15 @@ def handle_validate(session: EditorSession, payload: JsonDict) -> JsonResponse: level=logging.WARNING, ) return bad_request_response("Missing 'spec' or 'python_code' payload.") + try: + enforce_spec_api_limits(spec) + except ValueError as exc: + log_branch( + LOGGER, + f"Validation request exceeded API limits: {exc}", + level=logging.WARNING, + ) + return bad_request_response(str(exc)) issues = validate_spec(spec) if issues: log_branch( @@ -327,109 +413,24 @@ def handle_render(session: EditorSession, payload: JsonDict) -> JsonResponse: """Render the current editor payload to an academic text format.""" del session with log_operation( - LOGGER, "Render route", context={"route": "/api/render"} + LOGGER, "Render route", context=_route_context(None, "/api/render") ) as success_context: try: render_format = _resolve_render_format(payload) serialized_spec = require_serialized_spec(payload) spec = deserialize_spec(serialized_spec, validate=False) + enforce_spec_api_limits(spec) + label_options = _resolve_render_label_options(payload) + render_theme = _resolve_render_theme(payload) success_context["format"] = render_format + success_context["theme"] = render_theme success_context.update(summarize_spec_counts(spec)) - svg_options = SvgRenderOptions( - show_tensor_labels=require_boolean( - payload, "show_tensor_names", default=True - ), - show_index_labels=require_boolean( - payload, "show_index_names", default=True - ), - show_edge_labels=require_boolean( - payload, "show_bond_names", default=True - ), + response_payload = _build_render_response( + render_format, + spec, + label_options, + theme=render_theme, ) - if render_format == "tikz": - text = render_spec_tikz( - spec, - options=TikzRenderOptions( - show_tensor_labels=require_boolean( - payload, "show_tensor_names", default=True - ), - show_index_labels=require_boolean( - payload, "show_index_names", default=True - ), - show_edge_labels=require_boolean( - payload, "show_bond_names", default=True - ), - ), - ) - content_type = "text/x-tex;charset=utf-8" - response_payload: JsonDict = { - "format": render_format, - "text": text, - "content_type": content_type, - } - elif render_format == "dot": - text = render_spec_dot( - spec, - options=DotRenderOptions( - show_tensor_labels=require_boolean( - payload, "show_tensor_names", default=True - ), - show_index_labels=require_boolean( - payload, "show_index_names", default=True - ), - show_edge_labels=require_boolean( - payload, "show_bond_names", default=True - ), - ), - ) - content_type = "text/vnd.graphviz;charset=utf-8" - response_payload = { - "format": render_format, - "text": text, - "content_type": content_type, - } - elif render_format == "mermaid": - text = render_spec_mermaid( - spec, - options=DotRenderOptions( - show_tensor_labels=require_boolean( - payload, "show_tensor_names", default=True - ), - show_index_labels=require_boolean( - payload, "show_index_names", default=True - ), - show_edge_labels=require_boolean( - payload, "show_bond_names", default=True - ), - ), - ) - content_type = "text/plain;charset=utf-8" - response_payload = { - "format": render_format, - "text": text, - "content_type": content_type, - } - elif render_format == "svg": - text = render_spec_svg(spec, options=svg_options) - response_payload = { - "format": render_format, - "text": text, - "content_type": "image/svg+xml;charset=utf-8", - } - elif render_format == "png": - binary = render_spec_png(spec, options=svg_options) - response_payload = { - "format": render_format, - "base64": base64.b64encode(binary).decode("ascii"), - "content_type": "image/png", - } - else: - binary = render_spec_pdf(spec, options=svg_options) - response_payload = { - "format": render_format, - "base64": base64.b64encode(binary).decode("ascii"), - "content_type": "application/pdf", - } except ValueError as exc: return bad_request_response(str(exc)) except SerializationError as exc: @@ -521,6 +522,13 @@ def handle_analyze_contraction( level=logging.WARNING, ) return bad_request_response(str(exc)) + except ValueError as exc: + log_branch( + LOGGER, + f"Contraction analysis request exceeded API limits: {exc}", + level=logging.WARNING, + ) + return bad_request_response(str(exc)) except SpecValidationError as exc: log_branch( LOGGER, @@ -790,22 +798,166 @@ def _serialize_generate_result(result: CodegenResult) -> JsonDict: def _resolve_render_format( payload: JsonDict, -) -> Literal["tikz", "dot", "mermaid", "svg", "png", "pdf"]: +) -> _RenderFormat: raw_format = payload.get("format") if not isinstance(raw_format, str) or not raw_format.strip(): raise ValueError("Missing 'format' payload.") normalized_format = raw_format.strip().lower() if normalized_format in {"tikz", "dot", "mermaid", "svg", "png", "pdf"}: - return cast( - Literal["tikz", "dot", "mermaid", "svg", "png", "pdf"], - normalized_format, - ) + return cast(_RenderFormat, normalized_format) raise ValueError( "Unsupported render format " f"'{raw_format}'. Expected 'tikz', 'dot', 'mermaid', 'svg', 'png', or 'pdf'." ) +def _resolve_render_label_options(payload: JsonDict) -> _RenderLabelOptions: + """Return shared render-label visibility flags for one request payload.""" + return _RenderLabelOptions( + show_tensor_labels=require_boolean(payload, "show_tensor_names", default=True), + show_index_labels=require_boolean(payload, "show_index_names", default=True), + show_edge_labels=require_boolean(payload, "show_bond_names", default=True), + ) + + +def _resolve_render_theme(payload: JsonDict) -> EditorThemeName: + """Return the editor theme requested for one render payload.""" + raw_theme = payload.get("theme") + if raw_theme is None: + return DEFAULT_EDITOR_THEME + if not isinstance(raw_theme, str): + raise ValueError("'theme' must be a string when provided.") + return normalize_editor_theme(raw_theme) + + +def _svg_render_options( + label_options: _RenderLabelOptions, + *, + render_format: _RenderFormat, + theme: EditorThemeName, +) -> SvgRenderOptions: + """Return SVG/PNG/PDF render options derived from shared label flags.""" + return SvgRenderOptions( + show_tensor_labels=label_options.show_tensor_labels, + show_index_labels=label_options.show_index_labels, + show_edge_labels=label_options.show_edge_labels, + transparent_background=render_format in {"svg", "png"}, + **_IMAGE_EXPORT_THEME_OVERRIDES[theme], + ) + + +def _tikz_render_options(label_options: _RenderLabelOptions) -> TikzRenderOptions: + """Return TikZ render options derived from shared label flags.""" + return TikzRenderOptions( + show_tensor_labels=label_options.show_tensor_labels, + show_index_labels=label_options.show_index_labels, + show_edge_labels=label_options.show_edge_labels, + ) + + +def _dot_render_options(label_options: _RenderLabelOptions) -> DotRenderOptions: + """Return DOT/Mermaid render options derived from shared label flags.""" + return DotRenderOptions( + show_tensor_labels=label_options.show_tensor_labels, + show_index_labels=label_options.show_index_labels, + show_edge_labels=label_options.show_edge_labels, + ) + + +def _build_text_render_response( + render_format: _RenderFormat, + text: str, + *, + content_type: str, +) -> JsonDict: + """Return one text-based render response payload.""" + return { + "format": render_format, + "text": text, + "content_type": content_type, + } + + +def _build_binary_render_response( + render_format: _RenderFormat, + binary: bytes, + *, + content_type: str, +) -> JsonDict: + """Return one binary render response payload encoded for JSON transport.""" + return { + "format": render_format, + "base64": base64.b64encode(binary).decode("ascii"), + "content_type": content_type, + } + + +def _build_render_response( + render_format: _RenderFormat, + spec: NetworkSpec, + label_options: _RenderLabelOptions, + *, + theme: EditorThemeName = DEFAULT_EDITOR_THEME, +) -> JsonDict: + """Return the serialized academic render payload for one format request.""" + if render_format == "tikz": + return _build_text_render_response( + render_format, + render_spec_tikz(spec, options=_tikz_render_options(label_options)), + content_type="text/x-tex;charset=utf-8", + ) + if render_format == "dot": + return _build_text_render_response( + render_format, + render_spec_dot(spec, options=_dot_render_options(label_options)), + content_type="text/vnd.graphviz;charset=utf-8", + ) + if render_format == "mermaid": + return _build_text_render_response( + render_format, + render_spec_mermaid(spec, options=_dot_render_options(label_options)), + content_type="text/plain;charset=utf-8", + ) + if render_format == "svg": + return _build_text_render_response( + render_format, + render_spec_svg( + spec, + options=_svg_render_options( + label_options, + render_format=render_format, + theme=theme, + ), + ), + content_type="image/svg+xml;charset=utf-8", + ) + if render_format == "png": + return _build_binary_render_response( + render_format, + render_spec_png( + spec, + options=_svg_render_options( + label_options, + render_format=render_format, + theme=theme, + ), + ), + content_type="image/png", + ) + return _build_binary_render_response( + render_format, + render_spec_pdf( + spec, + options=_svg_render_options( + label_options, + render_format=render_format, + theme=theme, + ), + ), + content_type="application/pdf", + ) + + def _serialize_complete_result(result: EditorResult) -> JsonDict: """Serialize a complete-route editor result.""" return serialize_editor_result(result) @@ -833,6 +985,7 @@ def _handle_session_codegen_request( request.serialized_spec, request.engine, request.collection_format, + request.include_roundtrip_metadata, ) return ok_response(_serialize_generate_result(generate_result)) if operation == "complete": @@ -840,11 +993,14 @@ def _handle_session_codegen_request( request.serialized_spec, request.engine, request.collection_format, + request.include_roundtrip_metadata, ) return ok_response(_serialize_complete_result(complete_result)) raise ValueError(f"Unsupported code generation operation '{operation}'.") except SerializationError as exc: return bad_request_response(str(exc)) + except ValueError as exc: + return bad_request_response(str(exc)) except CodeGenerationError as exc: return bad_request_response(str(exc)) except PackageIOError as exc: diff --git a/src/tensor_network_editor/app/server.py b/src/tensor_network_editor/app/server.py index c0018e3..87a5ab7 100644 --- a/src/tensor_network_editor/app/server.py +++ b/src/tensor_network_editor/app/server.py @@ -2,10 +2,14 @@ from __future__ import annotations +import hmac +import ipaddress import json import logging import mimetypes +import secrets import threading +import time from collections.abc import Callable from dataclasses import dataclass from http import HTTPStatus @@ -14,6 +18,7 @@ from pathlib import Path from typing import Protocol, TypeAlias, cast from urllib.parse import urlparse +from urllib.request import urlopen from ..internal._logging import ( bind_log_context, @@ -27,17 +32,25 @@ JsonDict, JsonResponse, bad_request_response, + forbidden_response, internal_server_error_response, not_found_response, read_json, + unsupported_media_type_response, ) from .session import EditorSession LOGGER = logging.getLogger(__name__) _SERVE_FOREVER_POLL_INTERVAL_SECONDS: float = 0.05 +_STARTUP_READY_TIMEOUT_SECONDS: float = 5.0 +_STARTUP_READY_POLL_INTERVAL_SECONDS: float = 0.01 +_STARTUP_READY_REQUEST_TIMEOUT_SECONDS: float = 0.2 +_RESPONSE_WRITE_CHUNK_SIZE_BYTES: int = 64 * 1024 +_STATIC_ASSET_CACHE_VALIDATION_INTERVAL_SECONDS: float = 0.5 _MAX_REQUEST_BODY_BYTES: int = 1_048_576 _STATIC_ASSET_CACHE_LOCK = threading.Lock() _STATIC_ASSET_CACHE_BY_ROOT: dict[Path, _StaticAssetCache] = {} +_STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT: dict[Path, float] = {} _UNEXPECTED_INTERNAL_ERROR_MESSAGE = "Unexpected internal error." _UNEXPECTED_INTERNAL_ERROR_GUIDANCE = ( "Try again. If the problem continues, check the terminal output for this " @@ -46,6 +59,13 @@ _QUIET_MISSING_STATIC_ASSET_PATHS: frozenset[str] = frozenset({"/favicon.ico"}) _ScannedStaticAssetFile: TypeAlias = tuple[Path, str, int, int] _RUNTIME_CONFIG_PLACEHOLDER = "__TNE_RUNTIME_CONFIG__" +_CSP_NONCE_PLACEHOLDER = "__TNE_CSP_NONCE__" +_API_TOKEN_HEADER = "X-TNE-Session-Token" # noqa: S105, RUF100 - header name. +_EXPECTED_JSON_CONTENT_TYPE = "application/json" +_PERMISSIONS_POLICY_HEADER = ( + "accelerometer=(), camera=(), geolocation=(), gyroscope=(), " + "magnetometer=(), microphone=(), payment=(), usb=()" +) class SupportsReadBytes(Protocol): @@ -86,6 +106,75 @@ def _read_request_body_bytes(reader: SupportsReadBytes, content_length: int) -> return b"".join(chunks) +def _is_loopback_host_name(host_name: str) -> bool: + """Return whether a hostname literal is safe for local-only editor serving.""" + normalized_host = host_name.strip().strip("[]").rstrip(".").lower() + if normalized_host in {"localhost"} or normalized_host.endswith(".localhost"): + return True + if "%" in normalized_host: + normalized_host = normalized_host.split("%", 1)[0] + try: + address = ipaddress.ip_address(normalized_host) + except ValueError: + return False + return address.is_loopback + + +def _validate_bind_host(host: str, *, allow_remote: bool) -> None: + """Reject non-loopback bind hosts unless remote serving is explicit.""" + if allow_remote or _is_loopback_host_name(host): + return + raise ValueError( + "Refusing to bind the editor server to a non-loopback host. " + "Use allow_remote=True only when you intentionally expose this local API." + ) + + +def _host_name_from_header(host_header: str | None) -> str | None: + """Extract the hostname portion from one HTTP Host header.""" + if host_header is None: + return None + value = host_header.strip() + if not value: + return None + if value.startswith("["): + end_index = value.find("]") + if end_index <= 1: + return None + return value[1:end_index] + if value.count(":") == 1: + host_name, port_text = value.rsplit(":", 1) + if port_text.isdigit(): + return host_name + return value + + +def _is_trusted_host_header(host_header: str | None, *, allow_remote: bool) -> bool: + """Return whether one Host header is acceptable for this server.""" + if allow_remote: + return bool(host_header and host_header.strip()) + host_name = _host_name_from_header(host_header) + return host_name is not None and _is_loopback_host_name(host_name) + + +def _is_trusted_origin_header(origin_header: str | None, *, allow_remote: bool) -> bool: + """Return whether one optional Origin header is acceptable for API writes.""" + if origin_header is None: + return True + parsed_origin = urlparse(origin_header) + if parsed_origin.scheme not in {"http", "https"}: + return False + return _is_trusted_host_header(parsed_origin.netloc, allow_remote=allow_remote) + + +def _is_json_content_type(content_type: str | None) -> bool: + """Return whether one Content-Type header identifies a JSON request body.""" + if content_type is None: + return False + media_type = content_type.split(";", 1)[0].strip().lower() + return media_type == _EXPECTED_JSON_CONTENT_TYPE + + @dataclass(slots=True, frozen=True) class _BinaryResponse: """Internal response container for pre-encoded bytes.""" @@ -198,11 +287,14 @@ def _build_static_asset_cache( def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache: """Return a shared static asset cache for one editor static directory.""" resolved_static_dir = static_dir.resolve() - scanned_files = _scan_static_asset_files(resolved_static_dir) - current_signature = _build_static_asset_source_signature(scanned_files) with _STATIC_ASSET_CACHE_LOCK: + validation_started_at = time.monotonic() cache = _STATIC_ASSET_CACHE_BY_ROOT.get(resolved_static_dir) + last_validated_at = _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.get( + resolved_static_dir + ) if cache is None: + scanned_files = _scan_static_asset_files(resolved_static_dir) with log_operation( LOGGER, "Static asset cache build", @@ -213,9 +305,29 @@ def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache: scanned_files=scanned_files, ) _STATIC_ASSET_CACHE_BY_ROOT[resolved_static_dir] = cache + _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT[resolved_static_dir] = ( + validation_started_at + ) success_context["after"] = cache.asset_version success_context["asset_count"] = len(cache.body_by_relative_path) return cache + if ( + last_validated_at is not None + and validation_started_at - last_validated_at + < _STATIC_ASSET_CACHE_VALIDATION_INTERVAL_SECONDS + ): + log_branch( + LOGGER, + "Static asset cache reused", + context={ + "path": resolved_static_dir, + "after": cache.asset_version, + "asset_count": len(cache.body_by_relative_path), + }, + ) + return cache + scanned_files = _scan_static_asset_files(resolved_static_dir) + current_signature = _build_static_asset_source_signature(scanned_files) if cache.source_signature != current_signature: with log_operation( LOGGER, @@ -230,11 +342,17 @@ def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache: scanned_files=scanned_files, ) _STATIC_ASSET_CACHE_BY_ROOT[resolved_static_dir] = refreshed_cache + _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT[resolved_static_dir] = ( + validation_started_at + ) success_context["after"] = refreshed_cache.asset_version success_context["asset_count"] = len( refreshed_cache.body_by_relative_path ) return refreshed_cache + _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT[resolved_static_dir] = ( + validation_started_at + ) log_branch( LOGGER, "Static asset cache reused", @@ -247,29 +365,63 @@ def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache: return cache -def _build_frontend_runtime_config_payload(session: EditorSession) -> JsonDict: +def _build_frontend_runtime_config_payload( + session: EditorSession, *, api_token: str +) -> JsonDict: """Return the runtime configuration embedded into the editor HTML page.""" return { "session_id": session.session_id, + "api_token": api_token, "frontend_logging": build_frontend_logging_payload(session), } -def _serialize_frontend_runtime_config(session: EditorSession) -> str: +def _serialize_frontend_runtime_config( + session: EditorSession, *, api_token: str +) -> str: """Serialize one session runtime config safely for an inline JSON script.""" - return json.dumps(_build_frontend_runtime_config_payload(session)).replace( - " bytes: +def _render_session_index_body( + index_body: bytes, + session: EditorSession, + *, + api_token: str, + csp_nonce: str, +) -> bytes: """Return the per-session editor HTML body with embedded runtime config.""" return index_body.replace( _RUNTIME_CONFIG_PLACEHOLDER.encode("utf-8"), - _serialize_frontend_runtime_config(session).encode("utf-8"), + _serialize_frontend_runtime_config(session, api_token=api_token).encode( + "utf-8" + ), + ).replace( + _CSP_NONCE_PLACEHOLDER.encode("utf-8"), + csp_nonce.encode("utf-8"), ) +def _build_content_security_policy(*, csp_nonce: str) -> str: + """Return the editor CSP that permits only trusted local assets.""" + directives = [ + "default-src 'self'", + "base-uri 'none'", + "object-src 'none'", + "frame-ancestors 'none'", + "form-action 'none'", + "connect-src 'self'", + "img-src 'self' data: blob:", + f"script-src 'self' 'nonce-{csp_nonce}'", + "style-src 'self' 'unsafe-inline'", + "font-src 'self' data:", + "worker-src 'self' blob:", + ] + return "; ".join(directives) + + def _unexpected_internal_error_response(session_id: str) -> JsonResponse: """Return an actionable but safe error payload for unexpected failures.""" return internal_server_error_response( @@ -288,7 +440,13 @@ class EditorServer: """Serve the browser app and JSON API for one editor session.""" def __init__( - self, session: EditorSession, host: str = "127.0.0.1", port: int = 0 + self, + session: EditorSession, + host: str = "127.0.0.1", + port: int = 0, + *, + allow_remote: bool = False, + api_token: str | None = None, ) -> None: """Initialize the threaded local editor server. @@ -296,19 +454,33 @@ def __init__( session: Shared editor session state served by this HTTP server. host: Local host interface to bind. port: Local port to bind. Use ``0`` for an ephemeral port. + allow_remote: Whether non-loopback bind hosts are allowed. + api_token: Optional pre-generated API token for tests. """ + _validate_bind_host(host, allow_remote=allow_remote) self.session = session self.session_id = session.session_id self.host = host self.port = port + self.allow_remote = allow_remote + self.api_token = api_token or secrets.token_urlsafe(32) + if not self.api_token.strip(): + raise ValueError("Editor API token cannot be empty.") + self._csp_nonce = secrets.token_urlsafe(16) + self._content_security_policy = _build_content_security_policy( + csp_nonce=self._csp_nonce + ) self._static_dir = Path(__file__).resolve().parent / "static" self._static_asset_cache = _get_static_asset_cache(self._static_dir) self._index_body = _render_session_index_body( self._static_asset_cache.index_body, session, + api_token=self.api_token, + csp_nonce=self._csp_nonce, ) self._server = ThreadingHTTPServer((host, port), self._build_handler()) self._thread = threading.Thread(target=self._serve_forever, daemon=True) + self._serve_forever_ready = threading.Event() @property def base_url(self) -> str: @@ -322,6 +494,11 @@ def base_url(self) -> str: def start(self) -> None: """Start serving requests in a background thread.""" self._thread.start() + try: + self._wait_until_ready() + except Exception: + self._cleanup_failed_start() + raise log_branch( LOGGER, f"Editor server started at {self.base_url}", @@ -331,9 +508,7 @@ def start(self) -> None: def stop(self) -> None: """Stop the server and wait for the worker thread to exit.""" - self._server.shutdown() - self._server.server_close() - self._thread.join(timeout=5) + self._stop_server_worker() log_branch( LOGGER, "Editor server stopped", @@ -343,8 +518,61 @@ def stop(self) -> None: def _serve_forever(self) -> None: """Serve requests with a short shutdown polling interval.""" + self._serve_forever_ready.set() self._server.serve_forever(poll_interval=_SERVE_FOREVER_POLL_INTERVAL_SECONDS) + def _wait_until_ready(self) -> None: + """Block until loopback requests can read one fully served asset.""" + deadline = time.monotonic() + _STARTUP_READY_TIMEOUT_SECONDS + if not self._serve_forever_ready.wait(timeout=_STARTUP_READY_TIMEOUT_SECONDS): + raise RuntimeError( + "Editor server did not enter the serving loop before the startup timeout elapsed." + ) + + last_error: OSError | None = None + while True: + remaining_seconds = deadline - time.monotonic() + if remaining_seconds <= 0: + break + request_timeout_seconds = min( + _STARTUP_READY_REQUEST_TIMEOUT_SECONDS, + remaining_seconds, + ) + try: + self._probe_loopback_readiness(request_timeout_seconds) + except OSError as exc: + last_error = exc + time.sleep(min(_STARTUP_READY_POLL_INTERVAL_SECONDS, remaining_seconds)) + continue + return + + if last_error is None: + raise RuntimeError( + "Editor server readiness probe timed out before any loopback request succeeded." + ) + raise RuntimeError( + "Editor server did not become ready to serve loopback requests before the startup timeout elapsed." + ) from last_error + + def _probe_loopback_readiness(self, timeout_seconds: float) -> None: + """Read one small static asset to verify the server serves full responses.""" + with urlopen( # noqa: S310, RUF100 - probes this loopback server. + f"{self.base_url}/favicon.ico", timeout=timeout_seconds + ) as response: + response.read() + + def _stop_server_worker(self) -> None: + """Best-effort shutdown that is safe before the serve loop starts.""" + if self._thread.is_alive() and self._serve_forever_ready.is_set(): + self._server.shutdown() + self._server.server_close() + if self._thread.ident is not None: + self._thread.join(timeout=5) + + def _cleanup_failed_start(self) -> None: + """Best-effort cleanup when startup fails after allocating the server socket.""" + self._stop_server_worker() + def _build_handler(self) -> type[BaseHTTPRequestHandler]: """Build the request-handler class bound to this server instance.""" session = self.session @@ -352,6 +580,9 @@ def _build_handler(self) -> type[BaseHTTPRequestHandler]: static_dir = self._static_dir static_asset_cache = self._static_asset_cache index_body = self._index_body + api_token = self.api_token + allow_remote = self.allow_remote + content_security_policy = self._content_security_policy def build_index_response() -> _BinaryResponse: """Return the cached main HTML page for this editor session.""" @@ -442,6 +673,10 @@ def do_GET(self) -> None: """Handle one HTTP GET request for assets or bootstrap data.""" parsed = urlparse(self.path) with bind_log_context(session=session_id, route=parsed.path): + if self._reject_untrusted_host(): + return + if self._reject_invalid_api_token(parsed.path): + return try: with log_operation(LOGGER, "Route request"): response = self._dispatch_get(parsed.path) @@ -458,6 +693,14 @@ def do_POST(self) -> None: """Handle one HTTP POST request for the editor JSON API.""" parsed = urlparse(self.path) with bind_log_context(session=session_id, route=parsed.path): + if self._reject_untrusted_host(): + return + if self._reject_untrusted_origin(): + return + if self._reject_invalid_api_token(parsed.path): + return + if self._reject_unsupported_content_type(): + return try: with log_operation(LOGGER, "Route request"): try: @@ -513,6 +756,85 @@ def _dispatch_post(self, path: str, payload: JsonDict) -> JsonResponse: LOGGER.debug(format_log_message(f"Unknown POST path: {path}")) return not_found_response() + def _reject_untrusted_host(self) -> bool: + """Write a forbidden response when the Host header is not local.""" + if _is_trusted_host_header( + self.headers.get("Host"), + allow_remote=allow_remote, + ): + return False + LOGGER.warning( + format_log_message( + "Rejected request with untrusted Host header", + context={"host": self.headers.get("Host")}, + ), + ) + self._prepare_rejected_request_connection() + self._write_response(forbidden_response("Untrusted Host header.")) + return True + + def _reject_untrusted_origin(self) -> bool: + """Write a forbidden response when the Origin header is not local.""" + if _is_trusted_origin_header( + self.headers.get("Origin"), + allow_remote=allow_remote, + ): + return False + LOGGER.warning( + format_log_message( + "Rejected request with untrusted Origin header", + context={"origin": self.headers.get("Origin")}, + ), + ) + self._prepare_rejected_request_connection() + self._write_response(forbidden_response("Untrusted Origin header.")) + return True + + def _reject_invalid_api_token(self, path: str) -> bool: + """Write a forbidden response when an API request lacks the token.""" + if not path.startswith("/api/"): + return False + header_value = self.headers.get(_API_TOKEN_HEADER) + if header_value is not None and hmac.compare_digest( + header_value, + api_token, + ): + return False + LOGGER.warning( + format_log_message( + "Rejected API request with invalid session token" + ), + ) + self._prepare_rejected_request_connection() + self._write_response( + forbidden_response("Invalid editor session token.") + ) + return True + + def _reject_unsupported_content_type(self) -> bool: + """Write an unsupported-media response for non-JSON API writes.""" + if _is_json_content_type(self.headers.get("Content-Type")): + return False + LOGGER.warning( + format_log_message( + "Rejected API request with unsupported Content-Type", + context={"content_type": self.headers.get("Content-Type")}, + ), + ) + self._prepare_rejected_request_connection() + self._write_response( + unsupported_media_type_response( + "Expected Content-Type 'application/json'." + ) + ) + return True + + def _prepare_rejected_request_connection(self) -> None: + """Drain rejected POST bodies before closing the connection.""" + if self.command == "POST": + self._drain_pending_request_body() + self.close_connection = True + def _static_response( self, request_path: str ) -> JsonResponse | _BinaryResponse: @@ -594,11 +916,23 @@ def _write_bytes(self, status: int, body: bytes, content_type: str) -> None: self.send_response(status) self.send_header("Content-Type", content_type) self.send_header("Content-Length", str(len(body))) + self.send_header("X-Content-Type-Options", "nosniff") + self.send_header("Referrer-Policy", "no-referrer") + self.send_header("X-Frame-Options", "DENY") + self.send_header("Content-Security-Policy", content_security_policy) + self.send_header("Permissions-Policy", _PERMISSIONS_POLICY_HEADER) + self.send_header("Cross-Origin-Resource-Policy", "same-origin") if self.close_connection: self.send_header("Connection", "close") self._write_no_cache_headers() self.end_headers() - self.wfile.write(body) + body_view = memoryview(body) + for offset in range( + 0, len(body_view), _RESPONSE_WRITE_CHUNK_SIZE_BYTES + ): + next_offset = offset + _RESPONSE_WRITE_CHUNK_SIZE_BYTES + self.wfile.write(body_view[offset:next_offset]) + self.wfile.flush() def _write_no_cache_headers(self) -> None: """Emit headers that disable browser and intermediary caching.""" diff --git a/src/tensor_network_editor/app/session.py b/src/tensor_network_editor/app/session.py index 37f002f..8159ba3 100644 --- a/src/tensor_network_editor/app/session.py +++ b/src/tensor_network_editor/app/session.py @@ -6,7 +6,9 @@ import signal import threading import webbrowser +from base64 import b64decode from collections.abc import Callable, Mapping, Sequence +from importlib import import_module from pathlib import Path from types import FrameType from typing import Any, Literal @@ -47,6 +49,7 @@ LOGGER = logging.getLogger(__name__) SignalHandler = Callable[[int, FrameType | None], Any] +SessionUiMode = Literal["browser", "pywebview", "server"] def _print_editor_url(base_url: str) -> None: @@ -64,6 +67,178 @@ def _print_browser_open_fallback_message(base_url: str) -> None: _print_editor_url(base_url) +def _import_pywebview() -> Any: + """Import the optional pywebview module on demand.""" + try: + return import_module("webview") + except ModuleNotFoundError as exc: + raise RuntimeError( + "pywebview mode requires the optional desktop extra. Install it with " + 'python -m pip install "tensor-network-editor[desktop]".' + ) from exc + + +def _resolve_pywebview_icon_path() -> Path: + """Return the packaged desktop icon used for the native pywebview window.""" + return Path(__file__).resolve().parent / "static" / "favicon.ico" + + +def _apply_pywebview_native_window_icon(window: Any) -> None: + """Apply the packaged icon to the native pywebview window when supported.""" + icon_path = _resolve_pywebview_icon_path() + if not icon_path.is_file(): + return + native_window = getattr(window, "native", None) + if native_window is None: + return + try: + from System.Drawing import Icon as DrawingIcon # type: ignore[import-not-found] + except Exception: + return + try: + native_window.Icon = DrawingIcon(str(icon_path)) + if hasattr(native_window, "ShowIcon"): + native_window.ShowIcon = True + except Exception as exc: + log_branch( + LOGGER, + "Could not apply the native pywebview window icon", + level=logging.WARNING, + context={ + "icon_path": str(icon_path), + "error": str(exc), + }, + ) + + +class _PywebviewExportApi: + """Expose native save-file helpers to the embedded pywebview frontend.""" + + def __init__(self, pywebview_module: Any) -> None: + """Store the imported pywebview module for later dialog calls.""" + self._pywebview_module = pywebview_module + self._window: Any | None = None + + def bind_window(self, window: Any) -> None: + """Attach the created pywebview window once it exists.""" + self._window = window + + def save_text_file( + self, + filename: str, + text: str, + content_type: str = "text/plain;charset=utf-8", + ) -> bool: + """Prompt for a path and write one UTF-8 text file.""" + del content_type + output_path = self._select_output_path(filename) + if output_path is None: + return False + output_path.write_text(text, encoding="utf-8") + return True + + def save_binary_file( + self, + filename: str, + base64_payload: str, + content_type: str = "application/octet-stream", + ) -> bool: + """Prompt for a path and write one decoded binary export file.""" + del content_type + output_path = self._select_output_path(filename) + if output_path is None: + return False + output_path.write_bytes(b64decode(base64_payload)) + return True + + def _select_output_path(self, filename: str) -> Path | None: + """Ask pywebview for a target save path and normalize the response.""" + if self._window is None: + raise RuntimeError("pywebview export API is not bound to a window.") + dialog_result = self._window.create_file_dialog( + self._pywebview_module.SAVE_DIALOG, + save_filename=filename, + file_types=self._build_file_types(filename), + ) + if dialog_result is None: + return None + if isinstance(dialog_result, str): + return Path(dialog_result) + if isinstance(dialog_result, Sequence) and dialog_result: + first_entry = dialog_result[0] + if isinstance(first_entry, str) and first_entry: + return Path(first_entry) + return None + + def _build_file_types(self, filename: str) -> tuple[str, ...]: + """Build a compact pywebview filter tuple from one filename.""" + suffix = Path(filename).suffix.lower() + if not suffix: + return () + label = { + ".dot": "DOT", + ".json": "JSON", + ".mmd": "Mermaid", + ".pdf": "PDF", + ".png": "PNG", + ".py": "Python", + ".svg": "SVG", + ".tex": "LaTeX", + }.get(suffix, suffix.removeprefix(".").upper()) + return (f"{label} (*{suffix})",) + + +def _run_pywebview_session( + session: EditorSession, base_url: str +) -> EditorResult | None: + """Open the local editor in a pywebview window and wait for the result.""" + if threading.current_thread() is not threading.main_thread(): + raise RuntimeError("pywebview mode must be launched from the main thread.") + + try: + pywebview = _import_pywebview() + except ModuleNotFoundError as exc: + raise RuntimeError( + "pywebview mode requires the optional desktop extra. Install it with " + 'python -m pip install "tensor-network-editor[desktop]".' + ) from exc + pywebview_export_api = _PywebviewExportApi(pywebview) + pywebview_window = pywebview.create_window( + "Tensor Network Editor", + base_url, + maximized=True, + js_api=pywebview_export_api, + ) + pywebview_export_api.bind_window(pywebview_window) + window_events = getattr(pywebview_window, "events", None) + before_show_event = getattr(window_events, "before_show", None) + if before_show_event is not None: + before_show_event += lambda: _apply_pywebview_native_window_icon( + pywebview_window + ) + else: + _apply_pywebview_native_window_icon(pywebview_window) + + def _handle_window_closed(*_args: object) -> None: + """Cancel the editor session when the native window is closed.""" + if not session.is_finished(): + session.cancel() + + def _wait_for_session_and_close_window(window: Any) -> None: + """Close the native window after the editor session finishes.""" + wait_for_editor_result(session) + try: + window.destroy() + except Exception: + return None + + closed_event = getattr(window_events, "closed", None) + if closed_event is not None: + closed_event += _handle_window_closed + pywebview.start(_wait_for_session_and_close_window, pywebview_window) + return wait_for_editor_result(session) + + class EditorSession: """Mutable session state shared between the HTTP server and the caller.""" @@ -290,6 +465,7 @@ def generate( serialized_spec: Mapping[str, object], engine: EngineIdentifier, collection_format: TensorCollectionFormat | None = None, + include_roundtrip_metadata: bool = True, ) -> CodegenResult: """Generate preview code without finalizing the session.""" with log_operation( @@ -306,6 +482,7 @@ def generate( serialized_spec, engine, collection_format, + include_roundtrip_metadata, ) def complete( @@ -313,6 +490,7 @@ def complete( serialized_spec: Mapping[str, object], engine: EngineIdentifier, collection_format: TensorCollectionFormat | None = None, + include_roundtrip_metadata: bool = True, ) -> EditorResult: """Finalize the session and store the resulting editor output.""" with log_operation( @@ -335,6 +513,7 @@ def complete( serialized_spec, engine, collection_format, + include_roundtrip_metadata, ) with self._lock: if self._finished_event.is_set() and self._result is not None: @@ -390,8 +569,10 @@ def launch_editor_session( default_engine: EngineIdentifier = EngineName.TENSORKROWCH, default_collection_format: TensorCollectionFormat = TensorCollectionFormat.LIST, theme: EditorThemeName = DEFAULT_EDITOR_THEME, + ui_mode: SessionUiMode | None = None, open_browser: bool = True, host: str = "127.0.0.1", + allow_remote: bool = False, port: int = 0, print_code: bool = False, code_path: StrPath | None = None, @@ -412,8 +593,10 @@ def launch_editor_session( default_collection_format: Initial tensor collection layout for generated code. theme: Visual theme selected for this editor session. + ui_mode: Explicit UI launch mode for the editor session. open_browser: Whether to ask the system browser to open the local URL. host: Local host interface to bind. + allow_remote: Whether non-loopback bind hosts are allowed. port: Local port to bind. Use ``0`` for an ephemeral port. print_code: Whether to print generated code after confirmation. code_path: Optional output path for generated code after confirmation. @@ -437,6 +620,7 @@ def launch_editor_session( Raises: KeyboardInterrupt: If the session is interrupted from the main thread. """ + from ..editor import resolve_editor_ui_mode from .server import EditorServer active_logging_runtime = get_active_logging_runtime() @@ -467,7 +651,21 @@ def launch_editor_session( shared_subnetwork_catalog_path=shared_subnetwork_catalog_path, draft_path=draft_path, ) - server = EditorServer(session=session, host=host, port=port) + server = EditorServer( + session=session, + host=host, + port=port, + allow_remote=allow_remote, + ) + effective_ui_mode = resolve_editor_ui_mode( + ui_mode=ui_mode, + open_browser=open_browser, + ) + if ( + effective_ui_mode == "pywebview" + and threading.current_thread() is not threading.main_thread() + ): + raise RuntimeError("pywebview mode must be launched from the main thread.") previous_sigint_handler: SignalHandler | int | None = None server_started = False @@ -481,6 +679,7 @@ def launch_editor_session( "session": session.session_id, "engine": engine_name_to_text(default_engine), "mode": theme, + "ui_mode": effective_ui_mode, }, ): if threading.current_thread() is threading.main_thread(): @@ -497,9 +696,11 @@ def _handle_sigint(_signum: int, _frame: FrameType | None) -> None: server_started = True if _on_server_ready is not None: _on_server_ready(server.base_url) - should_print_editor_url = not open_browser + if effective_ui_mode == "pywebview": + return _run_pywebview_session(session, server.base_url) + should_print_editor_url = effective_ui_mode == "server" should_print_browser_fallback_message = False - if open_browser: + if effective_ui_mode == "browser": try: with log_operation( LOGGER, diff --git a/src/tensor_network_editor/app/static/app.css b/src/tensor_network_editor/app/static/app.css index 6d294f5..ca7a1ee 100644 --- a/src/tensor_network_editor/app/static/app.css +++ b/src/tensor_network_editor/app/static/app.css @@ -3369,6 +3369,31 @@ textarea[disabled] { align-items: center; } +.code-metadata-toggle { + display: inline-flex; + align-items: center; + gap: 0.45rem; + min-height: var(--canvas-control-height); + padding: 0 0.8rem; + border: 1px solid var(--border-subtle); + border-radius: 999px; + background: var(--surface-subtle); + color: var(--muted); + font-size: 0.78rem; + font-weight: 600; + line-height: 1.1; + cursor: pointer; + user-select: none; +} + +.code-metadata-toggle input { + margin: 0; +} + +.code-metadata-toggle[hidden] { + display: none; +} + .code-format-picker { position: relative; display: inline-flex; diff --git a/src/tensor_network_editor/app/static/index.html b/src/tensor_network_editor/app/static/index.html index 610db92..7ebb159 100644 --- a/src/tensor_network_editor/app/static/index.html +++ b/src/tensor_network_editor/app/static/index.html @@ -788,6 +788,30 @@

> Auto layout + + +

/> +

Generated code - - + ', + re.DOTALL, +) +_SESSION_TOKEN_BY_ORIGIN: dict[str, str | None] = {} + def request_json( url: str, @@ -28,6 +40,8 @@ def request_json_with_status( method: str = "GET", payload: dict[str, Any] | None = None, raw_body: bytes | None = None, + session_token: str | None = None, + include_session_token: bool = True, timeout: float = 5.0, ) -> tuple[int, dict[str, Any]]: data = None @@ -40,6 +54,12 @@ def request_json_with_status( elif raw_body is not None: data = raw_body headers["Content-Type"] = "application/json" + if include_session_token: + resolved_session_token = ( + session_token if session_token is not None else _session_token_for_url(url) + ) + if resolved_session_token: + headers["X-TNE-Session-Token"] = resolved_session_token request = Request(url=url, method=method, data=data, headers=headers) try: with urlopen(request, timeout=timeout) as response: @@ -48,13 +68,86 @@ def request_json_with_status( return exc.code, json.loads(exc.read().decode("utf-8")) +def _session_token_for_url(url: str) -> str | None: + """Read the embedded editor API token for a local test server URL.""" + origin = _origin_for_url(url) + if origin is None: + return None + if origin in _SESSION_TOKEN_BY_ORIGIN: + return _SESSION_TOKEN_BY_ORIGIN[origin] + try: + with urlopen(f"{origin}/", timeout=_ASSET_REQUEST_TIMEOUT_SECONDS) as response: + html = response.read().decode("utf-8") + except OSError: + _SESSION_TOKEN_BY_ORIGIN[origin] = None + return None + match = _RUNTIME_CONFIG_RE.search(html) + if match is None: + _SESSION_TOKEN_BY_ORIGIN[origin] = None + return None + try: + payload = json.loads(match.group(1)) + except json.JSONDecodeError: + _SESSION_TOKEN_BY_ORIGIN[origin] = None + return None + token = payload.get("api_token") if isinstance(payload, dict) else None + _SESSION_TOKEN_BY_ORIGIN[origin] = token if isinstance(token, str) else None + return _SESSION_TOKEN_BY_ORIGIN[origin] + + +def _origin_for_url(url: str) -> str | None: + """Return the scheme/authority origin for an absolute URL.""" + parsed = urlsplit(url) + if not parsed.scheme or not parsed.netloc: + return None + return f"{parsed.scheme}://{parsed.netloc}" + + +def _read_asset_response(url: str) -> tuple[bytes, dict[str, str]]: + """Read one asset request with retries for transient local-server hiccups.""" + last_error: OSError | None = None + for attempt_index in range(_ASSET_REQUEST_RETRY_COUNT): + try: + with urlopen(url, timeout=_ASSET_REQUEST_TIMEOUT_SECONDS) as response: + body = response.read() + headers = {key: value for key, value in response.headers.items()} + return body, headers + except OSError as exc: + last_error = exc + if attempt_index + 1 >= _ASSET_REQUEST_RETRY_COUNT: + raise + time.sleep(_ASSET_REQUEST_RETRY_DELAY_SECONDS) + if last_error is not None: + raise last_error + raise RuntimeError("Asset request retry loop ended unexpectedly.") + + def request_text(url: str) -> str: - with urlopen(url, timeout=5) as response: - return cast(str, response.read().decode("utf-8")) + body, _headers = _read_asset_response(url) + return cast(str, body.decode("utf-8")) def request_with_headers(url: str) -> tuple[str, dict[str, str]]: - with urlopen(url, timeout=5) as response: - body = response.read().decode("utf-8") - headers = {key: value for key, value in response.headers.items()} - return body, headers + body, headers = _read_asset_response(url) + return body.decode("utf-8"), headers + + +def request_headers(url: str) -> dict[str, str]: + last_error: OSError | None = None + for attempt_index in range(_ASSET_REQUEST_RETRY_COUNT): + try: + with urlopen(url, timeout=_ASSET_REQUEST_TIMEOUT_SECONDS) as response: + return {key: value for key, value in response.headers.items()} + except OSError as exc: + last_error = exc + if attempt_index + 1 >= _ASSET_REQUEST_RETRY_COUNT: + raise + time.sleep(_ASSET_REQUEST_RETRY_DELAY_SECONDS) + if last_error is not None: + raise last_error + raise RuntimeError("Asset header request retry loop ended unexpectedly.") + + +def request_bytes(url: str) -> bytes: + body, _headers = _read_asset_response(url) + return body diff --git a/tests/codegen/test_common.py b/tests/codegen/test_common.py index 5179f30..0dc4e62 100644 --- a/tests/codegen/test_common.py +++ b/tests/codegen/test_common.py @@ -16,6 +16,8 @@ ) from tensor_network_editor.models import ( CanvasPosition, + CodegenResult, + EngineName, NetworkSpec, TensorCollectionFormat, TensorSpec, @@ -158,3 +160,71 @@ def test_render_helper_function_lines_indents_rendered_sections() -> None: ) assert helper_lines == ["def build_cell(slot_index: int) -> dict[str, object]:"] + + +def test_dispatch_periodic_codegen_routes_supported_backends_and_roundtrip() -> None: + from tensor_network_editor.codegen.modes._periodic_codegen import ( + dispatch_periodic_codegen, + ) + + seen_calls: list[tuple[str, str]] = [] + + def render_array(payload: str) -> CodegenResult: + seen_calls.append(("array", payload)) + return CodegenResult(engine=EngineName.EINSUM_NUMPY, code="array_result = 1\n") + + def render_graph(payload: str) -> CodegenResult: + seen_calls.append(("graph", payload)) + return CodegenResult(engine=EngineName.TENSORNETWORK, code="graph_result = 1\n") + + spec = build_three_tensor_hyperedge_spec() + + array_result = dispatch_periodic_codegen( + spec=spec, + payload="array-payload", + missing_payload_message="missing payload", + unsupported_backend_label="periodic", + engine=EngineName.EINSUM_NUMPY, + include_roundtrip_metadata=True, + array_renderer=render_array, + graph_renderer=render_graph, + ) + graph_result = dispatch_periodic_codegen( + spec=spec, + payload="graph-payload", + missing_payload_message="missing payload", + unsupported_backend_label="periodic", + engine=EngineName.TENSORNETWORK, + include_roundtrip_metadata=False, + array_renderer=render_array, + graph_renderer=render_graph, + ) + + assert seen_calls == [("array", "array-payload"), ("graph", "graph-payload")] + assert "# TNE_SPEC_B64:" in array_result.code + assert graph_result.code == "graph_result = 1\n" + + +def test_dispatch_periodic_codegen_rejects_missing_payload() -> None: + from tensor_network_editor.codegen.modes._periodic_codegen import ( + dispatch_periodic_codegen, + ) + from tensor_network_editor.errors import CodeGenerationError + + with pytest.raises(CodeGenerationError, match="grid payload"): + dispatch_periodic_codegen( + spec=NetworkSpec(name="missing payload"), + payload=None, + missing_payload_message="Grid periodic code generation requires a grid payload.", + unsupported_backend_label="grid periodic", + engine=EngineName.EINSUM_NUMPY, + include_roundtrip_metadata=False, + array_renderer=lambda payload: CodegenResult( + engine=EngineName.EINSUM_NUMPY, + code=f"{payload}\n", + ), + graph_renderer=lambda payload: CodegenResult( + engine=EngineName.TENSORNETWORK, + code=f"{payload}\n", + ), + ) diff --git a/tests/codegen/test_generators.py b/tests/codegen/test_generators.py index 5876d2b..c3832d6 100644 --- a/tests/codegen/test_generators.py +++ b/tests/codegen/test_generators.py @@ -1,6 +1,8 @@ from __future__ import annotations +import sys from collections.abc import Callable +from types import ModuleType, SimpleNamespace from unittest.mock import patch import pytest @@ -13,6 +15,7 @@ from tensor_network_editor.errors import CodeGenerationError from tensor_network_editor.models import ( CanvasPosition, + ContractionStepSpec, EdgeEndpointRef, EdgeSpec, EngineName, @@ -34,6 +37,7 @@ build_outer_product_plan_spec, build_sample_spec, build_sample_spec_without_plan, + build_three_tensor_complete_plan_spec, build_three_tensor_hyperedge_spec, build_three_tensor_spec, build_three_tensor_spec_without_plan, @@ -263,6 +267,511 @@ def _execute_generated_code( return namespace +class _FakeTensorKrowchEdge: + """Minimal edge object for generated-code regression tests.""" + + def __init__( + self, + node: _FakeTensorKrowchNode, + axis_name: str, + *, + origin: tuple[str, str] | None = None, + ) -> None: + self.node1 = node + self.axis1 = SimpleNamespace(name=axis_name) + self.node2: _FakeTensorKrowchNode | None = None + self.axis2: SimpleNamespace | None = None + self.origin = origin or (node.name, axis_name) + + @classmethod + def from_endpoints( + cls, + *, + node1: _FakeTensorKrowchNode, + axis1_name: str, + node2: _FakeTensorKrowchNode | None = None, + axis2_name: str | None = None, + origin: tuple[str, str] | None = None, + ) -> _FakeTensorKrowchEdge: + """Build one edge with explicit endpoint ownership.""" + edge = cls(node1, axis1_name, origin=origin) + if node2 is not None and axis2_name is not None: + edge.attach_second(node2, axis2_name) + return edge + + def attach_second( + self, + node: _FakeTensorKrowchNode, + axis_name: str, + ) -> None: + self.node2 = node + self.axis2 = SimpleNamespace(name=axis_name) + + def replace_endpoint( + self, + old_node: _FakeTensorKrowchNode, + new_node: _FakeTensorKrowchNode, + new_axis_name: str, + ) -> None: + if self.node1 is old_node: + self.node1 = new_node + self.axis1 = SimpleNamespace(name=new_axis_name) + return + if self.node2 is old_node: + self.node2 = new_node + self.axis2 = SimpleNamespace(name=new_axis_name) + + def is_dangling(self) -> bool: + return self.node2 is None + + def axis_name_for_node( + self, + node: _FakeTensorKrowchNode, + ) -> SimpleNamespace: + """Return the endpoint axis metadata for ``node``.""" + if self.node1 is node: + return self.axis1 + assert self.node2 is node + assert self.axis2 is not None + return self.axis2 + + +class _FakeTensorKrowchNode: + """Minimal node object for generated-code regression tests.""" + + def __init__( + self, + *, + tensor: object, + axes_names: tuple[str, ...], + name: str, + network: object, + ) -> None: + del tensor, network + self.name = name + self.edges_by_axis_name = { + axis_name: _FakeTensorKrowchEdge(self, axis_name) + for axis_name in axes_names + } + self.pending_edges_by_axis_name: dict[str, _FakeTensorKrowchEdge] = {} + self.pending_edge_owner_by_axis_name: dict[str, _FakeTensorKrowchNode] = {} + self.axis_is_node1_by_axis_name = {axis_name: True for axis_name in axes_names} + + def __getitem__(self, axis_name: str) -> _FakeTensorKrowchEdge: + if axis_name in self.edges_by_axis_name: + return self.edges_by_axis_name[axis_name] + return self.pending_edges_by_axis_name[axis_name] + + def reattach_edges(self, override: bool = False) -> None: + for axis_name, edge in list(self.pending_edges_by_axis_name.items()): + owner = self.pending_edge_owner_by_axis_name.pop(axis_name) + owner_is_node1 = edge.node1 is owner + if owner_is_node1: + other_node = edge.node2 + other_axis_name = None if edge.axis2 is None else edge.axis2.name + else: + other_node = edge.node1 + other_axis_name = edge.axis1.name + if override: + if owner_is_node1: + edge.node1 = self + edge.axis1 = SimpleNamespace(name=axis_name) + else: + edge.node2 = self + edge.axis2 = SimpleNamespace(name=axis_name) + self.edges_by_axis_name[axis_name] = edge + else: + if owner_is_node1: + self.edges_by_axis_name[axis_name] = ( + _FakeTensorKrowchEdge.from_endpoints( + node1=self, + axis1_name=axis_name, + node2=other_node, + axis2_name=other_axis_name, + origin=edge.origin, + ) + ) + else: + assert other_node is not None + assert other_axis_name is not None + self.edges_by_axis_name[axis_name] = ( + _FakeTensorKrowchEdge.from_endpoints( + node1=other_node, + axis1_name=other_axis_name, + node2=self, + axis2_name=axis_name, + origin=edge.origin, + ) + ) + self.axis_is_node1_by_axis_name[axis_name] = owner_is_node1 + self.pending_edges_by_axis_name = {} + + +class _FakeTensorKrowchModule(ModuleType): + """Tiny ``tensorkrowch`` double that exposes fragile axis ordering.""" + + def __init__(self) -> None: + super().__init__("tensorkrowch") + self.Node = _FakeTensorKrowchNode + self.TensorNetwork = _fake_tensorkrowch_network_factory + + @staticmethod + def connect( + left_edge: _FakeTensorKrowchEdge, + right_edge: _FakeTensorKrowchEdge, + ) -> _FakeTensorKrowchEdge: + left_edge.attach_second(right_edge.node1, right_edge.axis1.name) + right_edge.node1.edges_by_axis_name[right_edge.axis1.name] = left_edge + right_edge.node1.axis_is_node1_by_axis_name[right_edge.axis1.name] = False + return left_edge + + @staticmethod + def contract_between( + left_node: _FakeTensorKrowchNode, + right_node: _FakeTensorKrowchNode, + ) -> _FakeTensorKrowchNode: + left_edges = set(left_node.edges_by_axis_name.values()) + right_edges = set(right_node.edges_by_axis_name.values()) + if not left_edges.intersection(right_edges): + raise ValueError( + f"No batch edges or shared edges between nodes {left_node.name} and {right_node.name} found" + ) + shared_edges = left_edges.intersection(right_edges) + surviving_edges_with_owner = [ + (edge, left_node) + for edge in left_node.edges_by_axis_name.values() + if edge not in shared_edges + ] + [ + (edge, right_node) + for edge in right_node.edges_by_axis_name.values() + if edge not in shared_edges + ] + surviving_axis_names = _deduplicate_fake_tensorkrowch_axis_names( + tuple( + edge.axis_name_for_node(owner).name + for edge, owner in surviving_edges_with_owner + ) + ) + result = _FakeTensorKrowchNode( + tensor=None, + axes_names=surviving_axis_names, + name=f"{left_node.name}_{right_node.name}", + network=None, + ) + result.edges_by_axis_name = {} + result.pending_edges_by_axis_name = {} + result.pending_edge_owner_by_axis_name = {} + result.axis_is_node1_by_axis_name = {} + for axis_name, (edge, owner) in zip( + surviving_axis_names, + surviving_edges_with_owner, + strict=True, + ): + result.pending_edges_by_axis_name[axis_name] = edge + result.pending_edge_owner_by_axis_name[axis_name] = owner + result.axis_is_node1_by_axis_name[axis_name] = edge.node1 is owner + return result + + +class _FakeTorchModule(ModuleType): + """Tiny ``torch`` double for generated-code regression tests.""" + + float32: object + + def __init__(self) -> None: + super().__init__("torch") + self.float32 = object() + + @staticmethod + def zeros( + shape: tuple[int, ...], + dtype: object | None = None, + ) -> tuple[tuple[int, ...], object | None]: + return (shape, dtype) + + +def _deduplicate_fake_tensorkrowch_axis_names( + axis_names: tuple[str, ...], +) -> tuple[str, ...]: + """Mirror TensorKrowch suffixing for exact duplicate surviving axes.""" + base_names = [ + axis_name.rsplit("_", 1)[0] + if axis_name.rsplit("_", 1)[-1].isdigit() + else axis_name + for axis_name in axis_names + ] + result: list[str] = [] + counts: dict[str, int] = {} + for axis_name in base_names: + index = counts.get(axis_name, 0) + counts[axis_name] = index + 1 + if base_names.count(axis_name) == 1: + result.append(axis_name) + else: + result.append(f"{axis_name}_{index}") + return tuple(result) + + +def _fake_tensorkrowch_network_factory() -> object: + """Return a placeholder TensorNetwork instance for generated code.""" + return SimpleNamespace(reset=lambda: None) + + +class _ResetAwareFakeTensorKrowchNetwork: + """Minimal network object that can resync inherited resultant edges.""" + + def __init__(self) -> None: + self.nodes: list[_ResetAwareFakeTensorKrowchNode] = [] + + def register(self, node: _ResetAwareFakeTensorKrowchNode) -> None: + self.nodes.append(node) + + def reset(self) -> None: + for node in self.nodes: + node.reset_inherited_edges() + + +class _ResetAwareFakeTensorKrowchEdge: + """Edge double that hides inherited-result connections until reset.""" + + def __init__( + self, + node: _ResetAwareFakeTensorKrowchNode, + axis_name: str, + ) -> None: + self.node1 = node + self.axis1 = SimpleNamespace(name=axis_name) + self.node2: _ResetAwareFakeTensorKrowchNode | None = None + self.axis2: SimpleNamespace | None = None + self.origin = (node.name, axis_name) + self.inherited_source_by_result_node: dict[ + _ResetAwareFakeTensorKrowchNode, + tuple[_ResetAwareFakeTensorKrowchNode, str], + ] = {} + + def attach_second( + self, + node: _ResetAwareFakeTensorKrowchNode, + axis_name: str, + ) -> None: + self.node2 = node + self.axis2 = SimpleNamespace(name=axis_name) + + def replace_endpoint( + self, + old_node: _ResetAwareFakeTensorKrowchNode, + new_node: _ResetAwareFakeTensorKrowchNode, + new_axis_name: str, + ) -> None: + if self.node1 is old_node: + self.inherited_source_by_result_node[new_node] = ( + old_node, + self.axis1.name, + ) + self.node1 = new_node + self.axis1 = SimpleNamespace(name=new_axis_name) + self._stale_other_resultant_endpoints(excluded_result_node=new_node) + return + if self.node2 is old_node: + assert self.axis2 is not None + self.inherited_source_by_result_node[new_node] = ( + old_node, + self.axis2.name, + ) + self.node2 = new_node + self.axis2 = SimpleNamespace(name=new_axis_name) + self._stale_other_resultant_endpoints(excluded_result_node=new_node) + + def materialize_leaf_endpoint_for_resultant( + self, + node: _ResetAwareFakeTensorKrowchNode, + ) -> None: + source = self.inherited_source_by_result_node.get(node) + if source is None: + return + source_node, source_axis_name = source + if self.node1 is node: + self.node1 = source_node + self.axis1 = SimpleNamespace(name=source_axis_name) + return + if self.node2 is node: + self.node2 = source_node + self.axis2 = SimpleNamespace(name=source_axis_name) + + def restore_resultant_endpoint( + self, + node: _ResetAwareFakeTensorKrowchNode, + axis_name: str, + ) -> None: + source = self.inherited_source_by_result_node.get(node) + if source is None: + return + source_node, source_axis_name = source + if self.node1 is source_node and self.axis1.name == source_axis_name: + self.node1 = node + self.axis1 = SimpleNamespace(name=axis_name) + return + if ( + self.node2 is source_node + and self.axis2 is not None + and self.axis2.name == source_axis_name + ): + self.node2 = node + self.axis2 = SimpleNamespace(name=axis_name) + + def _stale_other_resultant_endpoints( + self, + *, + excluded_result_node: _ResetAwareFakeTensorKrowchNode, + ) -> None: + """Hide this edge from other inherited-result views until reset.""" + for result_node in tuple(self.inherited_source_by_result_node): + if result_node is excluded_result_node: + continue + if self.node1 is result_node or self.node2 is result_node: + self.materialize_leaf_endpoint_for_resultant(result_node) + + def is_dangling(self) -> bool: + return self.node2 is None + + def axis_name_for_node( + self, + node: _ResetAwareFakeTensorKrowchNode, + ) -> SimpleNamespace: + if self.node1 is node: + return self.axis1 + assert self.node2 is node + assert self.axis2 is not None + return self.axis2 + + def connects_nodes( + self, + left_node: _ResetAwareFakeTensorKrowchNode, + right_node: _ResetAwareFakeTensorKrowchNode, + ) -> bool: + return (self.node1 is left_node and self.node2 is right_node) or ( + self.node1 is right_node and self.node2 is left_node + ) + + +class _ResetAwareFakeTensorKrowchNode: + """Node double that tracks resultant-edge visibility across resets.""" + + def __init__( + self, + *, + tensor: object, + axes_names: tuple[str, ...], + name: str, + network: _ResetAwareFakeTensorKrowchNetwork | None, + ) -> None: + del tensor + self.name = name + self.network = network + self.is_resultant = False + self.edges_by_axis_name = { + axis_name: _ResetAwareFakeTensorKrowchEdge(self, axis_name) + for axis_name in axes_names + } + self.pending_edges_by_axis_name: dict[ + str, + _ResetAwareFakeTensorKrowchEdge, + ] = {} + if network is not None: + network.register(self) + + def __getitem__(self, axis_name: str) -> _ResetAwareFakeTensorKrowchEdge: + if axis_name in self.edges_by_axis_name: + return self.edges_by_axis_name[axis_name] + return self.pending_edges_by_axis_name[axis_name] + + def reattach_edges(self) -> None: + self.edges_by_axis_name.update(self.pending_edges_by_axis_name) + self.pending_edges_by_axis_name = {} + + def reset_inherited_edges(self) -> None: + for axis_name, edge in self.edges_by_axis_name.items(): + edge.restore_resultant_endpoint(self, axis_name) + for axis_name, edge in self.pending_edges_by_axis_name.items(): + edge.restore_resultant_endpoint(self, axis_name) + + +class _ResetAwareFakeTensorKrowchModule(ModuleType): + """TensorKrowch double that requires ``network.reset()`` for inherited edges.""" + + def __init__(self) -> None: + super().__init__("tensorkrowch") + self.Node = _ResetAwareFakeTensorKrowchNode + self.TensorNetwork = _reset_aware_fake_tensorkrowch_network_factory + + @staticmethod + def connect( + left_edge: _ResetAwareFakeTensorKrowchEdge, + right_edge: _ResetAwareFakeTensorKrowchEdge, + ) -> _ResetAwareFakeTensorKrowchEdge: + if left_edge.is_dangling() and left_edge.node1.is_resultant: + left_edge.materialize_leaf_endpoint_for_resultant(left_edge.node1) + left_edge.attach_second(right_edge.node1, right_edge.axis1.name) + right_edge.node1.edges_by_axis_name[right_edge.axis1.name] = left_edge + return left_edge + + @staticmethod + def contract_between( + left_node: _ResetAwareFakeTensorKrowchNode, + right_node: _ResetAwareFakeTensorKrowchNode, + ) -> _ResetAwareFakeTensorKrowchNode: + left_edges = set(left_node.edges_by_axis_name.values()) + right_edges = set(right_node.edges_by_axis_name.values()) + shared_edges = { + edge + for edge in left_edges.intersection(right_edges) + if edge.connects_nodes(left_node, right_node) + } + if not shared_edges: + raise ValueError( + f"No batch edges or shared edges between nodes {left_node.name} and {right_node.name} found" + ) + surviving_edges_with_owner = [ + (edge, right_node) + for edge in right_node.edges_by_axis_name.values() + if edge not in shared_edges + ] + [ + (edge, left_node) + for edge in left_node.edges_by_axis_name.values() + if edge not in shared_edges + ] + surviving_axis_names = _deduplicate_fake_tensorkrowch_axis_names( + tuple( + edge.axis_name_for_node(owner).name + for edge, owner in surviving_edges_with_owner + ) + ) + result = _ResetAwareFakeTensorKrowchNode( + tensor=None, + axes_names=surviving_axis_names, + name=f"{left_node.name}_{right_node.name}", + network=left_node.network, + ) + result.is_resultant = True + result.edges_by_axis_name = {} + result.pending_edges_by_axis_name = {} + for axis_name, (edge, owner) in zip( + surviving_axis_names, + surviving_edges_with_owner, + strict=True, + ): + edge.replace_endpoint(owner, result, axis_name) + result.pending_edges_by_axis_name[axis_name] = edge + return result + + +def _reset_aware_fake_tensorkrowch_network_factory() -> ( + _ResetAwareFakeTensorKrowchNetwork +): + """Return a fake network that models inherited-edge reset semantics.""" + return _ResetAwareFakeTensorKrowchNetwork() + + @pytest.mark.parametrize( ("engine", "expected_snippets"), [ @@ -813,6 +1322,20 @@ def test_periodic_generate_code_emits_roundtrip_metadata_marker() -> None: assert "# TNE_SPEC_B64:" in result.code assert "# Tensor Network Editor linear periodic mode" in result.code + assert result.code.index( + "# Tensor Network Editor linear periodic mode" + ) < result.code.index("# TNE_SPEC_B64:") + + +def test_periodic_generate_code_can_skip_roundtrip_metadata_marker() -> None: + result = generate_code( + build_linear_periodic_chain_spec(), + engine=EngineName.EINSUM_NUMPY, + include_roundtrip_metadata=False, + ) + + assert "# TNE_SPEC_B64:" not in result.code + assert "# Tensor Network Editor linear periodic mode" in result.code @pytest.mark.parametrize("engine", list(EngineName)) @@ -856,6 +1379,16 @@ def test_generate_code_respects_manual_plan_steps( assert "result = results_list[-1]" in result.code +def test_tensorkrowch_normal_codegen_does_not_emit_reattach_edges() -> None: + result = generate_code( + build_three_tensor_complete_plan_spec(), + engine=EngineName.TENSORKROWCH, + ) + + assert "results_list.append(tk.contract_between(" in result.code + assert "reattach_edges(" not in result.code + + @pytest.mark.parametrize("engine", list(EngineName)) def test_generate_code_keeps_partial_manual_plan_as_prefix( engine: EngineName, @@ -1119,6 +1652,120 @@ def test_linear_periodic_carry_codegen_labels_shared_for_sections( assert "previous_payload: dict[str, object]" in result.code +def test_linear_periodic_carry_tensorkrowch_codegen_tracks_boundary_edges_without_axis_order_assumptions() -> ( + None +): + result = generate_code( + build_linear_periodic_carry_chain_spec(), + engine=EngineName.TENSORKROWCH, + ) + fake_torch = _FakeTorchModule() + fake_tensorkrowch = _FakeTensorKrowchModule() + + with patch.dict( + sys.modules, + { + "torch": fake_torch, + "tensorkrowch": fake_tensorkrowch, + }, + ): + namespace = _execute_generated_code(result.code, n=3) + + open_edges = namespace["open_edges"] + assert isinstance(open_edges, list) + assert len(open_edges) == 4 + assert [edge.origin for edge in open_edges] == [ + ("Initial", "phys"), + ("PeriodicLeft", "phys_l"), + ("PeriodicRight", "phys_r"), + ("Final", "phys"), + ] + + +def test_linear_periodic_carry_tensorkrowch_codegen_executes_when_periodic_cell_contracts_local_pair_before_previous_payload() -> ( + None +): + spec = build_linear_periodic_carry_chain_spec() + assert spec.linear_periodic_chain is not None + assert spec.linear_periodic_chain.periodic_cell.contraction_plan is not None + spec.linear_periodic_chain.periodic_cell.contraction_plan.steps = [ + ContractionStepSpec( + id="periodic_contract_internal_first", + left_operand_id="periodic_left_tensor", + right_operand_id="periodic_right_tensor", + ), + ContractionStepSpec( + id="periodic_consume_previous_second", + left_operand_id="periodic_contract_internal_first", + right_operand_id="__linear_previous__", + ), + ContractionStepSpec( + id="periodic_carry_last", + left_operand_id="periodic_consume_previous_second", + right_operand_id="__linear_next__", + ), + ] + result = generate_code(spec, engine=EngineName.TENSORKROWCH) + fake_torch = _FakeTorchModule() + fake_tensorkrowch = _FakeTensorKrowchModule() + + with patch.dict( + sys.modules, + { + "torch": fake_torch, + "tensorkrowch": fake_tensorkrowch, + }, + ): + namespace = _execute_generated_code(result.code, n=5) + + assert "result" in namespace + assert "open_edges" in namespace + + +def test_linear_periodic_carry_tensorkrowch_codegen_materializes_result_edges_with_override() -> ( + None +): + spec = build_linear_periodic_carry_chain_spec() + assert spec.linear_periodic_chain is not None + assert spec.linear_periodic_chain.periodic_cell.contraction_plan is not None + spec.linear_periodic_chain.periodic_cell.contraction_plan.steps = [ + ContractionStepSpec( + id="periodic_contract_internal_first", + left_operand_id="periodic_left_tensor", + right_operand_id="periodic_right_tensor", + ), + ContractionStepSpec( + id="periodic_consume_previous_second", + left_operand_id="periodic_contract_internal_first", + right_operand_id="__linear_previous__", + ), + ContractionStepSpec( + id="periodic_carry_last", + left_operand_id="periodic_consume_previous_second", + right_operand_id="__linear_next__", + ), + ] + result = generate_code(spec, engine=EngineName.TENSORKROWCH) + fake_torch = _FakeTorchModule() + fake_tensorkrowch = _FakeTensorKrowchModule() + + with patch.dict( + sys.modules, + { + "torch": fake_torch, + "tensorkrowch": fake_tensorkrowch, + }, + ): + namespace = _execute_generated_code(result.code, n=5) + + assert "reattach_edges(override=True)" in result.code + assert "network.reset()" not in result.code + assert "open_edges.extend([tracked_edge_0, tracked_edge_1])" in result.code + assert "outgoing_interface = [results_list[-1]['right']]" in result.code + assert "result" in namespace + assert "open_edges" in namespace + + @pytest.mark.parametrize("engine", list(EngineName)) def test_linear_periodic_codegen_does_not_stringify_manual_blocks( engine: EngineName, diff --git a/tests/codegen/test_grid_periodic_internals.py b/tests/codegen/test_grid_periodic_internals.py index 6b31c71..1856e30 100644 --- a/tests/codegen/test_grid_periodic_internals.py +++ b/tests/codegen/test_grid_periodic_internals.py @@ -1,6 +1,9 @@ from __future__ import annotations -from tensor_network_editor.models import GridPeriodicCellName +from tensor_network_editor.models import ( + GridPeriodicCellName, + TensorCollectionFormat, +) from tests.factories import build_grid_periodic_grid_spec @@ -72,3 +75,35 @@ def test_grid_periodic_internal_helpers_keep_shared_labels_and_main_flow() -> No "result = network_nodes[0] if len(network_nodes) == 1 else None" ] assert "output_labels.extend(bottom_right_cell['open_labels'])" in einsum_main_lines + + +def test_grid_periodic_array_shared_helpers_build_context_and_sections() -> None: + from tensor_network_editor.codegen.modes._grid_periodic.array_shared import ( + build_grid_array_cell_context, + render_grid_array_tensor_sections, + ) + + grid = build_grid_periodic_grid_spec().grid_periodic_grid + assert grid is not None + + context = build_grid_array_cell_context( + grid=grid, + cell_name=GridPeriodicCellName.TOP_LEFT, + collection_format=TensorCollectionFormat.LIST, + ) + tensor_collection_lines, tensor_construction_lines = ( + render_grid_array_tensor_sections( + context=context, + tensor_value_by_id={ + tensor.spec.id: f"value_{tensor.variable_name}" + for tensor in context.prepared.tensors + }, + ) + ) + + assert context.collection_name == "tensors" + assert context.prepared.tensors + assert context.interface_index_ids + assert tensor_collection_lines == ["tensors = []"] + assert any(line.startswith("# Tensor ") for line in tensor_construction_lines) + assert any("tensors.append(value_" in line for line in tensor_construction_lines) diff --git a/tests/codegen/test_linear_periodic_internals.py b/tests/codegen/test_linear_periodic_internals.py index 60846ab..8de0540 100644 --- a/tests/codegen/test_linear_periodic_internals.py +++ b/tests/codegen/test_linear_periodic_internals.py @@ -4,9 +4,17 @@ from tensor_network_editor.errors import CodeGenerationError from tensor_network_editor.models import ( + CanvasPosition, + ContractionPlanSpec, ContractionStepSpec, + EdgeEndpointRef, + EdgeSpec, EngineName, + IndexSpec, LinearPeriodicCellName, + LinearPeriodicCellSpec, + LinearPeriodicTensorRole, + TensorSpec, ) from tests.factories import build_linear_periodic_carry_chain_spec @@ -175,6 +183,203 @@ def test_simulate_carry_cell_rejects_next_step_that_is_not_final() -> None: ) +def test_simulate_carry_cell_accepts_previous_payload_labels_that_only_collide_by_name() -> ( + None +): + from tensor_network_editor.codegen.modes._linear_periodic.carry import ( + _CarryOperandState, + _CarryPayloadState, + _simulate_carry_cell, + ) + + periodic_cell = LinearPeriodicCellSpec( + tensors=[ + TensorSpec( + id="periodic_previous_boundary", + name="Previous cell", + position=CanvasPosition(x=-100.0, y=140.0), + linear_periodic_role=LinearPeriodicTensorRole.PREVIOUS, + indices=[ + IndexSpec( + id="periodic_previous_slot_1", name="slot_1", dimension=2 + ), + IndexSpec( + id="periodic_previous_slot_2", name="slot_2", dimension=2 + ), + ], + ), + TensorSpec( + id="periodic_next_boundary", + name="Next cell", + position=CanvasPosition(x=540.0, y=140.0), + linear_periodic_role=LinearPeriodicTensorRole.NEXT, + indices=[ + IndexSpec(id="periodic_next_slot_1", name="slot_1", dimension=2), + IndexSpec(id="periodic_next_slot_2", name="slot_2", dimension=2), + ], + ), + TensorSpec( + id="tensor_a1", + name="A1", + position=CanvasPosition(x=-255.0, y=363.0), + indices=[ + IndexSpec(id="a1_right", name="right", dimension=3), + IndexSpec(id="a1_phys", name="phys", dimension=2), + ], + ), + TensorSpec( + id="tensor_a2", + name="A2", + position=CanvasPosition(x=65.0, y=363.0), + indices=[ + IndexSpec(id="a2_left", name="left", dimension=3), + IndexSpec(id="a2_right", name="right", dimension=3), + IndexSpec(id="a2_phys", name="phys", dimension=2), + ], + ), + TensorSpec( + id="tensor_a3", + name="A3", + position=CanvasPosition(x=385.0, y=363.0), + indices=[ + IndexSpec(id="a3_left", name="left", dimension=3), + IndexSpec(id="a3_right", name="right", dimension=3), + IndexSpec(id="a3_phys", name="phys", dimension=2), + ], + ), + TensorSpec( + id="tensor_a4", + name="A4", + position=CanvasPosition(x=705.0, y=363.0), + indices=[ + IndexSpec(id="a4_left", name="left", dimension=3), + IndexSpec(id="a4_phys", name="phys", dimension=2), + ], + ), + ], + edges=[ + EdgeSpec( + id="edge_a1_a2", + name="edge-0-1", + left=EdgeEndpointRef(tensor_id="tensor_a1", index_id="a1_right"), + right=EdgeEndpointRef(tensor_id="tensor_a2", index_id="a2_left"), + ), + EdgeSpec( + id="edge_a2_a3", + name="edge-1-2", + left=EdgeEndpointRef(tensor_id="tensor_a2", index_id="a2_right"), + right=EdgeEndpointRef(tensor_id="tensor_a3", index_id="a3_left"), + ), + EdgeSpec( + id="edge_a3_a4", + name="edge-2-3", + left=EdgeEndpointRef(tensor_id="tensor_a3", index_id="a3_right"), + right=EdgeEndpointRef(tensor_id="tensor_a4", index_id="a4_left"), + ), + EdgeSpec( + id="edge_previous_a1", + name="bond1", + left=EdgeEndpointRef( + tensor_id="tensor_a1", + index_id="a1_phys", + ), + right=EdgeEndpointRef( + tensor_id="periodic_previous_boundary", + index_id="periodic_previous_slot_1", + ), + ), + EdgeSpec( + id="edge_previous_a2", + name="bond2", + left=EdgeEndpointRef( + tensor_id="periodic_previous_boundary", + index_id="periodic_previous_slot_2", + ), + right=EdgeEndpointRef( + tensor_id="tensor_a2", + index_id="a2_phys", + ), + ), + EdgeSpec( + id="edge_a3_next", + name="bond3", + left=EdgeEndpointRef(tensor_id="tensor_a3", index_id="a3_phys"), + right=EdgeEndpointRef( + tensor_id="periodic_next_boundary", + index_id="periodic_next_slot_1", + ), + ), + EdgeSpec( + id="edge_a4_next", + name="bond4", + left=EdgeEndpointRef( + tensor_id="periodic_next_boundary", + index_id="periodic_next_slot_2", + ), + right=EdgeEndpointRef(tensor_id="tensor_a4", index_id="a4_phys"), + ), + ], + contraction_plan=ContractionPlanSpec( + id="periodic_plan", + name="Manual path", + steps=[ + ContractionStepSpec( + id="step_contract_right", + left_operand_id="tensor_a4", + right_operand_id="tensor_a3", + ), + ContractionStepSpec( + id="step_from_previous", + left_operand_id="__linear_previous__", + right_operand_id="tensor_a2", + ), + ContractionStepSpec( + id="step_merge", + left_operand_id="step_from_previous", + right_operand_id="step_contract_right", + ), + ContractionStepSpec( + id="step_to_next", + left_operand_id="step_merge", + right_operand_id="__linear_next__", + ), + ], + ), + ) + previous_payload_state = _CarryPayloadState( + interface_operand_ids=("payload_left", "payload_right"), + interface_labels=("a1_phys", "a2_phys"), + operand_states={ + "payload_left": _CarryOperandState( + labels=("payload_edge", "a1_phys"), + axis_names=("left_payload", "slot_1"), + dimensions=(3, 2), + ), + "payload_right": _CarryOperandState( + labels=("a4_phys", "a3_phys", "payload_edge", "a2_phys"), + axis_names=("carry_0", "carry_1", "bridge", "slot_2"), + dimensions=(2, 2, 3, 2), + ), + }, + ) + + simulation = _simulate_carry_cell( + cell=periodic_cell, + cell_name=LinearPeriodicCellName.PERIODIC, + previous_payload_state=previous_payload_state, + engine=EngineName.TENSORKROWCH, + ) + + assert simulation.carry_operand_id == "step_merge" + assert simulation.outgoing_interface_operand_ids == ("step_merge", "step_merge") + assert ( + simulation.remaining_operand_states["step_merge"].labels.count("a3_phys") == 1 + ) + assert ( + simulation.remaining_operand_states["step_merge"].labels.count("a4_phys") == 1 + ) + + def test_build_carry_simulation_context_collects_interface_state() -> None: from tensor_network_editor.codegen.modes._linear_periodic.carry import ( _build_carry_simulation_context, diff --git a/tests/codegen/test_tree_periodic_internals.py b/tests/codegen/test_tree_periodic_internals.py index 164efbd..384beec 100644 --- a/tests/codegen/test_tree_periodic_internals.py +++ b/tests/codegen/test_tree_periodic_internals.py @@ -124,3 +124,36 @@ def test_tree_periodic_array_helpers_keep_child_interfaces_and_backend_tensor_bu assert "np.zeros(" in numpy_helper_body assert "torch.zeros(" in torch_helper_body assert "np.zeros(" not in torch_helper_body + + +def test_tree_periodic_array_shared_helpers_build_context_and_sections() -> None: + from tensor_network_editor.codegen.modes._tree_periodic.array_shared import ( + build_tree_array_cell_context, + render_tree_array_tensor_sections, + ) + + tree = build_tree_periodic_tree_spec().tree_periodic_tree + assert tree is not None + + context = build_tree_array_cell_context( + tree=tree, + cell_name=TreePeriodicCellName.ROOT, + collection_format=TensorCollectionFormat.LIST, + ) + tensor_collection_lines, tensor_construction_lines = ( + render_tree_array_tensor_sections( + context=context, + tensor_value_by_id={ + tensor.spec.id: f"value_{tensor.variable_name}" + for tensor in context.prepared.tensors + }, + ) + ) + + assert context.collection_name == "tensors" + assert context.parent_ports == () + assert tuple(context.child_ports_by_index) == tuple(range(tree.branching_factor)) + assert context.interface_index_ids + assert tensor_collection_lines == ["tensors = []"] + assert any(line.startswith("# Tensor ") for line in tensor_construction_lines) + assert any("tensors.append(value_" in line for line in tensor_construction_lines) diff --git a/tests/test_api.py b/tests/test_api.py index 2c8d150..baec4f0 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -10,7 +10,7 @@ import tensor_network_editor from tensor_network_editor import generate_code as _generate_code -from tensor_network_editor.editor import EditorLaunchOptions, open_editor +from tensor_network_editor.editor import EditorLaunchOptions, EditorUiMode, open_editor from tensor_network_editor.errors import ( CodeGenerationError, PackageIOError, @@ -152,6 +152,7 @@ def test_package_root_exports_supported_public_api() -> None: "EdgeSpec", "EditorLaunchOptions", "EditorThemeName", + "EditorUiMode", "EditorResult", "EngineName", "DotRenderOptions", @@ -233,8 +234,10 @@ def test_editor_launch_options_defaults_match_public_contract() -> None: assert options.default_engine is EngineName.TENSORKROWCH assert options.default_collection_format is TensorCollectionFormat.LIST assert options.theme == "dark" + assert options.ui_mode is None assert options.open_browser is True assert options.host == "127.0.0.1" + assert options.allow_remote is False assert options.port == 0 assert options.print_code is False assert options.code_path is None @@ -248,6 +251,40 @@ def test_editor_launch_options_rejects_unknown_theme() -> None: EditorLaunchOptions(theme="sepia") # type: ignore[arg-type] +def test_editor_launch_options_rejects_non_loopback_host_without_remote_opt_in() -> ( + None +): + with pytest.raises(ValueError, match="non-loopback"): + EditorLaunchOptions(host="0.0.0.0") + + +def test_editor_launch_options_allows_non_loopback_host_with_remote_opt_in() -> None: + options = EditorLaunchOptions(host="0.0.0.0", allow_remote=True) + + assert options.host == "0.0.0.0" + assert options.allow_remote is True + + +def test_editor_ui_mode_type_alias_matches_public_contract() -> None: + assert EditorUiMode == Literal["browser", "pywebview", "server"] + + +@pytest.mark.parametrize( + ("ui_mode", "open_browser", "expected_message"), + [ + ("browser", False, "ui_mode='browser' requires open_browser=True"), + ("server", True, "ui_mode='server' requires open_browser=False"), + ], +) +def test_editor_launch_options_rejects_conflicting_browser_flags( + ui_mode: EditorUiMode, + open_browser: bool, + expected_message: str, +) -> None: + with pytest.raises(ValueError, match=expected_message): + EditorLaunchOptions(ui_mode=ui_mode, open_browser=open_browser) + + def test_open_editor_passes_editor_launch_options(sample_spec: NetworkSpec) -> None: launch_result = object() @@ -261,8 +298,10 @@ def test_open_editor_passes_editor_launch_options(sample_spec: NetworkSpec) -> N default_engine=EngineName.EINSUM_NUMPY, default_collection_format=TensorCollectionFormat.DICT, theme="colorblind", + ui_mode="pywebview", open_browser=False, host="0.0.0.0", + allow_remote=True, port=8123, print_code=True, code_path="generated.py", @@ -281,8 +320,10 @@ def test_open_editor_passes_editor_launch_options(sample_spec: NetworkSpec) -> N default_engine=EngineName.EINSUM_NUMPY, default_collection_format=TensorCollectionFormat.DICT, theme="colorblind", + ui_mode="pywebview", open_browser=False, host="0.0.0.0", + allow_remote=True, port=8123, print_code=True, code_path="generated.py", diff --git a/tests/test_app_assets.py b/tests/test_app_assets.py index 4c47bd1..8d80460 100644 --- a/tests/test_app_assets.py +++ b/tests/test_app_assets.py @@ -8,7 +8,12 @@ import pytest from tensor_network_editor.app.server import EditorServer -from tests.app_support import request_text, request_with_headers +from tests.app_support import ( + request_bytes, + request_headers, + request_text, + request_with_headers, +) def request_runtime_bundle(editor_server: EditorServer, *relative_paths: str) -> str: @@ -109,6 +114,36 @@ def test_root_serves_editor_shell_with_versioned_module_entry( assert headers["Content-Type"].startswith("text/html") +def test_root_serves_editor_shell_with_csp_nonce_and_defensive_headers( + editor_server: EditorServer, +) -> None: + html, headers = request_with_headers(f"{editor_server.base_url}/") + + content_security_policy = headers["Content-Security-Policy"] + nonce_match = re.search( + r"(?:^|;\s*)script-src 'self' 'nonce-([^']+)';", + content_security_policy, + ) + + assert nonce_match is not None + nonce = nonce_match.group(1) + assert nonce + assert "'unsafe-inline'" not in nonce_match.group(0) + assert ( + f'", + encoding="utf-8", + ) + asset_path.write_text("console.log('first');", encoding="utf-8") + monotonic_time = 100.0 + + scan_calls: list[Path] = [] + original_scan = app_server._scan_static_asset_files + + def recording_scan(path: Path) -> list[tuple[Path, str, int, int]]: + scan_calls.append(path.resolve()) + return original_scan(path) + + monkeypatch.setattr(app_server.time, "monotonic", lambda: monotonic_time) + monkeypatch.setattr(app_server, "_scan_static_asset_files", recording_scan) + app_server._STATIC_ASSET_CACHE_BY_ROOT.pop(resolved_static_dir, None) + app_server._STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.pop( + resolved_static_dir, None + ) + + first_cache = app_server._get_static_asset_cache(static_dir) + second_cache = app_server._get_static_asset_cache(static_dir) + + assert first_cache is second_cache + assert scan_calls == [resolved_static_dir] + + def test_static_asset_cache_logs_build_and_reuse( tmp_path: Path, caplog: pytest.LogCaptureFixture, @@ -255,6 +418,9 @@ def test_static_asset_cache_logs_build_and_reuse( ) asset_path.write_text("console.log('first');", encoding="utf-8") app_server._STATIC_ASSET_CACHE_BY_ROOT.pop(resolved_static_dir, None) + app_server._STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.pop( + resolved_static_dir, None + ) with caplog.at_level(logging.DEBUG, logger="tensor_network_editor"): first_cache = app_server._get_static_asset_cache(static_dir) @@ -270,6 +436,7 @@ def test_static_asset_cache_logs_build_and_reuse( def test_static_asset_cache_logs_refresh_with_version_context( tmp_path: Path, caplog: pytest.LogCaptureFixture, + monkeypatch: pytest.MonkeyPatch, ) -> None: static_dir = tmp_path / "static" asset_path = static_dir / "js" / "app.js" @@ -280,7 +447,13 @@ def test_static_asset_cache_logs_refresh_with_version_context( encoding="utf-8", ) asset_path.write_text("console.log('first');", encoding="utf-8") + monotonic_time = 100.0 + + monkeypatch.setattr(app_server.time, "monotonic", lambda: monotonic_time) app_server._STATIC_ASSET_CACHE_BY_ROOT.pop(resolved_static_dir, None) + app_server._STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.pop( + resolved_static_dir, None + ) first_cache = app_server._get_static_asset_cache(static_dir) asset_path.write_text("console.log('second');", encoding="utf-8") @@ -289,6 +462,7 @@ def test_static_asset_cache_logs_refresh_with_version_context( + 1_000_000_000 ) os.utime(asset_path, ns=(future_timestamp_ns, future_timestamp_ns)) + monotonic_time += 1.0 with caplog.at_level(logging.DEBUG, logger="tensor_network_editor"): refreshed_cache = app_server._get_static_asset_cache(static_dir) diff --git a/tests/test_app_support.py b/tests/test_app_support.py new file mode 100644 index 0000000..fd8881b --- /dev/null +++ b/tests/test_app_support.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from unittest.mock import patch + +from tests import app_support + + +class _FakeResponse: + def __init__(self, body: str) -> None: + self._body = body.encode("utf-8") + self.status = 200 + self.headers = {"Cache-Control": "no-store"} + + def __enter__(self) -> _FakeResponse: + return self + + def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: + del exc_type, exc, traceback + return None + + def read(self) -> bytes: + return self._body + + +def test_request_text_uses_shared_asset_timeout() -> None: + recorded_timeout: list[float] = [] + + def fake_urlopen(url: str, timeout: float) -> _FakeResponse: + recorded_timeout.append(timeout) + assert url == "http://example.test/" + return _FakeResponse("body") + + with patch("tests.app_support.urlopen", side_effect=fake_urlopen): + body = app_support.request_text("http://example.test/") + + assert body == "body" + assert recorded_timeout == [app_support._ASSET_REQUEST_TIMEOUT_SECONDS] + + +def test_request_with_headers_uses_shared_asset_timeout() -> None: + recorded_timeout: list[float] = [] + + def fake_urlopen(url: str, timeout: float) -> _FakeResponse: + recorded_timeout.append(timeout) + assert url == "http://example.test/app.css" + return _FakeResponse("css") + + with patch("tests.app_support.urlopen", side_effect=fake_urlopen): + body, headers = app_support.request_with_headers("http://example.test/app.css") + + assert body == "css" + assert headers == {"Cache-Control": "no-store"} + assert recorded_timeout == [app_support._ASSET_REQUEST_TIMEOUT_SECONDS] + + +def test_request_headers_uses_shared_asset_timeout_without_reading_body() -> None: + recorded_timeout: list[float] = [] + response = _FakeResponse("body") + + def fake_urlopen(url: str, timeout: float) -> _FakeResponse: + recorded_timeout.append(timeout) + assert url == "http://example.test/vendor.js" + return response + + with patch("tests.app_support.urlopen", side_effect=fake_urlopen): + headers = app_support.request_headers("http://example.test/vendor.js") + + assert headers == {"Cache-Control": "no-store"} + assert recorded_timeout == [app_support._ASSET_REQUEST_TIMEOUT_SECONDS] + + +def test_read_asset_response_retries_transient_os_errors() -> None: + attempts = 0 + + def fake_urlopen(url: str, timeout: float) -> _FakeResponse: + nonlocal attempts + attempts += 1 + assert url == "http://example.test/retry.js" + assert timeout == app_support._ASSET_REQUEST_TIMEOUT_SECONDS + if attempts < 3: + raise TimeoutError("temporary timeout") + return _FakeResponse("ok") + + with patch("tests.app_support.urlopen", side_effect=fake_urlopen): + body, headers = app_support._read_asset_response("http://example.test/retry.js") + + assert body == b"ok" + assert headers == {"Cache-Control": "no-store"} + assert attempts == 3 + + +def test_request_bytes_uses_shared_asset_fetcher() -> None: + with patch( + "tests.app_support._read_asset_response", + return_value=(b"icon", {"Content-Type": "image/x-icon"}), + ) as read_asset_response_mock: + body = app_support.request_bytes("http://example.test/favicon.ico") + + assert body == b"icon" + read_asset_response_mock.assert_called_once_with("http://example.test/favicon.ico") diff --git a/tests/test_browser_smoke.py b/tests/test_browser_smoke.py index 5abb60e..0bdb758 100644 --- a/tests/test_browser_smoke.py +++ b/tests/test_browser_smoke.py @@ -6,12 +6,13 @@ import time from pathlib import Path from typing import Any -from urllib.request import urlopen +from urllib.request import Request, urlopen import pytest from tensor_network_editor.app.server import EditorServer from tensor_network_editor.app.session import EditorSession +from tests.app_support import _session_token_for_url pytestmark = pytest.mark.browser @@ -46,10 +47,22 @@ def _import_playwright_sync_api() -> Any: def _request_json(url: str) -> dict[str, Any]: """Read one JSON response from a local editor server URL.""" - with urlopen(url, timeout=1) as response: + request = Request(url) + session_token = _session_token_for_url(url) + if session_token: + request.add_header("X-TNE-Session-Token", session_token) + with urlopen(request, timeout=1) as response: return json.load(response) +def test_browser_smoke_json_helper_sends_session_token( + editor_server: EditorServer, +) -> None: + payload = _request_json(f"{editor_server.base_url}/api/bootstrap") + + assert payload["app_metadata"]["version"] + + def _wait_for_recoverable_draft_name( draft_url: str, expected_name: str, diff --git a/tests/test_cli.py b/tests/test_cli.py index 6b6276f..f09b447 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -148,6 +148,10 @@ def empty_lint_report(_spec: NetworkSpec) -> LintReport: return LintReport() +def compact_help_text(help_text: str) -> str: + return " ".join(help_text.split()) + + def test_main_requires_a_subcommand(capsys: pytest.CaptureFixture[str]) -> None: with patch("tensor_network_editor.cli.open_editor") as open_editor_mock: exit_code = main([]) @@ -201,6 +205,49 @@ def test_global_python_import_arguments_are_accepted_before_subcommand() -> None assert parsed_args.command == "edit" +def test_top_level_help_includes_command_argument_quick_reference() -> None: + parser = build_command_parser() + + help_text = compact_help_text(parser.format_help()) + + assert "Command argument quick reference:" in help_text + assert ( + "tensor-network-editor export PATH --engine ENGINE [--output FILE]" in help_text + ) + assert "tensor-network-editor template build TEMPLATE_NAME [options]" in help_text + assert "Run 'tensor-network-editor --help'" in help_text + + +def test_export_help_describes_required_arguments( + capsys: pytest.CaptureFixture[str], +) -> None: + exit_code = main(["export", "--help"]) + + assert exit_code == 0 + help_text = capsys.readouterr().out + help_text = compact_help_text(help_text) + assert ( + "Saved JSON design or supported generated Python file to export." in help_text + ) + assert "Backend used for generated Python code." in help_text + assert "Write generated code to a file instead of stdout." in help_text + + +def test_template_build_help_describes_template_options( + capsys: pytest.CaptureFixture[str], +) -> None: + exit_code = main(["template", "build", "--help"]) + + assert exit_code == 0 + help_text = capsys.readouterr().out + help_text = compact_help_text(help_text) + assert "Built-in template name to instantiate." in help_text + assert ( + "Override the graph size parameter when the template supports it." in help_text + ) + assert "Choose text or JSON output." in help_text + + def test_cli_modules_pass_targeted_mypy_check() -> None: result = subprocess.run( [ @@ -253,6 +300,45 @@ def test_edit_subcommand_passes_explicit_log_file_path() -> None: ) +def test_edit_subcommand_accepts_explicit_browser_ui_mode() -> None: + with patch("tensor_network_editor.cli.open_editor") as open_editor_mock: + exit_code = main(["edit", "--ui", "pywebview"]) + + assert exit_code == 0 + open_editor_mock.assert_called_once_with( + spec=None, + options=EditorLaunchOptions( + ui_mode="pywebview", + open_browser=False, + ), + ) + + +def test_edit_subcommand_ui_server_matches_no_browser_alias() -> None: + with patch("tensor_network_editor.cli.open_editor") as open_editor_mock: + exit_code = main(["edit", "--ui", "server"]) + + assert exit_code == 0 + open_editor_mock.assert_called_once_with( + spec=None, + options=EditorLaunchOptions( + ui_mode="server", + open_browser=False, + ), + ) + + +def test_edit_subcommand_rejects_ui_and_no_browser_combination( + capsys: pytest.CaptureFixture[str], +) -> None: + with patch("tensor_network_editor.cli.open_editor") as open_editor_mock: + exit_code = main(["edit", "--ui", "browser", "--no-browser"]) + + assert exit_code == 2 + open_editor_mock.assert_not_called() + assert "cannot combine --ui with --no-browser" in capsys.readouterr().err + + def test_edit_subcommand_passes_explicit_log_rotation_settings() -> None: with patch("tensor_network_editor.cli.open_editor") as open_editor_mock: exit_code = main( diff --git a/tests/test_frontend_architecture.py b/tests/test_frontend_architecture.py index 3c4f190..94a2d0f 100644 --- a/tests/test_frontend_architecture.py +++ b/tests/test_frontend_architecture.py @@ -1197,6 +1197,7 @@ def test_editor_services_route_session_requests_through_explicit_dependencies( await sessionService.generateCode({{ engine: "quimb", collectionFormat: "dict", + includeRoundtripMetadata: true, spec: {{ schema_version: 4, network: {{ id: "network_demo" }} }}, }}); await sessionService.renderSpec({{ @@ -1218,6 +1219,9 @@ def test_editor_services_route_session_requests_through_explicit_dependencies( if (calls[1].payload.collection_format !== "dict") {{ throw new Error(`Expected collection_format=dict, received ${{calls[1].payload.collection_format}}.`); }} + if (calls[1].payload.include_roundtrip_metadata !== true) {{ + throw new Error(`Expected include_roundtrip_metadata=true, received ${{calls[1].payload.include_roundtrip_metadata}}.`); + }} if (calls[2].path !== "/api/render") {{ throw new Error(`Unexpected render path: ${{calls[2].path}}`); }} @@ -3817,11 +3821,18 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( state.selectedEngine = engine; storeCalls.push({{ step: "setSelectedEngine", engine }}); }}, - setSelectedCollectionFormat(collectionFormat) {{ - state.selectedCollectionFormat = collectionFormat; - storeCalls.push({{ step: "setSelectedCollectionFormat", collectionFormat }}); - }}, - }}; + setSelectedCollectionFormat(collectionFormat) {{ + state.selectedCollectionFormat = collectionFormat; + storeCalls.push({{ step: "setSelectedCollectionFormat", collectionFormat }}); + }}, + setIncludeRoundtripMetadata(includeRoundtripMetadata) {{ + state.includeRoundtripMetadata = Boolean(includeRoundtripMetadata); + storeCalls.push({{ + step: "setIncludeRoundtripMetadata", + includeRoundtripMetadata: state.includeRoundtripMetadata, + }}); + }}, + }}; const flowEvents = []; const bootstrapFlow = bootstrapFlowModule.createEditorBootstrapFlow({{ state, @@ -3921,8 +3932,8 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( windowRef: {{ innerWidth: 800, innerHeight: 600, - addEventListener(type, handler) {{ - windowListeners.push(type); + addEventListener(type, handler, options) {{ + windowListeners.push({{ type, options }}); }}, }}, }}); @@ -4007,12 +4018,29 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( generatedCodeModal: getButton("generated-code-modal"), generatedCodeModalBackdrop: getButton("generated-code-modal-backdrop"), generatedCodeModalCloseButton: getButton("generated-code-modal-close-button"), + codegenRoundtripMetadataField: getButton("codegen-roundtrip-metadata-field"), + codegenRoundtripMetadataCheckbox: {{ + checked: false, + listeners: {{}}, + addEventListener(type, handler) {{ this.listeners[type] = handler; }}, + change(event) {{ + this.checked = Boolean(event?.target?.checked); + this.listeners.change?.(event); + }}, + }}, templateSelectField: getButton("template-select-field"), engineSelectField: getButton("engine-select-field"), collectionFormatSelectField: getButton("collection-format-select-field"), templateSelect: {{ value: "mps", - addEventListener(type, handler) {{ this[type] = handler; }}, + listeners: {{}}, + addEventListener(type, handler) {{ this.listeners[type] = handler; }}, + mousedown(event) {{ this.listeners.mousedown?.(event); }}, + change(event) {{ this.listeners.change?.(event); }}, + blur() {{ + flowEvents.push("templateSelect.blur"); + this.listeners.blur?.({{ target: this }}); + }}, }}, templateSettingsButton: getButton("template-settings-button"), templateSettingsPopover: getButton("template-settings-popover"), @@ -4026,11 +4054,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( editSessionTemplateMenuItem: getButton("edit-session-template-menu-item"), openSubnetworkLibraryMenuItem: getButton("open-subnetwork-library-menu-item"), reflowImportedButton: getButton("reflow-imported-button"), - reflowAlignLeftButton: getButton("reflow-align-left-button"), - reflowAlignRightButton: getButton("reflow-align-right-button"), - reflowAlignTopButton: getButton("reflow-align-top-button"), - reflowAlignMiddleButton: getButton("reflow-align-middle-button"), - reflowAlignBottomButton: getButton("reflow-align-bottom-button"), + reflowAlignHorizontalButton: getButton("reflow-align-horizontal-button"), + reflowAlignVerticalButton: getButton("reflow-align-vertical-button"), + reflowRotateSelectionButton: getButton("reflow-rotate-selection-button"), reflowIndicesLeftButton: getButton("reflow-indices-left-button"), reflowIndicesRightButton: getButton("reflow-indices-right-button"), reflowIndicesTopButton: getButton("reflow-indices-top-button"), @@ -4198,8 +4224,8 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( dom, documentRef: tooltipDocument, windowRef: {{ - addEventListener(type, handler) {{ - windowListeners.push(type); + addEventListener(type, handler, options) {{ + windowListeners.push({{ type, options }}); }}, }}, actions: shellActions, @@ -4211,6 +4237,7 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( getButton("expand-generated-code-button").click(); dom.generatedCodeModalBackdrop.click(); dom.generatedCodeModalCloseButton.click(); + dom.codegenRoundtripMetadataCheckbox.change({{ target: {{ checked: true }} }}); dom.engineSelect.change({{ target: {{ value: "cotengra" }} }}); dom.fileMenuButton.click(); dom.exportSubmenuShell.mouseenter(); @@ -4230,6 +4257,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( dom.templateSettingsButton.click(); dom.reflowImportedButton.click(); dom.reflowAutoLayoutButton.click(); + dom.reflowAlignHorizontalButton.click(); + dom.reflowAlignVerticalButton.click(); + dom.reflowRotateSelectionButton.click(); dom.reflowArrangeGridButton.click(); dom.reflowIndicesResetButton.click(); dom.templateManagerCloseButton.click(); @@ -4249,6 +4279,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( if (dom.templateSelectField.attributes["data-expanded"] !== "false") {{ throw new Error("Expected template select change to collapse the disclosure indicator."); }} + if (!flowEvents.includes("templateSelect.blur")) {{ + throw new Error("Expected template selection changes to blur the dropdown so keyboard shortcuts do not stay trapped in the select."); + }} dom.engineSelect.mousedown({{ target: dom.engineSelect }}); if (dom.engineSelectField.attributes["data-expanded"] !== "true") {{ throw new Error("Expected engine select mouse down to mark the disclosure as expanded."); @@ -4265,6 +4298,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( if (dom.collectionFormatSelectField.attributes["data-expanded"] !== "false") {{ throw new Error("Expected collection format select change to collapse the disclosure indicator."); }} + if (state.includeRoundtripMetadata !== true) {{ + throw new Error(`Expected metadata checkbox changes to update state, received ${{state.includeRoundtripMetadata}}.`); + }} if (!flowEvents.includes("generateCode")) {{ throw new Error(`Expected toolbar generate binding to invoke the injected action, received ${{JSON.stringify(flowEvents)}}.`); }} @@ -4330,6 +4366,15 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( if (!flowEvents.includes("applyReflowLayoutAction:auto")) {{ throw new Error(`Expected the Auto layout action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`); }} + if (!flowEvents.includes("applyReflowLayoutAction:align-horizontal")) {{ + throw new Error(`Expected the horizontal alignment action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`); + }} + if (!flowEvents.includes("applyReflowLayoutAction:align-vertical")) {{ + throw new Error(`Expected the vertical alignment action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`); + }} + if (!flowEvents.includes("applyReflowLayoutAction:rotate-90")) {{ + throw new Error(`Expected the rotate action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`); + }} if (!flowEvents.includes("applyReflowLayoutAction:grid")) {{ throw new Error(`Expected the Reflow popover actions to dispatch the requested layout, received ${{JSON.stringify(flowEvents)}}.`); }} @@ -4415,6 +4460,14 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( ) {{ throw new Error("Expected the Code tab to expose its tooltip description."); }} + const keydownBinding = windowListeners.find( + (entry) => entry && entry.type === "keydown" + ); + if (!keydownBinding || keydownBinding.options !== true) {{ + throw new Error( + `Expected the global keydown shortcut listener to register in capture mode, received ${{JSON.stringify(windowListeners)}}.` + ); + }} """, ) @@ -4705,8 +4758,12 @@ def test_editor_shell_helper_modules_expose_explicit_ui_and_invalidation_adapter throw new Error("Session UI confirm adapter should forward the injected result."); }} await sessionUi.copyText("result = 1"); - sessionUi.downloadText("demo.json", "{{}}", "application/json"); - sessionUi.downloadBlob("demo.py", {{ type: "text/x-python" }}); + await Promise.resolve( + sessionUi.downloadText("demo.json", "{{}}", "application/json") + ); + await Promise.resolve( + sessionUi.downloadBlob("demo.py", {{ type: "text/x-python" }}) + ); sessionUi.closeWindow(); if (!uiEvents.some((event) => event.kind === "copy" && event.text === "result = 1")) {{ throw new Error(`Expected injected copy adapter to run, received ${{JSON.stringify(uiEvents)}}.`); @@ -4876,6 +4933,478 @@ def test_editor_shell_helper_modules_expose_explicit_ui_and_invalidation_adapter ) +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_session_ui_adapters_use_pywebview_save_api_when_available( + tmp_path: Path, +) -> None: + script_path = _write_runtime_script( + tmp_path, + "session_ui_pywebview_save.mjs", + f""" + import {{ pathToFileURL }} from "node:url"; + + const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href; + const sessionUiModule = await import(sessionUiUrl); + + const calls = []; + class FakeBlob {{ + constructor(parts, options = {{}}) {{ + this.parts = parts; + this.type = options.type || ""; + }} + + async arrayBuffer() {{ + const firstPart = this.parts[0]; + if (!(firstPart instanceof Uint8Array)) {{ + throw new Error("Expected the test blob to receive Uint8Array content."); + }} + return firstPart.buffer.slice( + firstPart.byteOffset, + firstPart.byteOffset + firstPart.byteLength + ); + }} + }} + const sessionUi = sessionUiModule.createSessionUiAdapters({{ + windowRef: {{ + pywebview: {{ + api: {{ + async save_text_file(filename, text, contentType) {{ + calls.push({{ type: "text", filename, text, contentType }}); + return true; + }}, + async save_binary_file(filename, base64Payload, contentType) {{ + calls.push({{ type: "binary", filename, base64Payload, contentType }}); + return true; + }}, + }}, + }}, + }}, + blobCtor: FakeBlob, + }}); + + const textSaved = await sessionUi.downloadText( + "demo.json", + "{{\\"ok\\":true}}", + "application/json;charset=utf-8" + ); + const binarySaved = await sessionUi.downloadBlob( + "demo.pdf", + new FakeBlob([Uint8Array.from([0, 1, 2, 255])], {{ type: "application/pdf" }}) + ); + + if (textSaved !== true || binarySaved !== true) {{ + throw new Error(`Expected pywebview saves to resolve true, received ${{JSON.stringify({{ textSaved, binarySaved }})}}.`); + }} + const textCall = calls.find((entry) => entry.type === "text"); + const binaryCall = calls.find((entry) => entry.type === "binary"); + if (!textCall || textCall.filename !== "demo.json") {{ + throw new Error(`Expected text export to use the pywebview API, received ${{JSON.stringify(calls)}}.`); + }} + if ( + !binaryCall || + binaryCall.filename !== "demo.pdf" || + binaryCall.base64Payload !== "AAEC/w==" + ) {{ + throw new Error(`Expected binary export to send base64 bytes through the pywebview API, received ${{JSON.stringify(calls)}}.`); + }} + """, + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The pywebview session-ui adapter runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_session_ui_adapters_detect_pywebview_save_api_added_after_creation( + tmp_path: Path, +) -> None: + script_path = _write_runtime_script( + tmp_path, + "session_ui_pywebview_late_save.mjs", + f""" + import {{ pathToFileURL }} from "node:url"; + + const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href; + const sessionUiModule = await import(sessionUiUrl); + + const windowRef = {{}}; + const uiCalls = []; + const sessionUi = sessionUiModule.createSessionUiAdapters({{ + windowRef, + documentRef: {{ + createElement() {{ + uiCalls.push({{ type: "web-download" }}); + return {{ + click() {{ + uiCalls.push({{ type: "web-download-click" }}); + }}, + }}; + }}, + }}, + urlRef: {{ + createObjectURL() {{ + return "blob:test"; + }}, + revokeObjectURL() {{ + return undefined; + }}, + }}, + blobCtor: class FakeBlob {{ + constructor(parts, options = {{}}) {{ + this.parts = parts; + this.type = options.type || ""; + }} + }}, + }}); + + windowRef.pywebview = {{ + api: {{ + async save_text_file(filename, text, contentType) {{ + uiCalls.push({{ type: "pywebview", filename, text, contentType }}); + return true; + }}, + async save_binary_file() {{ + throw new Error("Unexpected binary save in text export test."); + }}, + }}, + }}; + + await sessionUi.downloadText( + "late.json", + "{{\\"late\\": true}}", + "application/json;charset=utf-8" + ); + + if (!uiCalls.some((entry) => entry.type === "pywebview")) {{ + throw new Error(`Expected late pywebview injection to be honored, received ${{JSON.stringify(uiCalls)}}.`); + }} + if (uiCalls.some((entry) => entry.type === "web-download")) {{ + throw new Error(`Expected pywebview save path instead of web download fallback, received ${{JSON.stringify(uiCalls)}}.`); + }} + """, + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The late pywebview session-ui adapter runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_session_ui_adapters_use_partial_pywebview_text_api_when_available( + tmp_path: Path, +) -> None: + script_path = _write_runtime_script( + tmp_path, + "session_ui_pywebview_partial_text_save.mjs", + f""" + import {{ pathToFileURL }} from "node:url"; + + const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href; + const sessionUiModule = await import(sessionUiUrl); + + const calls = []; + const sessionUi = sessionUiModule.createSessionUiAdapters({{ + windowRef: {{ + pywebview: {{ + api: {{ + async save_text_file(filename, text, contentType) {{ + calls.push({{ type: "text", filename, text, contentType }}); + return true; + }}, + }}, + }}, + }}, + documentRef: {{ + createElement() {{ + calls.push({{ type: "web-download" }}); + return {{ + click() {{ + calls.push({{ type: "web-download-click" }}); + }}, + }}; + }}, + }}, + urlRef: {{ + createObjectURL() {{ + calls.push({{ type: "object-url" }}); + return "blob:test"; + }}, + revokeObjectURL() {{ + return undefined; + }}, + }}, + blobCtor: class FakeBlob {{ + constructor(parts, options = {{}}) {{ + this.parts = parts; + this.type = options.type || ""; + }} + }}, + }}); + + const saved = await sessionUi.downloadText( + "partial.json", + "{{\\"partial\\": true}}", + "application/json;charset=utf-8" + ); + + if (saved !== true) {{ + throw new Error(`Expected the partial pywebview text save to resolve true, received ${{saved}}.`); + }} + if (!calls.some((entry) => entry.type === "text")) {{ + throw new Error(`Expected downloadText() to use save_text_file(), received ${{JSON.stringify(calls)}}.`); + }} + if (calls.some((entry) => entry.type === "web-download" || entry.type === "object-url")) {{ + throw new Error(`Expected no web-download fallback when save_text_file() exists, received ${{JSON.stringify(calls)}}.`); + }} + """, + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The partial pywebview text-save runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_session_ui_adapters_use_partial_pywebview_binary_api_when_available( + tmp_path: Path, +) -> None: + script_path = _write_runtime_script( + tmp_path, + "session_ui_pywebview_partial_binary_save.mjs", + f""" + import {{ pathToFileURL }} from "node:url"; + + const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href; + const sessionUiModule = await import(sessionUiUrl); + + const calls = []; + class FakeBlob {{ + constructor(parts, options = {{}}) {{ + this.parts = parts; + this.type = options.type || ""; + }} + + async arrayBuffer() {{ + const firstPart = this.parts[0]; + if (!(firstPart instanceof Uint8Array)) {{ + throw new Error("Expected Uint8Array content in the binary export test blob."); + }} + return firstPart.buffer.slice( + firstPart.byteOffset, + firstPart.byteOffset + firstPart.byteLength + ); + }} + }} + const sessionUi = sessionUiModule.createSessionUiAdapters({{ + windowRef: {{ + pywebview: {{ + api: {{ + async save_binary_file(filename, base64Payload, contentType) {{ + calls.push({{ type: "binary", filename, base64Payload, contentType }}); + return true; + }}, + }}, + }}, + }}, + documentRef: {{ + createElement() {{ + calls.push({{ type: "web-download" }}); + return {{ + click() {{ + calls.push({{ type: "web-download-click" }}); + }}, + }}; + }}, + }}, + urlRef: {{ + createObjectURL() {{ + calls.push({{ type: "object-url" }}); + return "blob:test"; + }}, + revokeObjectURL() {{ + return undefined; + }}, + }}, + blobCtor: FakeBlob, + }}); + + const saved = await sessionUi.downloadBlob( + "partial.pdf", + new FakeBlob([Uint8Array.from([0, 1, 2, 255])], {{ type: "application/pdf" }}) + ); + + if (saved !== true) {{ + throw new Error(`Expected the partial pywebview binary save to resolve true, received ${{saved}}.`); + }} + const binaryCall = calls.find((entry) => entry.type === "binary"); + if (!binaryCall || binaryCall.base64Payload !== "AAEC/w==") {{ + throw new Error(`Expected downloadBlob() to use save_binary_file(), received ${{JSON.stringify(calls)}}.`); + }} + if (calls.some((entry) => entry.type === "web-download" || entry.type === "object-url")) {{ + throw new Error(`Expected no web-download fallback when save_binary_file() exists, received ${{JSON.stringify(calls)}}.`); + }} + """, + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The partial pywebview binary-save runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_start_editor_bootstraps_immediately_when_dom_is_already_ready( + tmp_path: Path, +) -> None: + bootstrap_source = ( + REPO_ROOT + / "src" + / "tensor_network_editor" + / "app" + / "static" + / "js" + / "bootstrap.js" + ).read_text(encoding="utf-8") + (tmp_path / "shell").mkdir(parents=True, exist_ok=True) + (tmp_path / "bootstrap.js").write_text(bootstrap_source, encoding="utf-8") + (tmp_path / "shell" / "editorBootstrapFlow.js").write_text( + """ + export function createEditorBootstrapFlow() { + return { + async bootstrap() { + globalThis.__bootstrapCalls.push("bootstrap"); + return {}; + }, + }; + } + """, + encoding="utf-8", + ) + (tmp_path / "shell" / "shellActions.js").write_text( + """ + export function createShellActions() { + return { + setStatus(message, level = "info") { + globalThis.__bootstrapCalls.push(`status:${level}:${message}`); + }, + }; + } + """, + encoding="utf-8", + ) + (tmp_path / "shell" / "editorShellBindings.js").write_text( + """ + export function createEditorShellBindings() { + return { + attachToolbarHandlers() { + globalThis.__bootstrapCalls.push("attachToolbarHandlers"); + }, + }; + } + """, + encoding="utf-8", + ) + (tmp_path / "shell" / "shortcutTooltip.js").write_text( + """ + export function createShortcutTooltip() { + return { + attachShortcutTooltipHandlers() {}, + }; + } + """, + encoding="utf-8", + ) + script_path = _write_runtime_script( + tmp_path, + "bootstrap_dom_ready.mjs", + """ + globalThis.__bootstrapCalls = []; + const bootstrapUrl = new URL("./bootstrap.js", import.meta.url).href; + const bootstrapModule = await import(bootstrapUrl); + + const documentRef = { + readyState: "complete", + addEventListener(type, handler) { + globalThis.__bootstrapCalls.push(`listener:${type}`); + this.listener = handler; + }, + }; + const ctx = { + state: {}, + store: {}, + window: { + confirm() { + return false; + }, + }, + document: documentRef, + services: { session: {} }, + logger: null, + constants: { REDO_SHORTCUT_LABEL: "Ctrl+Shift+Z" }, + }; + + bootstrapModule.startEditor(ctx); + await Promise.resolve(); + + if (!globalThis.__bootstrapCalls.includes("attachToolbarHandlers")) { + throw new Error(`Expected toolbar handlers to attach immediately, received ${JSON.stringify(globalThis.__bootstrapCalls)}.`); + } + if (!globalThis.__bootstrapCalls.includes("bootstrap")) { + throw new Error(`Expected bootstrap to run immediately for a ready document, received ${JSON.stringify(globalThis.__bootstrapCalls)}.`); + } + """, + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The bootstrap DOM-ready runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + @pytest.mark.skipif(shutil.which("node") is None, reason="node is required") def test_benchmark_helper_modules_build_comparison_rows_and_history_state( tmp_path: Path, @@ -4992,7 +5521,21 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state( const historyEvents = []; const historyState = {{ - spec: {{ id: "network_demo" }}, + spec: {{ + id: "network_demo", + contraction_plan: {{ + id: "scheme_beta", + name: "Beta", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "tensor_a" }}], + }}, + ], + metadata: {{}}, + }}, + }}, tensorOrder: ["tensor_a"], undoStack: [], redoStack: [], @@ -5014,13 +5557,47 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state( benchmarkSession: {{ enabled: true, activePosition: 2, - originalPlan: {{ id: "original_plan", name: "Original", steps: [], metadata: {{}} }}, + originalPlan: {{ + id: "original_plan", + name: "Original", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "original_tensor" }}], + }}, + ], + metadata: {{}}, + }}, schemes: [ - {{ id: "scheme_alpha", name: "Alpha", steps: [], metadata: {{}} }}, - {{ id: "scheme_beta", name: "Beta", steps: [], metadata: {{}} }}, + {{ + id: "scheme_alpha", + name: "Alpha", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "alpha_tensor" }}], + }}, + ], + metadata: {{}}, + }}, + {{ + id: "scheme_beta", + name: "Beta", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "beta_tensor" }}], + }}, + ], + metadata: {{}}, + }}, ], compareModal: {{ open: true, + tableModel: {{ rows: [{{ scheme_id: "scheme_alpha" }}] }}, rows: [{{ scheme_id: "scheme_alpha" }}], activeRequestId: 7, }}, @@ -5050,19 +5627,80 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state( if (!snapshot.benchmarkSession || snapshot.benchmarkSession.activePosition !== 2) {{ throw new Error(`Expected history snapshots to capture benchmark session state, received ${{JSON.stringify(snapshot)}}.`); }} + if (snapshot.benchmarkSession.compareModal.open || snapshot.benchmarkSession.compareModal.activeRequestId !== 0) {{ + throw new Error(`Expected history snapshots to reset ephemeral benchmark compare state, received ${{JSON.stringify(snapshot.benchmarkSession.compareModal)}}.`); + }} + if (snapshot.benchmarkSession.compareModal.rows.length !== 0 || snapshot.benchmarkSession.compareModal.tableModel !== null) {{ + throw new Error(`Expected history snapshots to strip compare rows and table models, received ${{JSON.stringify(snapshot.benchmarkSession.compareModal)}}.`); + }} + if (snapshot.benchmarkSession.originalPlan.view_snapshots.length !== 0) {{ + throw new Error(`Expected history snapshots to strip original-plan view snapshots, received ${{JSON.stringify(snapshot.benchmarkSession.originalPlan)}}.`); + }} + if (snapshot.benchmarkSession.schemes.some((scheme) => scheme.view_snapshots.length !== 0)) {{ + throw new Error(`Expected history snapshots to strip inactive benchmark view snapshots, received ${{JSON.stringify(snapshot.benchmarkSession.schemes)}}.`); + }} + if (snapshot.spec.contraction_plan.view_snapshots.length !== 1) {{ + throw new Error(`Expected the active scheme view snapshots to stay in the main spec snapshot, received ${{JSON.stringify(snapshot.spec.contraction_plan)}}.`); + }} historySupport.restoreHistorySnapshot({{ - spec: {{ id: "restored_network" }}, + spec: {{ + id: "restored_network", + contraction_plan: {{ + id: "scheme_restored", + name: "Restored", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 1, + operand_layouts: [{{ operand_id: "restored_tensor" }}], + }}, + ], + metadata: {{}}, + }}, + }}, tensorOrder: ["tensor_b"], benchmarkSession: {{ enabled: true, activePosition: 1, - originalPlan: null, - schemes: [{{ id: "scheme_restored", name: "Restored", steps: [], metadata: {{}} }}], + originalPlan: {{ + id: "restored_original", + name: "Restored Original", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "restored_original_tensor" }}], + }}, + ], + metadata: {{}}, + }}, + schemes: [ + {{ + id: "scheme_restored", + name: "Restored", + steps: [], + view_snapshots: [], + metadata: {{}}, + }}, + {{ + id: "scheme_inactive", + name: "Inactive", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "inactive_tensor" }}], + }}, + ], + metadata: {{}}, + }}, + ], compareModal: {{ - open: false, - rows: [], - activeRequestId: 0, + open: true, + tableModel: {{ rows: [{{ scheme_id: "scheme_restored" }}] }}, + rows: [{{ scheme_id: "scheme_restored" }}], + activeRequestId: 9, }}, }}, }}); @@ -5070,6 +5708,21 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state( if (!historyState.benchmarkSession || historyState.benchmarkSession.activePosition !== 1) {{ throw new Error(`Expected history restore to recover benchmark session state, received ${{JSON.stringify(historyState.benchmarkSession)}}.`); }} + if (historyState.benchmarkSession.compareModal.open || historyState.benchmarkSession.compareModal.rows.length !== 0 || historyState.benchmarkSession.compareModal.tableModel !== null) {{ + throw new Error(`Expected history restore to keep benchmark compare state ephemeral, received ${{JSON.stringify(historyState.benchmarkSession.compareModal)}}.`); + }} + if (historyState.benchmarkSession.originalPlan.view_snapshots.length !== 0) {{ + throw new Error(`Expected history restore to keep original-plan snapshots lazy, received ${{JSON.stringify(historyState.benchmarkSession.originalPlan)}}.`); + }} + if (historyState.benchmarkSession.schemes[1].view_snapshots.length !== 0) {{ + throw new Error(`Expected inactive benchmark schemes to stay lightweight after restore, received ${{JSON.stringify(historyState.benchmarkSession.schemes)}}.`); + }} + if (historyState.benchmarkSession.schemes[0] !== historyState.spec.contraction_plan) {{ + throw new Error("Expected history restore to re-link the active benchmark scheme to the restored contraction plan."); + }} + if (historyState.spec.contraction_plan.view_snapshots.length !== 1) {{ + throw new Error(`Expected the restored active scheme to keep its exact view snapshots, received ${{JSON.stringify(historyState.spec.contraction_plan)}}.`); + }} """, ) diff --git a/tests/test_frontend_runtime.py b/tests/test_frontend_runtime.py index 6a3f9dd..24e9a13 100644 --- a/tests/test_frontend_runtime.py +++ b/tests/test_frontend_runtime.py @@ -986,6 +986,198 @@ def test_api_service_logs_request_lifecycle_with_frontend_logger( ) +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_api_service_sends_session_token_header( + tmp_path: Path, +) -> None: + script_path = tmp_path / "api_session_token_header.mjs" + script_path.write_text( + textwrap.dedent( + f""" + import {{ pathToFileURL }} from "node:url"; + + const apiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "services" / "api.js")!r}).href; + const apiModule = await import(apiUrl); + const calls = []; + + function headerValue(headers, name) {{ + if (headers && typeof headers.get === "function") {{ + return headers.get(name); + }} + return headers?.[name] || headers?.[name.toLowerCase()] || null; + }} + + globalThis.fetch = async (path, options = {{}}) => {{ + calls.push({{ path, options }}); + return new Response(JSON.stringify({{ ok: true }}), {{ + status: 200, + headers: {{ "Content-Type": "application/json" }}, + }}); + }}; + + await apiModule.apiGet("/api/bootstrap", {{ + apiToken: "secret-token", + }}); + await apiModule.apiPost("/api/cancel", {{}}, {{ + apiToken: "secret-token", + }}); + + if (calls.length !== 2) {{ + throw new Error(`Expected two calls, received ${{calls.length}}.`); + }} + for (const call of calls) {{ + const token = headerValue(call.options.headers, "X-TNE-Session-Token"); + if (token !== "secret-token") {{ + throw new Error(`Missing session token header: ${{JSON.stringify(call)}}`); + }} + }} + const contentType = headerValue(calls[1].options.headers, "Content-Type"); + if (contentType !== "application/json") {{ + throw new Error(`Missing JSON content type: ${{JSON.stringify(calls[1])}}`); + }} + """ + ), + encoding="utf-8", + ) + + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The api session token header script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_editor_context_passes_runtime_api_token_to_requests( + tmp_path: Path, +) -> None: + script_path = tmp_path / "editor_context_api_token.mjs" + script_path.write_text( + textwrap.dedent( + f""" + import {{ pathToFileURL }} from "node:url"; + + const contextUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "core" / "editorContext.js")!r}).href; + const contextModule = await import(contextUrl); + const calls = []; + const documentRef = {{ + getElementById() {{ + return null; + }}, + querySelector() {{ + return null; + }}, + }}; + + function headerValue(headers, name) {{ + if (headers && typeof headers.get === "function") {{ + return headers.get(name); + }} + return headers?.[name] || headers?.[name.toLowerCase()] || null; + }} + + globalThis.fetch = async (path, options = {{}}) => {{ + calls.push({{ path, options }}); + return new Response(JSON.stringify({{ ok: true }}), {{ + status: 200, + headers: {{ "Content-Type": "application/json" }}, + }}); + }}; + + const ctx = contextModule.createEditorContext({{ + window: {{}}, + document: documentRef, + cytoscape: null, + runtimeConfig: {{ apiToken: "runtime-secret" }}, + }}); + await ctx.apiPost("/api/cancel", {{}}); + + const token = headerValue(calls[0]?.options?.headers, "X-TNE-Session-Token"); + if (token !== "runtime-secret") {{ + throw new Error(`Missing context token header: ${{JSON.stringify(calls)}}`); + }} + """ + ), + encoding="utf-8", + ) + + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The editor context API token script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_runtime_config_reader_normalizes_api_token( + tmp_path: Path, +) -> None: + script_path = tmp_path / "runtime_config_api_token.mjs" + script_path.write_text( + textwrap.dedent( + f""" + import {{ pathToFileURL }} from "node:url"; + + const loggerUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "core" / "frontendLogger.js")!r}).href; + const loggerModule = await import(loggerUrl); + const documentRef = {{ + getElementById(id) {{ + if (id !== "tne-runtime-config") {{ + return null; + }} + return {{ + textContent: JSON.stringify({{ + session_id: "session-1", + api_token: "embedded-token", + frontend_logging: {{ enabled: false }}, + }}), + }}; + }}, + }}; + + const config = loggerModule.readFrontendRuntimeConfig({{ documentRef }}); + if (config.sessionId !== "session-1") {{ + throw new Error(`Unexpected session id: ${{JSON.stringify(config)}}`); + }} + if (config.apiToken !== "embedded-token") {{ + throw new Error(`Unexpected API token: ${{JSON.stringify(config)}}`); + }} + """ + ), + encoding="utf-8", + ) + + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The runtime config API token script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + @pytest.mark.skipif(shutil.which("node") is None, reason="node is required") def test_frontend_logger_persists_batched_logs_without_api_recursion( tmp_path: Path, @@ -7022,33 +7214,631 @@ def _write_port_layering_runtime_regression_script(tmp_path: Path) -> Path: }}, }}); - const model = builder(); - const zIndexFor = (elementId) => model.descriptorsById[elementId].data.zIndex; - const frontTensorZIndex = zIndexFor("tensor_front"); + const model = builder(); + const zIndexFor = (elementId) => model.descriptorsById[elementId].data.zIndex; + const frontTensorZIndex = zIndexFor("tensor_front"); + + if (!(zIndexFor("tensor_back") < zIndexFor("back_open"))) {{ + throw new Error("An open port should still sit above its owning tensor."); + }} + if (!(zIndexFor("back_open") < frontTensorZIndex)) {{ + throw new Error( + `An open port from a rear tensor should not cover a front tensor: open=${{zIndexFor("back_open")}}, front=${{frontTensorZIndex}}.` + ); + }} + if (!(zIndexFor("back_connected") > frontTensorZIndex)) {{ + throw new Error( + `A connected port should stay above tensors so connections remain visible: connected=${{zIndexFor("back_connected")}}, front=${{frontTensorZIndex}}.` + ); + }} + + state.selectionIds = ["tensor_back"]; + const selectedModel = builder(); + const selectedZIndexFor = (elementId) => + selectedModel.descriptorsById[elementId].data.zIndex; + if (!(selectedZIndexFor("back_open") > selectedZIndexFor("tensor_front"))) {{ + throw new Error( + `A selected tensor should keep its open ports visible above front tensors: open=${{selectedZIndexFor("back_open")}}, front=${{selectedZIndexFor("tensor_front")}}.` + ); + }} + """ + ), + encoding="utf-8", + ) + return script_path + + +def _write_contraction_scene_port_layering_runtime_regression_script( + tmp_path: Path, +) -> Path: + script_path = tmp_path / "contraction_scene_port_layering_runtime_regression.mjs" + geometry_module_path = ( + REPO_ROOT + / "src" + / "tensor_network_editor" + / "app" + / "static" + / "js" + / "utils/utilitiesGeometry.js" + ) + script_path.write_text( + textwrap.dedent( + f""" + import {{ pathToFileURL }} from "node:url"; + + const geometryUrl = pathToFileURL({str(geometry_module_path)!r}).href; + const {{ createUtilityGeometryBindings }} = await import(geometryUrl); + + function createFakeElement(id, initialZIndex) {{ + let zIndex = initialZIndex; + return {{ + length: 1, + id() {{ + return id; + }}, + data(name, value) {{ + if (value === undefined) {{ + return name === "zIndex" ? zIndex : undefined; + }} + if (name === "zIndex") {{ + zIndex = value; + }} + return undefined; + }}, + }}; + }} + + const visibleDerivedTensor = {{ + id: "scene-step-ab", + name: "A-B", + position: {{ x: 100, y: 100 }}, + size: {{ width: 160, height: 84 }}, + indices: [ + {{ + id: "scene-step-ab_open", + name: "open", + dimension: 2, + offset: {{ x: 38, y: 0 }}, + metadata: {{}}, + }}, + ], + isDerived: true, + sourceTensorIds: ["tensor_a", "tensor_b"], + metadata: {{}}, + }}; + const visibleFrontTensor = {{ + id: "tensor_front", + name: "Front", + position: {{ x: 130, y: 100 }}, + size: {{ width: 140, height: 84 }}, + indices: [ + {{ + id: "front_open", + name: "front", + dimension: 2, + offset: {{ x: -38, y: 0 }}, + metadata: {{}}, + }}, + ], + isDerived: false, + sourceTensorIds: ["tensor_front"], + metadata: {{}}, + }}; + + const elementMap = new Map([ + ["scene-step-ab", createFakeElement("scene-step-ab", 10)], + ["scene-step-ab_open", createFakeElement("scene-step-ab_open", 10.2)], + ["tensor_front", createFakeElement("tensor_front", 11)], + ["front_open", createFakeElement("front_open", 11.2)], + ]); + + const state = {{ + activeTensorDrag: null, + cy: {{ + getElementById(id) {{ + return elementMap.get(id) || {{ length: 0, data() {{ return undefined; }} }}; + }}, + edges() {{ + return []; + }}, + }}, + pendingIndexId: null, + selectionIds: ["scene-step-ab"], + spec: {{ + tensors: [ + {{ + id: "tensor_a", + indices: [], + position: {{ x: 0, y: 0 }}, + }}, + {{ + id: "tensor_b", + indices: [], + position: {{ x: 0, y: 0 }}, + }}, + {{ + id: "tensor_front", + indices: visibleFrontTensor.indices, + position: visibleFrontTensor.position, + }}, + ], + }}, + tensorOrder: [], + tensorRankById: {{}}, + }}; + const runtime = {{ + asFiniteNumber(value, fallbackValue) {{ + return Number.isFinite(value) ? value : fallbackValue; + }}, + findConnectionByIndexId() {{ + return null; + }}, + findEdgeByIndexId() {{ + return null; + }}, + findHyperedgeByIndexId() {{ + return null; + }}, + findTensorById(tensorId) {{ + return ( + state.spec.tensors.find((tensor) => tensor.id === tensorId) || null + ); + }}, + getVisibleTensors() {{ + return [visibleDerivedTensor, visibleFrontTensor]; + }}, + indexLabelNodeId(indexId) {{ + return `${{indexId}}__label`; + }}, + }}; + + const geometry = createUtilityGeometryBindings({{ + ctx: {{ state }}, + state, + constants: {{ + TENSOR_WIDTH: 140, + TENSOR_HEIGHT: 84, + MIN_TENSOR_WIDTH: 96, + MIN_TENSOR_HEIGHT: 60, + INDEX_RADIUS: 10, + INDEX_PADDING: 6, + }}, + runtime, + }}); + Object.assign(runtime, geometry); + + geometry.applyTensorLayerData(); + + const selectedOpenZIndex = elementMap.get("scene-step-ab_open").data("zIndex"); + const frontTensorZIndex = elementMap.get("tensor_front").data("zIndex"); + + if (!(selectedOpenZIndex > frontTensorZIndex)) {{ + throw new Error( + `A selected derived contraction tensor should keep its open ports visible above front tensors: open=${{selectedOpenZIndex}}, front=${{frontTensorZIndex}}.` + ); + }} + """ + ), + encoding="utf-8", + ) + return script_path + + +def _write_contraction_scene_base_tensor_port_layering_runtime_regression_script( + tmp_path: Path, +) -> Path: + script_path = ( + tmp_path / "contraction_scene_base_tensor_port_layering_runtime_regression.mjs" + ) + _copy_runtime_bundle( + tmp_path, + { + "state.runtime.mjs": "state/state.js", + "utilities.runtime.mjs": "utils/utilities.js", + "historySelection.runtime.mjs": "graph/historySelection.js", + "contractionScene.runtime.mjs": "graph/contractionScene.js", + }, + _RUNTIME_EDITOR_SUPPORT_MODULES, + ) + script_path.write_text( + textwrap.dedent( + """ + import { pathToFileURL } from "node:url"; + + function createClassList() { + return { + add() {}, + remove() {}, + toggle() {}, + }; + } + + function createButton() { + return { + disabled: false, + classList: createClassList(), + addEventListener() {}, + focus() {}, + }; + } + + function createSpec() { + return { + id: "network_manual_anchor", + name: "manual-anchor", + tensors: [ + { + id: "tensor_a", + name: "A", + position: { x: 120, y: 140 }, + size: { width: 140, height: 84 }, + metadata: {}, + indices: [ + { + id: "tensor_a_left", + name: "left", + dimension: 2, + offset: { x: -38, y: 0 }, + metadata: {}, + }, + { + id: "tensor_a_bond", + name: "bond", + dimension: 3, + offset: { x: 38, y: 0 }, + metadata: {}, + }, + ], + }, + { + id: "tensor_b", + name: "B", + position: { x: 360, y: 220 }, + size: { width: 140, height: 84 }, + metadata: {}, + indices: [ + { + id: "tensor_b_bond", + name: "bond", + dimension: 3, + offset: { x: -38, y: 0 }, + metadata: {}, + }, + { + id: "tensor_b_right", + name: "carry", + dimension: 5, + offset: { x: 38, y: 0 }, + metadata: {}, + }, + ], + }, + { + id: "tensor_c", + name: "C", + position: { x: 620, y: 300 }, + size: { width: 140, height: 84 }, + metadata: {}, + indices: [ + { + id: "tensor_c_left", + name: "carry", + dimension: 5, + offset: { x: -38, y: 0 }, + metadata: {}, + }, + { + id: "tensor_c_right", + name: "right", + dimension: 7, + offset: { x: 38, y: 0 }, + metadata: {}, + }, + ], + }, + ], + groups: [], + edges: [ + { + id: "edge_ab", + name: "bond_ab", + left: { tensor_id: "tensor_a", index_id: "tensor_a_bond" }, + right: { tensor_id: "tensor_b", index_id: "tensor_b_bond" }, + metadata: {}, + }, + { + id: "edge_bc", + name: "bond_bc", + left: { tensor_id: "tensor_b", index_id: "tensor_b_right" }, + right: { tensor_id: "tensor_c", index_id: "tensor_c_left" }, + metadata: {}, + }, + ], + notes: [], + contraction_plan: { + id: "plan_chain", + name: "Chain path", + steps: [ + { + id: "step_ab", + left_operand_id: "tensor_a", + right_operand_id: "tensor_b", + }, + ], + }, + metadata: {}, + }; + } + + function createFakeElement(id) { + let zIndex = null; + const classes = new Set(); + let selected = false; + return { + length: 1, + id() { + return id; + }, + data(name, value) { + if (value === undefined) { + return name === "zIndex" ? zIndex : undefined; + } + if (name === "zIndex") { + zIndex = value; + } + return undefined; + }, + select() { + selected = true; + }, + unselect() { + selected = false; + }, + addClass(className) { + classes.add(className); + }, + removeClass(className) { + classes.delete(className); + }, + hasClass(className) { + return classes.has(className); + }, + isSelected() { + return selected; + }, + position() {}, + selectable() {}, + grabbable() {}, + }; + } + + const baseUrl = new URL("./", import.meta.url); + const [stateModule, utilitiesModule, historyModule, contractionSceneModule] = + await Promise.all([ + import(new URL("./state.runtime.mjs", baseUrl).href), + import(new URL("./utilities.runtime.mjs", baseUrl).href), + import(new URL("./historySelection.runtime.mjs", baseUrl).href), + import(new URL("./contractionScene.runtime.mjs", baseUrl).href), + ]); + + const { createInitialState } = stateModule; + const { registerUtilities } = utilitiesModule; + const { registerHistorySelection } = historyModule; + const { registerContractionScene } = contractionSceneModule; + + const ctx = { + state: createInitialState(), + constants: { + TENSOR_WIDTH: 140, + TENSOR_HEIGHT: 84, + MIN_TENSOR_WIDTH: 96, + MIN_TENSOR_HEIGHT: 60, + INDEX_RADIUS: 10, + INDEX_PADDING: 6, + NOTE_WIDTH: 220, + NOTE_HEIGHT: 120, + NOTE_MIN_WIDTH: 120, + NOTE_MIN_HEIGHT: 90, + HISTORY_LIMIT: 100, + REDO_SHORTCUT_LABEL: "Ctrl+Shift+Z", + DEFAULT_INDEX_SLOTS: [ + { x: -38, y: 0 }, + { x: 38, y: 0 }, + { x: 0, y: -24 }, + { x: 0, y: 24 }, + ], + }, + dom: { + workspace: {}, + statusMessage: { textContent: "", classList: createClassList() }, + propertiesPanel: { innerHTML: "" }, + generatedCode: { value: "" }, + engineSelect: { options: [], value: "tensornetwork" }, + collectionFormatSelect: { options: [], value: "list" }, + exportFormatSelect: { value: "py" }, + addNoteButton: createButton(), + connectButton: { classList: createClassList() }, + loadInput: {}, + undoButton: createButton(), + redoButton: createButton(), + exportButton: createButton(), + toggleLinearPeriodicButton: { classList: createClassList() }, + linearPeriodicPreviousCellButton: createButton(), + linearPeriodicCellLabel: { textContent: "" }, + linearPeriodicNextCellButton: createButton(), + templateSelect: { value: "" }, + templateParameterPanel: { hidden: true }, + templateGraphSizeLabel: { textContent: "" }, + templateGraphSizeInput: { value: "2", min: "1" }, + templateBondDimensionInput: { value: "3", min: "1" }, + templatePhysicalDimensionInput: { value: "2", min: "1" }, + insertTemplateButton: createButton(), + createGroupButton: createButton(), + helpButton: createButton(), + helpModal: { classList: createClassList() }, + helpBackdrop: createButton(), + helpCloseButton: createButton(), + canvasShell: { + getBoundingClientRect() { + return { left: 0, top: 0, width: 1000, height: 800 }; + }, + }, + groupLayer: {}, + resizeLayer: {}, + notesLayer: {}, + selectionBox: {}, + minimapCanvas: {}, + sidebar: {}, + plannerPanel: { + innerHTML: "", + querySelectorAll() { + return []; + }, + }, + generateButton: createButton(), + }, + apiGet: async () => null, + apiPost: async () => null, + window: { + structuredClone: globalThis.structuredClone, + crypto: globalThis.crypto, + setTimeout, + clearTimeout, + confirm: () => true, + }, + document: { + activeElement: null, + createElement() { + return { + value: "", + textContent: "", + selected: false, + appendChild() {}, + click() {}, + }; + }, + getElementById() { + return createButton(); + }, + querySelectorAll() { + return []; + }, + }, + cytoscape: null, + tensorWidth: (tensor) => tensor?.size?.width ?? 140, + tensorHeight: (tensor) => tensor?.size?.height ?? 84, + render: () => {}, + renderOverlayDecorations: () => {}, + renderMinimap: () => {}, + renderPlanner: () => {}, + renderSidebarTabs: () => {}, + refreshContractionAnalysis: () => {}, + syncPendingInteractionClasses: () => {}, + setActiveSidebarTab: () => {}, + updateToolbarState: () => {}, + captureEditableFocus: () => null, + restoreEditableFocus: () => {}, + }; + + registerUtilities(ctx); + registerContractionScene(ctx); + registerHistorySelection(ctx); + + ctx.state.selectedEngine = "tensornetwork"; + ctx.state.selectedCollectionFormat = "list"; + ctx.state.spec = ctx.normalizeSpec(createSpec()); + + const scene = ctx.buildContractionScene(); + if (!scene) { + throw new Error("Expected a contraction scene after the manual step."); + } + const elementMap = new Map(); + scene.tensors.forEach((tensor) => { + elementMap.set(tensor.id, createFakeElement(tensor.id)); + tensor.indices.forEach((index) => { + elementMap.set(index.id, createFakeElement(index.id)); + elementMap.set(`${index.id}__label`, createFakeElement(`${index.id}__label`)); + }); + }); + + ctx.state.cy = { + batch(action) { + action(); + }, + getElementById(id) { + return ( + elementMap.get(id) || { + length: 0, + data() { + return undefined; + }, + select() {}, + unselect() {}, + addClass() {}, + removeClass() {}, + position() {}, + selectable() {}, + grabbable() {}, + } + ); + }, + edges() { + return { + forEach() {}, + }; + }, + $(selector) { + if (selector === ":selected") { + return { + forEach(callback) { + elementMap.forEach((element) => { + if (element.isSelected()) { + callback(element); + } + }); + }, + }; + } + if (selector === ".is-selection-highlight") { + return { + forEach(callback) { + elementMap.forEach((element) => { + if (element.hasClass("is-selection-highlight")) { + callback(element); + } + }); + }, + }; + } + return { + forEach() {}, + }; + }, + }; + + ctx.bringTensorToFront("tensor_c"); + ctx.setSelection(["tensor_c"], { primaryId: "tensor_c" }); - if (!(zIndexFor("tensor_back") < zIndexFor("back_open"))) {{ - throw new Error("An open port should still sit above its owning tensor."); - }} - if (!(zIndexFor("back_open") < frontTensorZIndex)) {{ - throw new Error( - `An open port from a rear tensor should not cover a front tensor: open=${{zIndexFor("back_open")}}, front=${{frontTensorZIndex}}.` - ); - }} - if (!(zIndexFor("back_connected") > frontTensorZIndex)) {{ + const selectedBaseTensor = scene.operandMap.tensor_c; + const selectedOpenPort = selectedBaseTensor.indices.find( + (index) => index.name === "right" + ); + const derivedTensor = scene.tensors.find((tensor) => tensor.isDerived); + + if (!ctx.state.tensorOrder.includes(derivedTensor.id)) { throw new Error( - `A connected port should stay above tensors so connections remain visible: connected=${{zIndexFor("back_connected")}}, front=${{frontTensorZIndex}}.` + `Expected tensor layering order to track visible contraction operands, received ${JSON.stringify(ctx.state.tensorOrder)}.` ); - }} - - state.selectionIds = ["tensor_back"]; - const selectedModel = builder(); - const selectedZIndexFor = (elementId) => - selectedModel.descriptorsById[elementId].data.zIndex; - if (!(selectedZIndexFor("back_open") > selectedZIndexFor("tensor_front"))) {{ + } + const selectedOpenPortZIndex = elementMap + .get(selectedOpenPort.id) + .data("zIndex"); + const derivedTensorZIndex = elementMap.get(derivedTensor.id).data("zIndex"); + if (!(selectedOpenPortZIndex > derivedTensorZIndex)) { throw new Error( - `A selected tensor should keep its open ports visible above front tensors: open=${{selectedZIndexFor("back_open")}}, front=${{selectedZIndexFor("tensor_front")}}.` + `A selected base tensor in contraction view should keep its free port visible above derived front tensors: open=${selectedOpenPortZIndex}, derived=${derivedTensorZIndex}.` ); - }} + } """ ), encoding="utf-8", @@ -9574,6 +10364,15 @@ def _write_metadata_properties_runtime_regression_script(tmp_path: Path) -> Path if (!propertiesPanel.innerHTML.includes('id="add-index-to-selection-button"')) { throw new Error("Mixed selections should keep the bulk Add index action when editable tensors remain."); } + if (/id="extract-selection-button"[^>]*disabled/.test(propertiesPanel.innerHTML)) { + throw new Error("Mixed selections with editable tensors should keep Extract enabled."); + } + if (/id="save-selection-subnetwork-library-button"[^>]*disabled/.test(propertiesPanel.innerHTML)) { + throw new Error("Mixed selections with editable tensors should keep To Library enabled."); + } + if (/id="promote-selection-template-button"[^>]*disabled/.test(propertiesPanel.innerHTML)) { + throw new Error("Mixed selections with editable tensors should keep To Template enabled."); + } document.getElementById("add-index-to-selection-button").click(); const editableTensorAfter = ctx.state.spec.tensors.find( (candidate) => candidate.id === "tensor_a" @@ -11602,6 +12401,52 @@ def test_graph_model_layers_open_ports_below_front_tensors( ) +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_contraction_scene_selection_keeps_derived_open_ports_visible( + tmp_path: Path, +) -> None: + script_path = _write_contraction_scene_port_layering_runtime_regression_script( + tmp_path + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The contraction-scene port layering runtime regression script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_contraction_scene_selection_keeps_base_tensor_open_ports_visible( + tmp_path: Path, +) -> None: + script_path = ( + _write_contraction_scene_base_tensor_port_layering_runtime_regression_script( + tmp_path + ) + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The contraction-scene base-tensor port layering runtime regression script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + @pytest.mark.skipif(shutil.which("node") is None, reason="node is required") def test_copy_shortcut_prefers_native_text_selection_over_graph_copy( tmp_path: Path, @@ -12120,11 +12965,9 @@ def _write_utility_runtime_contract_script(tmp_path: Path) -> Path: }), parentElement: reflowLayoutShell, }, - reflowAlignLeftButton: createButton(), - reflowAlignRightButton: createButton(), - reflowAlignTopButton: createButton(), - reflowAlignMiddleButton: createButton(), - reflowAlignBottomButton: createButton(), + reflowAlignHorizontalButton: createButton(), + reflowAlignVerticalButton: createButton(), + reflowRotateSelectionButton: createButton(), reflowIndicesLeftButton: createButton(), reflowIndicesRightButton: createButton(), reflowIndicesTopButton: createButton(), @@ -12361,6 +13204,12 @@ def _write_utility_runtime_contract_script(tmp_path: Path) -> Path: if (ctx.dom.reflowAutoLayoutButton.disabled) { throw new Error("Auto layout should stay enabled when the whole graph can be arranged."); } + runtime.isLinearPeriodicMode = () => true; + runtime.updateToolbarState(); + if (ctx.dom.templateSettingsButton.disabled) { + throw new Error("Template settings should stay enabled in For mode because they only affect future insertions."); + } + runtime.isLinearPeriodicMode = () => false; ctx.state.selectionIds = ["tensor_a"]; runtime.isBenchmarkMode = () => true; runtime.getBenchmarkSession = () => ({ @@ -13236,6 +14085,7 @@ def _write_interaction_session_dependency_injection_runtime_script( generatedCode: "", selectedEngine: "quimb", selectedCollectionFormat: "dict", + includeRoundtripMetadata: true, templateDefinitions: {}, availableTemplates: [], templateCatalogWarnings: [], @@ -13332,6 +14182,9 @@ def _write_interaction_session_dependency_injection_runtime_script( if (generateCall.payload.engine !== "quimb" || generateCall.payload.collectionFormat !== "dict") { throw new Error(`Unexpected generate payload: ${JSON.stringify(generateCall.payload)}.`); } + if (generateCall.payload.includeRoundtripMetadata !== true) { + throw new Error(`Expected includeRoundtripMetadata=true in the injected generate payload, received ${JSON.stringify(generateCall.payload)}.`); + } if (dom.generatedCode.value.trim() !== "result = 1") { throw new Error(`Expected injected preview sync to receive stripped code, received ${dom.generatedCode.value}.`); } @@ -13391,6 +14244,7 @@ def _write_session_editor_draft_autosave_runtime_script(tmp_path: Path) -> Path: spec: { name: "draft demo" }, generatedCode: "", editorFinished: false, + selectedTheme: "light", draftAutosaveReady: true, draftAutosaveTimer: null, draftAutosaveDirty: false, @@ -13575,8 +14429,7 @@ def _write_session_editor_draft_autosave_runtime_script(tmp_path: Path) -> Path: throw new Error(`Expected draft-save flow logging, received ${JSON.stringify(flowLog)}.`); } - flows.saveDesign(); - await Promise.resolve(); + await flows.saveDesign(); if (!calls.some((entry) => entry.type === "clearDraft")) { throw new Error(`Expected explicit JSON save to clear the draft, received ${JSON.stringify(calls)}.`); } @@ -13649,6 +14502,13 @@ def _write_session_editor_draft_autosave_runtime_script(tmp_path: Path) -> Path: ) { throw new Error(`Academic exports should persist view snapshots, received ${JSON.stringify(calls)}.`); } + if ( + svgRenderCall.payload.theme !== "light" || + pngRenderCall.payload.theme !== "light" || + pdfRenderCall.payload.theme !== "light" + ) { + throw new Error(`SVG/PNG/PDF exports should include the active theme, received ${JSON.stringify(calls)}.`); + } if (!svgDownloadCall || svgDownloadCall.contentType !== "image/svg+xml;charset=utf-8") { throw new Error(`Expected SVG export to download a .svg file, received ${JSON.stringify(calls)}.`); } @@ -13885,6 +14745,130 @@ def _write_session_editor_png_fallback_runtime_script(tmp_path: Path) -> Path: return script_path +def _write_session_editor_save_cancelled_runtime_script(tmp_path: Path) -> Path: + script_path = tmp_path / "session_editor_save_cancelled.mjs" + _copy_js_modules(tmp_path, _SESSION_EDITOR_FLOWS_DEPENDENCY_MODULES) + + script_path.write_text( + textwrap.dedent( + """ + const baseUrl = new URL("./", import.meta.url); + const { createSessionEditorFlows } = await import( + new URL("./session/sessionEditorFlows.js", baseUrl).href + ); + + const calls = []; + const flowLog = []; + const state = { + spec: { name: "draft demo" }, + generatedCode: "", + editorFinished: false, + draftAutosaveReady: true, + draftAutosaveTimer: null, + draftAutosaveDirty: false, + draftAutosaveSaving: false, + }; + + const flows = createSessionEditorFlows({ + dom: { + exportFormatSelect: { value: "json" }, + generatedCode: { value: "" }, + loadInput: { value: "" }, + }, + state, + logger: { + startOperation(name, context = {}) { + flowLog.push({ type: "start", name, context }); + return { + finish(nextContext = {}) { + flowLog.push({ type: "finish", name, context: nextContext }); + }, + fail(error, nextContext = {}) { + flowLog.push({ + type: "fail", + name, + message: error.message, + context: nextContext, + }); + }, + }; + }, + }, + store: { + setGeneratedCode() {}, + setEditorFinished() {}, + }, + selectors: { + getSelectedEngine: () => "quimb", + getSelectedCollectionFormat: () => "dict", + }, + services: { + session: { + async clearDraft() { + calls.push({ type: "clearDraft" }); + return { ok: true }; + }, + }, + }, + commands: { + syncGeneratedCodePreview() {}, + }, + sessionUi: { + async downloadText(filename, text, contentType) { + calls.push({ type: "downloadText", filename, text, contentType }); + return false; + }, + closeWindow() {}, + schedule() { + return 0; + }, + }, + actions: { + serializeCurrentSpec({ persistViewSnapshots }) { + return { + schema_version: 2, + persistViewSnapshots, + network: { id: "network_draft", name: "draft demo" }, + }; + }, + sanitizeFilename: (value) => value.replace(/\\s+/g, "_"), + setStatus(message, level = "info") { + calls.push({ type: "status", message, level }); + }, + }, + }); + + await flows.saveDesign(); + + if (calls.some((entry) => entry.type === "clearDraft")) { + throw new Error(`Cancelling the save dialog should not clear the draft, received ${JSON.stringify(calls)}.`); + } + const cancelStatus = calls.find( + (entry) => + entry.type === "status" && + entry.level === "info" && + entry.message === "Design save cancelled." + ); + if (!cancelStatus) { + throw new Error(`Expected a friendly cancellation status, received ${JSON.stringify(calls)}.`); + } + if ( + !flowLog.some( + (entry) => + entry.type === "finish" && + entry.name === "Save design" && + entry.context.outcome === "cancelled" + ) + ) { + throw new Error(`Expected cancelled save-design flow logging, received ${JSON.stringify(flowLog)}.`); + } + """ + ), + encoding="utf-8", + ) + return script_path + + @pytest.mark.skipif(shutil.which("node") is None, reason="node is required") def test_session_editor_flows_fall_back_to_svg_when_png_render_fails( tmp_path: Path, @@ -13905,6 +14889,26 @@ def test_session_editor_flows_fall_back_to_svg_when_png_render_fails( ) +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_session_editor_flows_report_save_cancelled_without_clearing_draft( + tmp_path: Path, +) -> None: + script_path = _write_session_editor_save_cancelled_runtime_script(tmp_path) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The session-editor save-cancelled runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + def _write_tensor_initializer_parsing_runtime_script(tmp_path: Path) -> Path: script_path = tmp_path / "tensor_initializer_parsing.mjs" _copy_js_modules( @@ -14176,6 +15180,12 @@ def _write_session_editor_live_python_import_runtime_script(tmp_path: Path) -> P if (confirmMessages.length !== 3) { throw new Error(`Expected the Python load flow to ask about live execution every time, received ${JSON.stringify(confirmMessages)}.`); } + if (!confirmMessages.every((message) => message.includes("Only continue for local Python files you trust"))) { + throw new Error(`Expected every live-import prompt to warn about trusted local files, received ${JSON.stringify(confirmMessages)}.`); + } + if (!confirmMessages.every((message) => message.includes("can read and write files"))) { + throw new Error(`Expected every live-import prompt to describe local execution risk, received ${JSON.stringify(confirmMessages)}.`); + } if (promptMessages.length !== 2) { throw new Error(`Expected object-name prompts only for live imports, received ${JSON.stringify(promptMessages)}.`); } @@ -14257,6 +15267,7 @@ def _write_editor_session_service_validate_python_runtime_script( await service.renderSpec({ format: "dot", spec: { schema_version: 2, network: { id: "network_draft" } }, + theme: "light", }); await service.clearDraft(); @@ -14305,6 +15316,9 @@ def _write_editor_session_service_validate_python_runtime_script( if (apiCalls[4].payload.format !== "dot" || apiCalls[4].payload.spec.network.id !== "network_draft") { throw new Error(`Expected renderSpec to keep format and spec payloads, received ${JSON.stringify(apiCalls[4])}.`); } + if (apiCalls[4].payload.theme !== "light") { + throw new Error(`Expected renderSpec to include the current theme, received ${JSON.stringify(apiCalls[4])}.`); + } if (apiCalls[5].path !== "/api/draft/clear" || apiCalls[5].method !== "POST") { throw new Error(`Expected clearDraft to POST /api/draft/clear, received ${JSON.stringify(apiCalls[5])}.`); } @@ -15069,6 +16083,93 @@ def _write_layout_subnetwork_runtime_regression_script(tmp_path: Path) -> Path: } } + ctx.state.selectionIds = ["tensor_a", "tensor_b"]; + ctx.state.primarySelectionId = "tensor_b"; + ctx.state.spec.tensors[0].position = { x: 100, y: 100 }; + ctx.state.spec.tensors[1].position = { x: 260, y: 220 }; + ctx.state.spec.tensors[0].indices[0].offset = { x: 20, y: -10 }; + ctx.state.spec.tensors[1].indices[0].offset = { x: 16, y: -8 }; + ctx.state.spec.tensors[1].indices[1].offset = { x: 20, y: 10 }; + ctx.applyReflowLayoutAction("align-horizontal"); + const horizontalAlignmentYs = ctx.state.spec.tensors + .slice(0, 2) + .map((tensor) => tensor.position.y); + if (!horizontalAlignmentYs.every((value) => value === horizontalAlignmentYs[0])) { + throw new Error( + `Horizontal alignment should align tensor centers on the y axis, received ${horizontalAlignmentYs.join(", ")}.` + ); + } + ctx.applyReflowLayoutAction("align-vertical"); + const verticalAlignmentXs = ctx.state.spec.tensors + .slice(0, 2) + .map((tensor) => tensor.position.x); + if (!verticalAlignmentXs.every((value) => value === verticalAlignmentXs[0])) { + throw new Error( + `Vertical alignment should align tensor centers on the x axis, received ${verticalAlignmentXs.join(", ")}.` + ); + } + + ctx.state.spec.tensors[0].position = { x: 100, y: 100 }; + ctx.state.spec.tensors[1].position = { x: 260, y: 220 }; + ctx.state.spec.tensors[0].indices[0].offset = { x: 20, y: -10 }; + ctx.state.spec.tensors[1].indices[0].offset = { x: 16, y: -8 }; + ctx.state.spec.tensors[1].indices[1].offset = { x: 20, y: 10 }; + ctx.serializeCurrentSpec(); + ctx.applyReflowLayoutAction("rotate-90"); + const tensorARotated = ctx.findTensorById("tensor_a"); + const tensorBRotated = ctx.findTensorById("tensor_b"); + if (tensorARotated.position.x !== 240 || tensorARotated.position.y !== 80) { + throw new Error( + `Rotate 90 should move tensor A clockwise around the selection center, received ${JSON.stringify(tensorARotated.position)}.` + ); + } + if (tensorBRotated.position.x !== 120 || tensorBRotated.position.y !== 240) { + throw new Error( + `Rotate 90 should move tensor B clockwise around the selection center, received ${JSON.stringify(tensorBRotated.position)}.` + ); + } + if (JSON.stringify(tensorARotated.indices[0].offset) !== JSON.stringify({ x: 10, y: 20 })) { + throw new Error( + `Rotate 90 should rotate tensor A ports, received ${JSON.stringify(tensorARotated.indices[0].offset)}.` + ); + } + if (JSON.stringify(tensorBRotated.indices[0].offset) !== JSON.stringify({ x: 8, y: 16 })) { + throw new Error( + `Rotate 90 should rotate tensor B first port, received ${JSON.stringify(tensorBRotated.indices[0].offset)}.` + ); + } + if (JSON.stringify(tensorBRotated.indices[1].offset) !== JSON.stringify({ x: -10, y: 20 })) { + throw new Error( + `Rotate 90 should rotate tensor B second port, received ${JSON.stringify(tensorBRotated.indices[1].offset)}.` + ); + } + if (ctx.state.selectionIds.join(",") !== "tensor_a,tensor_b") { + throw new Error("Rotate 90 should preserve the selected tensors."); + } + const serializedAfterRotate = ctx.serializeCurrentSpec(); + const serializedTensorA = serializedAfterRotate.network.tensors.find( + (tensor) => tensor.id === "tensor_a" + ); + const serializedTensorB = serializedAfterRotate.network.tensors.find( + (tensor) => tensor.id === "tensor_b" + ); + if ( + serializedTensorA.position.x !== 240 || serializedTensorA.position.y !== 80 + ) { + throw new Error( + `serializeCurrentSpec should invalidate its cache after layout changes for tensor A, received ${JSON.stringify(serializedTensorA.position)}.` + ); + } + if ( + serializedTensorB.position.x !== 120 || serializedTensorB.position.y !== 240 + ) { + throw new Error( + `serializeCurrentSpec should invalidate its cache after layout changes for tensor B, received ${JSON.stringify(serializedTensorB.position)}.` + ); + } + + ctx.state.selectionIds = ["tensor_a", "tensor_b", "tensor_c"]; + ctx.state.primarySelectionId = "tensor_c"; ctx.state.spec.tensors[0].position.x = 100; ctx.state.spec.tensors[1].position.x = 260; ctx.state.spec.tensors[2].position.x = 460; diff --git a/tests/test_models_validation.py b/tests/test_models_validation.py index ee76551..183270d 100644 --- a/tests/test_models_validation.py +++ b/tests/test_models_validation.py @@ -836,6 +836,39 @@ def test_validate_spec_accepts_linear_periodic_partial_carry_chain() -> None: assert validate_spec(build_linear_periodic_partial_carry_chain_spec()) == [] +def test_validate_spec_rejects_linear_periodic_previous_step_that_merges_multiple_payload_operands() -> ( + None +): + spec = build_linear_periodic_partial_carry_chain_spec() + assert spec.linear_periodic_chain is not None + periodic_cell = spec.linear_periodic_chain.periodic_cell + assert periodic_cell.contraction_plan is not None + periodic_cell.contraction_plan.steps = [ + ContractionStepSpec( + id="merge_previous_locals", + left_operand_id="periodic_previous_left_tensor", + right_operand_id="periodic_previous_right_tensor", + ), + ContractionStepSpec( + id="consume_previous_payload", + left_operand_id="__linear_previous__", + right_operand_id="merge_previous_locals", + ), + ContractionStepSpec( + id="carry_next_left", + left_operand_id="periodic_next_left_tensor", + right_operand_id="__linear_next__", + ), + ] + + issue = find_issue(validate_spec(spec), "linear-periodic-carry-codegen") + + assert issue.path == ( + "linear_periodic_chain.periodic_cell.contraction_plan.steps.consume_previous_payload" + ) + assert "one previous carry operand per step" in issue.message + + def test_build_carry_validation_context_internal_helper_collects_interface_state() -> ( None ): diff --git a/tests/test_packaging.py b/tests/test_packaging.py index db4bc28..5a9b1ad 100644 --- a/tests/test_packaging.py +++ b/tests/test_packaging.py @@ -2,6 +2,7 @@ import json import os +import re import subprocess import sys import tomllib @@ -103,6 +104,99 @@ def test_project_metadata_declares_required_matplotlib_dependency_and_backend_ex assert "png" not in optional_dependencies +def test_project_metadata_and_ci_enable_dependency_audits() -> None: + pyproject_path = Path.cwd() / "pyproject.toml" + ci_path = Path.cwd() / ".github" / "workflows" / "ci.yml" + + payload = tomllib.loads(pyproject_path.read_text(encoding="utf-8")) + dev_dependencies = payload["project"]["optional-dependencies"]["dev"] + ci_text = ci_path.read_text(encoding="utf-8") + + assert "pip-audit>=2.7" in dev_dependencies + assert "Run dependency security audit" in ci_text + assert "-m pip_audit" in ci_text + + +def test_ci_runs_source_security_lint_and_dependabot_tracks_updates() -> None: + ci_text = (Path.cwd() / ".github" / "workflows" / "ci.yml").read_text( + encoding="utf-8" + ) + dependabot_text = (Path.cwd() / ".github" / "dependabot.yml").read_text( + encoding="utf-8" + ) + + assert "Run source security lint" in ci_text + assert "-m ruff check src --select S" in ci_text + assert 'package-ecosystem: "pip"' in dependabot_text + assert 'package-ecosystem: "github-actions"' in dependabot_text + assert 'directory: "/"' in dependabot_text + + +def test_bundled_prism_version_stays_patched_for_cve_2024_53382() -> None: + third_party_text = (Path.cwd() / "THIRD_PARTY_LICENSES").read_text(encoding="utf-8") + version_match = re.search( + r"2\. PrismJS[\s\S]*?- Version: (\d+)\.(\d+)\.(\d+)", + third_party_text, + ) + + assert version_match is not None + version = tuple(int(part) for part in version_match.groups()) + assert version >= (1, 30, 0) + + +def test_live_python_import_docs_warn_to_use_only_trusted_files() -> None: + expected_warning = "Only use live import with local Python files you trust." + docs_paths = [ + Path.cwd() / "README.md", + Path.cwd() / "docs" / "api.md", + Path.cwd() / "docs" / "cli.md", + Path.cwd() / "docs" / "extended_guide.md", + ] + + for docs_path in docs_paths: + docs_text = docs_path.read_text(encoding="utf-8") + assert expected_warning in docs_text + + +def test_security_policy_documents_private_reporting_and_prism_advisory() -> None: + security_text = (Path.cwd() / "SECURITY.md").read_text(encoding="utf-8") + readme_text = (Path.cwd() / "README.md").read_text(encoding="utf-8") + + assert "GitHub private vulnerability reporting" in security_text + assert "Do not open a public issue with exploit details" in security_text + assert "CVE-2024-53382" in security_text + assert "GHSA-x7hr-w5r2-h6wg" in security_text + assert "Bundled PrismJS before 1.30.0" in security_text + assert "browser-based editor" in security_text + assert ( + "Installing or importing the Python package alone does not execute PrismJS" + in (security_text) + ) + assert "publish the patched release before publishing the advisory" in security_text + assert "Security policy: [SECURITY.md](SECURITY.md)" in readme_text + + +def test_docs_do_not_advertise_removed_png_extra() -> None: + readme_text = (Path.cwd() / "README.md").read_text(encoding="utf-8") + installation_text = (Path.cwd() / "docs" / "installation.md").read_text( + encoding="utf-8" + ) + + assert "tensor-network-editor[png]" not in readme_text + assert "optional `png` extra" not in readme_text + assert "tensor-network-editor[png]" not in installation_text + + +def test_manifest_omits_redundant_non_package_exclusions() -> None: + manifest_text = (Path.cwd() / "MANIFEST.in").read_text(encoding="utf-8") + + assert "docs/images" not in manifest_text + assert "prune tests" not in manifest_text + assert "tests" not in manifest_text + assert "recursive-exclude docs/images *" not in manifest_text + assert "recursive-exclude tests *" not in manifest_text + + def test_third_party_notices_describe_bundled_asset_scope() -> None: third_party_text = (Path.cwd() / "THIRD_PARTY_LICENSES").read_text(encoding="utf-8") readme_text = (Path.cwd() / "README.md").read_text(encoding="utf-8") @@ -115,6 +209,9 @@ def test_third_party_notices_describe_bundled_asset_scope() -> None: assert "Runtime pip-installed dependencies are not bundled" in third_party_text assert "Package: Matplotlib" in third_party_text assert "License: Matplotlib license" in third_party_text + assert "Development dependency notice" in third_party_text + assert "Package: pip-audit" in third_party_text + assert "License: Apache Software License" in third_party_text assert "THIRD_PARTY_LICENSES" in readme_text diff --git a/tests/test_protocol.py b/tests/test_protocol.py index f0daa36..cd64cec 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -88,6 +88,7 @@ def test_parse_codegen_request_uses_defaults_when_optional_fields_are_missing( serialized_spec=serialized_sample_spec, engine=EngineName.EINSUM_TORCH, collection_format=TensorCollectionFormat.DICT, + include_roundtrip_metadata=True, ) @@ -111,6 +112,7 @@ def test_parse_codegen_request_honors_explicit_engine_and_collection_format( serialized_spec=serialized_sample_spec, engine=EngineName.QUIMB, collection_format=TensorCollectionFormat.MATRIX, + include_roundtrip_metadata=True, ) diff --git a/tests/test_rendering.py b/tests/test_rendering.py index c684740..75001cc 100644 --- a/tests/test_rendering.py +++ b/tests/test_rendering.py @@ -1,17 +1,26 @@ from __future__ import annotations import re +from math import hypot from pathlib import Path from typing import Any from xml.etree import ElementTree as ET import pytest -from tensor_network_editor.models import NetworkSpec +from tensor_network_editor.models import ( + CanvasPosition, + EdgeEndpointRef, + EdgeSpec, + IndexSpec, + NetworkSpec, + TensorSpec, +) from tensor_network_editor.rendering import ( DotRenderOptions, SvgRenderOptions, TikzRenderOptions, + _number, _SvgRenderer, render_spec_dot, render_spec_mermaid, @@ -19,7 +28,12 @@ render_spec_svg, render_spec_tikz, ) -from tests.factories import build_sample_spec, build_three_tensor_hyperedge_spec +from tensor_network_editor.templates import TemplateParameters, build_template_spec +from tests.factories import ( + build_sample_spec, + build_three_tensor_hyperedge_spec, + build_three_tensor_spec, +) def _build_colored_parallel_edge_spec() -> NetworkSpec: @@ -107,6 +121,309 @@ def _build_three_parallel_edge_spec() -> NetworkSpec: return spec +def _build_cycle_spec() -> NetworkSpec: + return NetworkSpec( + id="network_cycle", + name="cycle", + tensors=[ + TensorSpec( + id="tensor_a", + name="A", + position=CanvasPosition(x=120.0, y=120.0), + indices=[ + IndexSpec(id="tensor_a_free", name="fa", dimension=2), + IndexSpec(id="tensor_a_ab", name="ab", dimension=3), + IndexSpec(id="tensor_a_da", name="da", dimension=5), + ], + ), + TensorSpec( + id="tensor_b", + name="B", + position=CanvasPosition(x=280.0, y=120.0), + indices=[ + IndexSpec(id="tensor_b_free", name="fb", dimension=2), + IndexSpec(id="tensor_b_ab", name="ab", dimension=3), + IndexSpec(id="tensor_b_bc", name="bc", dimension=7), + ], + ), + TensorSpec( + id="tensor_c", + name="C", + position=CanvasPosition(x=280.0, y=280.0), + indices=[ + IndexSpec(id="tensor_c_free", name="fc", dimension=2), + IndexSpec(id="tensor_c_bc", name="bc", dimension=7), + IndexSpec(id="tensor_c_cd", name="cd", dimension=11), + ], + ), + TensorSpec( + id="tensor_d", + name="D", + position=CanvasPosition(x=120.0, y=280.0), + indices=[ + IndexSpec(id="tensor_d_free", name="fd", dimension=2), + IndexSpec(id="tensor_d_cd", name="cd", dimension=11), + IndexSpec(id="tensor_d_da", name="da", dimension=5), + ], + ), + ], + edges=[ + EdgeSpec( + id="edge_ab", + name="ab", + left=EdgeEndpointRef(tensor_id="tensor_a", index_id="tensor_a_ab"), + right=EdgeEndpointRef(tensor_id="tensor_b", index_id="tensor_b_ab"), + ), + EdgeSpec( + id="edge_bc", + name="bc", + left=EdgeEndpointRef(tensor_id="tensor_b", index_id="tensor_b_bc"), + right=EdgeEndpointRef(tensor_id="tensor_c", index_id="tensor_c_bc"), + ), + EdgeSpec( + id="edge_cd", + name="cd", + left=EdgeEndpointRef(tensor_id="tensor_c", index_id="tensor_c_cd"), + right=EdgeEndpointRef(tensor_id="tensor_d", index_id="tensor_d_cd"), + ), + EdgeSpec( + id="edge_da", + name="da", + left=EdgeEndpointRef(tensor_id="tensor_d", index_id="tensor_d_da"), + right=EdgeEndpointRef(tensor_id="tensor_a", index_id="tensor_a_da"), + ), + ], + ) + + +def _build_grid_export_spec() -> NetworkSpec: + tensors: list[TensorSpec] = [] + edges: list[EdgeSpec] = [] + for row_index in range(3): + for column_index in range(3): + tensor_id = f"tensor_{row_index}_{column_index}" + indices = [ + IndexSpec( + id=f"{tensor_id}_free", + name=f"f_{row_index}_{column_index}", + dimension=2, + ) + ] + if column_index < 2: + indices.append( + IndexSpec( + id=f"{tensor_id}_right", + name=f"h_{row_index}_{column_index}", + dimension=3, + ) + ) + if column_index > 0: + indices.append( + IndexSpec( + id=f"{tensor_id}_left", + name=f"h_{row_index}_{column_index - 1}", + dimension=3, + ) + ) + if row_index < 2: + indices.append( + IndexSpec( + id=f"{tensor_id}_down", + name=f"v_{row_index}_{column_index}", + dimension=5, + ) + ) + if row_index > 0: + indices.append( + IndexSpec( + id=f"{tensor_id}_up", + name=f"v_{row_index - 1}_{column_index}", + dimension=5, + ) + ) + tensors.append( + TensorSpec( + id=tensor_id, + name=f"T{row_index}{column_index}", + position=CanvasPosition( + x=120.0 + 140.0 * column_index, + y=120.0 + 140.0 * row_index, + ), + indices=indices, + ) + ) + for row_index in range(3): + for column_index in range(2): + left_tensor_id = f"tensor_{row_index}_{column_index}" + right_tensor_id = f"tensor_{row_index}_{column_index + 1}" + edge_name = f"h_{row_index}_{column_index}" + edges.append( + EdgeSpec( + id=f"edge_{edge_name}", + name=edge_name, + left=EdgeEndpointRef( + tensor_id=left_tensor_id, + index_id=f"{left_tensor_id}_right", + ), + right=EdgeEndpointRef( + tensor_id=right_tensor_id, + index_id=f"{right_tensor_id}_left", + ), + ) + ) + for row_index in range(2): + for column_index in range(3): + top_tensor_id = f"tensor_{row_index}_{column_index}" + bottom_tensor_id = f"tensor_{row_index + 1}_{column_index}" + edge_name = f"v_{row_index}_{column_index}" + edges.append( + EdgeSpec( + id=f"edge_{edge_name}", + name=edge_name, + left=EdgeEndpointRef( + tensor_id=top_tensor_id, + index_id=f"{top_tensor_id}_down", + ), + right=EdgeEndpointRef( + tensor_id=bottom_tensor_id, + index_id=f"{bottom_tensor_id}_up", + ), + ) + ) + return NetworkSpec( + id="network_grid_export", + name="grid-export", + tensors=tensors, + edges=edges, + ) + + +def _build_vertical_three_tensor_spec() -> NetworkSpec: + spec = build_three_tensor_spec() + spec.tensors[0].position = CanvasPosition(x=240.0, y=80.0) + spec.tensors[1].position = CanvasPosition(x=240.0, y=240.0) + spec.tensors[2].position = CanvasPosition(x=240.0, y=400.0) + return spec + + +def _build_vertical_three_tensor_named_hint_spec() -> NetworkSpec: + spec = _build_vertical_three_tensor_spec() + spec.tensors[0].indices[0].name = "up" + return spec + + +def _build_diagonal_three_tensor_spec() -> NetworkSpec: + spec = build_three_tensor_spec() + spec.tensors[0].position = CanvasPosition(x=80.0, y=80.0) + spec.tensors[1].position = CanvasPosition(x=240.0, y=240.0) + spec.tensors[2].position = CanvasPosition(x=400.0, y=400.0) + return spec + + +def _build_rotated_grid_export_spec() -> NetworkSpec: + spec = _build_grid_export_spec() + center = CanvasPosition(x=240.0, y=240.0) + column_step = CanvasPosition(x=100.0, y=100.0) + row_step = CanvasPosition(x=-100.0, y=100.0) + for tensor in spec.tensors: + _, row_text, column_text = tensor.id.split("_") + row_index = int(row_text) + column_index = int(column_text) + tensor.position = CanvasPosition( + x=center.x + + (column_index - 1) * column_step.x + + (row_index - 1) * row_step.x, + y=center.y + + (column_index - 1) * column_step.y + + (row_index - 1) * row_step.y, + ) + return spec + + +def _build_vertical_mpo_export_spec() -> NetworkSpec: + spec = build_template_spec( + "mpo", + TemplateParameters( + graph_size=4, + bond_dimension=3, + physical_dimension=2, + boundary_condition="open", + j=1.0, + h=1.0, + ), + ) + for tensor_index, tensor in enumerate(spec.tensors): + tensor.position = CanvasPosition(x=240.0, y=80.0 + tensor_index * 160.0) + return spec + + +def _build_generic_export_spec() -> NetworkSpec: + return NetworkSpec( + id="network_generic_export", + name="generic-export", + tensors=[ + TensorSpec( + id="tensor_center", + name="Center", + position=CanvasPosition(x=220.0, y=200.0), + indices=[ + IndexSpec(id="tensor_center_free", name="free", dimension=2), + IndexSpec(id="tensor_center_right", name="r", dimension=3), + IndexSpec(id="tensor_center_down", name="d", dimension=5), + ], + ), + TensorSpec( + id="tensor_right", + name="Right", + position=CanvasPosition(x=360.0, y=180.0), + indices=[ + IndexSpec(id="tensor_right_left", name="r", dimension=3), + ], + ), + TensorSpec( + id="tensor_down", + name="Down", + position=CanvasPosition(x=260.0, y=340.0), + indices=[ + IndexSpec(id="tensor_down_up", name="d", dimension=5), + ], + ), + ], + edges=[ + EdgeSpec( + id="edge_center_right", + name="r", + left=EdgeEndpointRef( + tensor_id="tensor_center", index_id="tensor_center_right" + ), + right=EdgeEndpointRef( + tensor_id="tensor_right", index_id="tensor_right_left" + ), + ), + EdgeSpec( + id="edge_center_down", + name="d", + left=EdgeEndpointRef( + tensor_id="tensor_center", index_id="tensor_center_down" + ), + right=EdgeEndpointRef( + tensor_id="tensor_down", index_id="tensor_down_up" + ), + ), + ], + ) + + +def _dot(left: CanvasPosition, right: CanvasPosition) -> float: + return left.x * right.x + left.y * right.y + + +def _normalize(vector: CanvasPosition) -> CanvasPosition: + magnitude = hypot(vector.x, vector.y) + assert magnitude > 1e-9 + return CanvasPosition(x=vector.x / magnitude, y=vector.y / magnitude) + + def _svg_text_content(svg: str) -> list[str]: root = ET.fromstring(svg) text_nodes = root.findall(".//{http://www.w3.org/2000/svg}text") @@ -166,6 +483,146 @@ def test_academic_svg_and_tikz_exports_use_tensor_circles_and_dangling_ports() - assert r"\draw[tne open index]" in tikz +def test_export_geometry_prefers_perpendicular_free_index_directions_for_linear_chain() -> ( + None +): + spec = build_three_tensor_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + + left_tensor = spec.tensors[0] + left_index = left_tensor.indices[0] + direction = renderer._index_direction(left_tensor, left_index) + source = renderer.connection_point(left_tensor, left_index) + target = renderer.open_index_endpoint(left_tensor, left_index) + + assert abs(direction.x) < 0.25 + assert abs(direction.y) > 0.9 + assert hypot(target.x - source.x, target.y - source.y) == pytest.approx( + 2.0 * renderer.tensor_radius(left_tensor) + ) + + +def test_export_geometry_respects_vertical_linear_chain_orientation() -> None: + spec = _build_vertical_three_tensor_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + + first_tensor = spec.tensors[0] + free_index = first_tensor.indices[0] + direction = renderer._index_direction(first_tensor, free_index) + + assert abs(direction.x) > 0.9 + assert abs(direction.y) < 0.25 + + +def test_export_geometry_prefers_linear_component_orientation_over_named_hints() -> ( + None +): + spec = _build_vertical_three_tensor_named_hint_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + + first_tensor = spec.tensors[0] + free_index = first_tensor.indices[0] + direction = renderer._index_direction(first_tensor, free_index) + + assert abs(direction.x) > 0.9 + assert abs(direction.y) < 0.25 + + +def test_export_geometry_respects_diagonal_linear_chain_orientation() -> None: + spec = _build_diagonal_three_tensor_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + + first_tensor = spec.tensors[0] + free_index = first_tensor.indices[0] + direction = renderer._index_direction(first_tensor, free_index) + chain_axis = _normalize(CanvasPosition(x=1.0, y=1.0)) + diagonal_perpendicular = _normalize(CanvasPosition(x=-1.0, y=1.0)) + + assert abs(_dot(direction, chain_axis)) < 0.25 + assert abs(_dot(direction, diagonal_perpendicular)) > 0.9 + + +def test_export_geometry_prefers_vertical_mpo_component_orientation_over_index_names() -> ( + None +): + spec = _build_vertical_mpo_export_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + first_tensor = spec.tensors[0] + bra_index = next(index for index in first_tensor.indices if index.name == "bra") + ket_index = next(index for index in first_tensor.indices if index.name == "ket") + bra_direction = renderer._index_direction(first_tensor, bra_index) + ket_direction = renderer._index_direction(first_tensor, ket_index) + + assert abs(bra_direction.x) > 0.9 + assert abs(ket_direction.x) > 0.9 + assert abs(bra_direction.y) < 0.25 + assert abs(ket_direction.y) < 0.25 + assert _dot(bra_direction, ket_direction) < -0.85 + + +def test_export_geometry_points_cycle_free_indices_outward() -> None: + spec = _build_cycle_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + cycle_center = CanvasPosition(x=200.0, y=200.0) + + for tensor in spec.tensors: + free_index = tensor.indices[0] + direction = renderer._index_direction(tensor, free_index) + radial = _normalize( + CanvasPosition( + x=tensor.position.x - cycle_center.x, + y=tensor.position.y - cycle_center.y, + ) + ) + assert _dot(direction, radial) > 0.85 + + +def test_export_geometry_points_grid_boundary_free_indices_outward() -> None: + spec = _build_grid_export_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + expectations = { + "tensor_0_1": CanvasPosition(x=0.0, y=-1.0), + "tensor_1_0": CanvasPosition(x=-1.0, y=0.0), + "tensor_1_2": CanvasPosition(x=1.0, y=0.0), + "tensor_2_1": CanvasPosition(x=0.0, y=1.0), + } + + for tensor_id, expected_direction in expectations.items(): + tensor = next(tensor for tensor in spec.tensors if tensor.id == tensor_id) + free_index = tensor.indices[0] + direction = renderer._index_direction(tensor, free_index) + assert _dot(direction, expected_direction) > 0.85 + + +def test_export_geometry_points_rotated_grid_boundary_free_indices_outward() -> None: + spec = _build_rotated_grid_export_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + expectations = { + "tensor_0_1": _normalize(CanvasPosition(x=1.0, y=-1.0)), + "tensor_1_0": _normalize(CanvasPosition(x=-1.0, y=-1.0)), + "tensor_1_2": _normalize(CanvasPosition(x=1.0, y=1.0)), + "tensor_2_1": _normalize(CanvasPosition(x=-1.0, y=1.0)), + } + + for tensor_id, expected_direction in expectations.items(): + tensor = next(tensor for tensor in spec.tensors if tensor.id == tensor_id) + free_index = tensor.indices[0] + direction = renderer._index_direction(tensor, free_index) + assert _dot(direction, expected_direction) > 0.85 + + +def test_export_geometry_generic_free_indices_point_away_from_local_neighbors() -> None: + spec = _build_generic_export_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + center_tensor = spec.tensors[0] + free_index = center_tensor.indices[0] + + direction = renderer._index_direction(center_tensor, free_index) + away_from_neighbors = _normalize(CanvasPosition(x=-180.0, y=-120.0)) + + assert _dot(direction, away_from_neighbors) > 0.75 + + def test_academic_svg_tikz_and_dot_preserve_entity_colors_and_parallel_edges() -> None: spec = _build_colored_parallel_edge_spec() @@ -283,12 +740,20 @@ def test_academic_parallel_edges_curve_far_enough_to_separate_three_bonds() -> N def test_academic_edges_reach_tensor_centers_in_svg_and_tikz() -> None: spec = _assign_demo_index_offsets() - edge_render_infos = _SvgRenderer(spec, SvgRenderOptions())._edge_render_infos() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + edge_render_infos = renderer._edge_render_infos() + bounds = renderer._compute_bounds(edge_render_infos) tikz = render_spec_tikz(spec) assert edge_render_infos[0].source == spec.tensors[0].position assert edge_render_infos[0].target == spec.tensors[1].position - assert "(150, 116) -- (390, 116)" in tikz + expected_segment = ( + f"({_number(edge_render_infos[0].source.x - bounds.x1)}, " + f"{_number(bounds.y2 - edge_render_infos[0].source.y)}) -- " + f"({_number(edge_render_infos[0].target.x - bounds.x1)}, " + f"{_number(bounds.y2 - edge_render_infos[0].target.y)})" + ) + assert expected_segment in tikz def test_academic_svg_renderer_can_hide_tensor_index_and_bond_labels() -> None: @@ -335,6 +800,20 @@ def test_render_spec_svg_writes_output_path(tmp_path: Path) -> None: assert output_path.read_text(encoding="utf-8") == svg +def test_render_spec_svg_omits_solid_background_when_transparent() -> None: + pytest.importorskip("matplotlib") + + svg = render_spec_svg( + build_sample_spec(), + options=SvgRenderOptions( + background="#abcdef", + transparent_background=True, + ), + ) + + assert "#abcdef" not in svg + + def test_render_spec_svg_reuses_edge_geometry_within_one_render( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -361,6 +840,46 @@ def counting_edge_render_infos(self: Any) -> list[Any]: assert edge_render_info_call_count == 1 +def test_render_spec_svg_reuses_component_axis_geometry_within_one_render( + monkeypatch: pytest.MonkeyPatch, +) -> None: + import tensor_network_editor.rendering as rendering_module + + pytest.importorskip("matplotlib") + spec = build_template_spec( + "mps", + TemplateParameters( + graph_size=12, + bond_dimension=3, + physical_dimension=2, + boundary_condition="open", + initial_state="zeros", + ), + ) + component_primary_axis_call_count = 0 + original_component_primary_axis = ( + rendering_module._SvgRenderer._component_primary_axis + ) + + def counting_component_primary_axis( + self: Any, + component_tensors: list[TensorSpec], + ) -> CanvasPosition: + nonlocal component_primary_axis_call_count + component_primary_axis_call_count += 1 + return original_component_primary_axis(self, component_tensors) + + monkeypatch.setattr( + rendering_module._SvgRenderer, + "_component_primary_axis", + counting_component_primary_axis, + ) + + render_spec_svg(spec) + + assert component_primary_axis_call_count == 1 + + def test_render_spec_svg_keeps_labels_as_svg_text_elements() -> None: pytest.importorskip("matplotlib") @@ -549,6 +1068,110 @@ def reject_matplotlib_modules() -> tuple[object, object, object, object]: rendering_module.render_spec_pdf(build_sample_spec()) +def test_load_matplotlib_modules_memoizes_imports( + monkeypatch: pytest.MonkeyPatch, +) -> None: + import tensor_network_editor.rendering as rendering_module + + if hasattr(rendering_module._load_matplotlib_modules, "cache_clear"): + rendering_module._load_matplotlib_modules.cache_clear() + import_call_counts: dict[str, int] = {} + original_import_module = rendering_module.import_module + + def counting_import_module(name: str) -> Any: + import_call_counts[name] = import_call_counts.get(name, 0) + 1 + return original_import_module(name) + + monkeypatch.setattr( + rendering_module, + "import_module", + counting_import_module, + ) + + first_modules = rendering_module._load_matplotlib_modules() + second_modules = rendering_module._load_matplotlib_modules() + + assert second_modules == first_modules + assert import_call_counts == { + "matplotlib": 1, + "matplotlib.pyplot": 1, + "matplotlib.patches": 1, + "matplotlib.path": 1, + } + + +def test_validate_positive_render_scale_normalizes_and_rejects_invalid_values() -> None: + import tensor_network_editor.rendering as rendering_module + + assert rendering_module._validate_positive_render_scale( + 2, + description="PNG render scale", + ) == pytest.approx(2.0) + assert rendering_module._validate_positive_render_scale( + 1.5, + description="TikZ render scale", + ) == pytest.approx(1.5) + + for invalid_scale in (True, 0, -1, float("inf"), float("nan"), "2"): + with pytest.raises( + ValueError, + match="PNG render scale must be a positive finite number.", + ): + rendering_module._validate_positive_render_scale( + invalid_scale, + description="PNG render scale", + ) + + +def test_render_spec_output_validates_renders_and_writes_output( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + import tensor_network_editor.rendering as rendering_module + + spec = build_sample_spec() + validated_spec = build_three_tensor_spec() + output_path = tmp_path / "network.svg" + calls: dict[str, Any] = {} + + def fake_validate(received_spec: NetworkSpec) -> NetworkSpec: + calls["validate"] = received_spec + return validated_spec + + def fake_render( + received_spec: NetworkSpec, + received_options: SvgRenderOptions, + ) -> str: + calls["render"] = (received_spec, received_options) + return "" + + def fake_write( + path: Path, + content: str, + *, + description: str, + ) -> None: + calls["write"] = (path, content, description) + + monkeypatch.setattr(rendering_module, "ensure_valid_spec", fake_validate) + options = SvgRenderOptions(show_tensor_labels=False) + + rendered = rendering_module._render_spec_output( + spec, + format_name="svg", + options=options, + output_path=output_path, + description="SVG network rendering", + render=fake_render, + writer=fake_write, + ) + + assert rendered == "" + assert calls["validate"] is spec + assert calls["render"] == (validated_spec, options) + assert calls["write"] == (output_path, "", "SVG network rendering") + + def test_render_spec_png_returns_png_bytes_and_writes_output_path( tmp_path: Path, ) -> None: @@ -563,6 +1186,19 @@ def test_render_spec_png_returns_png_bytes_and_writes_output_path( assert output_path.read_bytes() == png_bytes +def test_render_spec_png_uses_alpha_channel_when_transparent() -> None: + pytest.importorskip("matplotlib") + from tensor_network_editor.rendering import render_spec_png + + png_bytes = render_spec_png( + build_sample_spec(), + options=SvgRenderOptions(transparent_background=True), + ) + + assert png_bytes[12:16] == b"IHDR" + assert png_bytes[25] == 6 + + def test_render_spec_pdf_returns_pdf_bytes_and_writes_output_path( tmp_path: Path, ) -> None: diff --git a/tests/test_scripts.py b/tests/test_scripts.py index 703dc40..3848faa 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -68,6 +68,9 @@ def seed_generated_artifacts(root: Path) -> None: root / ".coverage", root / ".coverage.unit", root / "coverage.xml", + root / "session.log", + root / "session.log.1", + root / "session.log.7", ] for file_path in files_to_create: file_path.write_text("temporary", encoding="utf-8") @@ -91,6 +94,9 @@ def assert_cleanup_removed_artifacts(root: Path) -> None: root / ".coverage", root / ".coverage.unit", root / "coverage.xml", + root / "session.log", + root / "session.log.1", + root / "session.log.7", ] for path in removed_paths: assert not path.exists() diff --git a/tests/test_session.py b/tests/test_session.py index 0a5ed5c..53dc03c 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -3,6 +3,7 @@ import logging import signal import threading +from base64 import b64encode from collections.abc import Iterator from importlib import import_module from pathlib import Path @@ -15,6 +16,7 @@ from tensor_network_editor.app._protocol import JsonDict from tensor_network_editor.app.session import ( EditorSession, + _PywebviewExportApi, build_blank_network_spec, wait_for_editor_result, ) @@ -628,6 +630,643 @@ class FakeThread: assert "http://127.0.0.1:43210" in captured +def test_launch_editor_session_pywebview_requires_main_thread( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + class FakeThread: + name = "worker" + + monkeypatch.setattr( + session_module.threading, "current_thread", lambda: FakeThread() + ) + + with pytest.raises(RuntimeError, match="pywebview mode must be launched"): + session_module.launch_editor_session(ui_mode="pywebview") + + +def test_launch_editor_session_pywebview_missing_dependency_raises_clear_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + main_thread = FakeMainThread() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr( + session_module, + "_import_pywebview", + lambda: (_ for _ in ()).throw(ModuleNotFoundError("No module named 'webview'")), + ) + + with pytest.raises(RuntimeError, match="tensor-network-editor\\[desktop\\]"): + session_module.launch_editor_session(ui_mode="pywebview") + + +def test_launch_editor_session_pywebview_closes_window_after_completion( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + completed_result = EditorResult( + spec=build_blank_network_spec(), + engine=EngineName.EINSUM_NUMPY, + confirmed=True, + ) + + class FakeEventHook: + def __init__(self) -> None: + self._callbacks: list[object] = [] + + def __iadd__(self, callback: object) -> FakeEventHook: + self._callbacks.append(callback) + return self + + def fire(self) -> None: + for callback in list(self._callbacks): + cast(Any, callback)() + + class FakeWindowEvents: + def __init__(self) -> None: + self.before_show = FakeEventHook() + self.closed = FakeEventHook() + + class FakeWindow: + def __init__(self) -> None: + self.events = FakeWindowEvents() + self.destroy_calls = 0 + + def destroy(self) -> None: + self.destroy_calls += 1 + + class FakePywebview: + def __init__(self) -> None: + self.created_urls: list[str] = [] + self.created_maximized: list[bool] = [] + self.created_js_apis: list[object] = [] + self.window = FakeWindow() + self.start_calls = 0 + + def create_window( + self, + title: str, + url: str, + *, + maximized: bool = False, + js_api: object | None = None, + ) -> FakeWindow: + assert title == "Tensor Network Editor" + self.created_urls.append(url) + self.created_maximized.append(maximized) + self.created_js_apis.append(js_api) + return self.window + + def start(self, callback: object, window: FakeWindow) -> None: + self.start_calls += 1 + self.window.events.before_show.fire() + cast(Any, callback)(window) + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + main_thread = FakeMainThread() + pywebview = FakePywebview() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr(session_module, "_import_pywebview", lambda: pywebview) + monkeypatch.setattr( + session_module, + "wait_for_editor_result", + lambda _session: completed_result, + ) + + result = session_module.launch_editor_session(ui_mode="pywebview") + + assert result is completed_result + assert pywebview.created_urls == ["http://127.0.0.1:43210"] + assert pywebview.created_maximized == [True] + assert len(pywebview.created_js_apis) == 1 + assert isinstance(pywebview.created_js_apis[0], _PywebviewExportApi) + assert pywebview.start_calls == 1 + assert pywebview.window.destroy_calls == 1 + + +def test_launch_editor_session_pywebview_applies_native_icon_before_show( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + completed_result = EditorResult( + spec=build_blank_network_spec(), + engine=EngineName.EINSUM_NUMPY, + confirmed=True, + ) + + class FakeEventHook: + def __init__(self) -> None: + self._callbacks: list[object] = [] + + def __iadd__(self, callback: object) -> FakeEventHook: + self._callbacks.append(callback) + return self + + def fire(self) -> None: + for callback in list(self._callbacks): + cast(Any, callback)() + + class FakeWindowEvents: + def __init__(self) -> None: + self.before_show = FakeEventHook() + self.closed = FakeEventHook() + + class FakeNativeWindow: + Icon = None + + class FakeWindow: + def __init__(self) -> None: + self.events = FakeWindowEvents() + self.native = FakeNativeWindow() + + def destroy(self) -> None: + return None + + class FakePywebview: + def __init__(self) -> None: + self.window = FakeWindow() + + def create_window( + self, + title: str, + url: str, + *, + maximized: bool = False, + js_api: object | None = None, + ) -> FakeWindow: + del title, url, maximized, js_api + return self.window + + def start(self, callback: object, window: FakeWindow) -> None: + self.window.events.before_show.fire() + cast(Any, callback)(window) + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + applied_windows: list[object] = [] + main_thread = FakeMainThread() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr(session_module, "_import_pywebview", lambda: FakePywebview()) + monkeypatch.setattr( + session_module, + "wait_for_editor_result", + lambda _session: completed_result, + ) + monkeypatch.setattr( + session_module, + "_apply_pywebview_native_window_icon", + lambda window: applied_windows.append(window), + ) + + result = session_module.launch_editor_session(ui_mode="pywebview") + + assert result is completed_result + assert len(applied_windows) == 1 + assert isinstance(applied_windows[0], FakeWindow) + + +def test_launch_editor_session_pywebview_applies_native_icon_without_before_show( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + completed_result = EditorResult( + spec=build_blank_network_spec(), + engine=EngineName.EINSUM_NUMPY, + confirmed=True, + ) + + class FakeEventHook: + def __init__(self) -> None: + self._callbacks: list[object] = [] + + def __iadd__(self, callback: object) -> FakeEventHook: + self._callbacks.append(callback) + return self + + def fire(self) -> None: + for callback in list(self._callbacks): + cast(Any, callback)() + + class FakeWindowEvents: + def __init__(self) -> None: + self.closed = FakeEventHook() + + class FakeWindow: + def __init__(self) -> None: + self.events = FakeWindowEvents() + + def destroy(self) -> None: + return None + + class FakePywebview: + def __init__(self) -> None: + self.window = FakeWindow() + + def create_window( + self, + title: str, + url: str, + *, + maximized: bool = False, + js_api: object | None = None, + ) -> FakeWindow: + del title, url, maximized, js_api + return self.window + + def start(self, callback: object, window: FakeWindow) -> None: + cast(Any, callback)(window) + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + applied_windows: list[object] = [] + main_thread = FakeMainThread() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr(session_module, "_import_pywebview", lambda: FakePywebview()) + monkeypatch.setattr( + session_module, + "wait_for_editor_result", + lambda _session: completed_result, + ) + monkeypatch.setattr( + session_module, + "_apply_pywebview_native_window_icon", + lambda window: applied_windows.append(window), + ) + + result = session_module.launch_editor_session(ui_mode="pywebview") + + assert result is completed_result + assert len(applied_windows) == 1 + assert isinstance(applied_windows[0], FakeWindow) + + +def test_launch_editor_session_pywebview_tolerates_missing_closed_event( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + completed_result = EditorResult( + spec=build_blank_network_spec(), + engine=EngineName.EINSUM_NUMPY, + confirmed=True, + ) + + class FakeEventHook: + def __init__(self) -> None: + self._callbacks: list[object] = [] + + def __iadd__(self, callback: object) -> FakeEventHook: + self._callbacks.append(callback) + return self + + def fire(self) -> None: + for callback in list(self._callbacks): + cast(Any, callback)() + + class FakeWindowEvents: + def __init__(self) -> None: + self.before_show = FakeEventHook() + + class FakeWindow: + def __init__(self) -> None: + self.events = FakeWindowEvents() + self.destroy_calls = 0 + + def destroy(self) -> None: + self.destroy_calls += 1 + + class FakePywebview: + def __init__(self) -> None: + self.window = FakeWindow() + + def create_window( + self, + title: str, + url: str, + *, + maximized: bool = False, + js_api: object | None = None, + ) -> FakeWindow: + del title, url, maximized, js_api + return self.window + + def start(self, callback: object, window: FakeWindow) -> None: + self.window.events.before_show.fire() + cast(Any, callback)(window) + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + applied_windows: list[object] = [] + main_thread = FakeMainThread() + pywebview = FakePywebview() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr(session_module, "_import_pywebview", lambda: pywebview) + monkeypatch.setattr( + session_module, + "wait_for_editor_result", + lambda _session: completed_result, + ) + monkeypatch.setattr( + session_module, + "_apply_pywebview_native_window_icon", + lambda window: applied_windows.append(window), + ) + + result = session_module.launch_editor_session(ui_mode="pywebview") + + assert result is completed_result + assert applied_windows == [pywebview.window] + assert pywebview.window.destroy_calls == 1 + + +def test_launch_editor_session_pywebview_window_close_cancels_session( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + class FakeEventHook: + def __init__(self) -> None: + self._callbacks: list[object] = [] + + def __iadd__(self, callback: object) -> FakeEventHook: + self._callbacks.append(callback) + return self + + def fire(self) -> None: + for callback in list(self._callbacks): + cast(Any, callback)() + + class FakeWindowEvents: + def __init__(self) -> None: + self.before_show = FakeEventHook() + self.closed = FakeEventHook() + + class FakeWindow: + def __init__(self) -> None: + self.events = FakeWindowEvents() + + def destroy(self) -> None: + return None + + class FakePywebview: + def __init__(self) -> None: + self.window = FakeWindow() + + def create_window( + self, + title: str, + url: str, + *, + maximized: bool = False, + js_api: object | None = None, + ) -> FakeWindow: + del title, url, maximized, js_api + return self.window + + def start(self, callback: object, window: FakeWindow) -> None: + del callback, window + self.window.events.closed.fire() + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + main_thread = FakeMainThread() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr( + session_module, + "_import_pywebview", + lambda: FakePywebview(), + ) + + result = session_module.launch_editor_session(ui_mode="pywebview") + + assert result is None + + +def test_pywebview_export_api_writes_text_file_to_selected_path( + tmp_path: Path, +) -> None: + output_path = tmp_path / "demo.json" + + class FakePywebview: + SAVE_DIALOG = object() + + class FakeWindow: + def __init__(self) -> None: + self.dialog_calls: list[dict[str, object]] = [] + + def create_file_dialog( + self, + dialog_type: object, + *, + save_filename: str, + file_types: tuple[str, ...], + ) -> tuple[str]: + self.dialog_calls.append( + { + "dialog_type": dialog_type, + "save_filename": save_filename, + "file_types": file_types, + } + ) + return (str(output_path),) + + api = _PywebviewExportApi(FakePywebview()) + window = FakeWindow() + api.bind_window(window) + + saved = api.save_text_file( + "demo.json", + '{\n "ok": true\n}\n', + "application/json;charset=utf-8", + ) + + assert saved is True + assert output_path.read_text(encoding="utf-8") == '{\n "ok": true\n}\n' + assert window.dialog_calls == [ + { + "dialog_type": FakePywebview.SAVE_DIALOG, + "save_filename": "demo.json", + "file_types": ("JSON (*.json)",), + } + ] + + +def test_pywebview_export_api_returns_false_when_save_dialog_is_cancelled( + tmp_path: Path, +) -> None: + output_path = tmp_path / "demo.json" + + class FakePywebview: + SAVE_DIALOG = object() + + class FakeWindow: + def create_file_dialog( + self, + dialog_type: object, + *, + save_filename: str, + file_types: tuple[str, ...], + ) -> tuple[str, ...]: + del dialog_type, save_filename, file_types + return () + + api = _PywebviewExportApi(FakePywebview()) + api.bind_window(FakeWindow()) + + saved = api.save_text_file( + "demo.json", + '{"ok": true}', + "application/json;charset=utf-8", + ) + + assert saved is False + assert output_path.exists() is False + + +def test_pywebview_export_api_writes_binary_file_to_selected_path( + tmp_path: Path, +) -> None: + output_path = tmp_path / "demo.pdf" + binary_payload = b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n" + + class FakePywebview: + SAVE_DIALOG = object() + + class FakeWindow: + def create_file_dialog( + self, + dialog_type: object, + *, + save_filename: str, + file_types: tuple[str, ...], + ) -> tuple[str]: + del dialog_type, save_filename, file_types + return (str(output_path),) + + api = _PywebviewExportApi(FakePywebview()) + api.bind_window(FakeWindow()) + + saved = api.save_binary_file( + "demo.pdf", + b64encode(binary_payload).decode("ascii"), + "application/pdf", + ) + + assert saved is True + assert output_path.read_bytes() == binary_payload + + def test_open_editor_passes_template_catalog_path( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, diff --git a/tests/test_template_catalog_internal.py b/tests/test_template_catalog_internal.py index 1967cb3..e1e31b5 100644 --- a/tests/test_template_catalog_internal.py +++ b/tests/test_template_catalog_internal.py @@ -1,5 +1,7 @@ from __future__ import annotations +import importlib + import pytest from tensor_network_editor.internal.models._model_tensor_data import TensorDataMode @@ -8,6 +10,8 @@ build_template, ) from tensor_network_editor.internal.templates._template_catalog import ( + _reset_template_registry_for_tests, + get_template_builder, get_template_definition, list_template_names, serialize_template_definitions, @@ -108,6 +112,66 @@ def test_template_builders_internal_dispatches_to_specific_builder() -> None: assert len(spec.tensors) == 5 +def test_template_builder_facade_reexports_family_modules() -> None: + try: + linear_module = importlib.import_module( + "tensor_network_editor.internal.templates._template_builders_linear" + ) + grid_module = importlib.import_module( + "tensor_network_editor.internal.templates._template_builders_grid" + ) + tree_module = importlib.import_module( + "tensor_network_editor.internal.templates._template_builders_tree" + ) + except ModuleNotFoundError as exc: + pytest.fail(f"Expected split template-builder modules to exist: {exc}") + + _reset_template_registry_for_tests() + + assert _build_linear_chain_template is linear_module._build_linear_chain_template + assert get_template_builder("mps").__module__ == linear_module.__name__ + assert get_template_builder("peps_2x2").__module__ == grid_module.__name__ + assert get_template_builder("mera").__module__ == tree_module.__name__ + + +def test_template_builder_common_module_exposes_shared_primitives() -> None: + try: + common_module = importlib.import_module( + "tensor_network_editor.internal.templates._template_builders_common" + ) + except ModuleNotFoundError as exc: + pytest.fail(f"Expected shared template-builder primitives module: {exc}") + + left_tensor = common_module._make_tensor( + "tensor_left", + "Left", + 10.0, + 20.0, + [("right", 3, (58.0, 0.0))], + ) + right_tensor = common_module._make_tensor( + "tensor_right", + "Right", + 40.0, + 20.0, + [("left", 3, (-58.0, 0.0))], + ) + edge = common_module._make_edge( + "edge_left_right", + left_tensor, + "right", + right_tensor, + "left", + ) + + assert left_tensor.indices[0].id == "tensor_left_right" + assert right_tensor.indices[0].id == "tensor_right_left" + assert edge.left.tensor_id == "tensor_left" + assert edge.left.index_id == "tensor_left_right" + assert edge.right.tensor_id == "tensor_right" + assert edge.right.index_id == "tensor_right_left" + + def test_linear_chain_template_helper_reuses_catalog_metadata() -> None: spec = _build_linear_chain_template( "mpo",