`
- - `content_type: "text/plain;charset=utf-8"`
-
-## Error Handling
-
-- If the spec is valid, Mermaid export should succeed.
-- Unsupported visual details must degrade gracefully instead of raising.
-- Rendering should only fail for the same categories already used elsewhere,
- such as invalid payloads or invalid specs.
-
-## Testing
-
-Add focused tests for:
-
-- `render_spec_mermaid` basic output for a normal network
-- label toggle behavior for tensor, index, and bond labels
-- open indices and hyperedges
-- group and note emission
-- escaping of special characters
-- API route `/api/render` with `format="mermaid"`
-- CLI `render --format mermaid`
-- editor menu and export selector wiring
-- frontend download flow and output filename extension
-
-## Documentation Updates
-
-- Add Mermaid export to `README.md`.
-- Add Mermaid export to the editor help text if that text enumerates supported
- export formats.
-- Add a short `CHANGELOG.md` entry when implementation lands.
-
-## Rollout Notes
-
-The first version should stay intentionally simple:
-
-- structure first
-- labels second
-- visual fidelity out of scope
-
-This keeps the renderer predictable, testable, and useful for documentation
-without turning Mermaid into a second layout engine.
diff --git a/docs/user-guide.md b/docs/user-guide.md
index 003616e..35c8748 100644
--- a/docs/user-guide.md
+++ b/docs/user-guide.md
@@ -158,11 +158,11 @@ Python:
```python
from tensor_network_editor.editor import EditorLaunchOptions, open_editor
-open_editor(options=EditorLaunchOptions(theme="contrast"))
+open_editor(options=EditorLaunchOptions(theme="contrast", ui_mode="browser"))
```
Available themes are `dark`, `light`, `contrast`, `colorblind`, and `shiny`.
-The choice only affects the browser editor appearance; saved network JSON and
+The choice only affects the editor appearance; saved network JSON and
recoverable drafts keep the same model data.
## Templates
diff --git a/pyproject.toml b/pyproject.toml
index 6ae9a8c..3e5073f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "tensor-network-editor"
dynamic = ["version"]
-description = "Local visual editor for tensor networks: versioned JSON diagrams and Python code for einsum and optional backends."
+description = "Production-ready local visual editor for tensor networks: versioned JSON diagrams and Python code for einsum and optional backends."
readme = "README.md"
requires-python = ">=3.11"
license = "MIT"
@@ -38,7 +38,7 @@ keywords = [
"visualization",
]
classifiers = [
- "Development Status :: 4 - Beta",
+ "Development Status :: 5 - Production/Stable",
"Environment :: Web Environment",
"Intended Audience :: Science/Research",
"Operating System :: OS Independent",
@@ -66,6 +66,7 @@ torch = ["torch>=2.0"]
dev = [
"build>=1.2",
"mypy>=1.10",
+ "pip-audit>=2.7",
"pyright>=1.1",
"pytest>=8.2",
"ruff>=0.6",
diff --git a/scripts/clean.bat b/scripts/clean.bat
index 31fc285..7b46220 100644
--- a/scripts/clean.bat
+++ b/scripts/clean.bat
@@ -22,6 +22,7 @@ call :remove_glob_dirs_warn ".\pytest-cache-files-*"
call :remove_glob_files ".\.coverage"
call :remove_glob_files ".\.coverage.*"
call :remove_glob_files ".\coverage.xml"
+call :remove_glob_files ".\session.log*"
call :remove_dir "__pycache__"
call :remove_named_dirs ".\src" "__pycache__"
diff --git a/scripts/clean.sh b/scripts/clean.sh
index 7d2fffa..0ce26b6 100644
--- a/scripts/clean.sh
+++ b/scripts/clean.sh
@@ -98,6 +98,7 @@ remove_glob_dirs_warn "./pytest-cache-files-*"
remove_file_pattern "./.coverage"
remove_file_pattern "./.coverage.*"
remove_file_pattern "./coverage.xml"
+remove_file_pattern "./session.log*"
remove_dir "__pycache__"
remove_named_dirs "./src" "__pycache__"
diff --git a/src/tensor_network_editor/__init__.py b/src/tensor_network_editor/__init__.py
index a06978e..6ec5d09 100644
--- a/src/tensor_network_editor/__init__.py
+++ b/src/tensor_network_editor/__init__.py
@@ -17,7 +17,7 @@
from .analysis import analyze_contraction, analyze_spec
from .builder import IndexHandle, NetworkBuilder, TensorHandle
from .canonicalization import canonicalize_spec
- from .editor import EditorLaunchOptions, EditorThemeName, open_editor
+ from .editor import EditorLaunchOptions, EditorThemeName, EditorUiMode, open_editor
from .internal.diffing._diffing import diff_specs, semantic_diff_specs
from .io import PythonLoadOptions, load_python_spec, load_spec, save_spec
from .linting import lint_spec
@@ -76,6 +76,7 @@
"EdgeSpec",
"EditorLaunchOptions",
"EditorThemeName",
+ "EditorUiMode",
"EditorResult",
"EngineName",
"DotRenderOptions",
@@ -130,6 +131,7 @@
"EdgeSpec": ".models",
"EditorLaunchOptions": ".editor",
"EditorThemeName": ".editor",
+ "EditorUiMode": ".editor",
"EditorResult": ".models",
"EngineName": ".models",
"DotRenderOptions": ".rendering",
diff --git a/src/tensor_network_editor/_public_codegen.py b/src/tensor_network_editor/_public_codegen.py
index 21f52ff..66f566f 100644
--- a/src/tensor_network_editor/_public_codegen.py
+++ b/src/tensor_network_editor/_public_codegen.py
@@ -27,6 +27,7 @@ def generate_code(
*,
engine: EngineIdentifier,
collection_format: TensorCollectionFormat = TensorCollectionFormat.LIST,
+ include_roundtrip_metadata: bool = True,
output_path: StrPath | None = None,
print_code: bool = False,
external_data_base_path: StrPath | None = None,
@@ -43,7 +44,10 @@ def generate_code(
external_data_base_path=external_data_base_path,
)
result = _generate_code(
- codegen_spec, engine, collection_format=collection_format
+ codegen_spec,
+ engine,
+ collection_format=collection_format,
+ include_roundtrip_metadata=include_roundtrip_metadata,
)
if print_code:
log_branch(LOGGER, "Printing generated code to stdout")
diff --git a/src/tensor_network_editor/_version.py b/src/tensor_network_editor/_version.py
index 3a58b17..efd8da9 100644
--- a/src/tensor_network_editor/_version.py
+++ b/src/tensor_network_editor/_version.py
@@ -4,4 +4,4 @@
from typing import Final
-__version__: Final[str] = "0.4.0"
+__version__: Final[str] = "1.0.1"
diff --git a/src/tensor_network_editor/app/_analysis_services.py b/src/tensor_network_editor/app/_analysis_services.py
index 0822cf4..ad42cf4 100644
--- a/src/tensor_network_editor/app/_analysis_services.py
+++ b/src/tensor_network_editor/app/_analysis_services.py
@@ -15,6 +15,7 @@
from ..internal.analysis._contraction_analysis import _analyze_validated_contraction
from ..internal.analysis._contraction_analysis_types import ContractionAnalysisResult
from ..models import NetworkSpec, ValidationIssue
+from ._limits import enforce_spec_api_limits
LOGGER = logging.getLogger(__name__)
@@ -32,6 +33,7 @@ def analyze_serialized_contraction(
emit_start=False,
) as success_context:
spec = deserialize_spec_fn(serialized_spec)
+ enforce_spec_api_limits(spec)
issues = validate_spec_fn(spec)
if issues:
log_branch(
diff --git a/src/tensor_network_editor/app/_limits.py b/src/tensor_network_editor/app/_limits.py
new file mode 100644
index 0000000..208298f
--- /dev/null
+++ b/src/tensor_network_editor/app/_limits.py
@@ -0,0 +1,180 @@
+"""Complexity limits for local editor API payloads."""
+
+from __future__ import annotations
+
+from collections.abc import Iterator
+from dataclasses import dataclass
+
+from ..internal.models._model_periodic import LinearPeriodicCellSpec
+from ..internal.templates._template_catalog import TemplateParameters
+from ..models import NetworkSpec, TensorSpec
+
+MAX_API_TENSORS = 512
+MAX_API_INDICES = 4096
+MAX_API_CONNECTIONS = 4096
+MAX_API_TENSOR_RANK = 64
+MAX_API_INDEX_DIMENSION = 1_000_000
+MAX_API_TENSOR_CARDINALITY = 10_000_000
+MAX_API_TEMPLATE_LINEAR_GRAPH_SIZE = 512
+MAX_API_TEMPLATE_GRID_SIDE_LENGTH = 32
+MAX_API_TEMPLATE_TREE_DEPTH = 10
+MAX_API_TEMPLATE_DIMENSION = 4096
+_GRID_TEMPLATE_NAMES = frozenset({"peps_2x2", "pepo"})
+_TREE_TEMPLATE_NAMES = frozenset({"mera", "ttn"})
+
+
+@dataclass(slots=True)
+class _SpecComplexity:
+ """Accumulated size information for one editor API payload."""
+
+ tensor_count: int = 0
+ index_count: int = 0
+ connection_count: int = 0
+
+
+def enforce_spec_api_limits(spec: NetworkSpec) -> None:
+ """Reject a network spec that is too expensive for the local HTTP API."""
+ complexity = _SpecComplexity()
+ for tensors, edge_count in _iter_spec_parts(spec):
+ complexity.tensor_count += len(tensors)
+ complexity.connection_count += edge_count
+ for tensor in tensors:
+ _enforce_tensor_api_limits(tensor)
+ complexity.index_count += len(tensor.indices)
+
+ complexity.connection_count += sum(
+ len(hyperedge.endpoints) for hyperedge in spec.hyperedges
+ )
+ _enforce_count_limit(
+ name="tensors",
+ count=complexity.tensor_count,
+ limit=MAX_API_TENSORS,
+ )
+ _enforce_count_limit(
+ name="indices",
+ count=complexity.index_count,
+ limit=MAX_API_INDICES,
+ )
+ _enforce_count_limit(
+ name="connections",
+ count=complexity.connection_count,
+ limit=MAX_API_CONNECTIONS,
+ )
+
+
+def enforce_template_api_limits(
+ template_name: str,
+ parameters: TemplateParameters | None,
+) -> None:
+ """Reject built-in template parameters that would create huge payloads."""
+ if parameters is None:
+ return
+ graph_size_limit = _template_graph_size_limit(template_name)
+ if parameters.graph_size is not None and parameters.graph_size > graph_size_limit:
+ raise ValueError(
+ "Template parameter 'graph_size' "
+ f"is {parameters.graph_size}, above the API limit of {graph_size_limit}."
+ )
+ _enforce_optional_template_dimension(
+ parameters.bond_dimension,
+ field_name="bond_dimension",
+ )
+ _enforce_optional_template_dimension(
+ parameters.physical_dimension,
+ field_name="physical_dimension",
+ )
+
+
+def _iter_spec_parts(spec: NetworkSpec) -> Iterator[tuple[list[TensorSpec], int]]:
+ """Yield tensor and edge collections stored in a spec payload."""
+ yield spec.tensors, len(spec.edges)
+ if spec.linear_periodic_chain is not None:
+ for cell in (
+ spec.linear_periodic_chain.initial_cell,
+ spec.linear_periodic_chain.periodic_cell,
+ spec.linear_periodic_chain.final_cell,
+ ):
+ yield from _iter_cell_parts(cell)
+ if spec.grid_periodic_grid is not None:
+ for cell in (
+ spec.grid_periodic_grid.top_left_cell,
+ spec.grid_periodic_grid.top_cell,
+ spec.grid_periodic_grid.top_right_cell,
+ spec.grid_periodic_grid.left_cell,
+ spec.grid_periodic_grid.center_cell,
+ spec.grid_periodic_grid.right_cell,
+ spec.grid_periodic_grid.bottom_left_cell,
+ spec.grid_periodic_grid.bottom_cell,
+ spec.grid_periodic_grid.bottom_right_cell,
+ ):
+ yield from _iter_cell_parts(cell)
+ if spec.tree_periodic_tree is not None:
+ for cell in (
+ spec.tree_periodic_tree.root_cell,
+ spec.tree_periodic_tree.branch_cell,
+ spec.tree_periodic_tree.leaf_cell,
+ ):
+ yield from _iter_cell_parts(cell)
+
+
+def _iter_cell_parts(
+ cell: LinearPeriodicCellSpec,
+) -> Iterator[tuple[list[TensorSpec], int]]:
+ """Yield tensor and edge collections stored in one periodic cell."""
+ yield cell.tensors, len(cell.edges)
+
+
+def _enforce_tensor_api_limits(tensor: TensorSpec) -> None:
+ """Reject one tensor whose local shape is too expensive."""
+ rank = len(tensor.indices)
+ if rank > MAX_API_TENSOR_RANK:
+ raise ValueError(
+ f"Tensor '{tensor.name}' has rank {rank}, "
+ f"above the API limit of {MAX_API_TENSOR_RANK}."
+ )
+ cardinality = 1
+ for index in tensor.indices:
+ if index.dimension > MAX_API_INDEX_DIMENSION:
+ raise ValueError(
+ f"Index '{index.name}' on tensor '{tensor.name}' has dimension "
+ f"{index.dimension}, above the API limit of {MAX_API_INDEX_DIMENSION}."
+ )
+ if index.dimension > 0:
+ cardinality *= index.dimension
+ if cardinality > MAX_API_TENSOR_CARDINALITY:
+ raise ValueError(
+ f"Tensor '{tensor.name}' spans {cardinality} elements, "
+ f"above the API limit of {MAX_API_TENSOR_CARDINALITY}."
+ )
+
+
+def _enforce_count_limit(*, name: str, count: int, limit: int) -> None:
+ """Reject one aggregate count when it exceeds its API limit."""
+ if count <= limit:
+ return
+ raise ValueError(
+ f"Network contains {count} {name}, above the API limit of {limit}."
+ )
+
+
+def _enforce_optional_template_dimension(
+ value: int | None,
+ *,
+ field_name: str,
+) -> None:
+ """Reject template dimensions that would produce very large tensors."""
+ if value is None or value <= MAX_API_TEMPLATE_DIMENSION:
+ return
+ raise ValueError(
+ f"Template parameter '{field_name}' is {value}, "
+ f"above the API limit of {MAX_API_TEMPLATE_DIMENSION}."
+ )
+
+
+def _template_graph_size_limit(template_name: str) -> int:
+ """Return the graph-size limit appropriate for one template family."""
+ if template_name in _GRID_TEMPLATE_NAMES:
+ return MAX_API_TEMPLATE_GRID_SIDE_LENGTH
+ if template_name in _TREE_TEMPLATE_NAMES:
+ return MAX_API_TEMPLATE_TREE_DEPTH
+ return MAX_API_TEMPLATE_LINEAR_GRAPH_SIZE
diff --git a/src/tensor_network_editor/app/_protocol.py b/src/tensor_network_editor/app/_protocol.py
index 9e281a9..8cb4ef6 100644
--- a/src/tensor_network_editor/app/_protocol.py
+++ b/src/tensor_network_editor/app/_protocol.py
@@ -43,6 +43,7 @@ class CodegenRequest:
serialized_spec: JsonDict
engine: EngineIdentifier
collection_format: TensorCollectionFormat
+ include_roundtrip_metadata: bool
@dataclass(slots=True, frozen=True)
@@ -269,6 +270,9 @@ def parse_codegen_request(
serialized_spec=require_serialized_spec(payload),
engine=resolve_engine(payload, default_engine),
collection_format=resolve_collection_format(payload, default_collection_format),
+ include_roundtrip_metadata=require_boolean(
+ payload, "include_roundtrip_metadata", default=True
+ ),
)
@@ -408,6 +412,16 @@ def bad_request_response(message: str) -> JsonResponse:
return HTTPStatus.BAD_REQUEST, {"ok": False, "message": message}
+def forbidden_response(message: str) -> JsonResponse:
+ """Return a standard forbidden JSON response."""
+ return HTTPStatus.FORBIDDEN, {"ok": False, "message": message}
+
+
+def unsupported_media_type_response(message: str) -> JsonResponse:
+ """Return a standard unsupported-media-type JSON response."""
+ return HTTPStatus.UNSUPPORTED_MEDIA_TYPE, {"ok": False, "message": message}
+
+
def not_found_response() -> JsonResponse:
"""Return a standard not-found JSON response."""
return HTTPStatus.NOT_FOUND, {"ok": False, "message": "Not found."}
diff --git a/src/tensor_network_editor/app/_session_requests.py b/src/tensor_network_editor/app/_session_requests.py
index f1a6ac4..cd20cd0 100644
--- a/src/tensor_network_editor/app/_session_requests.py
+++ b/src/tensor_network_editor/app/_session_requests.py
@@ -17,6 +17,7 @@
EngineIdentifier,
TensorCollectionFormat,
)
+from ._limits import enforce_spec_api_limits
if TYPE_CHECKING:
from .session import EditorSession
@@ -30,6 +31,7 @@ def generate_session_request(
serialized_spec: Mapping[str, object],
engine: EngineIdentifier,
collection_format: TensorCollectionFormat | None = None,
+ include_roundtrip_metadata: bool = True,
) -> CodegenResult:
"""Generate preview code for one editor request."""
with log_operation(
@@ -38,6 +40,7 @@ def generate_session_request(
context={"engine": engine_name_to_text(engine)},
):
spec = deserialize_spec(serialized_spec)
+ enforce_spec_api_limits(spec)
log_branch(
LOGGER,
"Deserialized preview spec",
@@ -47,6 +50,7 @@ def generate_session_request(
spec,
engine,
collection_format=_resolve_collection_format(session, collection_format),
+ include_roundtrip_metadata=include_roundtrip_metadata,
validate=False,
)
@@ -56,6 +60,7 @@ def complete_session_request(
serialized_spec: Mapping[str, object],
engine: EngineIdentifier,
collection_format: TensorCollectionFormat | None = None,
+ include_roundtrip_metadata: bool = True,
) -> EditorResult:
"""Finalize a session request and optionally print or save generated code."""
with log_operation(
@@ -66,6 +71,7 @@ def complete_session_request(
context={"engine": engine_name_to_text(engine)},
):
spec = deserialize_spec(serialized_spec)
+ enforce_spec_api_limits(spec)
log_branch(
LOGGER,
"Deserialized completion spec",
@@ -75,6 +81,7 @@ def complete_session_request(
spec,
engine,
collection_format=_resolve_collection_format(session, collection_format),
+ include_roundtrip_metadata=include_roundtrip_metadata,
validate=False,
)
if session.print_code:
diff --git a/src/tensor_network_editor/app/_subnetwork_library_services.py b/src/tensor_network_editor/app/_subnetwork_library_services.py
index ecc1716..09a3e08 100644
--- a/src/tensor_network_editor/app/_subnetwork_library_services.py
+++ b/src/tensor_network_editor/app/_subnetwork_library_services.py
@@ -14,6 +14,7 @@
)
from ..models import CanvasPosition, NetworkSpec
from ._bootstrap_payloads import build_subnetwork_catalog_payload
+from ._limits import enforce_spec_api_limits
from ._protocol import JsonDict
if TYPE_CHECKING:
@@ -44,7 +45,9 @@ def save_serialized_subnetwork_to_library(
LOGGER, "Reusable subnetwork save", context=context
) as success_context:
spec = deserialize_spec(serialized_spec, validate=False)
+ enforce_spec_api_limits(spec)
saved_spec = extract_subnetwork_spec(spec, tensor_ids=tensor_ids)
+ enforce_spec_api_limits(saved_spec)
session.save_project_subnetwork(
subnetwork_name,
saved_spec,
@@ -138,9 +141,11 @@ def prepare_saved_subnetwork_for_insertion(
context=context,
) as success_context:
spec = session.build_saved_subnetwork(subnetwork_name)
+ enforce_spec_api_limits(spec)
prepared_spec = prepare_subnetwork_for_insertion(
spec,
target_center=target_center,
)
+ enforce_spec_api_limits(prepared_spec)
success_context.update(summarize_spec_counts(prepared_spec))
return prepared_spec
diff --git a/src/tensor_network_editor/app/_subnetwork_services.py b/src/tensor_network_editor/app/_subnetwork_services.py
index 5893417..5007258 100644
--- a/src/tensor_network_editor/app/_subnetwork_services.py
+++ b/src/tensor_network_editor/app/_subnetwork_services.py
@@ -12,6 +12,7 @@
prepare_subnetwork_for_insertion,
)
from ..models import CanvasPosition, NetworkSpec
+from ._limits import enforce_spec_api_limits
LOGGER = logging.getLogger(__name__)
@@ -28,7 +29,9 @@ def extract_serialized_subnetwork(
context={"tensor_id_count": len(tensor_ids)},
):
spec = deserialize_spec(serialized_spec, validate=False)
+ enforce_spec_api_limits(spec)
extracted_spec = extract_subnetwork_spec(spec, tensor_ids=tensor_ids)
+ enforce_spec_api_limits(extracted_spec)
log_branch(
LOGGER,
"Extracted transient reusable subnetwork",
@@ -45,10 +48,12 @@ def prepare_serialized_subnetwork_for_insertion(
"""Deserialize one payload and prepare it for editor insertion."""
with log_operation(LOGGER, "Transient subnetwork insertion preparation"):
spec = deserialize_spec(serialized_spec, validate=False)
+ enforce_spec_api_limits(spec)
prepared_spec = prepare_subnetwork_for_insertion(
spec,
target_center=target_center,
)
+ enforce_spec_api_limits(prepared_spec)
log_branch(
LOGGER,
"Prepared transient subnetwork for insertion",
diff --git a/src/tensor_network_editor/app/_template_services.py b/src/tensor_network_editor/app/_template_services.py
index 924d27c..56b5187 100644
--- a/src/tensor_network_editor/app/_template_services.py
+++ b/src/tensor_network_editor/app/_template_services.py
@@ -16,6 +16,7 @@
parse_template_parameters,
)
from ._bootstrap_payloads import build_template_catalog_payload
+from ._limits import enforce_spec_api_limits, enforce_template_api_limits
from ._protocol import JsonDict
if TYPE_CHECKING:
@@ -36,6 +37,7 @@ def build_template_from_payload(
if session.has_project_template(template_name):
log_branch(LOGGER, "Loading template from the project catalog")
spec = session.build_project_template(template_name)
+ enforce_spec_api_limits(spec)
success_context.update(summarize_spec_counts(spec))
success_context["status"] = "project"
return spec
@@ -43,7 +45,9 @@ def build_template_from_payload(
template_name,
raw_parameters,
)
+ enforce_template_api_limits(template_name, parameters)
spec = build_template_spec(template_name, parameters)
+ enforce_spec_api_limits(spec)
success_context.update(summarize_spec_counts(spec))
success_context["status"] = "global"
return spec
@@ -68,7 +72,9 @@ def promote_serialized_subnetwork_to_template(
LOGGER, "Template promotion", context=context
) as success_context:
spec = deserialize_spec(serialized_spec, validate=False)
+ enforce_spec_api_limits(spec)
promoted_spec = extract_subnetwork_spec(spec, tensor_ids=tensor_ids)
+ enforce_spec_api_limits(promoted_spec)
promoted_spec.name = session.build_project_template_display_name(template_name)
session.save_project_template(
template_name,
diff --git a/src/tensor_network_editor/app/routes.py b/src/tensor_network_editor/app/routes.py
index 69cfbd5..ece031c 100644
--- a/src/tensor_network_editor/app/routes.py
+++ b/src/tensor_network_editor/app/routes.py
@@ -7,8 +7,9 @@
from collections.abc import Callable
from dataclasses import dataclass
from http import HTTPStatus
-from typing import Literal, cast
+from typing import Literal, TypedDict, cast
+from .._themes import DEFAULT_EDITOR_THEME, EditorThemeName, normalize_editor_theme
from ..errors import (
CodeGenerationError,
PackageIOError,
@@ -40,6 +41,7 @@
from ..types import JSONValue
from ..validation import validate_spec
from ._drafts import clear_project_draft, load_project_draft, save_project_draft
+from ._limits import enforce_spec_api_limits
from ._protocol import (
JsonDict,
JsonResponse,
@@ -91,6 +93,20 @@
_MAX_FRONTEND_CLIENT_LOG_EVENTS = 200
_MAX_FRONTEND_CLIENT_LOG_MESSAGE_LENGTH = 400
_MAX_FRONTEND_CLIENT_LOG_CONTEXT_VALUE_LENGTH = 200
+_RenderFormat = Literal["tikz", "dot", "mermaid", "svg", "png", "pdf"]
+
+
+class _ImageExportThemeOverride(TypedDict, total=False):
+ """Theme override fields supported by image render options."""
+
+ background: str
+ edge_stroke: str
+ group_stroke: str
+ hyperedge_stroke: str
+ index_fill: str
+ muted_text_fill: str
+ note_fill: str
+ text_fill: str
@dataclass(slots=True, frozen=True)
@@ -102,6 +118,67 @@ class _FrontendClientLogEvent:
context: dict[str, object]
+@dataclass(slots=True, frozen=True)
+class _RenderLabelOptions:
+ """Shared label-visibility flags for academic render routes."""
+
+ show_tensor_labels: bool
+ show_index_labels: bool
+ show_edge_labels: bool
+
+
+_IMAGE_EXPORT_THEME_OVERRIDES: dict[EditorThemeName, _ImageExportThemeOverride] = {
+ "dark": {
+ "background": "#0b0d12",
+ "edge_stroke": "#7e8aa3",
+ "index_fill": "#d7ae68",
+ "group_stroke": "#8f7cf7",
+ "note_fill": "#252b34",
+ "text_fill": "#f2f5f8",
+ "muted_text_fill": "#c6d3e6",
+ },
+ "light": {
+ "background": "#ffffff",
+ "edge_stroke": "#64748b",
+ "index_fill": "#b45309",
+ "group_stroke": "#6d28d9",
+ "note_fill": "#ffffff",
+ "text_fill": "#172033",
+ "muted_text_fill": "#475569",
+ },
+ "contrast": {
+ "background": "#000000",
+ "edge_stroke": "#ffffff",
+ "index_fill": "#ffff00",
+ "hyperedge_stroke": "#ff5f5f",
+ "group_stroke": "#ff00ff",
+ "note_fill": "#101010",
+ "text_fill": "#ffffff",
+ "muted_text_fill": "#ffffff",
+ },
+ "colorblind": {
+ "background": "#ffffff",
+ "edge_stroke": "#5b5b5b",
+ "index_fill": "#e69f00",
+ "hyperedge_stroke": "#d55e00",
+ "group_stroke": "#cc79a7",
+ "note_fill": "#ffffff",
+ "text_fill": "#202124",
+ "muted_text_fill": "#5b5b5b",
+ },
+ "shiny": {
+ "background": "#070915",
+ "edge_stroke": "#94a3b8",
+ "index_fill": "#facc15",
+ "hyperedge_stroke": "#fb7185",
+ "group_stroke": "#e879f9",
+ "note_fill": "#11152c",
+ "text_fill": "#f8fafc",
+ "muted_text_fill": "#c4b5fd",
+ },
+}
+
+
def _route_context(
session: EditorSession | None,
route: str,
@@ -222,6 +299,15 @@ def handle_validate(session: EditorSession, payload: JsonDict) -> JsonResponse:
level=logging.WARNING,
)
return bad_request_response("Missing 'spec' or 'python_code' payload.")
+ try:
+ enforce_spec_api_limits(spec)
+ except ValueError as exc:
+ log_branch(
+ LOGGER,
+ f"Validation request exceeded API limits: {exc}",
+ level=logging.WARNING,
+ )
+ return bad_request_response(str(exc))
issues = validate_spec(spec)
if issues:
log_branch(
@@ -327,109 +413,24 @@ def handle_render(session: EditorSession, payload: JsonDict) -> JsonResponse:
"""Render the current editor payload to an academic text format."""
del session
with log_operation(
- LOGGER, "Render route", context={"route": "/api/render"}
+ LOGGER, "Render route", context=_route_context(None, "/api/render")
) as success_context:
try:
render_format = _resolve_render_format(payload)
serialized_spec = require_serialized_spec(payload)
spec = deserialize_spec(serialized_spec, validate=False)
+ enforce_spec_api_limits(spec)
+ label_options = _resolve_render_label_options(payload)
+ render_theme = _resolve_render_theme(payload)
success_context["format"] = render_format
+ success_context["theme"] = render_theme
success_context.update(summarize_spec_counts(spec))
- svg_options = SvgRenderOptions(
- show_tensor_labels=require_boolean(
- payload, "show_tensor_names", default=True
- ),
- show_index_labels=require_boolean(
- payload, "show_index_names", default=True
- ),
- show_edge_labels=require_boolean(
- payload, "show_bond_names", default=True
- ),
+ response_payload = _build_render_response(
+ render_format,
+ spec,
+ label_options,
+ theme=render_theme,
)
- if render_format == "tikz":
- text = render_spec_tikz(
- spec,
- options=TikzRenderOptions(
- show_tensor_labels=require_boolean(
- payload, "show_tensor_names", default=True
- ),
- show_index_labels=require_boolean(
- payload, "show_index_names", default=True
- ),
- show_edge_labels=require_boolean(
- payload, "show_bond_names", default=True
- ),
- ),
- )
- content_type = "text/x-tex;charset=utf-8"
- response_payload: JsonDict = {
- "format": render_format,
- "text": text,
- "content_type": content_type,
- }
- elif render_format == "dot":
- text = render_spec_dot(
- spec,
- options=DotRenderOptions(
- show_tensor_labels=require_boolean(
- payload, "show_tensor_names", default=True
- ),
- show_index_labels=require_boolean(
- payload, "show_index_names", default=True
- ),
- show_edge_labels=require_boolean(
- payload, "show_bond_names", default=True
- ),
- ),
- )
- content_type = "text/vnd.graphviz;charset=utf-8"
- response_payload = {
- "format": render_format,
- "text": text,
- "content_type": content_type,
- }
- elif render_format == "mermaid":
- text = render_spec_mermaid(
- spec,
- options=DotRenderOptions(
- show_tensor_labels=require_boolean(
- payload, "show_tensor_names", default=True
- ),
- show_index_labels=require_boolean(
- payload, "show_index_names", default=True
- ),
- show_edge_labels=require_boolean(
- payload, "show_bond_names", default=True
- ),
- ),
- )
- content_type = "text/plain;charset=utf-8"
- response_payload = {
- "format": render_format,
- "text": text,
- "content_type": content_type,
- }
- elif render_format == "svg":
- text = render_spec_svg(spec, options=svg_options)
- response_payload = {
- "format": render_format,
- "text": text,
- "content_type": "image/svg+xml;charset=utf-8",
- }
- elif render_format == "png":
- binary = render_spec_png(spec, options=svg_options)
- response_payload = {
- "format": render_format,
- "base64": base64.b64encode(binary).decode("ascii"),
- "content_type": "image/png",
- }
- else:
- binary = render_spec_pdf(spec, options=svg_options)
- response_payload = {
- "format": render_format,
- "base64": base64.b64encode(binary).decode("ascii"),
- "content_type": "application/pdf",
- }
except ValueError as exc:
return bad_request_response(str(exc))
except SerializationError as exc:
@@ -521,6 +522,13 @@ def handle_analyze_contraction(
level=logging.WARNING,
)
return bad_request_response(str(exc))
+ except ValueError as exc:
+ log_branch(
+ LOGGER,
+ f"Contraction analysis request exceeded API limits: {exc}",
+ level=logging.WARNING,
+ )
+ return bad_request_response(str(exc))
except SpecValidationError as exc:
log_branch(
LOGGER,
@@ -790,22 +798,166 @@ def _serialize_generate_result(result: CodegenResult) -> JsonDict:
def _resolve_render_format(
payload: JsonDict,
-) -> Literal["tikz", "dot", "mermaid", "svg", "png", "pdf"]:
+) -> _RenderFormat:
raw_format = payload.get("format")
if not isinstance(raw_format, str) or not raw_format.strip():
raise ValueError("Missing 'format' payload.")
normalized_format = raw_format.strip().lower()
if normalized_format in {"tikz", "dot", "mermaid", "svg", "png", "pdf"}:
- return cast(
- Literal["tikz", "dot", "mermaid", "svg", "png", "pdf"],
- normalized_format,
- )
+ return cast(_RenderFormat, normalized_format)
raise ValueError(
"Unsupported render format "
f"'{raw_format}'. Expected 'tikz', 'dot', 'mermaid', 'svg', 'png', or 'pdf'."
)
+def _resolve_render_label_options(payload: JsonDict) -> _RenderLabelOptions:
+ """Return shared render-label visibility flags for one request payload."""
+ return _RenderLabelOptions(
+ show_tensor_labels=require_boolean(payload, "show_tensor_names", default=True),
+ show_index_labels=require_boolean(payload, "show_index_names", default=True),
+ show_edge_labels=require_boolean(payload, "show_bond_names", default=True),
+ )
+
+
+def _resolve_render_theme(payload: JsonDict) -> EditorThemeName:
+ """Return the editor theme requested for one render payload."""
+ raw_theme = payload.get("theme")
+ if raw_theme is None:
+ return DEFAULT_EDITOR_THEME
+ if not isinstance(raw_theme, str):
+ raise ValueError("'theme' must be a string when provided.")
+ return normalize_editor_theme(raw_theme)
+
+
+def _svg_render_options(
+ label_options: _RenderLabelOptions,
+ *,
+ render_format: _RenderFormat,
+ theme: EditorThemeName,
+) -> SvgRenderOptions:
+ """Return SVG/PNG/PDF render options derived from shared label flags."""
+ return SvgRenderOptions(
+ show_tensor_labels=label_options.show_tensor_labels,
+ show_index_labels=label_options.show_index_labels,
+ show_edge_labels=label_options.show_edge_labels,
+ transparent_background=render_format in {"svg", "png"},
+ **_IMAGE_EXPORT_THEME_OVERRIDES[theme],
+ )
+
+
+def _tikz_render_options(label_options: _RenderLabelOptions) -> TikzRenderOptions:
+ """Return TikZ render options derived from shared label flags."""
+ return TikzRenderOptions(
+ show_tensor_labels=label_options.show_tensor_labels,
+ show_index_labels=label_options.show_index_labels,
+ show_edge_labels=label_options.show_edge_labels,
+ )
+
+
+def _dot_render_options(label_options: _RenderLabelOptions) -> DotRenderOptions:
+ """Return DOT/Mermaid render options derived from shared label flags."""
+ return DotRenderOptions(
+ show_tensor_labels=label_options.show_tensor_labels,
+ show_index_labels=label_options.show_index_labels,
+ show_edge_labels=label_options.show_edge_labels,
+ )
+
+
+def _build_text_render_response(
+ render_format: _RenderFormat,
+ text: str,
+ *,
+ content_type: str,
+) -> JsonDict:
+ """Return one text-based render response payload."""
+ return {
+ "format": render_format,
+ "text": text,
+ "content_type": content_type,
+ }
+
+
+def _build_binary_render_response(
+ render_format: _RenderFormat,
+ binary: bytes,
+ *,
+ content_type: str,
+) -> JsonDict:
+ """Return one binary render response payload encoded for JSON transport."""
+ return {
+ "format": render_format,
+ "base64": base64.b64encode(binary).decode("ascii"),
+ "content_type": content_type,
+ }
+
+
+def _build_render_response(
+ render_format: _RenderFormat,
+ spec: NetworkSpec,
+ label_options: _RenderLabelOptions,
+ *,
+ theme: EditorThemeName = DEFAULT_EDITOR_THEME,
+) -> JsonDict:
+ """Return the serialized academic render payload for one format request."""
+ if render_format == "tikz":
+ return _build_text_render_response(
+ render_format,
+ render_spec_tikz(spec, options=_tikz_render_options(label_options)),
+ content_type="text/x-tex;charset=utf-8",
+ )
+ if render_format == "dot":
+ return _build_text_render_response(
+ render_format,
+ render_spec_dot(spec, options=_dot_render_options(label_options)),
+ content_type="text/vnd.graphviz;charset=utf-8",
+ )
+ if render_format == "mermaid":
+ return _build_text_render_response(
+ render_format,
+ render_spec_mermaid(spec, options=_dot_render_options(label_options)),
+ content_type="text/plain;charset=utf-8",
+ )
+ if render_format == "svg":
+ return _build_text_render_response(
+ render_format,
+ render_spec_svg(
+ spec,
+ options=_svg_render_options(
+ label_options,
+ render_format=render_format,
+ theme=theme,
+ ),
+ ),
+ content_type="image/svg+xml;charset=utf-8",
+ )
+ if render_format == "png":
+ return _build_binary_render_response(
+ render_format,
+ render_spec_png(
+ spec,
+ options=_svg_render_options(
+ label_options,
+ render_format=render_format,
+ theme=theme,
+ ),
+ ),
+ content_type="image/png",
+ )
+ return _build_binary_render_response(
+ render_format,
+ render_spec_pdf(
+ spec,
+ options=_svg_render_options(
+ label_options,
+ render_format=render_format,
+ theme=theme,
+ ),
+ ),
+ content_type="application/pdf",
+ )
+
+
def _serialize_complete_result(result: EditorResult) -> JsonDict:
"""Serialize a complete-route editor result."""
return serialize_editor_result(result)
@@ -833,6 +985,7 @@ def _handle_session_codegen_request(
request.serialized_spec,
request.engine,
request.collection_format,
+ request.include_roundtrip_metadata,
)
return ok_response(_serialize_generate_result(generate_result))
if operation == "complete":
@@ -840,11 +993,14 @@ def _handle_session_codegen_request(
request.serialized_spec,
request.engine,
request.collection_format,
+ request.include_roundtrip_metadata,
)
return ok_response(_serialize_complete_result(complete_result))
raise ValueError(f"Unsupported code generation operation '{operation}'.")
except SerializationError as exc:
return bad_request_response(str(exc))
+ except ValueError as exc:
+ return bad_request_response(str(exc))
except CodeGenerationError as exc:
return bad_request_response(str(exc))
except PackageIOError as exc:
diff --git a/src/tensor_network_editor/app/server.py b/src/tensor_network_editor/app/server.py
index c0018e3..87a5ab7 100644
--- a/src/tensor_network_editor/app/server.py
+++ b/src/tensor_network_editor/app/server.py
@@ -2,10 +2,14 @@
from __future__ import annotations
+import hmac
+import ipaddress
import json
import logging
import mimetypes
+import secrets
import threading
+import time
from collections.abc import Callable
from dataclasses import dataclass
from http import HTTPStatus
@@ -14,6 +18,7 @@
from pathlib import Path
from typing import Protocol, TypeAlias, cast
from urllib.parse import urlparse
+from urllib.request import urlopen
from ..internal._logging import (
bind_log_context,
@@ -27,17 +32,25 @@
JsonDict,
JsonResponse,
bad_request_response,
+ forbidden_response,
internal_server_error_response,
not_found_response,
read_json,
+ unsupported_media_type_response,
)
from .session import EditorSession
LOGGER = logging.getLogger(__name__)
_SERVE_FOREVER_POLL_INTERVAL_SECONDS: float = 0.05
+_STARTUP_READY_TIMEOUT_SECONDS: float = 5.0
+_STARTUP_READY_POLL_INTERVAL_SECONDS: float = 0.01
+_STARTUP_READY_REQUEST_TIMEOUT_SECONDS: float = 0.2
+_RESPONSE_WRITE_CHUNK_SIZE_BYTES: int = 64 * 1024
+_STATIC_ASSET_CACHE_VALIDATION_INTERVAL_SECONDS: float = 0.5
_MAX_REQUEST_BODY_BYTES: int = 1_048_576
_STATIC_ASSET_CACHE_LOCK = threading.Lock()
_STATIC_ASSET_CACHE_BY_ROOT: dict[Path, _StaticAssetCache] = {}
+_STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT: dict[Path, float] = {}
_UNEXPECTED_INTERNAL_ERROR_MESSAGE = "Unexpected internal error."
_UNEXPECTED_INTERNAL_ERROR_GUIDANCE = (
"Try again. If the problem continues, check the terminal output for this "
@@ -46,6 +59,13 @@
_QUIET_MISSING_STATIC_ASSET_PATHS: frozenset[str] = frozenset({"/favicon.ico"})
_ScannedStaticAssetFile: TypeAlias = tuple[Path, str, int, int]
_RUNTIME_CONFIG_PLACEHOLDER = "__TNE_RUNTIME_CONFIG__"
+_CSP_NONCE_PLACEHOLDER = "__TNE_CSP_NONCE__"
+_API_TOKEN_HEADER = "X-TNE-Session-Token" # noqa: S105, RUF100 - header name.
+_EXPECTED_JSON_CONTENT_TYPE = "application/json"
+_PERMISSIONS_POLICY_HEADER = (
+ "accelerometer=(), camera=(), geolocation=(), gyroscope=(), "
+ "magnetometer=(), microphone=(), payment=(), usb=()"
+)
class SupportsReadBytes(Protocol):
@@ -86,6 +106,75 @@ def _read_request_body_bytes(reader: SupportsReadBytes, content_length: int) ->
return b"".join(chunks)
+def _is_loopback_host_name(host_name: str) -> bool:
+ """Return whether a hostname literal is safe for local-only editor serving."""
+ normalized_host = host_name.strip().strip("[]").rstrip(".").lower()
+ if normalized_host in {"localhost"} or normalized_host.endswith(".localhost"):
+ return True
+ if "%" in normalized_host:
+ normalized_host = normalized_host.split("%", 1)[0]
+ try:
+ address = ipaddress.ip_address(normalized_host)
+ except ValueError:
+ return False
+ return address.is_loopback
+
+
+def _validate_bind_host(host: str, *, allow_remote: bool) -> None:
+ """Reject non-loopback bind hosts unless remote serving is explicit."""
+ if allow_remote or _is_loopback_host_name(host):
+ return
+ raise ValueError(
+ "Refusing to bind the editor server to a non-loopback host. "
+ "Use allow_remote=True only when you intentionally expose this local API."
+ )
+
+
+def _host_name_from_header(host_header: str | None) -> str | None:
+ """Extract the hostname portion from one HTTP Host header."""
+ if host_header is None:
+ return None
+ value = host_header.strip()
+ if not value:
+ return None
+ if value.startswith("["):
+ end_index = value.find("]")
+ if end_index <= 1:
+ return None
+ return value[1:end_index]
+ if value.count(":") == 1:
+ host_name, port_text = value.rsplit(":", 1)
+ if port_text.isdigit():
+ return host_name
+ return value
+
+
+def _is_trusted_host_header(host_header: str | None, *, allow_remote: bool) -> bool:
+ """Return whether one Host header is acceptable for this server."""
+ if allow_remote:
+ return bool(host_header and host_header.strip())
+ host_name = _host_name_from_header(host_header)
+ return host_name is not None and _is_loopback_host_name(host_name)
+
+
+def _is_trusted_origin_header(origin_header: str | None, *, allow_remote: bool) -> bool:
+ """Return whether one optional Origin header is acceptable for API writes."""
+ if origin_header is None:
+ return True
+ parsed_origin = urlparse(origin_header)
+ if parsed_origin.scheme not in {"http", "https"}:
+ return False
+ return _is_trusted_host_header(parsed_origin.netloc, allow_remote=allow_remote)
+
+
+def _is_json_content_type(content_type: str | None) -> bool:
+ """Return whether one Content-Type header identifies a JSON request body."""
+ if content_type is None:
+ return False
+ media_type = content_type.split(";", 1)[0].strip().lower()
+ return media_type == _EXPECTED_JSON_CONTENT_TYPE
+
+
@dataclass(slots=True, frozen=True)
class _BinaryResponse:
"""Internal response container for pre-encoded bytes."""
@@ -198,11 +287,14 @@ def _build_static_asset_cache(
def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache:
"""Return a shared static asset cache for one editor static directory."""
resolved_static_dir = static_dir.resolve()
- scanned_files = _scan_static_asset_files(resolved_static_dir)
- current_signature = _build_static_asset_source_signature(scanned_files)
with _STATIC_ASSET_CACHE_LOCK:
+ validation_started_at = time.monotonic()
cache = _STATIC_ASSET_CACHE_BY_ROOT.get(resolved_static_dir)
+ last_validated_at = _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.get(
+ resolved_static_dir
+ )
if cache is None:
+ scanned_files = _scan_static_asset_files(resolved_static_dir)
with log_operation(
LOGGER,
"Static asset cache build",
@@ -213,9 +305,29 @@ def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache:
scanned_files=scanned_files,
)
_STATIC_ASSET_CACHE_BY_ROOT[resolved_static_dir] = cache
+ _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT[resolved_static_dir] = (
+ validation_started_at
+ )
success_context["after"] = cache.asset_version
success_context["asset_count"] = len(cache.body_by_relative_path)
return cache
+ if (
+ last_validated_at is not None
+ and validation_started_at - last_validated_at
+ < _STATIC_ASSET_CACHE_VALIDATION_INTERVAL_SECONDS
+ ):
+ log_branch(
+ LOGGER,
+ "Static asset cache reused",
+ context={
+ "path": resolved_static_dir,
+ "after": cache.asset_version,
+ "asset_count": len(cache.body_by_relative_path),
+ },
+ )
+ return cache
+ scanned_files = _scan_static_asset_files(resolved_static_dir)
+ current_signature = _build_static_asset_source_signature(scanned_files)
if cache.source_signature != current_signature:
with log_operation(
LOGGER,
@@ -230,11 +342,17 @@ def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache:
scanned_files=scanned_files,
)
_STATIC_ASSET_CACHE_BY_ROOT[resolved_static_dir] = refreshed_cache
+ _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT[resolved_static_dir] = (
+ validation_started_at
+ )
success_context["after"] = refreshed_cache.asset_version
success_context["asset_count"] = len(
refreshed_cache.body_by_relative_path
)
return refreshed_cache
+ _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT[resolved_static_dir] = (
+ validation_started_at
+ )
log_branch(
LOGGER,
"Static asset cache reused",
@@ -247,29 +365,63 @@ def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache:
return cache
-def _build_frontend_runtime_config_payload(session: EditorSession) -> JsonDict:
+def _build_frontend_runtime_config_payload(
+ session: EditorSession, *, api_token: str
+) -> JsonDict:
"""Return the runtime configuration embedded into the editor HTML page."""
return {
"session_id": session.session_id,
+ "api_token": api_token,
"frontend_logging": build_frontend_logging_payload(session),
}
-def _serialize_frontend_runtime_config(session: EditorSession) -> str:
+def _serialize_frontend_runtime_config(
+ session: EditorSession, *, api_token: str
+) -> str:
"""Serialize one session runtime config safely for an inline JSON script."""
- return json.dumps(_build_frontend_runtime_config_payload(session)).replace(
- "", "<\\/"
- )
+ return json.dumps(
+ _build_frontend_runtime_config_payload(session, api_token=api_token)
+ ).replace("", "<\\/")
-def _render_session_index_body(index_body: bytes, session: EditorSession) -> bytes:
+def _render_session_index_body(
+ index_body: bytes,
+ session: EditorSession,
+ *,
+ api_token: str,
+ csp_nonce: str,
+) -> bytes:
"""Return the per-session editor HTML body with embedded runtime config."""
return index_body.replace(
_RUNTIME_CONFIG_PLACEHOLDER.encode("utf-8"),
- _serialize_frontend_runtime_config(session).encode("utf-8"),
+ _serialize_frontend_runtime_config(session, api_token=api_token).encode(
+ "utf-8"
+ ),
+ ).replace(
+ _CSP_NONCE_PLACEHOLDER.encode("utf-8"),
+ csp_nonce.encode("utf-8"),
)
+def _build_content_security_policy(*, csp_nonce: str) -> str:
+ """Return the editor CSP that permits only trusted local assets."""
+ directives = [
+ "default-src 'self'",
+ "base-uri 'none'",
+ "object-src 'none'",
+ "frame-ancestors 'none'",
+ "form-action 'none'",
+ "connect-src 'self'",
+ "img-src 'self' data: blob:",
+ f"script-src 'self' 'nonce-{csp_nonce}'",
+ "style-src 'self' 'unsafe-inline'",
+ "font-src 'self' data:",
+ "worker-src 'self' blob:",
+ ]
+ return "; ".join(directives)
+
+
def _unexpected_internal_error_response(session_id: str) -> JsonResponse:
"""Return an actionable but safe error payload for unexpected failures."""
return internal_server_error_response(
@@ -288,7 +440,13 @@ class EditorServer:
"""Serve the browser app and JSON API for one editor session."""
def __init__(
- self, session: EditorSession, host: str = "127.0.0.1", port: int = 0
+ self,
+ session: EditorSession,
+ host: str = "127.0.0.1",
+ port: int = 0,
+ *,
+ allow_remote: bool = False,
+ api_token: str | None = None,
) -> None:
"""Initialize the threaded local editor server.
@@ -296,19 +454,33 @@ def __init__(
session: Shared editor session state served by this HTTP server.
host: Local host interface to bind.
port: Local port to bind. Use ``0`` for an ephemeral port.
+ allow_remote: Whether non-loopback bind hosts are allowed.
+ api_token: Optional pre-generated API token for tests.
"""
+ _validate_bind_host(host, allow_remote=allow_remote)
self.session = session
self.session_id = session.session_id
self.host = host
self.port = port
+ self.allow_remote = allow_remote
+ self.api_token = api_token or secrets.token_urlsafe(32)
+ if not self.api_token.strip():
+ raise ValueError("Editor API token cannot be empty.")
+ self._csp_nonce = secrets.token_urlsafe(16)
+ self._content_security_policy = _build_content_security_policy(
+ csp_nonce=self._csp_nonce
+ )
self._static_dir = Path(__file__).resolve().parent / "static"
self._static_asset_cache = _get_static_asset_cache(self._static_dir)
self._index_body = _render_session_index_body(
self._static_asset_cache.index_body,
session,
+ api_token=self.api_token,
+ csp_nonce=self._csp_nonce,
)
self._server = ThreadingHTTPServer((host, port), self._build_handler())
self._thread = threading.Thread(target=self._serve_forever, daemon=True)
+ self._serve_forever_ready = threading.Event()
@property
def base_url(self) -> str:
@@ -322,6 +494,11 @@ def base_url(self) -> str:
def start(self) -> None:
"""Start serving requests in a background thread."""
self._thread.start()
+ try:
+ self._wait_until_ready()
+ except Exception:
+ self._cleanup_failed_start()
+ raise
log_branch(
LOGGER,
f"Editor server started at {self.base_url}",
@@ -331,9 +508,7 @@ def start(self) -> None:
def stop(self) -> None:
"""Stop the server and wait for the worker thread to exit."""
- self._server.shutdown()
- self._server.server_close()
- self._thread.join(timeout=5)
+ self._stop_server_worker()
log_branch(
LOGGER,
"Editor server stopped",
@@ -343,8 +518,61 @@ def stop(self) -> None:
def _serve_forever(self) -> None:
"""Serve requests with a short shutdown polling interval."""
+ self._serve_forever_ready.set()
self._server.serve_forever(poll_interval=_SERVE_FOREVER_POLL_INTERVAL_SECONDS)
+ def _wait_until_ready(self) -> None:
+ """Block until loopback requests can read one fully served asset."""
+ deadline = time.monotonic() + _STARTUP_READY_TIMEOUT_SECONDS
+ if not self._serve_forever_ready.wait(timeout=_STARTUP_READY_TIMEOUT_SECONDS):
+ raise RuntimeError(
+ "Editor server did not enter the serving loop before the startup timeout elapsed."
+ )
+
+ last_error: OSError | None = None
+ while True:
+ remaining_seconds = deadline - time.monotonic()
+ if remaining_seconds <= 0:
+ break
+ request_timeout_seconds = min(
+ _STARTUP_READY_REQUEST_TIMEOUT_SECONDS,
+ remaining_seconds,
+ )
+ try:
+ self._probe_loopback_readiness(request_timeout_seconds)
+ except OSError as exc:
+ last_error = exc
+ time.sleep(min(_STARTUP_READY_POLL_INTERVAL_SECONDS, remaining_seconds))
+ continue
+ return
+
+ if last_error is None:
+ raise RuntimeError(
+ "Editor server readiness probe timed out before any loopback request succeeded."
+ )
+ raise RuntimeError(
+ "Editor server did not become ready to serve loopback requests before the startup timeout elapsed."
+ ) from last_error
+
+ def _probe_loopback_readiness(self, timeout_seconds: float) -> None:
+ """Read one small static asset to verify the server serves full responses."""
+ with urlopen( # noqa: S310, RUF100 - probes this loopback server.
+ f"{self.base_url}/favicon.ico", timeout=timeout_seconds
+ ) as response:
+ response.read()
+
+ def _stop_server_worker(self) -> None:
+ """Best-effort shutdown that is safe before the serve loop starts."""
+ if self._thread.is_alive() and self._serve_forever_ready.is_set():
+ self._server.shutdown()
+ self._server.server_close()
+ if self._thread.ident is not None:
+ self._thread.join(timeout=5)
+
+ def _cleanup_failed_start(self) -> None:
+ """Best-effort cleanup when startup fails after allocating the server socket."""
+ self._stop_server_worker()
+
def _build_handler(self) -> type[BaseHTTPRequestHandler]:
"""Build the request-handler class bound to this server instance."""
session = self.session
@@ -352,6 +580,9 @@ def _build_handler(self) -> type[BaseHTTPRequestHandler]:
static_dir = self._static_dir
static_asset_cache = self._static_asset_cache
index_body = self._index_body
+ api_token = self.api_token
+ allow_remote = self.allow_remote
+ content_security_policy = self._content_security_policy
def build_index_response() -> _BinaryResponse:
"""Return the cached main HTML page for this editor session."""
@@ -442,6 +673,10 @@ def do_GET(self) -> None:
"""Handle one HTTP GET request for assets or bootstrap data."""
parsed = urlparse(self.path)
with bind_log_context(session=session_id, route=parsed.path):
+ if self._reject_untrusted_host():
+ return
+ if self._reject_invalid_api_token(parsed.path):
+ return
try:
with log_operation(LOGGER, "Route request"):
response = self._dispatch_get(parsed.path)
@@ -458,6 +693,14 @@ def do_POST(self) -> None:
"""Handle one HTTP POST request for the editor JSON API."""
parsed = urlparse(self.path)
with bind_log_context(session=session_id, route=parsed.path):
+ if self._reject_untrusted_host():
+ return
+ if self._reject_untrusted_origin():
+ return
+ if self._reject_invalid_api_token(parsed.path):
+ return
+ if self._reject_unsupported_content_type():
+ return
try:
with log_operation(LOGGER, "Route request"):
try:
@@ -513,6 +756,85 @@ def _dispatch_post(self, path: str, payload: JsonDict) -> JsonResponse:
LOGGER.debug(format_log_message(f"Unknown POST path: {path}"))
return not_found_response()
+ def _reject_untrusted_host(self) -> bool:
+ """Write a forbidden response when the Host header is not local."""
+ if _is_trusted_host_header(
+ self.headers.get("Host"),
+ allow_remote=allow_remote,
+ ):
+ return False
+ LOGGER.warning(
+ format_log_message(
+ "Rejected request with untrusted Host header",
+ context={"host": self.headers.get("Host")},
+ ),
+ )
+ self._prepare_rejected_request_connection()
+ self._write_response(forbidden_response("Untrusted Host header."))
+ return True
+
+ def _reject_untrusted_origin(self) -> bool:
+ """Write a forbidden response when the Origin header is not local."""
+ if _is_trusted_origin_header(
+ self.headers.get("Origin"),
+ allow_remote=allow_remote,
+ ):
+ return False
+ LOGGER.warning(
+ format_log_message(
+ "Rejected request with untrusted Origin header",
+ context={"origin": self.headers.get("Origin")},
+ ),
+ )
+ self._prepare_rejected_request_connection()
+ self._write_response(forbidden_response("Untrusted Origin header."))
+ return True
+
+ def _reject_invalid_api_token(self, path: str) -> bool:
+ """Write a forbidden response when an API request lacks the token."""
+ if not path.startswith("/api/"):
+ return False
+ header_value = self.headers.get(_API_TOKEN_HEADER)
+ if header_value is not None and hmac.compare_digest(
+ header_value,
+ api_token,
+ ):
+ return False
+ LOGGER.warning(
+ format_log_message(
+ "Rejected API request with invalid session token"
+ ),
+ )
+ self._prepare_rejected_request_connection()
+ self._write_response(
+ forbidden_response("Invalid editor session token.")
+ )
+ return True
+
+ def _reject_unsupported_content_type(self) -> bool:
+ """Write an unsupported-media response for non-JSON API writes."""
+ if _is_json_content_type(self.headers.get("Content-Type")):
+ return False
+ LOGGER.warning(
+ format_log_message(
+ "Rejected API request with unsupported Content-Type",
+ context={"content_type": self.headers.get("Content-Type")},
+ ),
+ )
+ self._prepare_rejected_request_connection()
+ self._write_response(
+ unsupported_media_type_response(
+ "Expected Content-Type 'application/json'."
+ )
+ )
+ return True
+
+ def _prepare_rejected_request_connection(self) -> None:
+ """Drain rejected POST bodies before closing the connection."""
+ if self.command == "POST":
+ self._drain_pending_request_body()
+ self.close_connection = True
+
def _static_response(
self, request_path: str
) -> JsonResponse | _BinaryResponse:
@@ -594,11 +916,23 @@ def _write_bytes(self, status: int, body: bytes, content_type: str) -> None:
self.send_response(status)
self.send_header("Content-Type", content_type)
self.send_header("Content-Length", str(len(body)))
+ self.send_header("X-Content-Type-Options", "nosniff")
+ self.send_header("Referrer-Policy", "no-referrer")
+ self.send_header("X-Frame-Options", "DENY")
+ self.send_header("Content-Security-Policy", content_security_policy)
+ self.send_header("Permissions-Policy", _PERMISSIONS_POLICY_HEADER)
+ self.send_header("Cross-Origin-Resource-Policy", "same-origin")
if self.close_connection:
self.send_header("Connection", "close")
self._write_no_cache_headers()
self.end_headers()
- self.wfile.write(body)
+ body_view = memoryview(body)
+ for offset in range(
+ 0, len(body_view), _RESPONSE_WRITE_CHUNK_SIZE_BYTES
+ ):
+ next_offset = offset + _RESPONSE_WRITE_CHUNK_SIZE_BYTES
+ self.wfile.write(body_view[offset:next_offset])
+ self.wfile.flush()
def _write_no_cache_headers(self) -> None:
"""Emit headers that disable browser and intermediary caching."""
diff --git a/src/tensor_network_editor/app/session.py b/src/tensor_network_editor/app/session.py
index 37f002f..8159ba3 100644
--- a/src/tensor_network_editor/app/session.py
+++ b/src/tensor_network_editor/app/session.py
@@ -6,7 +6,9 @@
import signal
import threading
import webbrowser
+from base64 import b64decode
from collections.abc import Callable, Mapping, Sequence
+from importlib import import_module
from pathlib import Path
from types import FrameType
from typing import Any, Literal
@@ -47,6 +49,7 @@
LOGGER = logging.getLogger(__name__)
SignalHandler = Callable[[int, FrameType | None], Any]
+SessionUiMode = Literal["browser", "pywebview", "server"]
def _print_editor_url(base_url: str) -> None:
@@ -64,6 +67,178 @@ def _print_browser_open_fallback_message(base_url: str) -> None:
_print_editor_url(base_url)
+def _import_pywebview() -> Any:
+ """Import the optional pywebview module on demand."""
+ try:
+ return import_module("webview")
+ except ModuleNotFoundError as exc:
+ raise RuntimeError(
+ "pywebview mode requires the optional desktop extra. Install it with "
+ 'python -m pip install "tensor-network-editor[desktop]".'
+ ) from exc
+
+
+def _resolve_pywebview_icon_path() -> Path:
+ """Return the packaged desktop icon used for the native pywebview window."""
+ return Path(__file__).resolve().parent / "static" / "favicon.ico"
+
+
+def _apply_pywebview_native_window_icon(window: Any) -> None:
+ """Apply the packaged icon to the native pywebview window when supported."""
+ icon_path = _resolve_pywebview_icon_path()
+ if not icon_path.is_file():
+ return
+ native_window = getattr(window, "native", None)
+ if native_window is None:
+ return
+ try:
+ from System.Drawing import Icon as DrawingIcon # type: ignore[import-not-found]
+ except Exception:
+ return
+ try:
+ native_window.Icon = DrawingIcon(str(icon_path))
+ if hasattr(native_window, "ShowIcon"):
+ native_window.ShowIcon = True
+ except Exception as exc:
+ log_branch(
+ LOGGER,
+ "Could not apply the native pywebview window icon",
+ level=logging.WARNING,
+ context={
+ "icon_path": str(icon_path),
+ "error": str(exc),
+ },
+ )
+
+
+class _PywebviewExportApi:
+ """Expose native save-file helpers to the embedded pywebview frontend."""
+
+ def __init__(self, pywebview_module: Any) -> None:
+ """Store the imported pywebview module for later dialog calls."""
+ self._pywebview_module = pywebview_module
+ self._window: Any | None = None
+
+ def bind_window(self, window: Any) -> None:
+ """Attach the created pywebview window once it exists."""
+ self._window = window
+
+ def save_text_file(
+ self,
+ filename: str,
+ text: str,
+ content_type: str = "text/plain;charset=utf-8",
+ ) -> bool:
+ """Prompt for a path and write one UTF-8 text file."""
+ del content_type
+ output_path = self._select_output_path(filename)
+ if output_path is None:
+ return False
+ output_path.write_text(text, encoding="utf-8")
+ return True
+
+ def save_binary_file(
+ self,
+ filename: str,
+ base64_payload: str,
+ content_type: str = "application/octet-stream",
+ ) -> bool:
+ """Prompt for a path and write one decoded binary export file."""
+ del content_type
+ output_path = self._select_output_path(filename)
+ if output_path is None:
+ return False
+ output_path.write_bytes(b64decode(base64_payload))
+ return True
+
+ def _select_output_path(self, filename: str) -> Path | None:
+ """Ask pywebview for a target save path and normalize the response."""
+ if self._window is None:
+ raise RuntimeError("pywebview export API is not bound to a window.")
+ dialog_result = self._window.create_file_dialog(
+ self._pywebview_module.SAVE_DIALOG,
+ save_filename=filename,
+ file_types=self._build_file_types(filename),
+ )
+ if dialog_result is None:
+ return None
+ if isinstance(dialog_result, str):
+ return Path(dialog_result)
+ if isinstance(dialog_result, Sequence) and dialog_result:
+ first_entry = dialog_result[0]
+ if isinstance(first_entry, str) and first_entry:
+ return Path(first_entry)
+ return None
+
+ def _build_file_types(self, filename: str) -> tuple[str, ...]:
+ """Build a compact pywebview filter tuple from one filename."""
+ suffix = Path(filename).suffix.lower()
+ if not suffix:
+ return ()
+ label = {
+ ".dot": "DOT",
+ ".json": "JSON",
+ ".mmd": "Mermaid",
+ ".pdf": "PDF",
+ ".png": "PNG",
+ ".py": "Python",
+ ".svg": "SVG",
+ ".tex": "LaTeX",
+ }.get(suffix, suffix.removeprefix(".").upper())
+ return (f"{label} (*{suffix})",)
+
+
+def _run_pywebview_session(
+ session: EditorSession, base_url: str
+) -> EditorResult | None:
+ """Open the local editor in a pywebview window and wait for the result."""
+ if threading.current_thread() is not threading.main_thread():
+ raise RuntimeError("pywebview mode must be launched from the main thread.")
+
+ try:
+ pywebview = _import_pywebview()
+ except ModuleNotFoundError as exc:
+ raise RuntimeError(
+ "pywebview mode requires the optional desktop extra. Install it with "
+ 'python -m pip install "tensor-network-editor[desktop]".'
+ ) from exc
+ pywebview_export_api = _PywebviewExportApi(pywebview)
+ pywebview_window = pywebview.create_window(
+ "Tensor Network Editor",
+ base_url,
+ maximized=True,
+ js_api=pywebview_export_api,
+ )
+ pywebview_export_api.bind_window(pywebview_window)
+ window_events = getattr(pywebview_window, "events", None)
+ before_show_event = getattr(window_events, "before_show", None)
+ if before_show_event is not None:
+ before_show_event += lambda: _apply_pywebview_native_window_icon(
+ pywebview_window
+ )
+ else:
+ _apply_pywebview_native_window_icon(pywebview_window)
+
+ def _handle_window_closed(*_args: object) -> None:
+ """Cancel the editor session when the native window is closed."""
+ if not session.is_finished():
+ session.cancel()
+
+ def _wait_for_session_and_close_window(window: Any) -> None:
+ """Close the native window after the editor session finishes."""
+ wait_for_editor_result(session)
+ try:
+ window.destroy()
+ except Exception:
+ return None
+
+ closed_event = getattr(window_events, "closed", None)
+ if closed_event is not None:
+ closed_event += _handle_window_closed
+ pywebview.start(_wait_for_session_and_close_window, pywebview_window)
+ return wait_for_editor_result(session)
+
+
class EditorSession:
"""Mutable session state shared between the HTTP server and the caller."""
@@ -290,6 +465,7 @@ def generate(
serialized_spec: Mapping[str, object],
engine: EngineIdentifier,
collection_format: TensorCollectionFormat | None = None,
+ include_roundtrip_metadata: bool = True,
) -> CodegenResult:
"""Generate preview code without finalizing the session."""
with log_operation(
@@ -306,6 +482,7 @@ def generate(
serialized_spec,
engine,
collection_format,
+ include_roundtrip_metadata,
)
def complete(
@@ -313,6 +490,7 @@ def complete(
serialized_spec: Mapping[str, object],
engine: EngineIdentifier,
collection_format: TensorCollectionFormat | None = None,
+ include_roundtrip_metadata: bool = True,
) -> EditorResult:
"""Finalize the session and store the resulting editor output."""
with log_operation(
@@ -335,6 +513,7 @@ def complete(
serialized_spec,
engine,
collection_format,
+ include_roundtrip_metadata,
)
with self._lock:
if self._finished_event.is_set() and self._result is not None:
@@ -390,8 +569,10 @@ def launch_editor_session(
default_engine: EngineIdentifier = EngineName.TENSORKROWCH,
default_collection_format: TensorCollectionFormat = TensorCollectionFormat.LIST,
theme: EditorThemeName = DEFAULT_EDITOR_THEME,
+ ui_mode: SessionUiMode | None = None,
open_browser: bool = True,
host: str = "127.0.0.1",
+ allow_remote: bool = False,
port: int = 0,
print_code: bool = False,
code_path: StrPath | None = None,
@@ -412,8 +593,10 @@ def launch_editor_session(
default_collection_format: Initial tensor collection layout for
generated code.
theme: Visual theme selected for this editor session.
+ ui_mode: Explicit UI launch mode for the editor session.
open_browser: Whether to ask the system browser to open the local URL.
host: Local host interface to bind.
+ allow_remote: Whether non-loopback bind hosts are allowed.
port: Local port to bind. Use ``0`` for an ephemeral port.
print_code: Whether to print generated code after confirmation.
code_path: Optional output path for generated code after confirmation.
@@ -437,6 +620,7 @@ def launch_editor_session(
Raises:
KeyboardInterrupt: If the session is interrupted from the main thread.
"""
+ from ..editor import resolve_editor_ui_mode
from .server import EditorServer
active_logging_runtime = get_active_logging_runtime()
@@ -467,7 +651,21 @@ def launch_editor_session(
shared_subnetwork_catalog_path=shared_subnetwork_catalog_path,
draft_path=draft_path,
)
- server = EditorServer(session=session, host=host, port=port)
+ server = EditorServer(
+ session=session,
+ host=host,
+ port=port,
+ allow_remote=allow_remote,
+ )
+ effective_ui_mode = resolve_editor_ui_mode(
+ ui_mode=ui_mode,
+ open_browser=open_browser,
+ )
+ if (
+ effective_ui_mode == "pywebview"
+ and threading.current_thread() is not threading.main_thread()
+ ):
+ raise RuntimeError("pywebview mode must be launched from the main thread.")
previous_sigint_handler: SignalHandler | int | None = None
server_started = False
@@ -481,6 +679,7 @@ def launch_editor_session(
"session": session.session_id,
"engine": engine_name_to_text(default_engine),
"mode": theme,
+ "ui_mode": effective_ui_mode,
},
):
if threading.current_thread() is threading.main_thread():
@@ -497,9 +696,11 @@ def _handle_sigint(_signum: int, _frame: FrameType | None) -> None:
server_started = True
if _on_server_ready is not None:
_on_server_ready(server.base_url)
- should_print_editor_url = not open_browser
+ if effective_ui_mode == "pywebview":
+ return _run_pywebview_session(session, server.base_url)
+ should_print_editor_url = effective_ui_mode == "server"
should_print_browser_fallback_message = False
- if open_browser:
+ if effective_ui_mode == "browser":
try:
with log_operation(
LOGGER,
diff --git a/src/tensor_network_editor/app/static/app.css b/src/tensor_network_editor/app/static/app.css
index 6d294f5..ca7a1ee 100644
--- a/src/tensor_network_editor/app/static/app.css
+++ b/src/tensor_network_editor/app/static/app.css
@@ -3369,6 +3369,31 @@ textarea[disabled] {
align-items: center;
}
+.code-metadata-toggle {
+ display: inline-flex;
+ align-items: center;
+ gap: 0.45rem;
+ min-height: var(--canvas-control-height);
+ padding: 0 0.8rem;
+ border: 1px solid var(--border-subtle);
+ border-radius: 999px;
+ background: var(--surface-subtle);
+ color: var(--muted);
+ font-size: 0.78rem;
+ font-weight: 600;
+ line-height: 1.1;
+ cursor: pointer;
+ user-select: none;
+}
+
+.code-metadata-toggle input {
+ margin: 0;
+}
+
+.code-metadata-toggle[hidden] {
+ display: none;
+}
+
.code-format-picker {
position: relative;
display: inline-flex;
diff --git a/src/tensor_network_editor/app/static/index.html b/src/tensor_network_editor/app/static/index.html
index 610db92..7ebb159 100644
--- a/src/tensor_network_editor/app/static/index.html
+++ b/src/tensor_network_editor/app/static/index.html
@@ -788,6 +788,30 @@ Generated code
-
-
+ ',
+ re.DOTALL,
+)
+_SESSION_TOKEN_BY_ORIGIN: dict[str, str | None] = {}
+
def request_json(
url: str,
@@ -28,6 +40,8 @@ def request_json_with_status(
method: str = "GET",
payload: dict[str, Any] | None = None,
raw_body: bytes | None = None,
+ session_token: str | None = None,
+ include_session_token: bool = True,
timeout: float = 5.0,
) -> tuple[int, dict[str, Any]]:
data = None
@@ -40,6 +54,12 @@ def request_json_with_status(
elif raw_body is not None:
data = raw_body
headers["Content-Type"] = "application/json"
+ if include_session_token:
+ resolved_session_token = (
+ session_token if session_token is not None else _session_token_for_url(url)
+ )
+ if resolved_session_token:
+ headers["X-TNE-Session-Token"] = resolved_session_token
request = Request(url=url, method=method, data=data, headers=headers)
try:
with urlopen(request, timeout=timeout) as response:
@@ -48,13 +68,86 @@ def request_json_with_status(
return exc.code, json.loads(exc.read().decode("utf-8"))
+def _session_token_for_url(url: str) -> str | None:
+ """Read the embedded editor API token for a local test server URL."""
+ origin = _origin_for_url(url)
+ if origin is None:
+ return None
+ if origin in _SESSION_TOKEN_BY_ORIGIN:
+ return _SESSION_TOKEN_BY_ORIGIN[origin]
+ try:
+ with urlopen(f"{origin}/", timeout=_ASSET_REQUEST_TIMEOUT_SECONDS) as response:
+ html = response.read().decode("utf-8")
+ except OSError:
+ _SESSION_TOKEN_BY_ORIGIN[origin] = None
+ return None
+ match = _RUNTIME_CONFIG_RE.search(html)
+ if match is None:
+ _SESSION_TOKEN_BY_ORIGIN[origin] = None
+ return None
+ try:
+ payload = json.loads(match.group(1))
+ except json.JSONDecodeError:
+ _SESSION_TOKEN_BY_ORIGIN[origin] = None
+ return None
+ token = payload.get("api_token") if isinstance(payload, dict) else None
+ _SESSION_TOKEN_BY_ORIGIN[origin] = token if isinstance(token, str) else None
+ return _SESSION_TOKEN_BY_ORIGIN[origin]
+
+
+def _origin_for_url(url: str) -> str | None:
+ """Return the scheme/authority origin for an absolute URL."""
+ parsed = urlsplit(url)
+ if not parsed.scheme or not parsed.netloc:
+ return None
+ return f"{parsed.scheme}://{parsed.netloc}"
+
+
+def _read_asset_response(url: str) -> tuple[bytes, dict[str, str]]:
+ """Read one asset request with retries for transient local-server hiccups."""
+ last_error: OSError | None = None
+ for attempt_index in range(_ASSET_REQUEST_RETRY_COUNT):
+ try:
+ with urlopen(url, timeout=_ASSET_REQUEST_TIMEOUT_SECONDS) as response:
+ body = response.read()
+ headers = {key: value for key, value in response.headers.items()}
+ return body, headers
+ except OSError as exc:
+ last_error = exc
+ if attempt_index + 1 >= _ASSET_REQUEST_RETRY_COUNT:
+ raise
+ time.sleep(_ASSET_REQUEST_RETRY_DELAY_SECONDS)
+ if last_error is not None:
+ raise last_error
+ raise RuntimeError("Asset request retry loop ended unexpectedly.")
+
+
def request_text(url: str) -> str:
- with urlopen(url, timeout=5) as response:
- return cast(str, response.read().decode("utf-8"))
+ body, _headers = _read_asset_response(url)
+ return cast(str, body.decode("utf-8"))
def request_with_headers(url: str) -> tuple[str, dict[str, str]]:
- with urlopen(url, timeout=5) as response:
- body = response.read().decode("utf-8")
- headers = {key: value for key, value in response.headers.items()}
- return body, headers
+ body, headers = _read_asset_response(url)
+ return body.decode("utf-8"), headers
+
+
+def request_headers(url: str) -> dict[str, str]:
+ last_error: OSError | None = None
+ for attempt_index in range(_ASSET_REQUEST_RETRY_COUNT):
+ try:
+ with urlopen(url, timeout=_ASSET_REQUEST_TIMEOUT_SECONDS) as response:
+ return {key: value for key, value in response.headers.items()}
+ except OSError as exc:
+ last_error = exc
+ if attempt_index + 1 >= _ASSET_REQUEST_RETRY_COUNT:
+ raise
+ time.sleep(_ASSET_REQUEST_RETRY_DELAY_SECONDS)
+ if last_error is not None:
+ raise last_error
+ raise RuntimeError("Asset header request retry loop ended unexpectedly.")
+
+
+def request_bytes(url: str) -> bytes:
+ body, _headers = _read_asset_response(url)
+ return body
diff --git a/tests/codegen/test_common.py b/tests/codegen/test_common.py
index 5179f30..0dc4e62 100644
--- a/tests/codegen/test_common.py
+++ b/tests/codegen/test_common.py
@@ -16,6 +16,8 @@
)
from tensor_network_editor.models import (
CanvasPosition,
+ CodegenResult,
+ EngineName,
NetworkSpec,
TensorCollectionFormat,
TensorSpec,
@@ -158,3 +160,71 @@ def test_render_helper_function_lines_indents_rendered_sections() -> None:
)
assert helper_lines == ["def build_cell(slot_index: int) -> dict[str, object]:"]
+
+
+def test_dispatch_periodic_codegen_routes_supported_backends_and_roundtrip() -> None:
+ from tensor_network_editor.codegen.modes._periodic_codegen import (
+ dispatch_periodic_codegen,
+ )
+
+ seen_calls: list[tuple[str, str]] = []
+
+ def render_array(payload: str) -> CodegenResult:
+ seen_calls.append(("array", payload))
+ return CodegenResult(engine=EngineName.EINSUM_NUMPY, code="array_result = 1\n")
+
+ def render_graph(payload: str) -> CodegenResult:
+ seen_calls.append(("graph", payload))
+ return CodegenResult(engine=EngineName.TENSORNETWORK, code="graph_result = 1\n")
+
+ spec = build_three_tensor_hyperedge_spec()
+
+ array_result = dispatch_periodic_codegen(
+ spec=spec,
+ payload="array-payload",
+ missing_payload_message="missing payload",
+ unsupported_backend_label="periodic",
+ engine=EngineName.EINSUM_NUMPY,
+ include_roundtrip_metadata=True,
+ array_renderer=render_array,
+ graph_renderer=render_graph,
+ )
+ graph_result = dispatch_periodic_codegen(
+ spec=spec,
+ payload="graph-payload",
+ missing_payload_message="missing payload",
+ unsupported_backend_label="periodic",
+ engine=EngineName.TENSORNETWORK,
+ include_roundtrip_metadata=False,
+ array_renderer=render_array,
+ graph_renderer=render_graph,
+ )
+
+ assert seen_calls == [("array", "array-payload"), ("graph", "graph-payload")]
+ assert "# TNE_SPEC_B64:" in array_result.code
+ assert graph_result.code == "graph_result = 1\n"
+
+
+def test_dispatch_periodic_codegen_rejects_missing_payload() -> None:
+ from tensor_network_editor.codegen.modes._periodic_codegen import (
+ dispatch_periodic_codegen,
+ )
+ from tensor_network_editor.errors import CodeGenerationError
+
+ with pytest.raises(CodeGenerationError, match="grid payload"):
+ dispatch_periodic_codegen(
+ spec=NetworkSpec(name="missing payload"),
+ payload=None,
+ missing_payload_message="Grid periodic code generation requires a grid payload.",
+ unsupported_backend_label="grid periodic",
+ engine=EngineName.EINSUM_NUMPY,
+ include_roundtrip_metadata=False,
+ array_renderer=lambda payload: CodegenResult(
+ engine=EngineName.EINSUM_NUMPY,
+ code=f"{payload}\n",
+ ),
+ graph_renderer=lambda payload: CodegenResult(
+ engine=EngineName.TENSORNETWORK,
+ code=f"{payload}\n",
+ ),
+ )
diff --git a/tests/codegen/test_generators.py b/tests/codegen/test_generators.py
index 5876d2b..c3832d6 100644
--- a/tests/codegen/test_generators.py
+++ b/tests/codegen/test_generators.py
@@ -1,6 +1,8 @@
from __future__ import annotations
+import sys
from collections.abc import Callable
+from types import ModuleType, SimpleNamespace
from unittest.mock import patch
import pytest
@@ -13,6 +15,7 @@
from tensor_network_editor.errors import CodeGenerationError
from tensor_network_editor.models import (
CanvasPosition,
+ ContractionStepSpec,
EdgeEndpointRef,
EdgeSpec,
EngineName,
@@ -34,6 +37,7 @@
build_outer_product_plan_spec,
build_sample_spec,
build_sample_spec_without_plan,
+ build_three_tensor_complete_plan_spec,
build_three_tensor_hyperedge_spec,
build_three_tensor_spec,
build_three_tensor_spec_without_plan,
@@ -263,6 +267,511 @@ def _execute_generated_code(
return namespace
+class _FakeTensorKrowchEdge:
+ """Minimal edge object for generated-code regression tests."""
+
+ def __init__(
+ self,
+ node: _FakeTensorKrowchNode,
+ axis_name: str,
+ *,
+ origin: tuple[str, str] | None = None,
+ ) -> None:
+ self.node1 = node
+ self.axis1 = SimpleNamespace(name=axis_name)
+ self.node2: _FakeTensorKrowchNode | None = None
+ self.axis2: SimpleNamespace | None = None
+ self.origin = origin or (node.name, axis_name)
+
+ @classmethod
+ def from_endpoints(
+ cls,
+ *,
+ node1: _FakeTensorKrowchNode,
+ axis1_name: str,
+ node2: _FakeTensorKrowchNode | None = None,
+ axis2_name: str | None = None,
+ origin: tuple[str, str] | None = None,
+ ) -> _FakeTensorKrowchEdge:
+ """Build one edge with explicit endpoint ownership."""
+ edge = cls(node1, axis1_name, origin=origin)
+ if node2 is not None and axis2_name is not None:
+ edge.attach_second(node2, axis2_name)
+ return edge
+
+ def attach_second(
+ self,
+ node: _FakeTensorKrowchNode,
+ axis_name: str,
+ ) -> None:
+ self.node2 = node
+ self.axis2 = SimpleNamespace(name=axis_name)
+
+ def replace_endpoint(
+ self,
+ old_node: _FakeTensorKrowchNode,
+ new_node: _FakeTensorKrowchNode,
+ new_axis_name: str,
+ ) -> None:
+ if self.node1 is old_node:
+ self.node1 = new_node
+ self.axis1 = SimpleNamespace(name=new_axis_name)
+ return
+ if self.node2 is old_node:
+ self.node2 = new_node
+ self.axis2 = SimpleNamespace(name=new_axis_name)
+
+ def is_dangling(self) -> bool:
+ return self.node2 is None
+
+ def axis_name_for_node(
+ self,
+ node: _FakeTensorKrowchNode,
+ ) -> SimpleNamespace:
+ """Return the endpoint axis metadata for ``node``."""
+ if self.node1 is node:
+ return self.axis1
+ assert self.node2 is node
+ assert self.axis2 is not None
+ return self.axis2
+
+
+class _FakeTensorKrowchNode:
+ """Minimal node object for generated-code regression tests."""
+
+ def __init__(
+ self,
+ *,
+ tensor: object,
+ axes_names: tuple[str, ...],
+ name: str,
+ network: object,
+ ) -> None:
+ del tensor, network
+ self.name = name
+ self.edges_by_axis_name = {
+ axis_name: _FakeTensorKrowchEdge(self, axis_name)
+ for axis_name in axes_names
+ }
+ self.pending_edges_by_axis_name: dict[str, _FakeTensorKrowchEdge] = {}
+ self.pending_edge_owner_by_axis_name: dict[str, _FakeTensorKrowchNode] = {}
+ self.axis_is_node1_by_axis_name = {axis_name: True for axis_name in axes_names}
+
+ def __getitem__(self, axis_name: str) -> _FakeTensorKrowchEdge:
+ if axis_name in self.edges_by_axis_name:
+ return self.edges_by_axis_name[axis_name]
+ return self.pending_edges_by_axis_name[axis_name]
+
+ def reattach_edges(self, override: bool = False) -> None:
+ for axis_name, edge in list(self.pending_edges_by_axis_name.items()):
+ owner = self.pending_edge_owner_by_axis_name.pop(axis_name)
+ owner_is_node1 = edge.node1 is owner
+ if owner_is_node1:
+ other_node = edge.node2
+ other_axis_name = None if edge.axis2 is None else edge.axis2.name
+ else:
+ other_node = edge.node1
+ other_axis_name = edge.axis1.name
+ if override:
+ if owner_is_node1:
+ edge.node1 = self
+ edge.axis1 = SimpleNamespace(name=axis_name)
+ else:
+ edge.node2 = self
+ edge.axis2 = SimpleNamespace(name=axis_name)
+ self.edges_by_axis_name[axis_name] = edge
+ else:
+ if owner_is_node1:
+ self.edges_by_axis_name[axis_name] = (
+ _FakeTensorKrowchEdge.from_endpoints(
+ node1=self,
+ axis1_name=axis_name,
+ node2=other_node,
+ axis2_name=other_axis_name,
+ origin=edge.origin,
+ )
+ )
+ else:
+ assert other_node is not None
+ assert other_axis_name is not None
+ self.edges_by_axis_name[axis_name] = (
+ _FakeTensorKrowchEdge.from_endpoints(
+ node1=other_node,
+ axis1_name=other_axis_name,
+ node2=self,
+ axis2_name=axis_name,
+ origin=edge.origin,
+ )
+ )
+ self.axis_is_node1_by_axis_name[axis_name] = owner_is_node1
+ self.pending_edges_by_axis_name = {}
+
+
+class _FakeTensorKrowchModule(ModuleType):
+ """Tiny ``tensorkrowch`` double that exposes fragile axis ordering."""
+
+ def __init__(self) -> None:
+ super().__init__("tensorkrowch")
+ self.Node = _FakeTensorKrowchNode
+ self.TensorNetwork = _fake_tensorkrowch_network_factory
+
+ @staticmethod
+ def connect(
+ left_edge: _FakeTensorKrowchEdge,
+ right_edge: _FakeTensorKrowchEdge,
+ ) -> _FakeTensorKrowchEdge:
+ left_edge.attach_second(right_edge.node1, right_edge.axis1.name)
+ right_edge.node1.edges_by_axis_name[right_edge.axis1.name] = left_edge
+ right_edge.node1.axis_is_node1_by_axis_name[right_edge.axis1.name] = False
+ return left_edge
+
+ @staticmethod
+ def contract_between(
+ left_node: _FakeTensorKrowchNode,
+ right_node: _FakeTensorKrowchNode,
+ ) -> _FakeTensorKrowchNode:
+ left_edges = set(left_node.edges_by_axis_name.values())
+ right_edges = set(right_node.edges_by_axis_name.values())
+ if not left_edges.intersection(right_edges):
+ raise ValueError(
+ f"No batch edges or shared edges between nodes {left_node.name} and {right_node.name} found"
+ )
+ shared_edges = left_edges.intersection(right_edges)
+ surviving_edges_with_owner = [
+ (edge, left_node)
+ for edge in left_node.edges_by_axis_name.values()
+ if edge not in shared_edges
+ ] + [
+ (edge, right_node)
+ for edge in right_node.edges_by_axis_name.values()
+ if edge not in shared_edges
+ ]
+ surviving_axis_names = _deduplicate_fake_tensorkrowch_axis_names(
+ tuple(
+ edge.axis_name_for_node(owner).name
+ for edge, owner in surviving_edges_with_owner
+ )
+ )
+ result = _FakeTensorKrowchNode(
+ tensor=None,
+ axes_names=surviving_axis_names,
+ name=f"{left_node.name}_{right_node.name}",
+ network=None,
+ )
+ result.edges_by_axis_name = {}
+ result.pending_edges_by_axis_name = {}
+ result.pending_edge_owner_by_axis_name = {}
+ result.axis_is_node1_by_axis_name = {}
+ for axis_name, (edge, owner) in zip(
+ surviving_axis_names,
+ surviving_edges_with_owner,
+ strict=True,
+ ):
+ result.pending_edges_by_axis_name[axis_name] = edge
+ result.pending_edge_owner_by_axis_name[axis_name] = owner
+ result.axis_is_node1_by_axis_name[axis_name] = edge.node1 is owner
+ return result
+
+
+class _FakeTorchModule(ModuleType):
+ """Tiny ``torch`` double for generated-code regression tests."""
+
+ float32: object
+
+ def __init__(self) -> None:
+ super().__init__("torch")
+ self.float32 = object()
+
+ @staticmethod
+ def zeros(
+ shape: tuple[int, ...],
+ dtype: object | None = None,
+ ) -> tuple[tuple[int, ...], object | None]:
+ return (shape, dtype)
+
+
+def _deduplicate_fake_tensorkrowch_axis_names(
+ axis_names: tuple[str, ...],
+) -> tuple[str, ...]:
+ """Mirror TensorKrowch suffixing for exact duplicate surviving axes."""
+ base_names = [
+ axis_name.rsplit("_", 1)[0]
+ if axis_name.rsplit("_", 1)[-1].isdigit()
+ else axis_name
+ for axis_name in axis_names
+ ]
+ result: list[str] = []
+ counts: dict[str, int] = {}
+ for axis_name in base_names:
+ index = counts.get(axis_name, 0)
+ counts[axis_name] = index + 1
+ if base_names.count(axis_name) == 1:
+ result.append(axis_name)
+ else:
+ result.append(f"{axis_name}_{index}")
+ return tuple(result)
+
+
+def _fake_tensorkrowch_network_factory() -> object:
+ """Return a placeholder TensorNetwork instance for generated code."""
+ return SimpleNamespace(reset=lambda: None)
+
+
+class _ResetAwareFakeTensorKrowchNetwork:
+ """Minimal network object that can resync inherited resultant edges."""
+
+ def __init__(self) -> None:
+ self.nodes: list[_ResetAwareFakeTensorKrowchNode] = []
+
+ def register(self, node: _ResetAwareFakeTensorKrowchNode) -> None:
+ self.nodes.append(node)
+
+ def reset(self) -> None:
+ for node in self.nodes:
+ node.reset_inherited_edges()
+
+
+class _ResetAwareFakeTensorKrowchEdge:
+ """Edge double that hides inherited-result connections until reset."""
+
+ def __init__(
+ self,
+ node: _ResetAwareFakeTensorKrowchNode,
+ axis_name: str,
+ ) -> None:
+ self.node1 = node
+ self.axis1 = SimpleNamespace(name=axis_name)
+ self.node2: _ResetAwareFakeTensorKrowchNode | None = None
+ self.axis2: SimpleNamespace | None = None
+ self.origin = (node.name, axis_name)
+ self.inherited_source_by_result_node: dict[
+ _ResetAwareFakeTensorKrowchNode,
+ tuple[_ResetAwareFakeTensorKrowchNode, str],
+ ] = {}
+
+ def attach_second(
+ self,
+ node: _ResetAwareFakeTensorKrowchNode,
+ axis_name: str,
+ ) -> None:
+ self.node2 = node
+ self.axis2 = SimpleNamespace(name=axis_name)
+
+ def replace_endpoint(
+ self,
+ old_node: _ResetAwareFakeTensorKrowchNode,
+ new_node: _ResetAwareFakeTensorKrowchNode,
+ new_axis_name: str,
+ ) -> None:
+ if self.node1 is old_node:
+ self.inherited_source_by_result_node[new_node] = (
+ old_node,
+ self.axis1.name,
+ )
+ self.node1 = new_node
+ self.axis1 = SimpleNamespace(name=new_axis_name)
+ self._stale_other_resultant_endpoints(excluded_result_node=new_node)
+ return
+ if self.node2 is old_node:
+ assert self.axis2 is not None
+ self.inherited_source_by_result_node[new_node] = (
+ old_node,
+ self.axis2.name,
+ )
+ self.node2 = new_node
+ self.axis2 = SimpleNamespace(name=new_axis_name)
+ self._stale_other_resultant_endpoints(excluded_result_node=new_node)
+
+ def materialize_leaf_endpoint_for_resultant(
+ self,
+ node: _ResetAwareFakeTensorKrowchNode,
+ ) -> None:
+ source = self.inherited_source_by_result_node.get(node)
+ if source is None:
+ return
+ source_node, source_axis_name = source
+ if self.node1 is node:
+ self.node1 = source_node
+ self.axis1 = SimpleNamespace(name=source_axis_name)
+ return
+ if self.node2 is node:
+ self.node2 = source_node
+ self.axis2 = SimpleNamespace(name=source_axis_name)
+
+ def restore_resultant_endpoint(
+ self,
+ node: _ResetAwareFakeTensorKrowchNode,
+ axis_name: str,
+ ) -> None:
+ source = self.inherited_source_by_result_node.get(node)
+ if source is None:
+ return
+ source_node, source_axis_name = source
+ if self.node1 is source_node and self.axis1.name == source_axis_name:
+ self.node1 = node
+ self.axis1 = SimpleNamespace(name=axis_name)
+ return
+ if (
+ self.node2 is source_node
+ and self.axis2 is not None
+ and self.axis2.name == source_axis_name
+ ):
+ self.node2 = node
+ self.axis2 = SimpleNamespace(name=axis_name)
+
+ def _stale_other_resultant_endpoints(
+ self,
+ *,
+ excluded_result_node: _ResetAwareFakeTensorKrowchNode,
+ ) -> None:
+ """Hide this edge from other inherited-result views until reset."""
+ for result_node in tuple(self.inherited_source_by_result_node):
+ if result_node is excluded_result_node:
+ continue
+ if self.node1 is result_node or self.node2 is result_node:
+ self.materialize_leaf_endpoint_for_resultant(result_node)
+
+ def is_dangling(self) -> bool:
+ return self.node2 is None
+
+ def axis_name_for_node(
+ self,
+ node: _ResetAwareFakeTensorKrowchNode,
+ ) -> SimpleNamespace:
+ if self.node1 is node:
+ return self.axis1
+ assert self.node2 is node
+ assert self.axis2 is not None
+ return self.axis2
+
+ def connects_nodes(
+ self,
+ left_node: _ResetAwareFakeTensorKrowchNode,
+ right_node: _ResetAwareFakeTensorKrowchNode,
+ ) -> bool:
+ return (self.node1 is left_node and self.node2 is right_node) or (
+ self.node1 is right_node and self.node2 is left_node
+ )
+
+
+class _ResetAwareFakeTensorKrowchNode:
+ """Node double that tracks resultant-edge visibility across resets."""
+
+ def __init__(
+ self,
+ *,
+ tensor: object,
+ axes_names: tuple[str, ...],
+ name: str,
+ network: _ResetAwareFakeTensorKrowchNetwork | None,
+ ) -> None:
+ del tensor
+ self.name = name
+ self.network = network
+ self.is_resultant = False
+ self.edges_by_axis_name = {
+ axis_name: _ResetAwareFakeTensorKrowchEdge(self, axis_name)
+ for axis_name in axes_names
+ }
+ self.pending_edges_by_axis_name: dict[
+ str,
+ _ResetAwareFakeTensorKrowchEdge,
+ ] = {}
+ if network is not None:
+ network.register(self)
+
+ def __getitem__(self, axis_name: str) -> _ResetAwareFakeTensorKrowchEdge:
+ if axis_name in self.edges_by_axis_name:
+ return self.edges_by_axis_name[axis_name]
+ return self.pending_edges_by_axis_name[axis_name]
+
+ def reattach_edges(self) -> None:
+ self.edges_by_axis_name.update(self.pending_edges_by_axis_name)
+ self.pending_edges_by_axis_name = {}
+
+ def reset_inherited_edges(self) -> None:
+ for axis_name, edge in self.edges_by_axis_name.items():
+ edge.restore_resultant_endpoint(self, axis_name)
+ for axis_name, edge in self.pending_edges_by_axis_name.items():
+ edge.restore_resultant_endpoint(self, axis_name)
+
+
+class _ResetAwareFakeTensorKrowchModule(ModuleType):
+ """TensorKrowch double that requires ``network.reset()`` for inherited edges."""
+
+ def __init__(self) -> None:
+ super().__init__("tensorkrowch")
+ self.Node = _ResetAwareFakeTensorKrowchNode
+ self.TensorNetwork = _reset_aware_fake_tensorkrowch_network_factory
+
+ @staticmethod
+ def connect(
+ left_edge: _ResetAwareFakeTensorKrowchEdge,
+ right_edge: _ResetAwareFakeTensorKrowchEdge,
+ ) -> _ResetAwareFakeTensorKrowchEdge:
+ if left_edge.is_dangling() and left_edge.node1.is_resultant:
+ left_edge.materialize_leaf_endpoint_for_resultant(left_edge.node1)
+ left_edge.attach_second(right_edge.node1, right_edge.axis1.name)
+ right_edge.node1.edges_by_axis_name[right_edge.axis1.name] = left_edge
+ return left_edge
+
+ @staticmethod
+ def contract_between(
+ left_node: _ResetAwareFakeTensorKrowchNode,
+ right_node: _ResetAwareFakeTensorKrowchNode,
+ ) -> _ResetAwareFakeTensorKrowchNode:
+ left_edges = set(left_node.edges_by_axis_name.values())
+ right_edges = set(right_node.edges_by_axis_name.values())
+ shared_edges = {
+ edge
+ for edge in left_edges.intersection(right_edges)
+ if edge.connects_nodes(left_node, right_node)
+ }
+ if not shared_edges:
+ raise ValueError(
+ f"No batch edges or shared edges between nodes {left_node.name} and {right_node.name} found"
+ )
+ surviving_edges_with_owner = [
+ (edge, right_node)
+ for edge in right_node.edges_by_axis_name.values()
+ if edge not in shared_edges
+ ] + [
+ (edge, left_node)
+ for edge in left_node.edges_by_axis_name.values()
+ if edge not in shared_edges
+ ]
+ surviving_axis_names = _deduplicate_fake_tensorkrowch_axis_names(
+ tuple(
+ edge.axis_name_for_node(owner).name
+ for edge, owner in surviving_edges_with_owner
+ )
+ )
+ result = _ResetAwareFakeTensorKrowchNode(
+ tensor=None,
+ axes_names=surviving_axis_names,
+ name=f"{left_node.name}_{right_node.name}",
+ network=left_node.network,
+ )
+ result.is_resultant = True
+ result.edges_by_axis_name = {}
+ result.pending_edges_by_axis_name = {}
+ for axis_name, (edge, owner) in zip(
+ surviving_axis_names,
+ surviving_edges_with_owner,
+ strict=True,
+ ):
+ edge.replace_endpoint(owner, result, axis_name)
+ result.pending_edges_by_axis_name[axis_name] = edge
+ return result
+
+
+def _reset_aware_fake_tensorkrowch_network_factory() -> (
+ _ResetAwareFakeTensorKrowchNetwork
+):
+ """Return a fake network that models inherited-edge reset semantics."""
+ return _ResetAwareFakeTensorKrowchNetwork()
+
+
@pytest.mark.parametrize(
("engine", "expected_snippets"),
[
@@ -813,6 +1322,20 @@ def test_periodic_generate_code_emits_roundtrip_metadata_marker() -> None:
assert "# TNE_SPEC_B64:" in result.code
assert "# Tensor Network Editor linear periodic mode" in result.code
+ assert result.code.index(
+ "# Tensor Network Editor linear periodic mode"
+ ) < result.code.index("# TNE_SPEC_B64:")
+
+
+def test_periodic_generate_code_can_skip_roundtrip_metadata_marker() -> None:
+ result = generate_code(
+ build_linear_periodic_chain_spec(),
+ engine=EngineName.EINSUM_NUMPY,
+ include_roundtrip_metadata=False,
+ )
+
+ assert "# TNE_SPEC_B64:" not in result.code
+ assert "# Tensor Network Editor linear periodic mode" in result.code
@pytest.mark.parametrize("engine", list(EngineName))
@@ -856,6 +1379,16 @@ def test_generate_code_respects_manual_plan_steps(
assert "result = results_list[-1]" in result.code
+def test_tensorkrowch_normal_codegen_does_not_emit_reattach_edges() -> None:
+ result = generate_code(
+ build_three_tensor_complete_plan_spec(),
+ engine=EngineName.TENSORKROWCH,
+ )
+
+ assert "results_list.append(tk.contract_between(" in result.code
+ assert "reattach_edges(" not in result.code
+
+
@pytest.mark.parametrize("engine", list(EngineName))
def test_generate_code_keeps_partial_manual_plan_as_prefix(
engine: EngineName,
@@ -1119,6 +1652,120 @@ def test_linear_periodic_carry_codegen_labels_shared_for_sections(
assert "previous_payload: dict[str, object]" in result.code
+def test_linear_periodic_carry_tensorkrowch_codegen_tracks_boundary_edges_without_axis_order_assumptions() -> (
+ None
+):
+ result = generate_code(
+ build_linear_periodic_carry_chain_spec(),
+ engine=EngineName.TENSORKROWCH,
+ )
+ fake_torch = _FakeTorchModule()
+ fake_tensorkrowch = _FakeTensorKrowchModule()
+
+ with patch.dict(
+ sys.modules,
+ {
+ "torch": fake_torch,
+ "tensorkrowch": fake_tensorkrowch,
+ },
+ ):
+ namespace = _execute_generated_code(result.code, n=3)
+
+ open_edges = namespace["open_edges"]
+ assert isinstance(open_edges, list)
+ assert len(open_edges) == 4
+ assert [edge.origin for edge in open_edges] == [
+ ("Initial", "phys"),
+ ("PeriodicLeft", "phys_l"),
+ ("PeriodicRight", "phys_r"),
+ ("Final", "phys"),
+ ]
+
+
+def test_linear_periodic_carry_tensorkrowch_codegen_executes_when_periodic_cell_contracts_local_pair_before_previous_payload() -> (
+ None
+):
+ spec = build_linear_periodic_carry_chain_spec()
+ assert spec.linear_periodic_chain is not None
+ assert spec.linear_periodic_chain.periodic_cell.contraction_plan is not None
+ spec.linear_periodic_chain.periodic_cell.contraction_plan.steps = [
+ ContractionStepSpec(
+ id="periodic_contract_internal_first",
+ left_operand_id="periodic_left_tensor",
+ right_operand_id="periodic_right_tensor",
+ ),
+ ContractionStepSpec(
+ id="periodic_consume_previous_second",
+ left_operand_id="periodic_contract_internal_first",
+ right_operand_id="__linear_previous__",
+ ),
+ ContractionStepSpec(
+ id="periodic_carry_last",
+ left_operand_id="periodic_consume_previous_second",
+ right_operand_id="__linear_next__",
+ ),
+ ]
+ result = generate_code(spec, engine=EngineName.TENSORKROWCH)
+ fake_torch = _FakeTorchModule()
+ fake_tensorkrowch = _FakeTensorKrowchModule()
+
+ with patch.dict(
+ sys.modules,
+ {
+ "torch": fake_torch,
+ "tensorkrowch": fake_tensorkrowch,
+ },
+ ):
+ namespace = _execute_generated_code(result.code, n=5)
+
+ assert "result" in namespace
+ assert "open_edges" in namespace
+
+
+def test_linear_periodic_carry_tensorkrowch_codegen_materializes_result_edges_with_override() -> (
+ None
+):
+ spec = build_linear_periodic_carry_chain_spec()
+ assert spec.linear_periodic_chain is not None
+ assert spec.linear_periodic_chain.periodic_cell.contraction_plan is not None
+ spec.linear_periodic_chain.periodic_cell.contraction_plan.steps = [
+ ContractionStepSpec(
+ id="periodic_contract_internal_first",
+ left_operand_id="periodic_left_tensor",
+ right_operand_id="periodic_right_tensor",
+ ),
+ ContractionStepSpec(
+ id="periodic_consume_previous_second",
+ left_operand_id="periodic_contract_internal_first",
+ right_operand_id="__linear_previous__",
+ ),
+ ContractionStepSpec(
+ id="periodic_carry_last",
+ left_operand_id="periodic_consume_previous_second",
+ right_operand_id="__linear_next__",
+ ),
+ ]
+ result = generate_code(spec, engine=EngineName.TENSORKROWCH)
+ fake_torch = _FakeTorchModule()
+ fake_tensorkrowch = _FakeTensorKrowchModule()
+
+ with patch.dict(
+ sys.modules,
+ {
+ "torch": fake_torch,
+ "tensorkrowch": fake_tensorkrowch,
+ },
+ ):
+ namespace = _execute_generated_code(result.code, n=5)
+
+ assert "reattach_edges(override=True)" in result.code
+ assert "network.reset()" not in result.code
+ assert "open_edges.extend([tracked_edge_0, tracked_edge_1])" in result.code
+ assert "outgoing_interface = [results_list[-1]['right']]" in result.code
+ assert "result" in namespace
+ assert "open_edges" in namespace
+
+
@pytest.mark.parametrize("engine", list(EngineName))
def test_linear_periodic_codegen_does_not_stringify_manual_blocks(
engine: EngineName,
diff --git a/tests/codegen/test_grid_periodic_internals.py b/tests/codegen/test_grid_periodic_internals.py
index 6b31c71..1856e30 100644
--- a/tests/codegen/test_grid_periodic_internals.py
+++ b/tests/codegen/test_grid_periodic_internals.py
@@ -1,6 +1,9 @@
from __future__ import annotations
-from tensor_network_editor.models import GridPeriodicCellName
+from tensor_network_editor.models import (
+ GridPeriodicCellName,
+ TensorCollectionFormat,
+)
from tests.factories import build_grid_periodic_grid_spec
@@ -72,3 +75,35 @@ def test_grid_periodic_internal_helpers_keep_shared_labels_and_main_flow() -> No
"result = network_nodes[0] if len(network_nodes) == 1 else None"
]
assert "output_labels.extend(bottom_right_cell['open_labels'])" in einsum_main_lines
+
+
+def test_grid_periodic_array_shared_helpers_build_context_and_sections() -> None:
+ from tensor_network_editor.codegen.modes._grid_periodic.array_shared import (
+ build_grid_array_cell_context,
+ render_grid_array_tensor_sections,
+ )
+
+ grid = build_grid_periodic_grid_spec().grid_periodic_grid
+ assert grid is not None
+
+ context = build_grid_array_cell_context(
+ grid=grid,
+ cell_name=GridPeriodicCellName.TOP_LEFT,
+ collection_format=TensorCollectionFormat.LIST,
+ )
+ tensor_collection_lines, tensor_construction_lines = (
+ render_grid_array_tensor_sections(
+ context=context,
+ tensor_value_by_id={
+ tensor.spec.id: f"value_{tensor.variable_name}"
+ for tensor in context.prepared.tensors
+ },
+ )
+ )
+
+ assert context.collection_name == "tensors"
+ assert context.prepared.tensors
+ assert context.interface_index_ids
+ assert tensor_collection_lines == ["tensors = []"]
+ assert any(line.startswith("# Tensor ") for line in tensor_construction_lines)
+ assert any("tensors.append(value_" in line for line in tensor_construction_lines)
diff --git a/tests/codegen/test_linear_periodic_internals.py b/tests/codegen/test_linear_periodic_internals.py
index 60846ab..8de0540 100644
--- a/tests/codegen/test_linear_periodic_internals.py
+++ b/tests/codegen/test_linear_periodic_internals.py
@@ -4,9 +4,17 @@
from tensor_network_editor.errors import CodeGenerationError
from tensor_network_editor.models import (
+ CanvasPosition,
+ ContractionPlanSpec,
ContractionStepSpec,
+ EdgeEndpointRef,
+ EdgeSpec,
EngineName,
+ IndexSpec,
LinearPeriodicCellName,
+ LinearPeriodicCellSpec,
+ LinearPeriodicTensorRole,
+ TensorSpec,
)
from tests.factories import build_linear_periodic_carry_chain_spec
@@ -175,6 +183,203 @@ def test_simulate_carry_cell_rejects_next_step_that_is_not_final() -> None:
)
+def test_simulate_carry_cell_accepts_previous_payload_labels_that_only_collide_by_name() -> (
+ None
+):
+ from tensor_network_editor.codegen.modes._linear_periodic.carry import (
+ _CarryOperandState,
+ _CarryPayloadState,
+ _simulate_carry_cell,
+ )
+
+ periodic_cell = LinearPeriodicCellSpec(
+ tensors=[
+ TensorSpec(
+ id="periodic_previous_boundary",
+ name="Previous cell",
+ position=CanvasPosition(x=-100.0, y=140.0),
+ linear_periodic_role=LinearPeriodicTensorRole.PREVIOUS,
+ indices=[
+ IndexSpec(
+ id="periodic_previous_slot_1", name="slot_1", dimension=2
+ ),
+ IndexSpec(
+ id="periodic_previous_slot_2", name="slot_2", dimension=2
+ ),
+ ],
+ ),
+ TensorSpec(
+ id="periodic_next_boundary",
+ name="Next cell",
+ position=CanvasPosition(x=540.0, y=140.0),
+ linear_periodic_role=LinearPeriodicTensorRole.NEXT,
+ indices=[
+ IndexSpec(id="periodic_next_slot_1", name="slot_1", dimension=2),
+ IndexSpec(id="periodic_next_slot_2", name="slot_2", dimension=2),
+ ],
+ ),
+ TensorSpec(
+ id="tensor_a1",
+ name="A1",
+ position=CanvasPosition(x=-255.0, y=363.0),
+ indices=[
+ IndexSpec(id="a1_right", name="right", dimension=3),
+ IndexSpec(id="a1_phys", name="phys", dimension=2),
+ ],
+ ),
+ TensorSpec(
+ id="tensor_a2",
+ name="A2",
+ position=CanvasPosition(x=65.0, y=363.0),
+ indices=[
+ IndexSpec(id="a2_left", name="left", dimension=3),
+ IndexSpec(id="a2_right", name="right", dimension=3),
+ IndexSpec(id="a2_phys", name="phys", dimension=2),
+ ],
+ ),
+ TensorSpec(
+ id="tensor_a3",
+ name="A3",
+ position=CanvasPosition(x=385.0, y=363.0),
+ indices=[
+ IndexSpec(id="a3_left", name="left", dimension=3),
+ IndexSpec(id="a3_right", name="right", dimension=3),
+ IndexSpec(id="a3_phys", name="phys", dimension=2),
+ ],
+ ),
+ TensorSpec(
+ id="tensor_a4",
+ name="A4",
+ position=CanvasPosition(x=705.0, y=363.0),
+ indices=[
+ IndexSpec(id="a4_left", name="left", dimension=3),
+ IndexSpec(id="a4_phys", name="phys", dimension=2),
+ ],
+ ),
+ ],
+ edges=[
+ EdgeSpec(
+ id="edge_a1_a2",
+ name="edge-0-1",
+ left=EdgeEndpointRef(tensor_id="tensor_a1", index_id="a1_right"),
+ right=EdgeEndpointRef(tensor_id="tensor_a2", index_id="a2_left"),
+ ),
+ EdgeSpec(
+ id="edge_a2_a3",
+ name="edge-1-2",
+ left=EdgeEndpointRef(tensor_id="tensor_a2", index_id="a2_right"),
+ right=EdgeEndpointRef(tensor_id="tensor_a3", index_id="a3_left"),
+ ),
+ EdgeSpec(
+ id="edge_a3_a4",
+ name="edge-2-3",
+ left=EdgeEndpointRef(tensor_id="tensor_a3", index_id="a3_right"),
+ right=EdgeEndpointRef(tensor_id="tensor_a4", index_id="a4_left"),
+ ),
+ EdgeSpec(
+ id="edge_previous_a1",
+ name="bond1",
+ left=EdgeEndpointRef(
+ tensor_id="tensor_a1",
+ index_id="a1_phys",
+ ),
+ right=EdgeEndpointRef(
+ tensor_id="periodic_previous_boundary",
+ index_id="periodic_previous_slot_1",
+ ),
+ ),
+ EdgeSpec(
+ id="edge_previous_a2",
+ name="bond2",
+ left=EdgeEndpointRef(
+ tensor_id="periodic_previous_boundary",
+ index_id="periodic_previous_slot_2",
+ ),
+ right=EdgeEndpointRef(
+ tensor_id="tensor_a2",
+ index_id="a2_phys",
+ ),
+ ),
+ EdgeSpec(
+ id="edge_a3_next",
+ name="bond3",
+ left=EdgeEndpointRef(tensor_id="tensor_a3", index_id="a3_phys"),
+ right=EdgeEndpointRef(
+ tensor_id="periodic_next_boundary",
+ index_id="periodic_next_slot_1",
+ ),
+ ),
+ EdgeSpec(
+ id="edge_a4_next",
+ name="bond4",
+ left=EdgeEndpointRef(
+ tensor_id="periodic_next_boundary",
+ index_id="periodic_next_slot_2",
+ ),
+ right=EdgeEndpointRef(tensor_id="tensor_a4", index_id="a4_phys"),
+ ),
+ ],
+ contraction_plan=ContractionPlanSpec(
+ id="periodic_plan",
+ name="Manual path",
+ steps=[
+ ContractionStepSpec(
+ id="step_contract_right",
+ left_operand_id="tensor_a4",
+ right_operand_id="tensor_a3",
+ ),
+ ContractionStepSpec(
+ id="step_from_previous",
+ left_operand_id="__linear_previous__",
+ right_operand_id="tensor_a2",
+ ),
+ ContractionStepSpec(
+ id="step_merge",
+ left_operand_id="step_from_previous",
+ right_operand_id="step_contract_right",
+ ),
+ ContractionStepSpec(
+ id="step_to_next",
+ left_operand_id="step_merge",
+ right_operand_id="__linear_next__",
+ ),
+ ],
+ ),
+ )
+ previous_payload_state = _CarryPayloadState(
+ interface_operand_ids=("payload_left", "payload_right"),
+ interface_labels=("a1_phys", "a2_phys"),
+ operand_states={
+ "payload_left": _CarryOperandState(
+ labels=("payload_edge", "a1_phys"),
+ axis_names=("left_payload", "slot_1"),
+ dimensions=(3, 2),
+ ),
+ "payload_right": _CarryOperandState(
+ labels=("a4_phys", "a3_phys", "payload_edge", "a2_phys"),
+ axis_names=("carry_0", "carry_1", "bridge", "slot_2"),
+ dimensions=(2, 2, 3, 2),
+ ),
+ },
+ )
+
+ simulation = _simulate_carry_cell(
+ cell=periodic_cell,
+ cell_name=LinearPeriodicCellName.PERIODIC,
+ previous_payload_state=previous_payload_state,
+ engine=EngineName.TENSORKROWCH,
+ )
+
+ assert simulation.carry_operand_id == "step_merge"
+ assert simulation.outgoing_interface_operand_ids == ("step_merge", "step_merge")
+ assert (
+ simulation.remaining_operand_states["step_merge"].labels.count("a3_phys") == 1
+ )
+ assert (
+ simulation.remaining_operand_states["step_merge"].labels.count("a4_phys") == 1
+ )
+
+
def test_build_carry_simulation_context_collects_interface_state() -> None:
from tensor_network_editor.codegen.modes._linear_periodic.carry import (
_build_carry_simulation_context,
diff --git a/tests/codegen/test_tree_periodic_internals.py b/tests/codegen/test_tree_periodic_internals.py
index 164efbd..384beec 100644
--- a/tests/codegen/test_tree_periodic_internals.py
+++ b/tests/codegen/test_tree_periodic_internals.py
@@ -124,3 +124,36 @@ def test_tree_periodic_array_helpers_keep_child_interfaces_and_backend_tensor_bu
assert "np.zeros(" in numpy_helper_body
assert "torch.zeros(" in torch_helper_body
assert "np.zeros(" not in torch_helper_body
+
+
+def test_tree_periodic_array_shared_helpers_build_context_and_sections() -> None:
+ from tensor_network_editor.codegen.modes._tree_periodic.array_shared import (
+ build_tree_array_cell_context,
+ render_tree_array_tensor_sections,
+ )
+
+ tree = build_tree_periodic_tree_spec().tree_periodic_tree
+ assert tree is not None
+
+ context = build_tree_array_cell_context(
+ tree=tree,
+ cell_name=TreePeriodicCellName.ROOT,
+ collection_format=TensorCollectionFormat.LIST,
+ )
+ tensor_collection_lines, tensor_construction_lines = (
+ render_tree_array_tensor_sections(
+ context=context,
+ tensor_value_by_id={
+ tensor.spec.id: f"value_{tensor.variable_name}"
+ for tensor in context.prepared.tensors
+ },
+ )
+ )
+
+ assert context.collection_name == "tensors"
+ assert context.parent_ports == ()
+ assert tuple(context.child_ports_by_index) == tuple(range(tree.branching_factor))
+ assert context.interface_index_ids
+ assert tensor_collection_lines == ["tensors = []"]
+ assert any(line.startswith("# Tensor ") for line in tensor_construction_lines)
+ assert any("tensors.append(value_" in line for line in tensor_construction_lines)
diff --git a/tests/test_api.py b/tests/test_api.py
index 2c8d150..baec4f0 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -10,7 +10,7 @@
import tensor_network_editor
from tensor_network_editor import generate_code as _generate_code
-from tensor_network_editor.editor import EditorLaunchOptions, open_editor
+from tensor_network_editor.editor import EditorLaunchOptions, EditorUiMode, open_editor
from tensor_network_editor.errors import (
CodeGenerationError,
PackageIOError,
@@ -152,6 +152,7 @@ def test_package_root_exports_supported_public_api() -> None:
"EdgeSpec",
"EditorLaunchOptions",
"EditorThemeName",
+ "EditorUiMode",
"EditorResult",
"EngineName",
"DotRenderOptions",
@@ -233,8 +234,10 @@ def test_editor_launch_options_defaults_match_public_contract() -> None:
assert options.default_engine is EngineName.TENSORKROWCH
assert options.default_collection_format is TensorCollectionFormat.LIST
assert options.theme == "dark"
+ assert options.ui_mode is None
assert options.open_browser is True
assert options.host == "127.0.0.1"
+ assert options.allow_remote is False
assert options.port == 0
assert options.print_code is False
assert options.code_path is None
@@ -248,6 +251,40 @@ def test_editor_launch_options_rejects_unknown_theme() -> None:
EditorLaunchOptions(theme="sepia") # type: ignore[arg-type]
+def test_editor_launch_options_rejects_non_loopback_host_without_remote_opt_in() -> (
+ None
+):
+ with pytest.raises(ValueError, match="non-loopback"):
+ EditorLaunchOptions(host="0.0.0.0")
+
+
+def test_editor_launch_options_allows_non_loopback_host_with_remote_opt_in() -> None:
+ options = EditorLaunchOptions(host="0.0.0.0", allow_remote=True)
+
+ assert options.host == "0.0.0.0"
+ assert options.allow_remote is True
+
+
+def test_editor_ui_mode_type_alias_matches_public_contract() -> None:
+ assert EditorUiMode == Literal["browser", "pywebview", "server"]
+
+
+@pytest.mark.parametrize(
+ ("ui_mode", "open_browser", "expected_message"),
+ [
+ ("browser", False, "ui_mode='browser' requires open_browser=True"),
+ ("server", True, "ui_mode='server' requires open_browser=False"),
+ ],
+)
+def test_editor_launch_options_rejects_conflicting_browser_flags(
+ ui_mode: EditorUiMode,
+ open_browser: bool,
+ expected_message: str,
+) -> None:
+ with pytest.raises(ValueError, match=expected_message):
+ EditorLaunchOptions(ui_mode=ui_mode, open_browser=open_browser)
+
+
def test_open_editor_passes_editor_launch_options(sample_spec: NetworkSpec) -> None:
launch_result = object()
@@ -261,8 +298,10 @@ def test_open_editor_passes_editor_launch_options(sample_spec: NetworkSpec) -> N
default_engine=EngineName.EINSUM_NUMPY,
default_collection_format=TensorCollectionFormat.DICT,
theme="colorblind",
+ ui_mode="pywebview",
open_browser=False,
host="0.0.0.0",
+ allow_remote=True,
port=8123,
print_code=True,
code_path="generated.py",
@@ -281,8 +320,10 @@ def test_open_editor_passes_editor_launch_options(sample_spec: NetworkSpec) -> N
default_engine=EngineName.EINSUM_NUMPY,
default_collection_format=TensorCollectionFormat.DICT,
theme="colorblind",
+ ui_mode="pywebview",
open_browser=False,
host="0.0.0.0",
+ allow_remote=True,
port=8123,
print_code=True,
code_path="generated.py",
diff --git a/tests/test_app_assets.py b/tests/test_app_assets.py
index 4c47bd1..8d80460 100644
--- a/tests/test_app_assets.py
+++ b/tests/test_app_assets.py
@@ -8,7 +8,12 @@
import pytest
from tensor_network_editor.app.server import EditorServer
-from tests.app_support import request_text, request_with_headers
+from tests.app_support import (
+ request_bytes,
+ request_headers,
+ request_text,
+ request_with_headers,
+)
def request_runtime_bundle(editor_server: EditorServer, *relative_paths: str) -> str:
@@ -109,6 +114,36 @@ def test_root_serves_editor_shell_with_versioned_module_entry(
assert headers["Content-Type"].startswith("text/html")
+def test_root_serves_editor_shell_with_csp_nonce_and_defensive_headers(
+ editor_server: EditorServer,
+) -> None:
+ html, headers = request_with_headers(f"{editor_server.base_url}/")
+
+ content_security_policy = headers["Content-Security-Policy"]
+ nonce_match = re.search(
+ r"(?:^|;\s*)script-src 'self' 'nonce-([^']+)';",
+ content_security_policy,
+ )
+
+ assert nonce_match is not None
+ nonce = nonce_match.group(1)
+ assert nonce
+ assert "'unsafe-inline'" not in nonce_match.group(0)
+ assert (
+ f'",
+ encoding="utf-8",
+ )
+ asset_path.write_text("console.log('first');", encoding="utf-8")
+ monotonic_time = 100.0
+
+ scan_calls: list[Path] = []
+ original_scan = app_server._scan_static_asset_files
+
+ def recording_scan(path: Path) -> list[tuple[Path, str, int, int]]:
+ scan_calls.append(path.resolve())
+ return original_scan(path)
+
+ monkeypatch.setattr(app_server.time, "monotonic", lambda: monotonic_time)
+ monkeypatch.setattr(app_server, "_scan_static_asset_files", recording_scan)
+ app_server._STATIC_ASSET_CACHE_BY_ROOT.pop(resolved_static_dir, None)
+ app_server._STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.pop(
+ resolved_static_dir, None
+ )
+
+ first_cache = app_server._get_static_asset_cache(static_dir)
+ second_cache = app_server._get_static_asset_cache(static_dir)
+
+ assert first_cache is second_cache
+ assert scan_calls == [resolved_static_dir]
+
+
def test_static_asset_cache_logs_build_and_reuse(
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
@@ -255,6 +418,9 @@ def test_static_asset_cache_logs_build_and_reuse(
)
asset_path.write_text("console.log('first');", encoding="utf-8")
app_server._STATIC_ASSET_CACHE_BY_ROOT.pop(resolved_static_dir, None)
+ app_server._STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.pop(
+ resolved_static_dir, None
+ )
with caplog.at_level(logging.DEBUG, logger="tensor_network_editor"):
first_cache = app_server._get_static_asset_cache(static_dir)
@@ -270,6 +436,7 @@ def test_static_asset_cache_logs_build_and_reuse(
def test_static_asset_cache_logs_refresh_with_version_context(
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
+ monkeypatch: pytest.MonkeyPatch,
) -> None:
static_dir = tmp_path / "static"
asset_path = static_dir / "js" / "app.js"
@@ -280,7 +447,13 @@ def test_static_asset_cache_logs_refresh_with_version_context(
encoding="utf-8",
)
asset_path.write_text("console.log('first');", encoding="utf-8")
+ monotonic_time = 100.0
+
+ monkeypatch.setattr(app_server.time, "monotonic", lambda: monotonic_time)
app_server._STATIC_ASSET_CACHE_BY_ROOT.pop(resolved_static_dir, None)
+ app_server._STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.pop(
+ resolved_static_dir, None
+ )
first_cache = app_server._get_static_asset_cache(static_dir)
asset_path.write_text("console.log('second');", encoding="utf-8")
@@ -289,6 +462,7 @@ def test_static_asset_cache_logs_refresh_with_version_context(
+ 1_000_000_000
)
os.utime(asset_path, ns=(future_timestamp_ns, future_timestamp_ns))
+ monotonic_time += 1.0
with caplog.at_level(logging.DEBUG, logger="tensor_network_editor"):
refreshed_cache = app_server._get_static_asset_cache(static_dir)
diff --git a/tests/test_app_support.py b/tests/test_app_support.py
new file mode 100644
index 0000000..fd8881b
--- /dev/null
+++ b/tests/test_app_support.py
@@ -0,0 +1,100 @@
+from __future__ import annotations
+
+from unittest.mock import patch
+
+from tests import app_support
+
+
+class _FakeResponse:
+ def __init__(self, body: str) -> None:
+ self._body = body.encode("utf-8")
+ self.status = 200
+ self.headers = {"Cache-Control": "no-store"}
+
+ def __enter__(self) -> _FakeResponse:
+ return self
+
+ def __exit__(self, exc_type: object, exc: object, traceback: object) -> None:
+ del exc_type, exc, traceback
+ return None
+
+ def read(self) -> bytes:
+ return self._body
+
+
+def test_request_text_uses_shared_asset_timeout() -> None:
+ recorded_timeout: list[float] = []
+
+ def fake_urlopen(url: str, timeout: float) -> _FakeResponse:
+ recorded_timeout.append(timeout)
+ assert url == "http://example.test/"
+ return _FakeResponse("body")
+
+ with patch("tests.app_support.urlopen", side_effect=fake_urlopen):
+ body = app_support.request_text("http://example.test/")
+
+ assert body == "body"
+ assert recorded_timeout == [app_support._ASSET_REQUEST_TIMEOUT_SECONDS]
+
+
+def test_request_with_headers_uses_shared_asset_timeout() -> None:
+ recorded_timeout: list[float] = []
+
+ def fake_urlopen(url: str, timeout: float) -> _FakeResponse:
+ recorded_timeout.append(timeout)
+ assert url == "http://example.test/app.css"
+ return _FakeResponse("css")
+
+ with patch("tests.app_support.urlopen", side_effect=fake_urlopen):
+ body, headers = app_support.request_with_headers("http://example.test/app.css")
+
+ assert body == "css"
+ assert headers == {"Cache-Control": "no-store"}
+ assert recorded_timeout == [app_support._ASSET_REQUEST_TIMEOUT_SECONDS]
+
+
+def test_request_headers_uses_shared_asset_timeout_without_reading_body() -> None:
+ recorded_timeout: list[float] = []
+ response = _FakeResponse("body")
+
+ def fake_urlopen(url: str, timeout: float) -> _FakeResponse:
+ recorded_timeout.append(timeout)
+ assert url == "http://example.test/vendor.js"
+ return response
+
+ with patch("tests.app_support.urlopen", side_effect=fake_urlopen):
+ headers = app_support.request_headers("http://example.test/vendor.js")
+
+ assert headers == {"Cache-Control": "no-store"}
+ assert recorded_timeout == [app_support._ASSET_REQUEST_TIMEOUT_SECONDS]
+
+
+def test_read_asset_response_retries_transient_os_errors() -> None:
+ attempts = 0
+
+ def fake_urlopen(url: str, timeout: float) -> _FakeResponse:
+ nonlocal attempts
+ attempts += 1
+ assert url == "http://example.test/retry.js"
+ assert timeout == app_support._ASSET_REQUEST_TIMEOUT_SECONDS
+ if attempts < 3:
+ raise TimeoutError("temporary timeout")
+ return _FakeResponse("ok")
+
+ with patch("tests.app_support.urlopen", side_effect=fake_urlopen):
+ body, headers = app_support._read_asset_response("http://example.test/retry.js")
+
+ assert body == b"ok"
+ assert headers == {"Cache-Control": "no-store"}
+ assert attempts == 3
+
+
+def test_request_bytes_uses_shared_asset_fetcher() -> None:
+ with patch(
+ "tests.app_support._read_asset_response",
+ return_value=(b"icon", {"Content-Type": "image/x-icon"}),
+ ) as read_asset_response_mock:
+ body = app_support.request_bytes("http://example.test/favicon.ico")
+
+ assert body == b"icon"
+ read_asset_response_mock.assert_called_once_with("http://example.test/favicon.ico")
diff --git a/tests/test_browser_smoke.py b/tests/test_browser_smoke.py
index 5abb60e..0bdb758 100644
--- a/tests/test_browser_smoke.py
+++ b/tests/test_browser_smoke.py
@@ -6,12 +6,13 @@
import time
from pathlib import Path
from typing import Any
-from urllib.request import urlopen
+from urllib.request import Request, urlopen
import pytest
from tensor_network_editor.app.server import EditorServer
from tensor_network_editor.app.session import EditorSession
+from tests.app_support import _session_token_for_url
pytestmark = pytest.mark.browser
@@ -46,10 +47,22 @@ def _import_playwright_sync_api() -> Any:
def _request_json(url: str) -> dict[str, Any]:
"""Read one JSON response from a local editor server URL."""
- with urlopen(url, timeout=1) as response:
+ request = Request(url)
+ session_token = _session_token_for_url(url)
+ if session_token:
+ request.add_header("X-TNE-Session-Token", session_token)
+ with urlopen(request, timeout=1) as response:
return json.load(response)
+def test_browser_smoke_json_helper_sends_session_token(
+ editor_server: EditorServer,
+) -> None:
+ payload = _request_json(f"{editor_server.base_url}/api/bootstrap")
+
+ assert payload["app_metadata"]["version"]
+
+
def _wait_for_recoverable_draft_name(
draft_url: str,
expected_name: str,
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 6b6276f..f09b447 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -148,6 +148,10 @@ def empty_lint_report(_spec: NetworkSpec) -> LintReport:
return LintReport()
+def compact_help_text(help_text: str) -> str:
+ return " ".join(help_text.split())
+
+
def test_main_requires_a_subcommand(capsys: pytest.CaptureFixture[str]) -> None:
with patch("tensor_network_editor.cli.open_editor") as open_editor_mock:
exit_code = main([])
@@ -201,6 +205,49 @@ def test_global_python_import_arguments_are_accepted_before_subcommand() -> None
assert parsed_args.command == "edit"
+def test_top_level_help_includes_command_argument_quick_reference() -> None:
+ parser = build_command_parser()
+
+ help_text = compact_help_text(parser.format_help())
+
+ assert "Command argument quick reference:" in help_text
+ assert (
+ "tensor-network-editor export PATH --engine ENGINE [--output FILE]" in help_text
+ )
+ assert "tensor-network-editor template build TEMPLATE_NAME [options]" in help_text
+ assert "Run 'tensor-network-editor --help'" in help_text
+
+
+def test_export_help_describes_required_arguments(
+ capsys: pytest.CaptureFixture[str],
+) -> None:
+ exit_code = main(["export", "--help"])
+
+ assert exit_code == 0
+ help_text = capsys.readouterr().out
+ help_text = compact_help_text(help_text)
+ assert (
+ "Saved JSON design or supported generated Python file to export." in help_text
+ )
+ assert "Backend used for generated Python code." in help_text
+ assert "Write generated code to a file instead of stdout." in help_text
+
+
+def test_template_build_help_describes_template_options(
+ capsys: pytest.CaptureFixture[str],
+) -> None:
+ exit_code = main(["template", "build", "--help"])
+
+ assert exit_code == 0
+ help_text = capsys.readouterr().out
+ help_text = compact_help_text(help_text)
+ assert "Built-in template name to instantiate." in help_text
+ assert (
+ "Override the graph size parameter when the template supports it." in help_text
+ )
+ assert "Choose text or JSON output." in help_text
+
+
def test_cli_modules_pass_targeted_mypy_check() -> None:
result = subprocess.run(
[
@@ -253,6 +300,45 @@ def test_edit_subcommand_passes_explicit_log_file_path() -> None:
)
+def test_edit_subcommand_accepts_explicit_browser_ui_mode() -> None:
+ with patch("tensor_network_editor.cli.open_editor") as open_editor_mock:
+ exit_code = main(["edit", "--ui", "pywebview"])
+
+ assert exit_code == 0
+ open_editor_mock.assert_called_once_with(
+ spec=None,
+ options=EditorLaunchOptions(
+ ui_mode="pywebview",
+ open_browser=False,
+ ),
+ )
+
+
+def test_edit_subcommand_ui_server_matches_no_browser_alias() -> None:
+ with patch("tensor_network_editor.cli.open_editor") as open_editor_mock:
+ exit_code = main(["edit", "--ui", "server"])
+
+ assert exit_code == 0
+ open_editor_mock.assert_called_once_with(
+ spec=None,
+ options=EditorLaunchOptions(
+ ui_mode="server",
+ open_browser=False,
+ ),
+ )
+
+
+def test_edit_subcommand_rejects_ui_and_no_browser_combination(
+ capsys: pytest.CaptureFixture[str],
+) -> None:
+ with patch("tensor_network_editor.cli.open_editor") as open_editor_mock:
+ exit_code = main(["edit", "--ui", "browser", "--no-browser"])
+
+ assert exit_code == 2
+ open_editor_mock.assert_not_called()
+ assert "cannot combine --ui with --no-browser" in capsys.readouterr().err
+
+
def test_edit_subcommand_passes_explicit_log_rotation_settings() -> None:
with patch("tensor_network_editor.cli.open_editor") as open_editor_mock:
exit_code = main(
diff --git a/tests/test_frontend_architecture.py b/tests/test_frontend_architecture.py
index 3c4f190..94a2d0f 100644
--- a/tests/test_frontend_architecture.py
+++ b/tests/test_frontend_architecture.py
@@ -1197,6 +1197,7 @@ def test_editor_services_route_session_requests_through_explicit_dependencies(
await sessionService.generateCode({{
engine: "quimb",
collectionFormat: "dict",
+ includeRoundtripMetadata: true,
spec: {{ schema_version: 4, network: {{ id: "network_demo" }} }},
}});
await sessionService.renderSpec({{
@@ -1218,6 +1219,9 @@ def test_editor_services_route_session_requests_through_explicit_dependencies(
if (calls[1].payload.collection_format !== "dict") {{
throw new Error(`Expected collection_format=dict, received ${{calls[1].payload.collection_format}}.`);
}}
+ if (calls[1].payload.include_roundtrip_metadata !== true) {{
+ throw new Error(`Expected include_roundtrip_metadata=true, received ${{calls[1].payload.include_roundtrip_metadata}}.`);
+ }}
if (calls[2].path !== "/api/render") {{
throw new Error(`Unexpected render path: ${{calls[2].path}}`);
}}
@@ -3817,11 +3821,18 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings(
state.selectedEngine = engine;
storeCalls.push({{ step: "setSelectedEngine", engine }});
}},
- setSelectedCollectionFormat(collectionFormat) {{
- state.selectedCollectionFormat = collectionFormat;
- storeCalls.push({{ step: "setSelectedCollectionFormat", collectionFormat }});
- }},
- }};
+ setSelectedCollectionFormat(collectionFormat) {{
+ state.selectedCollectionFormat = collectionFormat;
+ storeCalls.push({{ step: "setSelectedCollectionFormat", collectionFormat }});
+ }},
+ setIncludeRoundtripMetadata(includeRoundtripMetadata) {{
+ state.includeRoundtripMetadata = Boolean(includeRoundtripMetadata);
+ storeCalls.push({{
+ step: "setIncludeRoundtripMetadata",
+ includeRoundtripMetadata: state.includeRoundtripMetadata,
+ }});
+ }},
+ }};
const flowEvents = [];
const bootstrapFlow = bootstrapFlowModule.createEditorBootstrapFlow({{
state,
@@ -3921,8 +3932,8 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings(
windowRef: {{
innerWidth: 800,
innerHeight: 600,
- addEventListener(type, handler) {{
- windowListeners.push(type);
+ addEventListener(type, handler, options) {{
+ windowListeners.push({{ type, options }});
}},
}},
}});
@@ -4007,12 +4018,29 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings(
generatedCodeModal: getButton("generated-code-modal"),
generatedCodeModalBackdrop: getButton("generated-code-modal-backdrop"),
generatedCodeModalCloseButton: getButton("generated-code-modal-close-button"),
+ codegenRoundtripMetadataField: getButton("codegen-roundtrip-metadata-field"),
+ codegenRoundtripMetadataCheckbox: {{
+ checked: false,
+ listeners: {{}},
+ addEventListener(type, handler) {{ this.listeners[type] = handler; }},
+ change(event) {{
+ this.checked = Boolean(event?.target?.checked);
+ this.listeners.change?.(event);
+ }},
+ }},
templateSelectField: getButton("template-select-field"),
engineSelectField: getButton("engine-select-field"),
collectionFormatSelectField: getButton("collection-format-select-field"),
templateSelect: {{
value: "mps",
- addEventListener(type, handler) {{ this[type] = handler; }},
+ listeners: {{}},
+ addEventListener(type, handler) {{ this.listeners[type] = handler; }},
+ mousedown(event) {{ this.listeners.mousedown?.(event); }},
+ change(event) {{ this.listeners.change?.(event); }},
+ blur() {{
+ flowEvents.push("templateSelect.blur");
+ this.listeners.blur?.({{ target: this }});
+ }},
}},
templateSettingsButton: getButton("template-settings-button"),
templateSettingsPopover: getButton("template-settings-popover"),
@@ -4026,11 +4054,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings(
editSessionTemplateMenuItem: getButton("edit-session-template-menu-item"),
openSubnetworkLibraryMenuItem: getButton("open-subnetwork-library-menu-item"),
reflowImportedButton: getButton("reflow-imported-button"),
- reflowAlignLeftButton: getButton("reflow-align-left-button"),
- reflowAlignRightButton: getButton("reflow-align-right-button"),
- reflowAlignTopButton: getButton("reflow-align-top-button"),
- reflowAlignMiddleButton: getButton("reflow-align-middle-button"),
- reflowAlignBottomButton: getButton("reflow-align-bottom-button"),
+ reflowAlignHorizontalButton: getButton("reflow-align-horizontal-button"),
+ reflowAlignVerticalButton: getButton("reflow-align-vertical-button"),
+ reflowRotateSelectionButton: getButton("reflow-rotate-selection-button"),
reflowIndicesLeftButton: getButton("reflow-indices-left-button"),
reflowIndicesRightButton: getButton("reflow-indices-right-button"),
reflowIndicesTopButton: getButton("reflow-indices-top-button"),
@@ -4198,8 +4224,8 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings(
dom,
documentRef: tooltipDocument,
windowRef: {{
- addEventListener(type, handler) {{
- windowListeners.push(type);
+ addEventListener(type, handler, options) {{
+ windowListeners.push({{ type, options }});
}},
}},
actions: shellActions,
@@ -4211,6 +4237,7 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings(
getButton("expand-generated-code-button").click();
dom.generatedCodeModalBackdrop.click();
dom.generatedCodeModalCloseButton.click();
+ dom.codegenRoundtripMetadataCheckbox.change({{ target: {{ checked: true }} }});
dom.engineSelect.change({{ target: {{ value: "cotengra" }} }});
dom.fileMenuButton.click();
dom.exportSubmenuShell.mouseenter();
@@ -4230,6 +4257,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings(
dom.templateSettingsButton.click();
dom.reflowImportedButton.click();
dom.reflowAutoLayoutButton.click();
+ dom.reflowAlignHorizontalButton.click();
+ dom.reflowAlignVerticalButton.click();
+ dom.reflowRotateSelectionButton.click();
dom.reflowArrangeGridButton.click();
dom.reflowIndicesResetButton.click();
dom.templateManagerCloseButton.click();
@@ -4249,6 +4279,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings(
if (dom.templateSelectField.attributes["data-expanded"] !== "false") {{
throw new Error("Expected template select change to collapse the disclosure indicator.");
}}
+ if (!flowEvents.includes("templateSelect.blur")) {{
+ throw new Error("Expected template selection changes to blur the dropdown so keyboard shortcuts do not stay trapped in the select.");
+ }}
dom.engineSelect.mousedown({{ target: dom.engineSelect }});
if (dom.engineSelectField.attributes["data-expanded"] !== "true") {{
throw new Error("Expected engine select mouse down to mark the disclosure as expanded.");
@@ -4265,6 +4298,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings(
if (dom.collectionFormatSelectField.attributes["data-expanded"] !== "false") {{
throw new Error("Expected collection format select change to collapse the disclosure indicator.");
}}
+ if (state.includeRoundtripMetadata !== true) {{
+ throw new Error(`Expected metadata checkbox changes to update state, received ${{state.includeRoundtripMetadata}}.`);
+ }}
if (!flowEvents.includes("generateCode")) {{
throw new Error(`Expected toolbar generate binding to invoke the injected action, received ${{JSON.stringify(flowEvents)}}.`);
}}
@@ -4330,6 +4366,15 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings(
if (!flowEvents.includes("applyReflowLayoutAction:auto")) {{
throw new Error(`Expected the Auto layout action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`);
}}
+ if (!flowEvents.includes("applyReflowLayoutAction:align-horizontal")) {{
+ throw new Error(`Expected the horizontal alignment action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`);
+ }}
+ if (!flowEvents.includes("applyReflowLayoutAction:align-vertical")) {{
+ throw new Error(`Expected the vertical alignment action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`);
+ }}
+ if (!flowEvents.includes("applyReflowLayoutAction:rotate-90")) {{
+ throw new Error(`Expected the rotate action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`);
+ }}
if (!flowEvents.includes("applyReflowLayoutAction:grid")) {{
throw new Error(`Expected the Reflow popover actions to dispatch the requested layout, received ${{JSON.stringify(flowEvents)}}.`);
}}
@@ -4415,6 +4460,14 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings(
) {{
throw new Error("Expected the Code tab to expose its tooltip description.");
}}
+ const keydownBinding = windowListeners.find(
+ (entry) => entry && entry.type === "keydown"
+ );
+ if (!keydownBinding || keydownBinding.options !== true) {{
+ throw new Error(
+ `Expected the global keydown shortcut listener to register in capture mode, received ${{JSON.stringify(windowListeners)}}.`
+ );
+ }}
""",
)
@@ -4705,8 +4758,12 @@ def test_editor_shell_helper_modules_expose_explicit_ui_and_invalidation_adapter
throw new Error("Session UI confirm adapter should forward the injected result.");
}}
await sessionUi.copyText("result = 1");
- sessionUi.downloadText("demo.json", "{{}}", "application/json");
- sessionUi.downloadBlob("demo.py", {{ type: "text/x-python" }});
+ await Promise.resolve(
+ sessionUi.downloadText("demo.json", "{{}}", "application/json")
+ );
+ await Promise.resolve(
+ sessionUi.downloadBlob("demo.py", {{ type: "text/x-python" }})
+ );
sessionUi.closeWindow();
if (!uiEvents.some((event) => event.kind === "copy" && event.text === "result = 1")) {{
throw new Error(`Expected injected copy adapter to run, received ${{JSON.stringify(uiEvents)}}.`);
@@ -4876,6 +4933,478 @@ def test_editor_shell_helper_modules_expose_explicit_ui_and_invalidation_adapter
)
+@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
+def test_session_ui_adapters_use_pywebview_save_api_when_available(
+ tmp_path: Path,
+) -> None:
+ script_path = _write_runtime_script(
+ tmp_path,
+ "session_ui_pywebview_save.mjs",
+ f"""
+ import {{ pathToFileURL }} from "node:url";
+
+ const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href;
+ const sessionUiModule = await import(sessionUiUrl);
+
+ const calls = [];
+ class FakeBlob {{
+ constructor(parts, options = {{}}) {{
+ this.parts = parts;
+ this.type = options.type || "";
+ }}
+
+ async arrayBuffer() {{
+ const firstPart = this.parts[0];
+ if (!(firstPart instanceof Uint8Array)) {{
+ throw new Error("Expected the test blob to receive Uint8Array content.");
+ }}
+ return firstPart.buffer.slice(
+ firstPart.byteOffset,
+ firstPart.byteOffset + firstPart.byteLength
+ );
+ }}
+ }}
+ const sessionUi = sessionUiModule.createSessionUiAdapters({{
+ windowRef: {{
+ pywebview: {{
+ api: {{
+ async save_text_file(filename, text, contentType) {{
+ calls.push({{ type: "text", filename, text, contentType }});
+ return true;
+ }},
+ async save_binary_file(filename, base64Payload, contentType) {{
+ calls.push({{ type: "binary", filename, base64Payload, contentType }});
+ return true;
+ }},
+ }},
+ }},
+ }},
+ blobCtor: FakeBlob,
+ }});
+
+ const textSaved = await sessionUi.downloadText(
+ "demo.json",
+ "{{\\"ok\\":true}}",
+ "application/json;charset=utf-8"
+ );
+ const binarySaved = await sessionUi.downloadBlob(
+ "demo.pdf",
+ new FakeBlob([Uint8Array.from([0, 1, 2, 255])], {{ type: "application/pdf" }})
+ );
+
+ if (textSaved !== true || binarySaved !== true) {{
+ throw new Error(`Expected pywebview saves to resolve true, received ${{JSON.stringify({{ textSaved, binarySaved }})}}.`);
+ }}
+ const textCall = calls.find((entry) => entry.type === "text");
+ const binaryCall = calls.find((entry) => entry.type === "binary");
+ if (!textCall || textCall.filename !== "demo.json") {{
+ throw new Error(`Expected text export to use the pywebview API, received ${{JSON.stringify(calls)}}.`);
+ }}
+ if (
+ !binaryCall ||
+ binaryCall.filename !== "demo.pdf" ||
+ binaryCall.base64Payload !== "AAEC/w=="
+ ) {{
+ throw new Error(`Expected binary export to send base64 bytes through the pywebview API, received ${{JSON.stringify(calls)}}.`);
+ }}
+ """,
+ )
+ completed_process = subprocess.run(
+ ["node", str(script_path)],
+ cwd=REPO_ROOT,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert completed_process.returncode == 0, (
+ "The pywebview session-ui adapter runtime script failed.\n"
+ f"STDOUT:\n{completed_process.stdout}\n"
+ f"STDERR:\n{completed_process.stderr}"
+ )
+
+
+@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
+def test_session_ui_adapters_detect_pywebview_save_api_added_after_creation(
+ tmp_path: Path,
+) -> None:
+ script_path = _write_runtime_script(
+ tmp_path,
+ "session_ui_pywebview_late_save.mjs",
+ f"""
+ import {{ pathToFileURL }} from "node:url";
+
+ const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href;
+ const sessionUiModule = await import(sessionUiUrl);
+
+ const windowRef = {{}};
+ const uiCalls = [];
+ const sessionUi = sessionUiModule.createSessionUiAdapters({{
+ windowRef,
+ documentRef: {{
+ createElement() {{
+ uiCalls.push({{ type: "web-download" }});
+ return {{
+ click() {{
+ uiCalls.push({{ type: "web-download-click" }});
+ }},
+ }};
+ }},
+ }},
+ urlRef: {{
+ createObjectURL() {{
+ return "blob:test";
+ }},
+ revokeObjectURL() {{
+ return undefined;
+ }},
+ }},
+ blobCtor: class FakeBlob {{
+ constructor(parts, options = {{}}) {{
+ this.parts = parts;
+ this.type = options.type || "";
+ }}
+ }},
+ }});
+
+ windowRef.pywebview = {{
+ api: {{
+ async save_text_file(filename, text, contentType) {{
+ uiCalls.push({{ type: "pywebview", filename, text, contentType }});
+ return true;
+ }},
+ async save_binary_file() {{
+ throw new Error("Unexpected binary save in text export test.");
+ }},
+ }},
+ }};
+
+ await sessionUi.downloadText(
+ "late.json",
+ "{{\\"late\\": true}}",
+ "application/json;charset=utf-8"
+ );
+
+ if (!uiCalls.some((entry) => entry.type === "pywebview")) {{
+ throw new Error(`Expected late pywebview injection to be honored, received ${{JSON.stringify(uiCalls)}}.`);
+ }}
+ if (uiCalls.some((entry) => entry.type === "web-download")) {{
+ throw new Error(`Expected pywebview save path instead of web download fallback, received ${{JSON.stringify(uiCalls)}}.`);
+ }}
+ """,
+ )
+ completed_process = subprocess.run(
+ ["node", str(script_path)],
+ cwd=REPO_ROOT,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert completed_process.returncode == 0, (
+ "The late pywebview session-ui adapter runtime script failed.\n"
+ f"STDOUT:\n{completed_process.stdout}\n"
+ f"STDERR:\n{completed_process.stderr}"
+ )
+
+
+@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
+def test_session_ui_adapters_use_partial_pywebview_text_api_when_available(
+ tmp_path: Path,
+) -> None:
+ script_path = _write_runtime_script(
+ tmp_path,
+ "session_ui_pywebview_partial_text_save.mjs",
+ f"""
+ import {{ pathToFileURL }} from "node:url";
+
+ const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href;
+ const sessionUiModule = await import(sessionUiUrl);
+
+ const calls = [];
+ const sessionUi = sessionUiModule.createSessionUiAdapters({{
+ windowRef: {{
+ pywebview: {{
+ api: {{
+ async save_text_file(filename, text, contentType) {{
+ calls.push({{ type: "text", filename, text, contentType }});
+ return true;
+ }},
+ }},
+ }},
+ }},
+ documentRef: {{
+ createElement() {{
+ calls.push({{ type: "web-download" }});
+ return {{
+ click() {{
+ calls.push({{ type: "web-download-click" }});
+ }},
+ }};
+ }},
+ }},
+ urlRef: {{
+ createObjectURL() {{
+ calls.push({{ type: "object-url" }});
+ return "blob:test";
+ }},
+ revokeObjectURL() {{
+ return undefined;
+ }},
+ }},
+ blobCtor: class FakeBlob {{
+ constructor(parts, options = {{}}) {{
+ this.parts = parts;
+ this.type = options.type || "";
+ }}
+ }},
+ }});
+
+ const saved = await sessionUi.downloadText(
+ "partial.json",
+ "{{\\"partial\\": true}}",
+ "application/json;charset=utf-8"
+ );
+
+ if (saved !== true) {{
+ throw new Error(`Expected the partial pywebview text save to resolve true, received ${{saved}}.`);
+ }}
+ if (!calls.some((entry) => entry.type === "text")) {{
+ throw new Error(`Expected downloadText() to use save_text_file(), received ${{JSON.stringify(calls)}}.`);
+ }}
+ if (calls.some((entry) => entry.type === "web-download" || entry.type === "object-url")) {{
+ throw new Error(`Expected no web-download fallback when save_text_file() exists, received ${{JSON.stringify(calls)}}.`);
+ }}
+ """,
+ )
+ completed_process = subprocess.run(
+ ["node", str(script_path)],
+ cwd=REPO_ROOT,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert completed_process.returncode == 0, (
+ "The partial pywebview text-save runtime script failed.\n"
+ f"STDOUT:\n{completed_process.stdout}\n"
+ f"STDERR:\n{completed_process.stderr}"
+ )
+
+
+@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
+def test_session_ui_adapters_use_partial_pywebview_binary_api_when_available(
+ tmp_path: Path,
+) -> None:
+ script_path = _write_runtime_script(
+ tmp_path,
+ "session_ui_pywebview_partial_binary_save.mjs",
+ f"""
+ import {{ pathToFileURL }} from "node:url";
+
+ const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href;
+ const sessionUiModule = await import(sessionUiUrl);
+
+ const calls = [];
+ class FakeBlob {{
+ constructor(parts, options = {{}}) {{
+ this.parts = parts;
+ this.type = options.type || "";
+ }}
+
+ async arrayBuffer() {{
+ const firstPart = this.parts[0];
+ if (!(firstPart instanceof Uint8Array)) {{
+ throw new Error("Expected Uint8Array content in the binary export test blob.");
+ }}
+ return firstPart.buffer.slice(
+ firstPart.byteOffset,
+ firstPart.byteOffset + firstPart.byteLength
+ );
+ }}
+ }}
+ const sessionUi = sessionUiModule.createSessionUiAdapters({{
+ windowRef: {{
+ pywebview: {{
+ api: {{
+ async save_binary_file(filename, base64Payload, contentType) {{
+ calls.push({{ type: "binary", filename, base64Payload, contentType }});
+ return true;
+ }},
+ }},
+ }},
+ }},
+ documentRef: {{
+ createElement() {{
+ calls.push({{ type: "web-download" }});
+ return {{
+ click() {{
+ calls.push({{ type: "web-download-click" }});
+ }},
+ }};
+ }},
+ }},
+ urlRef: {{
+ createObjectURL() {{
+ calls.push({{ type: "object-url" }});
+ return "blob:test";
+ }},
+ revokeObjectURL() {{
+ return undefined;
+ }},
+ }},
+ blobCtor: FakeBlob,
+ }});
+
+ const saved = await sessionUi.downloadBlob(
+ "partial.pdf",
+ new FakeBlob([Uint8Array.from([0, 1, 2, 255])], {{ type: "application/pdf" }})
+ );
+
+ if (saved !== true) {{
+ throw new Error(`Expected the partial pywebview binary save to resolve true, received ${{saved}}.`);
+ }}
+ const binaryCall = calls.find((entry) => entry.type === "binary");
+ if (!binaryCall || binaryCall.base64Payload !== "AAEC/w==") {{
+ throw new Error(`Expected downloadBlob() to use save_binary_file(), received ${{JSON.stringify(calls)}}.`);
+ }}
+ if (calls.some((entry) => entry.type === "web-download" || entry.type === "object-url")) {{
+ throw new Error(`Expected no web-download fallback when save_binary_file() exists, received ${{JSON.stringify(calls)}}.`);
+ }}
+ """,
+ )
+ completed_process = subprocess.run(
+ ["node", str(script_path)],
+ cwd=REPO_ROOT,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert completed_process.returncode == 0, (
+ "The partial pywebview binary-save runtime script failed.\n"
+ f"STDOUT:\n{completed_process.stdout}\n"
+ f"STDERR:\n{completed_process.stderr}"
+ )
+
+
+@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
+def test_start_editor_bootstraps_immediately_when_dom_is_already_ready(
+ tmp_path: Path,
+) -> None:
+ bootstrap_source = (
+ REPO_ROOT
+ / "src"
+ / "tensor_network_editor"
+ / "app"
+ / "static"
+ / "js"
+ / "bootstrap.js"
+ ).read_text(encoding="utf-8")
+ (tmp_path / "shell").mkdir(parents=True, exist_ok=True)
+ (tmp_path / "bootstrap.js").write_text(bootstrap_source, encoding="utf-8")
+ (tmp_path / "shell" / "editorBootstrapFlow.js").write_text(
+ """
+ export function createEditorBootstrapFlow() {
+ return {
+ async bootstrap() {
+ globalThis.__bootstrapCalls.push("bootstrap");
+ return {};
+ },
+ };
+ }
+ """,
+ encoding="utf-8",
+ )
+ (tmp_path / "shell" / "shellActions.js").write_text(
+ """
+ export function createShellActions() {
+ return {
+ setStatus(message, level = "info") {
+ globalThis.__bootstrapCalls.push(`status:${level}:${message}`);
+ },
+ };
+ }
+ """,
+ encoding="utf-8",
+ )
+ (tmp_path / "shell" / "editorShellBindings.js").write_text(
+ """
+ export function createEditorShellBindings() {
+ return {
+ attachToolbarHandlers() {
+ globalThis.__bootstrapCalls.push("attachToolbarHandlers");
+ },
+ };
+ }
+ """,
+ encoding="utf-8",
+ )
+ (tmp_path / "shell" / "shortcutTooltip.js").write_text(
+ """
+ export function createShortcutTooltip() {
+ return {
+ attachShortcutTooltipHandlers() {},
+ };
+ }
+ """,
+ encoding="utf-8",
+ )
+ script_path = _write_runtime_script(
+ tmp_path,
+ "bootstrap_dom_ready.mjs",
+ """
+ globalThis.__bootstrapCalls = [];
+ const bootstrapUrl = new URL("./bootstrap.js", import.meta.url).href;
+ const bootstrapModule = await import(bootstrapUrl);
+
+ const documentRef = {
+ readyState: "complete",
+ addEventListener(type, handler) {
+ globalThis.__bootstrapCalls.push(`listener:${type}`);
+ this.listener = handler;
+ },
+ };
+ const ctx = {
+ state: {},
+ store: {},
+ window: {
+ confirm() {
+ return false;
+ },
+ },
+ document: documentRef,
+ services: { session: {} },
+ logger: null,
+ constants: { REDO_SHORTCUT_LABEL: "Ctrl+Shift+Z" },
+ };
+
+ bootstrapModule.startEditor(ctx);
+ await Promise.resolve();
+
+ if (!globalThis.__bootstrapCalls.includes("attachToolbarHandlers")) {
+ throw new Error(`Expected toolbar handlers to attach immediately, received ${JSON.stringify(globalThis.__bootstrapCalls)}.`);
+ }
+ if (!globalThis.__bootstrapCalls.includes("bootstrap")) {
+ throw new Error(`Expected bootstrap to run immediately for a ready document, received ${JSON.stringify(globalThis.__bootstrapCalls)}.`);
+ }
+ """,
+ )
+ completed_process = subprocess.run(
+ ["node", str(script_path)],
+ cwd=REPO_ROOT,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert completed_process.returncode == 0, (
+ "The bootstrap DOM-ready runtime script failed.\n"
+ f"STDOUT:\n{completed_process.stdout}\n"
+ f"STDERR:\n{completed_process.stderr}"
+ )
+
+
@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
def test_benchmark_helper_modules_build_comparison_rows_and_history_state(
tmp_path: Path,
@@ -4992,7 +5521,21 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state(
const historyEvents = [];
const historyState = {{
- spec: {{ id: "network_demo" }},
+ spec: {{
+ id: "network_demo",
+ contraction_plan: {{
+ id: "scheme_beta",
+ name: "Beta",
+ steps: [],
+ view_snapshots: [
+ {{
+ applied_step_count: 0,
+ operand_layouts: [{{ operand_id: "tensor_a" }}],
+ }},
+ ],
+ metadata: {{}},
+ }},
+ }},
tensorOrder: ["tensor_a"],
undoStack: [],
redoStack: [],
@@ -5014,13 +5557,47 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state(
benchmarkSession: {{
enabled: true,
activePosition: 2,
- originalPlan: {{ id: "original_plan", name: "Original", steps: [], metadata: {{}} }},
+ originalPlan: {{
+ id: "original_plan",
+ name: "Original",
+ steps: [],
+ view_snapshots: [
+ {{
+ applied_step_count: 0,
+ operand_layouts: [{{ operand_id: "original_tensor" }}],
+ }},
+ ],
+ metadata: {{}},
+ }},
schemes: [
- {{ id: "scheme_alpha", name: "Alpha", steps: [], metadata: {{}} }},
- {{ id: "scheme_beta", name: "Beta", steps: [], metadata: {{}} }},
+ {{
+ id: "scheme_alpha",
+ name: "Alpha",
+ steps: [],
+ view_snapshots: [
+ {{
+ applied_step_count: 0,
+ operand_layouts: [{{ operand_id: "alpha_tensor" }}],
+ }},
+ ],
+ metadata: {{}},
+ }},
+ {{
+ id: "scheme_beta",
+ name: "Beta",
+ steps: [],
+ view_snapshots: [
+ {{
+ applied_step_count: 0,
+ operand_layouts: [{{ operand_id: "beta_tensor" }}],
+ }},
+ ],
+ metadata: {{}},
+ }},
],
compareModal: {{
open: true,
+ tableModel: {{ rows: [{{ scheme_id: "scheme_alpha" }}] }},
rows: [{{ scheme_id: "scheme_alpha" }}],
activeRequestId: 7,
}},
@@ -5050,19 +5627,80 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state(
if (!snapshot.benchmarkSession || snapshot.benchmarkSession.activePosition !== 2) {{
throw new Error(`Expected history snapshots to capture benchmark session state, received ${{JSON.stringify(snapshot)}}.`);
}}
+ if (snapshot.benchmarkSession.compareModal.open || snapshot.benchmarkSession.compareModal.activeRequestId !== 0) {{
+ throw new Error(`Expected history snapshots to reset ephemeral benchmark compare state, received ${{JSON.stringify(snapshot.benchmarkSession.compareModal)}}.`);
+ }}
+ if (snapshot.benchmarkSession.compareModal.rows.length !== 0 || snapshot.benchmarkSession.compareModal.tableModel !== null) {{
+ throw new Error(`Expected history snapshots to strip compare rows and table models, received ${{JSON.stringify(snapshot.benchmarkSession.compareModal)}}.`);
+ }}
+ if (snapshot.benchmarkSession.originalPlan.view_snapshots.length !== 0) {{
+ throw new Error(`Expected history snapshots to strip original-plan view snapshots, received ${{JSON.stringify(snapshot.benchmarkSession.originalPlan)}}.`);
+ }}
+ if (snapshot.benchmarkSession.schemes.some((scheme) => scheme.view_snapshots.length !== 0)) {{
+ throw new Error(`Expected history snapshots to strip inactive benchmark view snapshots, received ${{JSON.stringify(snapshot.benchmarkSession.schemes)}}.`);
+ }}
+ if (snapshot.spec.contraction_plan.view_snapshots.length !== 1) {{
+ throw new Error(`Expected the active scheme view snapshots to stay in the main spec snapshot, received ${{JSON.stringify(snapshot.spec.contraction_plan)}}.`);
+ }}
historySupport.restoreHistorySnapshot({{
- spec: {{ id: "restored_network" }},
+ spec: {{
+ id: "restored_network",
+ contraction_plan: {{
+ id: "scheme_restored",
+ name: "Restored",
+ steps: [],
+ view_snapshots: [
+ {{
+ applied_step_count: 1,
+ operand_layouts: [{{ operand_id: "restored_tensor" }}],
+ }},
+ ],
+ metadata: {{}},
+ }},
+ }},
tensorOrder: ["tensor_b"],
benchmarkSession: {{
enabled: true,
activePosition: 1,
- originalPlan: null,
- schemes: [{{ id: "scheme_restored", name: "Restored", steps: [], metadata: {{}} }}],
+ originalPlan: {{
+ id: "restored_original",
+ name: "Restored Original",
+ steps: [],
+ view_snapshots: [
+ {{
+ applied_step_count: 0,
+ operand_layouts: [{{ operand_id: "restored_original_tensor" }}],
+ }},
+ ],
+ metadata: {{}},
+ }},
+ schemes: [
+ {{
+ id: "scheme_restored",
+ name: "Restored",
+ steps: [],
+ view_snapshots: [],
+ metadata: {{}},
+ }},
+ {{
+ id: "scheme_inactive",
+ name: "Inactive",
+ steps: [],
+ view_snapshots: [
+ {{
+ applied_step_count: 0,
+ operand_layouts: [{{ operand_id: "inactive_tensor" }}],
+ }},
+ ],
+ metadata: {{}},
+ }},
+ ],
compareModal: {{
- open: false,
- rows: [],
- activeRequestId: 0,
+ open: true,
+ tableModel: {{ rows: [{{ scheme_id: "scheme_restored" }}] }},
+ rows: [{{ scheme_id: "scheme_restored" }}],
+ activeRequestId: 9,
}},
}},
}});
@@ -5070,6 +5708,21 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state(
if (!historyState.benchmarkSession || historyState.benchmarkSession.activePosition !== 1) {{
throw new Error(`Expected history restore to recover benchmark session state, received ${{JSON.stringify(historyState.benchmarkSession)}}.`);
}}
+ if (historyState.benchmarkSession.compareModal.open || historyState.benchmarkSession.compareModal.rows.length !== 0 || historyState.benchmarkSession.compareModal.tableModel !== null) {{
+ throw new Error(`Expected history restore to keep benchmark compare state ephemeral, received ${{JSON.stringify(historyState.benchmarkSession.compareModal)}}.`);
+ }}
+ if (historyState.benchmarkSession.originalPlan.view_snapshots.length !== 0) {{
+ throw new Error(`Expected history restore to keep original-plan snapshots lazy, received ${{JSON.stringify(historyState.benchmarkSession.originalPlan)}}.`);
+ }}
+ if (historyState.benchmarkSession.schemes[1].view_snapshots.length !== 0) {{
+ throw new Error(`Expected inactive benchmark schemes to stay lightweight after restore, received ${{JSON.stringify(historyState.benchmarkSession.schemes)}}.`);
+ }}
+ if (historyState.benchmarkSession.schemes[0] !== historyState.spec.contraction_plan) {{
+ throw new Error("Expected history restore to re-link the active benchmark scheme to the restored contraction plan.");
+ }}
+ if (historyState.spec.contraction_plan.view_snapshots.length !== 1) {{
+ throw new Error(`Expected the restored active scheme to keep its exact view snapshots, received ${{JSON.stringify(historyState.spec.contraction_plan)}}.`);
+ }}
""",
)
diff --git a/tests/test_frontend_runtime.py b/tests/test_frontend_runtime.py
index 6a3f9dd..24e9a13 100644
--- a/tests/test_frontend_runtime.py
+++ b/tests/test_frontend_runtime.py
@@ -986,6 +986,198 @@ def test_api_service_logs_request_lifecycle_with_frontend_logger(
)
+@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
+def test_api_service_sends_session_token_header(
+ tmp_path: Path,
+) -> None:
+ script_path = tmp_path / "api_session_token_header.mjs"
+ script_path.write_text(
+ textwrap.dedent(
+ f"""
+ import {{ pathToFileURL }} from "node:url";
+
+ const apiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "services" / "api.js")!r}).href;
+ const apiModule = await import(apiUrl);
+ const calls = [];
+
+ function headerValue(headers, name) {{
+ if (headers && typeof headers.get === "function") {{
+ return headers.get(name);
+ }}
+ return headers?.[name] || headers?.[name.toLowerCase()] || null;
+ }}
+
+ globalThis.fetch = async (path, options = {{}}) => {{
+ calls.push({{ path, options }});
+ return new Response(JSON.stringify({{ ok: true }}), {{
+ status: 200,
+ headers: {{ "Content-Type": "application/json" }},
+ }});
+ }};
+
+ await apiModule.apiGet("/api/bootstrap", {{
+ apiToken: "secret-token",
+ }});
+ await apiModule.apiPost("/api/cancel", {{}}, {{
+ apiToken: "secret-token",
+ }});
+
+ if (calls.length !== 2) {{
+ throw new Error(`Expected two calls, received ${{calls.length}}.`);
+ }}
+ for (const call of calls) {{
+ const token = headerValue(call.options.headers, "X-TNE-Session-Token");
+ if (token !== "secret-token") {{
+ throw new Error(`Missing session token header: ${{JSON.stringify(call)}}`);
+ }}
+ }}
+ const contentType = headerValue(calls[1].options.headers, "Content-Type");
+ if (contentType !== "application/json") {{
+ throw new Error(`Missing JSON content type: ${{JSON.stringify(calls[1])}}`);
+ }}
+ """
+ ),
+ encoding="utf-8",
+ )
+
+ completed_process = subprocess.run(
+ ["node", str(script_path)],
+ cwd=REPO_ROOT,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert completed_process.returncode == 0, (
+ "The api session token header script failed.\n"
+ f"STDOUT:\n{completed_process.stdout}\n"
+ f"STDERR:\n{completed_process.stderr}"
+ )
+
+
+@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
+def test_editor_context_passes_runtime_api_token_to_requests(
+ tmp_path: Path,
+) -> None:
+ script_path = tmp_path / "editor_context_api_token.mjs"
+ script_path.write_text(
+ textwrap.dedent(
+ f"""
+ import {{ pathToFileURL }} from "node:url";
+
+ const contextUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "core" / "editorContext.js")!r}).href;
+ const contextModule = await import(contextUrl);
+ const calls = [];
+ const documentRef = {{
+ getElementById() {{
+ return null;
+ }},
+ querySelector() {{
+ return null;
+ }},
+ }};
+
+ function headerValue(headers, name) {{
+ if (headers && typeof headers.get === "function") {{
+ return headers.get(name);
+ }}
+ return headers?.[name] || headers?.[name.toLowerCase()] || null;
+ }}
+
+ globalThis.fetch = async (path, options = {{}}) => {{
+ calls.push({{ path, options }});
+ return new Response(JSON.stringify({{ ok: true }}), {{
+ status: 200,
+ headers: {{ "Content-Type": "application/json" }},
+ }});
+ }};
+
+ const ctx = contextModule.createEditorContext({{
+ window: {{}},
+ document: documentRef,
+ cytoscape: null,
+ runtimeConfig: {{ apiToken: "runtime-secret" }},
+ }});
+ await ctx.apiPost("/api/cancel", {{}});
+
+ const token = headerValue(calls[0]?.options?.headers, "X-TNE-Session-Token");
+ if (token !== "runtime-secret") {{
+ throw new Error(`Missing context token header: ${{JSON.stringify(calls)}}`);
+ }}
+ """
+ ),
+ encoding="utf-8",
+ )
+
+ completed_process = subprocess.run(
+ ["node", str(script_path)],
+ cwd=REPO_ROOT,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert completed_process.returncode == 0, (
+ "The editor context API token script failed.\n"
+ f"STDOUT:\n{completed_process.stdout}\n"
+ f"STDERR:\n{completed_process.stderr}"
+ )
+
+
+@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
+def test_runtime_config_reader_normalizes_api_token(
+ tmp_path: Path,
+) -> None:
+ script_path = tmp_path / "runtime_config_api_token.mjs"
+ script_path.write_text(
+ textwrap.dedent(
+ f"""
+ import {{ pathToFileURL }} from "node:url";
+
+ const loggerUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "core" / "frontendLogger.js")!r}).href;
+ const loggerModule = await import(loggerUrl);
+ const documentRef = {{
+ getElementById(id) {{
+ if (id !== "tne-runtime-config") {{
+ return null;
+ }}
+ return {{
+ textContent: JSON.stringify({{
+ session_id: "session-1",
+ api_token: "embedded-token",
+ frontend_logging: {{ enabled: false }},
+ }}),
+ }};
+ }},
+ }};
+
+ const config = loggerModule.readFrontendRuntimeConfig({{ documentRef }});
+ if (config.sessionId !== "session-1") {{
+ throw new Error(`Unexpected session id: ${{JSON.stringify(config)}}`);
+ }}
+ if (config.apiToken !== "embedded-token") {{
+ throw new Error(`Unexpected API token: ${{JSON.stringify(config)}}`);
+ }}
+ """
+ ),
+ encoding="utf-8",
+ )
+
+ completed_process = subprocess.run(
+ ["node", str(script_path)],
+ cwd=REPO_ROOT,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert completed_process.returncode == 0, (
+ "The runtime config API token script failed.\n"
+ f"STDOUT:\n{completed_process.stdout}\n"
+ f"STDERR:\n{completed_process.stderr}"
+ )
+
+
@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
def test_frontend_logger_persists_batched_logs_without_api_recursion(
tmp_path: Path,
@@ -7022,33 +7214,631 @@ def _write_port_layering_runtime_regression_script(tmp_path: Path) -> Path:
}},
}});
- const model = builder();
- const zIndexFor = (elementId) => model.descriptorsById[elementId].data.zIndex;
- const frontTensorZIndex = zIndexFor("tensor_front");
+ const model = builder();
+ const zIndexFor = (elementId) => model.descriptorsById[elementId].data.zIndex;
+ const frontTensorZIndex = zIndexFor("tensor_front");
+
+ if (!(zIndexFor("tensor_back") < zIndexFor("back_open"))) {{
+ throw new Error("An open port should still sit above its owning tensor.");
+ }}
+ if (!(zIndexFor("back_open") < frontTensorZIndex)) {{
+ throw new Error(
+ `An open port from a rear tensor should not cover a front tensor: open=${{zIndexFor("back_open")}}, front=${{frontTensorZIndex}}.`
+ );
+ }}
+ if (!(zIndexFor("back_connected") > frontTensorZIndex)) {{
+ throw new Error(
+ `A connected port should stay above tensors so connections remain visible: connected=${{zIndexFor("back_connected")}}, front=${{frontTensorZIndex}}.`
+ );
+ }}
+
+ state.selectionIds = ["tensor_back"];
+ const selectedModel = builder();
+ const selectedZIndexFor = (elementId) =>
+ selectedModel.descriptorsById[elementId].data.zIndex;
+ if (!(selectedZIndexFor("back_open") > selectedZIndexFor("tensor_front"))) {{
+ throw new Error(
+ `A selected tensor should keep its open ports visible above front tensors: open=${{selectedZIndexFor("back_open")}}, front=${{selectedZIndexFor("tensor_front")}}.`
+ );
+ }}
+ """
+ ),
+ encoding="utf-8",
+ )
+ return script_path
+
+
+def _write_contraction_scene_port_layering_runtime_regression_script(
+ tmp_path: Path,
+) -> Path:
+ script_path = tmp_path / "contraction_scene_port_layering_runtime_regression.mjs"
+ geometry_module_path = (
+ REPO_ROOT
+ / "src"
+ / "tensor_network_editor"
+ / "app"
+ / "static"
+ / "js"
+ / "utils/utilitiesGeometry.js"
+ )
+ script_path.write_text(
+ textwrap.dedent(
+ f"""
+ import {{ pathToFileURL }} from "node:url";
+
+ const geometryUrl = pathToFileURL({str(geometry_module_path)!r}).href;
+ const {{ createUtilityGeometryBindings }} = await import(geometryUrl);
+
+ function createFakeElement(id, initialZIndex) {{
+ let zIndex = initialZIndex;
+ return {{
+ length: 1,
+ id() {{
+ return id;
+ }},
+ data(name, value) {{
+ if (value === undefined) {{
+ return name === "zIndex" ? zIndex : undefined;
+ }}
+ if (name === "zIndex") {{
+ zIndex = value;
+ }}
+ return undefined;
+ }},
+ }};
+ }}
+
+ const visibleDerivedTensor = {{
+ id: "scene-step-ab",
+ name: "A-B",
+ position: {{ x: 100, y: 100 }},
+ size: {{ width: 160, height: 84 }},
+ indices: [
+ {{
+ id: "scene-step-ab_open",
+ name: "open",
+ dimension: 2,
+ offset: {{ x: 38, y: 0 }},
+ metadata: {{}},
+ }},
+ ],
+ isDerived: true,
+ sourceTensorIds: ["tensor_a", "tensor_b"],
+ metadata: {{}},
+ }};
+ const visibleFrontTensor = {{
+ id: "tensor_front",
+ name: "Front",
+ position: {{ x: 130, y: 100 }},
+ size: {{ width: 140, height: 84 }},
+ indices: [
+ {{
+ id: "front_open",
+ name: "front",
+ dimension: 2,
+ offset: {{ x: -38, y: 0 }},
+ metadata: {{}},
+ }},
+ ],
+ isDerived: false,
+ sourceTensorIds: ["tensor_front"],
+ metadata: {{}},
+ }};
+
+ const elementMap = new Map([
+ ["scene-step-ab", createFakeElement("scene-step-ab", 10)],
+ ["scene-step-ab_open", createFakeElement("scene-step-ab_open", 10.2)],
+ ["tensor_front", createFakeElement("tensor_front", 11)],
+ ["front_open", createFakeElement("front_open", 11.2)],
+ ]);
+
+ const state = {{
+ activeTensorDrag: null,
+ cy: {{
+ getElementById(id) {{
+ return elementMap.get(id) || {{ length: 0, data() {{ return undefined; }} }};
+ }},
+ edges() {{
+ return [];
+ }},
+ }},
+ pendingIndexId: null,
+ selectionIds: ["scene-step-ab"],
+ spec: {{
+ tensors: [
+ {{
+ id: "tensor_a",
+ indices: [],
+ position: {{ x: 0, y: 0 }},
+ }},
+ {{
+ id: "tensor_b",
+ indices: [],
+ position: {{ x: 0, y: 0 }},
+ }},
+ {{
+ id: "tensor_front",
+ indices: visibleFrontTensor.indices,
+ position: visibleFrontTensor.position,
+ }},
+ ],
+ }},
+ tensorOrder: [],
+ tensorRankById: {{}},
+ }};
+ const runtime = {{
+ asFiniteNumber(value, fallbackValue) {{
+ return Number.isFinite(value) ? value : fallbackValue;
+ }},
+ findConnectionByIndexId() {{
+ return null;
+ }},
+ findEdgeByIndexId() {{
+ return null;
+ }},
+ findHyperedgeByIndexId() {{
+ return null;
+ }},
+ findTensorById(tensorId) {{
+ return (
+ state.spec.tensors.find((tensor) => tensor.id === tensorId) || null
+ );
+ }},
+ getVisibleTensors() {{
+ return [visibleDerivedTensor, visibleFrontTensor];
+ }},
+ indexLabelNodeId(indexId) {{
+ return `${{indexId}}__label`;
+ }},
+ }};
+
+ const geometry = createUtilityGeometryBindings({{
+ ctx: {{ state }},
+ state,
+ constants: {{
+ TENSOR_WIDTH: 140,
+ TENSOR_HEIGHT: 84,
+ MIN_TENSOR_WIDTH: 96,
+ MIN_TENSOR_HEIGHT: 60,
+ INDEX_RADIUS: 10,
+ INDEX_PADDING: 6,
+ }},
+ runtime,
+ }});
+ Object.assign(runtime, geometry);
+
+ geometry.applyTensorLayerData();
+
+ const selectedOpenZIndex = elementMap.get("scene-step-ab_open").data("zIndex");
+ const frontTensorZIndex = elementMap.get("tensor_front").data("zIndex");
+
+ if (!(selectedOpenZIndex > frontTensorZIndex)) {{
+ throw new Error(
+ `A selected derived contraction tensor should keep its open ports visible above front tensors: open=${{selectedOpenZIndex}}, front=${{frontTensorZIndex}}.`
+ );
+ }}
+ """
+ ),
+ encoding="utf-8",
+ )
+ return script_path
+
+
+def _write_contraction_scene_base_tensor_port_layering_runtime_regression_script(
+ tmp_path: Path,
+) -> Path:
+ script_path = (
+ tmp_path / "contraction_scene_base_tensor_port_layering_runtime_regression.mjs"
+ )
+ _copy_runtime_bundle(
+ tmp_path,
+ {
+ "state.runtime.mjs": "state/state.js",
+ "utilities.runtime.mjs": "utils/utilities.js",
+ "historySelection.runtime.mjs": "graph/historySelection.js",
+ "contractionScene.runtime.mjs": "graph/contractionScene.js",
+ },
+ _RUNTIME_EDITOR_SUPPORT_MODULES,
+ )
+ script_path.write_text(
+ textwrap.dedent(
+ """
+ import { pathToFileURL } from "node:url";
+
+ function createClassList() {
+ return {
+ add() {},
+ remove() {},
+ toggle() {},
+ };
+ }
+
+ function createButton() {
+ return {
+ disabled: false,
+ classList: createClassList(),
+ addEventListener() {},
+ focus() {},
+ };
+ }
+
+ function createSpec() {
+ return {
+ id: "network_manual_anchor",
+ name: "manual-anchor",
+ tensors: [
+ {
+ id: "tensor_a",
+ name: "A",
+ position: { x: 120, y: 140 },
+ size: { width: 140, height: 84 },
+ metadata: {},
+ indices: [
+ {
+ id: "tensor_a_left",
+ name: "left",
+ dimension: 2,
+ offset: { x: -38, y: 0 },
+ metadata: {},
+ },
+ {
+ id: "tensor_a_bond",
+ name: "bond",
+ dimension: 3,
+ offset: { x: 38, y: 0 },
+ metadata: {},
+ },
+ ],
+ },
+ {
+ id: "tensor_b",
+ name: "B",
+ position: { x: 360, y: 220 },
+ size: { width: 140, height: 84 },
+ metadata: {},
+ indices: [
+ {
+ id: "tensor_b_bond",
+ name: "bond",
+ dimension: 3,
+ offset: { x: -38, y: 0 },
+ metadata: {},
+ },
+ {
+ id: "tensor_b_right",
+ name: "carry",
+ dimension: 5,
+ offset: { x: 38, y: 0 },
+ metadata: {},
+ },
+ ],
+ },
+ {
+ id: "tensor_c",
+ name: "C",
+ position: { x: 620, y: 300 },
+ size: { width: 140, height: 84 },
+ metadata: {},
+ indices: [
+ {
+ id: "tensor_c_left",
+ name: "carry",
+ dimension: 5,
+ offset: { x: -38, y: 0 },
+ metadata: {},
+ },
+ {
+ id: "tensor_c_right",
+ name: "right",
+ dimension: 7,
+ offset: { x: 38, y: 0 },
+ metadata: {},
+ },
+ ],
+ },
+ ],
+ groups: [],
+ edges: [
+ {
+ id: "edge_ab",
+ name: "bond_ab",
+ left: { tensor_id: "tensor_a", index_id: "tensor_a_bond" },
+ right: { tensor_id: "tensor_b", index_id: "tensor_b_bond" },
+ metadata: {},
+ },
+ {
+ id: "edge_bc",
+ name: "bond_bc",
+ left: { tensor_id: "tensor_b", index_id: "tensor_b_right" },
+ right: { tensor_id: "tensor_c", index_id: "tensor_c_left" },
+ metadata: {},
+ },
+ ],
+ notes: [],
+ contraction_plan: {
+ id: "plan_chain",
+ name: "Chain path",
+ steps: [
+ {
+ id: "step_ab",
+ left_operand_id: "tensor_a",
+ right_operand_id: "tensor_b",
+ },
+ ],
+ },
+ metadata: {},
+ };
+ }
+
+ function createFakeElement(id) {
+ let zIndex = null;
+ const classes = new Set();
+ let selected = false;
+ return {
+ length: 1,
+ id() {
+ return id;
+ },
+ data(name, value) {
+ if (value === undefined) {
+ return name === "zIndex" ? zIndex : undefined;
+ }
+ if (name === "zIndex") {
+ zIndex = value;
+ }
+ return undefined;
+ },
+ select() {
+ selected = true;
+ },
+ unselect() {
+ selected = false;
+ },
+ addClass(className) {
+ classes.add(className);
+ },
+ removeClass(className) {
+ classes.delete(className);
+ },
+ hasClass(className) {
+ return classes.has(className);
+ },
+ isSelected() {
+ return selected;
+ },
+ position() {},
+ selectable() {},
+ grabbable() {},
+ };
+ }
+
+ const baseUrl = new URL("./", import.meta.url);
+ const [stateModule, utilitiesModule, historyModule, contractionSceneModule] =
+ await Promise.all([
+ import(new URL("./state.runtime.mjs", baseUrl).href),
+ import(new URL("./utilities.runtime.mjs", baseUrl).href),
+ import(new URL("./historySelection.runtime.mjs", baseUrl).href),
+ import(new URL("./contractionScene.runtime.mjs", baseUrl).href),
+ ]);
+
+ const { createInitialState } = stateModule;
+ const { registerUtilities } = utilitiesModule;
+ const { registerHistorySelection } = historyModule;
+ const { registerContractionScene } = contractionSceneModule;
+
+ const ctx = {
+ state: createInitialState(),
+ constants: {
+ TENSOR_WIDTH: 140,
+ TENSOR_HEIGHT: 84,
+ MIN_TENSOR_WIDTH: 96,
+ MIN_TENSOR_HEIGHT: 60,
+ INDEX_RADIUS: 10,
+ INDEX_PADDING: 6,
+ NOTE_WIDTH: 220,
+ NOTE_HEIGHT: 120,
+ NOTE_MIN_WIDTH: 120,
+ NOTE_MIN_HEIGHT: 90,
+ HISTORY_LIMIT: 100,
+ REDO_SHORTCUT_LABEL: "Ctrl+Shift+Z",
+ DEFAULT_INDEX_SLOTS: [
+ { x: -38, y: 0 },
+ { x: 38, y: 0 },
+ { x: 0, y: -24 },
+ { x: 0, y: 24 },
+ ],
+ },
+ dom: {
+ workspace: {},
+ statusMessage: { textContent: "", classList: createClassList() },
+ propertiesPanel: { innerHTML: "" },
+ generatedCode: { value: "" },
+ engineSelect: { options: [], value: "tensornetwork" },
+ collectionFormatSelect: { options: [], value: "list" },
+ exportFormatSelect: { value: "py" },
+ addNoteButton: createButton(),
+ connectButton: { classList: createClassList() },
+ loadInput: {},
+ undoButton: createButton(),
+ redoButton: createButton(),
+ exportButton: createButton(),
+ toggleLinearPeriodicButton: { classList: createClassList() },
+ linearPeriodicPreviousCellButton: createButton(),
+ linearPeriodicCellLabel: { textContent: "" },
+ linearPeriodicNextCellButton: createButton(),
+ templateSelect: { value: "" },
+ templateParameterPanel: { hidden: true },
+ templateGraphSizeLabel: { textContent: "" },
+ templateGraphSizeInput: { value: "2", min: "1" },
+ templateBondDimensionInput: { value: "3", min: "1" },
+ templatePhysicalDimensionInput: { value: "2", min: "1" },
+ insertTemplateButton: createButton(),
+ createGroupButton: createButton(),
+ helpButton: createButton(),
+ helpModal: { classList: createClassList() },
+ helpBackdrop: createButton(),
+ helpCloseButton: createButton(),
+ canvasShell: {
+ getBoundingClientRect() {
+ return { left: 0, top: 0, width: 1000, height: 800 };
+ },
+ },
+ groupLayer: {},
+ resizeLayer: {},
+ notesLayer: {},
+ selectionBox: {},
+ minimapCanvas: {},
+ sidebar: {},
+ plannerPanel: {
+ innerHTML: "",
+ querySelectorAll() {
+ return [];
+ },
+ },
+ generateButton: createButton(),
+ },
+ apiGet: async () => null,
+ apiPost: async () => null,
+ window: {
+ structuredClone: globalThis.structuredClone,
+ crypto: globalThis.crypto,
+ setTimeout,
+ clearTimeout,
+ confirm: () => true,
+ },
+ document: {
+ activeElement: null,
+ createElement() {
+ return {
+ value: "",
+ textContent: "",
+ selected: false,
+ appendChild() {},
+ click() {},
+ };
+ },
+ getElementById() {
+ return createButton();
+ },
+ querySelectorAll() {
+ return [];
+ },
+ },
+ cytoscape: null,
+ tensorWidth: (tensor) => tensor?.size?.width ?? 140,
+ tensorHeight: (tensor) => tensor?.size?.height ?? 84,
+ render: () => {},
+ renderOverlayDecorations: () => {},
+ renderMinimap: () => {},
+ renderPlanner: () => {},
+ renderSidebarTabs: () => {},
+ refreshContractionAnalysis: () => {},
+ syncPendingInteractionClasses: () => {},
+ setActiveSidebarTab: () => {},
+ updateToolbarState: () => {},
+ captureEditableFocus: () => null,
+ restoreEditableFocus: () => {},
+ };
+
+ registerUtilities(ctx);
+ registerContractionScene(ctx);
+ registerHistorySelection(ctx);
+
+ ctx.state.selectedEngine = "tensornetwork";
+ ctx.state.selectedCollectionFormat = "list";
+ ctx.state.spec = ctx.normalizeSpec(createSpec());
+
+ const scene = ctx.buildContractionScene();
+ if (!scene) {
+ throw new Error("Expected a contraction scene after the manual step.");
+ }
+ const elementMap = new Map();
+ scene.tensors.forEach((tensor) => {
+ elementMap.set(tensor.id, createFakeElement(tensor.id));
+ tensor.indices.forEach((index) => {
+ elementMap.set(index.id, createFakeElement(index.id));
+ elementMap.set(`${index.id}__label`, createFakeElement(`${index.id}__label`));
+ });
+ });
+
+ ctx.state.cy = {
+ batch(action) {
+ action();
+ },
+ getElementById(id) {
+ return (
+ elementMap.get(id) || {
+ length: 0,
+ data() {
+ return undefined;
+ },
+ select() {},
+ unselect() {},
+ addClass() {},
+ removeClass() {},
+ position() {},
+ selectable() {},
+ grabbable() {},
+ }
+ );
+ },
+ edges() {
+ return {
+ forEach() {},
+ };
+ },
+ $(selector) {
+ if (selector === ":selected") {
+ return {
+ forEach(callback) {
+ elementMap.forEach((element) => {
+ if (element.isSelected()) {
+ callback(element);
+ }
+ });
+ },
+ };
+ }
+ if (selector === ".is-selection-highlight") {
+ return {
+ forEach(callback) {
+ elementMap.forEach((element) => {
+ if (element.hasClass("is-selection-highlight")) {
+ callback(element);
+ }
+ });
+ },
+ };
+ }
+ return {
+ forEach() {},
+ };
+ },
+ };
+
+ ctx.bringTensorToFront("tensor_c");
+ ctx.setSelection(["tensor_c"], { primaryId: "tensor_c" });
- if (!(zIndexFor("tensor_back") < zIndexFor("back_open"))) {{
- throw new Error("An open port should still sit above its owning tensor.");
- }}
- if (!(zIndexFor("back_open") < frontTensorZIndex)) {{
- throw new Error(
- `An open port from a rear tensor should not cover a front tensor: open=${{zIndexFor("back_open")}}, front=${{frontTensorZIndex}}.`
- );
- }}
- if (!(zIndexFor("back_connected") > frontTensorZIndex)) {{
+ const selectedBaseTensor = scene.operandMap.tensor_c;
+ const selectedOpenPort = selectedBaseTensor.indices.find(
+ (index) => index.name === "right"
+ );
+ const derivedTensor = scene.tensors.find((tensor) => tensor.isDerived);
+
+ if (!ctx.state.tensorOrder.includes(derivedTensor.id)) {
throw new Error(
- `A connected port should stay above tensors so connections remain visible: connected=${{zIndexFor("back_connected")}}, front=${{frontTensorZIndex}}.`
+ `Expected tensor layering order to track visible contraction operands, received ${JSON.stringify(ctx.state.tensorOrder)}.`
);
- }}
-
- state.selectionIds = ["tensor_back"];
- const selectedModel = builder();
- const selectedZIndexFor = (elementId) =>
- selectedModel.descriptorsById[elementId].data.zIndex;
- if (!(selectedZIndexFor("back_open") > selectedZIndexFor("tensor_front"))) {{
+ }
+ const selectedOpenPortZIndex = elementMap
+ .get(selectedOpenPort.id)
+ .data("zIndex");
+ const derivedTensorZIndex = elementMap.get(derivedTensor.id).data("zIndex");
+ if (!(selectedOpenPortZIndex > derivedTensorZIndex)) {
throw new Error(
- `A selected tensor should keep its open ports visible above front tensors: open=${{selectedZIndexFor("back_open")}}, front=${{selectedZIndexFor("tensor_front")}}.`
+ `A selected base tensor in contraction view should keep its free port visible above derived front tensors: open=${selectedOpenPortZIndex}, derived=${derivedTensorZIndex}.`
);
- }}
+ }
"""
),
encoding="utf-8",
@@ -9574,6 +10364,15 @@ def _write_metadata_properties_runtime_regression_script(tmp_path: Path) -> Path
if (!propertiesPanel.innerHTML.includes('id="add-index-to-selection-button"')) {
throw new Error("Mixed selections should keep the bulk Add index action when editable tensors remain.");
}
+ if (/id="extract-selection-button"[^>]*disabled/.test(propertiesPanel.innerHTML)) {
+ throw new Error("Mixed selections with editable tensors should keep Extract enabled.");
+ }
+ if (/id="save-selection-subnetwork-library-button"[^>]*disabled/.test(propertiesPanel.innerHTML)) {
+ throw new Error("Mixed selections with editable tensors should keep To Library enabled.");
+ }
+ if (/id="promote-selection-template-button"[^>]*disabled/.test(propertiesPanel.innerHTML)) {
+ throw new Error("Mixed selections with editable tensors should keep To Template enabled.");
+ }
document.getElementById("add-index-to-selection-button").click();
const editableTensorAfter = ctx.state.spec.tensors.find(
(candidate) => candidate.id === "tensor_a"
@@ -11602,6 +12401,52 @@ def test_graph_model_layers_open_ports_below_front_tensors(
)
+@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
+def test_contraction_scene_selection_keeps_derived_open_ports_visible(
+ tmp_path: Path,
+) -> None:
+ script_path = _write_contraction_scene_port_layering_runtime_regression_script(
+ tmp_path
+ )
+ completed_process = subprocess.run(
+ ["node", str(script_path)],
+ cwd=REPO_ROOT,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert completed_process.returncode == 0, (
+ "The contraction-scene port layering runtime regression script failed.\n"
+ f"STDOUT:\n{completed_process.stdout}\n"
+ f"STDERR:\n{completed_process.stderr}"
+ )
+
+
+@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
+def test_contraction_scene_selection_keeps_base_tensor_open_ports_visible(
+ tmp_path: Path,
+) -> None:
+ script_path = (
+ _write_contraction_scene_base_tensor_port_layering_runtime_regression_script(
+ tmp_path
+ )
+ )
+ completed_process = subprocess.run(
+ ["node", str(script_path)],
+ cwd=REPO_ROOT,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert completed_process.returncode == 0, (
+ "The contraction-scene base-tensor port layering runtime regression script failed.\n"
+ f"STDOUT:\n{completed_process.stdout}\n"
+ f"STDERR:\n{completed_process.stderr}"
+ )
+
+
@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
def test_copy_shortcut_prefers_native_text_selection_over_graph_copy(
tmp_path: Path,
@@ -12120,11 +12965,9 @@ def _write_utility_runtime_contract_script(tmp_path: Path) -> Path:
}),
parentElement: reflowLayoutShell,
},
- reflowAlignLeftButton: createButton(),
- reflowAlignRightButton: createButton(),
- reflowAlignTopButton: createButton(),
- reflowAlignMiddleButton: createButton(),
- reflowAlignBottomButton: createButton(),
+ reflowAlignHorizontalButton: createButton(),
+ reflowAlignVerticalButton: createButton(),
+ reflowRotateSelectionButton: createButton(),
reflowIndicesLeftButton: createButton(),
reflowIndicesRightButton: createButton(),
reflowIndicesTopButton: createButton(),
@@ -12361,6 +13204,12 @@ def _write_utility_runtime_contract_script(tmp_path: Path) -> Path:
if (ctx.dom.reflowAutoLayoutButton.disabled) {
throw new Error("Auto layout should stay enabled when the whole graph can be arranged.");
}
+ runtime.isLinearPeriodicMode = () => true;
+ runtime.updateToolbarState();
+ if (ctx.dom.templateSettingsButton.disabled) {
+ throw new Error("Template settings should stay enabled in For mode because they only affect future insertions.");
+ }
+ runtime.isLinearPeriodicMode = () => false;
ctx.state.selectionIds = ["tensor_a"];
runtime.isBenchmarkMode = () => true;
runtime.getBenchmarkSession = () => ({
@@ -13236,6 +14085,7 @@ def _write_interaction_session_dependency_injection_runtime_script(
generatedCode: "",
selectedEngine: "quimb",
selectedCollectionFormat: "dict",
+ includeRoundtripMetadata: true,
templateDefinitions: {},
availableTemplates: [],
templateCatalogWarnings: [],
@@ -13332,6 +14182,9 @@ def _write_interaction_session_dependency_injection_runtime_script(
if (generateCall.payload.engine !== "quimb" || generateCall.payload.collectionFormat !== "dict") {
throw new Error(`Unexpected generate payload: ${JSON.stringify(generateCall.payload)}.`);
}
+ if (generateCall.payload.includeRoundtripMetadata !== true) {
+ throw new Error(`Expected includeRoundtripMetadata=true in the injected generate payload, received ${JSON.stringify(generateCall.payload)}.`);
+ }
if (dom.generatedCode.value.trim() !== "result = 1") {
throw new Error(`Expected injected preview sync to receive stripped code, received ${dom.generatedCode.value}.`);
}
@@ -13391,6 +14244,7 @@ def _write_session_editor_draft_autosave_runtime_script(tmp_path: Path) -> Path:
spec: { name: "draft demo" },
generatedCode: "",
editorFinished: false,
+ selectedTheme: "light",
draftAutosaveReady: true,
draftAutosaveTimer: null,
draftAutosaveDirty: false,
@@ -13575,8 +14429,7 @@ def _write_session_editor_draft_autosave_runtime_script(tmp_path: Path) -> Path:
throw new Error(`Expected draft-save flow logging, received ${JSON.stringify(flowLog)}.`);
}
- flows.saveDesign();
- await Promise.resolve();
+ await flows.saveDesign();
if (!calls.some((entry) => entry.type === "clearDraft")) {
throw new Error(`Expected explicit JSON save to clear the draft, received ${JSON.stringify(calls)}.`);
}
@@ -13649,6 +14502,13 @@ def _write_session_editor_draft_autosave_runtime_script(tmp_path: Path) -> Path:
) {
throw new Error(`Academic exports should persist view snapshots, received ${JSON.stringify(calls)}.`);
}
+ if (
+ svgRenderCall.payload.theme !== "light" ||
+ pngRenderCall.payload.theme !== "light" ||
+ pdfRenderCall.payload.theme !== "light"
+ ) {
+ throw new Error(`SVG/PNG/PDF exports should include the active theme, received ${JSON.stringify(calls)}.`);
+ }
if (!svgDownloadCall || svgDownloadCall.contentType !== "image/svg+xml;charset=utf-8") {
throw new Error(`Expected SVG export to download a .svg file, received ${JSON.stringify(calls)}.`);
}
@@ -13885,6 +14745,130 @@ def _write_session_editor_png_fallback_runtime_script(tmp_path: Path) -> Path:
return script_path
+def _write_session_editor_save_cancelled_runtime_script(tmp_path: Path) -> Path:
+ script_path = tmp_path / "session_editor_save_cancelled.mjs"
+ _copy_js_modules(tmp_path, _SESSION_EDITOR_FLOWS_DEPENDENCY_MODULES)
+
+ script_path.write_text(
+ textwrap.dedent(
+ """
+ const baseUrl = new URL("./", import.meta.url);
+ const { createSessionEditorFlows } = await import(
+ new URL("./session/sessionEditorFlows.js", baseUrl).href
+ );
+
+ const calls = [];
+ const flowLog = [];
+ const state = {
+ spec: { name: "draft demo" },
+ generatedCode: "",
+ editorFinished: false,
+ draftAutosaveReady: true,
+ draftAutosaveTimer: null,
+ draftAutosaveDirty: false,
+ draftAutosaveSaving: false,
+ };
+
+ const flows = createSessionEditorFlows({
+ dom: {
+ exportFormatSelect: { value: "json" },
+ generatedCode: { value: "" },
+ loadInput: { value: "" },
+ },
+ state,
+ logger: {
+ startOperation(name, context = {}) {
+ flowLog.push({ type: "start", name, context });
+ return {
+ finish(nextContext = {}) {
+ flowLog.push({ type: "finish", name, context: nextContext });
+ },
+ fail(error, nextContext = {}) {
+ flowLog.push({
+ type: "fail",
+ name,
+ message: error.message,
+ context: nextContext,
+ });
+ },
+ };
+ },
+ },
+ store: {
+ setGeneratedCode() {},
+ setEditorFinished() {},
+ },
+ selectors: {
+ getSelectedEngine: () => "quimb",
+ getSelectedCollectionFormat: () => "dict",
+ },
+ services: {
+ session: {
+ async clearDraft() {
+ calls.push({ type: "clearDraft" });
+ return { ok: true };
+ },
+ },
+ },
+ commands: {
+ syncGeneratedCodePreview() {},
+ },
+ sessionUi: {
+ async downloadText(filename, text, contentType) {
+ calls.push({ type: "downloadText", filename, text, contentType });
+ return false;
+ },
+ closeWindow() {},
+ schedule() {
+ return 0;
+ },
+ },
+ actions: {
+ serializeCurrentSpec({ persistViewSnapshots }) {
+ return {
+ schema_version: 2,
+ persistViewSnapshots,
+ network: { id: "network_draft", name: "draft demo" },
+ };
+ },
+ sanitizeFilename: (value) => value.replace(/\\s+/g, "_"),
+ setStatus(message, level = "info") {
+ calls.push({ type: "status", message, level });
+ },
+ },
+ });
+
+ await flows.saveDesign();
+
+ if (calls.some((entry) => entry.type === "clearDraft")) {
+ throw new Error(`Cancelling the save dialog should not clear the draft, received ${JSON.stringify(calls)}.`);
+ }
+ const cancelStatus = calls.find(
+ (entry) =>
+ entry.type === "status" &&
+ entry.level === "info" &&
+ entry.message === "Design save cancelled."
+ );
+ if (!cancelStatus) {
+ throw new Error(`Expected a friendly cancellation status, received ${JSON.stringify(calls)}.`);
+ }
+ if (
+ !flowLog.some(
+ (entry) =>
+ entry.type === "finish" &&
+ entry.name === "Save design" &&
+ entry.context.outcome === "cancelled"
+ )
+ ) {
+ throw new Error(`Expected cancelled save-design flow logging, received ${JSON.stringify(flowLog)}.`);
+ }
+ """
+ ),
+ encoding="utf-8",
+ )
+ return script_path
+
+
@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
def test_session_editor_flows_fall_back_to_svg_when_png_render_fails(
tmp_path: Path,
@@ -13905,6 +14889,26 @@ def test_session_editor_flows_fall_back_to_svg_when_png_render_fails(
)
+@pytest.mark.skipif(shutil.which("node") is None, reason="node is required")
+def test_session_editor_flows_report_save_cancelled_without_clearing_draft(
+ tmp_path: Path,
+) -> None:
+ script_path = _write_session_editor_save_cancelled_runtime_script(tmp_path)
+ completed_process = subprocess.run(
+ ["node", str(script_path)],
+ cwd=REPO_ROOT,
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ assert completed_process.returncode == 0, (
+ "The session-editor save-cancelled runtime script failed.\n"
+ f"STDOUT:\n{completed_process.stdout}\n"
+ f"STDERR:\n{completed_process.stderr}"
+ )
+
+
def _write_tensor_initializer_parsing_runtime_script(tmp_path: Path) -> Path:
script_path = tmp_path / "tensor_initializer_parsing.mjs"
_copy_js_modules(
@@ -14176,6 +15180,12 @@ def _write_session_editor_live_python_import_runtime_script(tmp_path: Path) -> P
if (confirmMessages.length !== 3) {
throw new Error(`Expected the Python load flow to ask about live execution every time, received ${JSON.stringify(confirmMessages)}.`);
}
+ if (!confirmMessages.every((message) => message.includes("Only continue for local Python files you trust"))) {
+ throw new Error(`Expected every live-import prompt to warn about trusted local files, received ${JSON.stringify(confirmMessages)}.`);
+ }
+ if (!confirmMessages.every((message) => message.includes("can read and write files"))) {
+ throw new Error(`Expected every live-import prompt to describe local execution risk, received ${JSON.stringify(confirmMessages)}.`);
+ }
if (promptMessages.length !== 2) {
throw new Error(`Expected object-name prompts only for live imports, received ${JSON.stringify(promptMessages)}.`);
}
@@ -14257,6 +15267,7 @@ def _write_editor_session_service_validate_python_runtime_script(
await service.renderSpec({
format: "dot",
spec: { schema_version: 2, network: { id: "network_draft" } },
+ theme: "light",
});
await service.clearDraft();
@@ -14305,6 +15316,9 @@ def _write_editor_session_service_validate_python_runtime_script(
if (apiCalls[4].payload.format !== "dot" || apiCalls[4].payload.spec.network.id !== "network_draft") {
throw new Error(`Expected renderSpec to keep format and spec payloads, received ${JSON.stringify(apiCalls[4])}.`);
}
+ if (apiCalls[4].payload.theme !== "light") {
+ throw new Error(`Expected renderSpec to include the current theme, received ${JSON.stringify(apiCalls[4])}.`);
+ }
if (apiCalls[5].path !== "/api/draft/clear" || apiCalls[5].method !== "POST") {
throw new Error(`Expected clearDraft to POST /api/draft/clear, received ${JSON.stringify(apiCalls[5])}.`);
}
@@ -15069,6 +16083,93 @@ def _write_layout_subnetwork_runtime_regression_script(tmp_path: Path) -> Path:
}
}
+ ctx.state.selectionIds = ["tensor_a", "tensor_b"];
+ ctx.state.primarySelectionId = "tensor_b";
+ ctx.state.spec.tensors[0].position = { x: 100, y: 100 };
+ ctx.state.spec.tensors[1].position = { x: 260, y: 220 };
+ ctx.state.spec.tensors[0].indices[0].offset = { x: 20, y: -10 };
+ ctx.state.spec.tensors[1].indices[0].offset = { x: 16, y: -8 };
+ ctx.state.spec.tensors[1].indices[1].offset = { x: 20, y: 10 };
+ ctx.applyReflowLayoutAction("align-horizontal");
+ const horizontalAlignmentYs = ctx.state.spec.tensors
+ .slice(0, 2)
+ .map((tensor) => tensor.position.y);
+ if (!horizontalAlignmentYs.every((value) => value === horizontalAlignmentYs[0])) {
+ throw new Error(
+ `Horizontal alignment should align tensor centers on the y axis, received ${horizontalAlignmentYs.join(", ")}.`
+ );
+ }
+ ctx.applyReflowLayoutAction("align-vertical");
+ const verticalAlignmentXs = ctx.state.spec.tensors
+ .slice(0, 2)
+ .map((tensor) => tensor.position.x);
+ if (!verticalAlignmentXs.every((value) => value === verticalAlignmentXs[0])) {
+ throw new Error(
+ `Vertical alignment should align tensor centers on the x axis, received ${verticalAlignmentXs.join(", ")}.`
+ );
+ }
+
+ ctx.state.spec.tensors[0].position = { x: 100, y: 100 };
+ ctx.state.spec.tensors[1].position = { x: 260, y: 220 };
+ ctx.state.spec.tensors[0].indices[0].offset = { x: 20, y: -10 };
+ ctx.state.spec.tensors[1].indices[0].offset = { x: 16, y: -8 };
+ ctx.state.spec.tensors[1].indices[1].offset = { x: 20, y: 10 };
+ ctx.serializeCurrentSpec();
+ ctx.applyReflowLayoutAction("rotate-90");
+ const tensorARotated = ctx.findTensorById("tensor_a");
+ const tensorBRotated = ctx.findTensorById("tensor_b");
+ if (tensorARotated.position.x !== 240 || tensorARotated.position.y !== 80) {
+ throw new Error(
+ `Rotate 90 should move tensor A clockwise around the selection center, received ${JSON.stringify(tensorARotated.position)}.`
+ );
+ }
+ if (tensorBRotated.position.x !== 120 || tensorBRotated.position.y !== 240) {
+ throw new Error(
+ `Rotate 90 should move tensor B clockwise around the selection center, received ${JSON.stringify(tensorBRotated.position)}.`
+ );
+ }
+ if (JSON.stringify(tensorARotated.indices[0].offset) !== JSON.stringify({ x: 10, y: 20 })) {
+ throw new Error(
+ `Rotate 90 should rotate tensor A ports, received ${JSON.stringify(tensorARotated.indices[0].offset)}.`
+ );
+ }
+ if (JSON.stringify(tensorBRotated.indices[0].offset) !== JSON.stringify({ x: 8, y: 16 })) {
+ throw new Error(
+ `Rotate 90 should rotate tensor B first port, received ${JSON.stringify(tensorBRotated.indices[0].offset)}.`
+ );
+ }
+ if (JSON.stringify(tensorBRotated.indices[1].offset) !== JSON.stringify({ x: -10, y: 20 })) {
+ throw new Error(
+ `Rotate 90 should rotate tensor B second port, received ${JSON.stringify(tensorBRotated.indices[1].offset)}.`
+ );
+ }
+ if (ctx.state.selectionIds.join(",") !== "tensor_a,tensor_b") {
+ throw new Error("Rotate 90 should preserve the selected tensors.");
+ }
+ const serializedAfterRotate = ctx.serializeCurrentSpec();
+ const serializedTensorA = serializedAfterRotate.network.tensors.find(
+ (tensor) => tensor.id === "tensor_a"
+ );
+ const serializedTensorB = serializedAfterRotate.network.tensors.find(
+ (tensor) => tensor.id === "tensor_b"
+ );
+ if (
+ serializedTensorA.position.x !== 240 || serializedTensorA.position.y !== 80
+ ) {
+ throw new Error(
+ `serializeCurrentSpec should invalidate its cache after layout changes for tensor A, received ${JSON.stringify(serializedTensorA.position)}.`
+ );
+ }
+ if (
+ serializedTensorB.position.x !== 120 || serializedTensorB.position.y !== 240
+ ) {
+ throw new Error(
+ `serializeCurrentSpec should invalidate its cache after layout changes for tensor B, received ${JSON.stringify(serializedTensorB.position)}.`
+ );
+ }
+
+ ctx.state.selectionIds = ["tensor_a", "tensor_b", "tensor_c"];
+ ctx.state.primarySelectionId = "tensor_c";
ctx.state.spec.tensors[0].position.x = 100;
ctx.state.spec.tensors[1].position.x = 260;
ctx.state.spec.tensors[2].position.x = 460;
diff --git a/tests/test_models_validation.py b/tests/test_models_validation.py
index ee76551..183270d 100644
--- a/tests/test_models_validation.py
+++ b/tests/test_models_validation.py
@@ -836,6 +836,39 @@ def test_validate_spec_accepts_linear_periodic_partial_carry_chain() -> None:
assert validate_spec(build_linear_periodic_partial_carry_chain_spec()) == []
+def test_validate_spec_rejects_linear_periodic_previous_step_that_merges_multiple_payload_operands() -> (
+ None
+):
+ spec = build_linear_periodic_partial_carry_chain_spec()
+ assert spec.linear_periodic_chain is not None
+ periodic_cell = spec.linear_periodic_chain.periodic_cell
+ assert periodic_cell.contraction_plan is not None
+ periodic_cell.contraction_plan.steps = [
+ ContractionStepSpec(
+ id="merge_previous_locals",
+ left_operand_id="periodic_previous_left_tensor",
+ right_operand_id="periodic_previous_right_tensor",
+ ),
+ ContractionStepSpec(
+ id="consume_previous_payload",
+ left_operand_id="__linear_previous__",
+ right_operand_id="merge_previous_locals",
+ ),
+ ContractionStepSpec(
+ id="carry_next_left",
+ left_operand_id="periodic_next_left_tensor",
+ right_operand_id="__linear_next__",
+ ),
+ ]
+
+ issue = find_issue(validate_spec(spec), "linear-periodic-carry-codegen")
+
+ assert issue.path == (
+ "linear_periodic_chain.periodic_cell.contraction_plan.steps.consume_previous_payload"
+ )
+ assert "one previous carry operand per step" in issue.message
+
+
def test_build_carry_validation_context_internal_helper_collects_interface_state() -> (
None
):
diff --git a/tests/test_packaging.py b/tests/test_packaging.py
index db4bc28..5a9b1ad 100644
--- a/tests/test_packaging.py
+++ b/tests/test_packaging.py
@@ -2,6 +2,7 @@
import json
import os
+import re
import subprocess
import sys
import tomllib
@@ -103,6 +104,99 @@ def test_project_metadata_declares_required_matplotlib_dependency_and_backend_ex
assert "png" not in optional_dependencies
+def test_project_metadata_and_ci_enable_dependency_audits() -> None:
+ pyproject_path = Path.cwd() / "pyproject.toml"
+ ci_path = Path.cwd() / ".github" / "workflows" / "ci.yml"
+
+ payload = tomllib.loads(pyproject_path.read_text(encoding="utf-8"))
+ dev_dependencies = payload["project"]["optional-dependencies"]["dev"]
+ ci_text = ci_path.read_text(encoding="utf-8")
+
+ assert "pip-audit>=2.7" in dev_dependencies
+ assert "Run dependency security audit" in ci_text
+ assert "-m pip_audit" in ci_text
+
+
+def test_ci_runs_source_security_lint_and_dependabot_tracks_updates() -> None:
+ ci_text = (Path.cwd() / ".github" / "workflows" / "ci.yml").read_text(
+ encoding="utf-8"
+ )
+ dependabot_text = (Path.cwd() / ".github" / "dependabot.yml").read_text(
+ encoding="utf-8"
+ )
+
+ assert "Run source security lint" in ci_text
+ assert "-m ruff check src --select S" in ci_text
+ assert 'package-ecosystem: "pip"' in dependabot_text
+ assert 'package-ecosystem: "github-actions"' in dependabot_text
+ assert 'directory: "/"' in dependabot_text
+
+
+def test_bundled_prism_version_stays_patched_for_cve_2024_53382() -> None:
+ third_party_text = (Path.cwd() / "THIRD_PARTY_LICENSES").read_text(encoding="utf-8")
+ version_match = re.search(
+ r"2\. PrismJS[\s\S]*?- Version: (\d+)\.(\d+)\.(\d+)",
+ third_party_text,
+ )
+
+ assert version_match is not None
+ version = tuple(int(part) for part in version_match.groups())
+ assert version >= (1, 30, 0)
+
+
+def test_live_python_import_docs_warn_to_use_only_trusted_files() -> None:
+ expected_warning = "Only use live import with local Python files you trust."
+ docs_paths = [
+ Path.cwd() / "README.md",
+ Path.cwd() / "docs" / "api.md",
+ Path.cwd() / "docs" / "cli.md",
+ Path.cwd() / "docs" / "extended_guide.md",
+ ]
+
+ for docs_path in docs_paths:
+ docs_text = docs_path.read_text(encoding="utf-8")
+ assert expected_warning in docs_text
+
+
+def test_security_policy_documents_private_reporting_and_prism_advisory() -> None:
+ security_text = (Path.cwd() / "SECURITY.md").read_text(encoding="utf-8")
+ readme_text = (Path.cwd() / "README.md").read_text(encoding="utf-8")
+
+ assert "GitHub private vulnerability reporting" in security_text
+ assert "Do not open a public issue with exploit details" in security_text
+ assert "CVE-2024-53382" in security_text
+ assert "GHSA-x7hr-w5r2-h6wg" in security_text
+ assert "Bundled PrismJS before 1.30.0" in security_text
+ assert "browser-based editor" in security_text
+ assert (
+ "Installing or importing the Python package alone does not execute PrismJS"
+ in (security_text)
+ )
+ assert "publish the patched release before publishing the advisory" in security_text
+ assert "Security policy: [SECURITY.md](SECURITY.md)" in readme_text
+
+
+def test_docs_do_not_advertise_removed_png_extra() -> None:
+ readme_text = (Path.cwd() / "README.md").read_text(encoding="utf-8")
+ installation_text = (Path.cwd() / "docs" / "installation.md").read_text(
+ encoding="utf-8"
+ )
+
+ assert "tensor-network-editor[png]" not in readme_text
+ assert "optional `png` extra" not in readme_text
+ assert "tensor-network-editor[png]" not in installation_text
+
+
+def test_manifest_omits_redundant_non_package_exclusions() -> None:
+ manifest_text = (Path.cwd() / "MANIFEST.in").read_text(encoding="utf-8")
+
+ assert "docs/images" not in manifest_text
+ assert "prune tests" not in manifest_text
+ assert "tests" not in manifest_text
+ assert "recursive-exclude docs/images *" not in manifest_text
+ assert "recursive-exclude tests *" not in manifest_text
+
+
def test_third_party_notices_describe_bundled_asset_scope() -> None:
third_party_text = (Path.cwd() / "THIRD_PARTY_LICENSES").read_text(encoding="utf-8")
readme_text = (Path.cwd() / "README.md").read_text(encoding="utf-8")
@@ -115,6 +209,9 @@ def test_third_party_notices_describe_bundled_asset_scope() -> None:
assert "Runtime pip-installed dependencies are not bundled" in third_party_text
assert "Package: Matplotlib" in third_party_text
assert "License: Matplotlib license" in third_party_text
+ assert "Development dependency notice" in third_party_text
+ assert "Package: pip-audit" in third_party_text
+ assert "License: Apache Software License" in third_party_text
assert "THIRD_PARTY_LICENSES" in readme_text
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
index f0daa36..cd64cec 100644
--- a/tests/test_protocol.py
+++ b/tests/test_protocol.py
@@ -88,6 +88,7 @@ def test_parse_codegen_request_uses_defaults_when_optional_fields_are_missing(
serialized_spec=serialized_sample_spec,
engine=EngineName.EINSUM_TORCH,
collection_format=TensorCollectionFormat.DICT,
+ include_roundtrip_metadata=True,
)
@@ -111,6 +112,7 @@ def test_parse_codegen_request_honors_explicit_engine_and_collection_format(
serialized_spec=serialized_sample_spec,
engine=EngineName.QUIMB,
collection_format=TensorCollectionFormat.MATRIX,
+ include_roundtrip_metadata=True,
)
diff --git a/tests/test_rendering.py b/tests/test_rendering.py
index c684740..75001cc 100644
--- a/tests/test_rendering.py
+++ b/tests/test_rendering.py
@@ -1,17 +1,26 @@
from __future__ import annotations
import re
+from math import hypot
from pathlib import Path
from typing import Any
from xml.etree import ElementTree as ET
import pytest
-from tensor_network_editor.models import NetworkSpec
+from tensor_network_editor.models import (
+ CanvasPosition,
+ EdgeEndpointRef,
+ EdgeSpec,
+ IndexSpec,
+ NetworkSpec,
+ TensorSpec,
+)
from tensor_network_editor.rendering import (
DotRenderOptions,
SvgRenderOptions,
TikzRenderOptions,
+ _number,
_SvgRenderer,
render_spec_dot,
render_spec_mermaid,
@@ -19,7 +28,12 @@
render_spec_svg,
render_spec_tikz,
)
-from tests.factories import build_sample_spec, build_three_tensor_hyperedge_spec
+from tensor_network_editor.templates import TemplateParameters, build_template_spec
+from tests.factories import (
+ build_sample_spec,
+ build_three_tensor_hyperedge_spec,
+ build_three_tensor_spec,
+)
def _build_colored_parallel_edge_spec() -> NetworkSpec:
@@ -107,6 +121,309 @@ def _build_three_parallel_edge_spec() -> NetworkSpec:
return spec
+def _build_cycle_spec() -> NetworkSpec:
+ return NetworkSpec(
+ id="network_cycle",
+ name="cycle",
+ tensors=[
+ TensorSpec(
+ id="tensor_a",
+ name="A",
+ position=CanvasPosition(x=120.0, y=120.0),
+ indices=[
+ IndexSpec(id="tensor_a_free", name="fa", dimension=2),
+ IndexSpec(id="tensor_a_ab", name="ab", dimension=3),
+ IndexSpec(id="tensor_a_da", name="da", dimension=5),
+ ],
+ ),
+ TensorSpec(
+ id="tensor_b",
+ name="B",
+ position=CanvasPosition(x=280.0, y=120.0),
+ indices=[
+ IndexSpec(id="tensor_b_free", name="fb", dimension=2),
+ IndexSpec(id="tensor_b_ab", name="ab", dimension=3),
+ IndexSpec(id="tensor_b_bc", name="bc", dimension=7),
+ ],
+ ),
+ TensorSpec(
+ id="tensor_c",
+ name="C",
+ position=CanvasPosition(x=280.0, y=280.0),
+ indices=[
+ IndexSpec(id="tensor_c_free", name="fc", dimension=2),
+ IndexSpec(id="tensor_c_bc", name="bc", dimension=7),
+ IndexSpec(id="tensor_c_cd", name="cd", dimension=11),
+ ],
+ ),
+ TensorSpec(
+ id="tensor_d",
+ name="D",
+ position=CanvasPosition(x=120.0, y=280.0),
+ indices=[
+ IndexSpec(id="tensor_d_free", name="fd", dimension=2),
+ IndexSpec(id="tensor_d_cd", name="cd", dimension=11),
+ IndexSpec(id="tensor_d_da", name="da", dimension=5),
+ ],
+ ),
+ ],
+ edges=[
+ EdgeSpec(
+ id="edge_ab",
+ name="ab",
+ left=EdgeEndpointRef(tensor_id="tensor_a", index_id="tensor_a_ab"),
+ right=EdgeEndpointRef(tensor_id="tensor_b", index_id="tensor_b_ab"),
+ ),
+ EdgeSpec(
+ id="edge_bc",
+ name="bc",
+ left=EdgeEndpointRef(tensor_id="tensor_b", index_id="tensor_b_bc"),
+ right=EdgeEndpointRef(tensor_id="tensor_c", index_id="tensor_c_bc"),
+ ),
+ EdgeSpec(
+ id="edge_cd",
+ name="cd",
+ left=EdgeEndpointRef(tensor_id="tensor_c", index_id="tensor_c_cd"),
+ right=EdgeEndpointRef(tensor_id="tensor_d", index_id="tensor_d_cd"),
+ ),
+ EdgeSpec(
+ id="edge_da",
+ name="da",
+ left=EdgeEndpointRef(tensor_id="tensor_d", index_id="tensor_d_da"),
+ right=EdgeEndpointRef(tensor_id="tensor_a", index_id="tensor_a_da"),
+ ),
+ ],
+ )
+
+
+def _build_grid_export_spec() -> NetworkSpec:
+ tensors: list[TensorSpec] = []
+ edges: list[EdgeSpec] = []
+ for row_index in range(3):
+ for column_index in range(3):
+ tensor_id = f"tensor_{row_index}_{column_index}"
+ indices = [
+ IndexSpec(
+ id=f"{tensor_id}_free",
+ name=f"f_{row_index}_{column_index}",
+ dimension=2,
+ )
+ ]
+ if column_index < 2:
+ indices.append(
+ IndexSpec(
+ id=f"{tensor_id}_right",
+ name=f"h_{row_index}_{column_index}",
+ dimension=3,
+ )
+ )
+ if column_index > 0:
+ indices.append(
+ IndexSpec(
+ id=f"{tensor_id}_left",
+ name=f"h_{row_index}_{column_index - 1}",
+ dimension=3,
+ )
+ )
+ if row_index < 2:
+ indices.append(
+ IndexSpec(
+ id=f"{tensor_id}_down",
+ name=f"v_{row_index}_{column_index}",
+ dimension=5,
+ )
+ )
+ if row_index > 0:
+ indices.append(
+ IndexSpec(
+ id=f"{tensor_id}_up",
+ name=f"v_{row_index - 1}_{column_index}",
+ dimension=5,
+ )
+ )
+ tensors.append(
+ TensorSpec(
+ id=tensor_id,
+ name=f"T{row_index}{column_index}",
+ position=CanvasPosition(
+ x=120.0 + 140.0 * column_index,
+ y=120.0 + 140.0 * row_index,
+ ),
+ indices=indices,
+ )
+ )
+ for row_index in range(3):
+ for column_index in range(2):
+ left_tensor_id = f"tensor_{row_index}_{column_index}"
+ right_tensor_id = f"tensor_{row_index}_{column_index + 1}"
+ edge_name = f"h_{row_index}_{column_index}"
+ edges.append(
+ EdgeSpec(
+ id=f"edge_{edge_name}",
+ name=edge_name,
+ left=EdgeEndpointRef(
+ tensor_id=left_tensor_id,
+ index_id=f"{left_tensor_id}_right",
+ ),
+ right=EdgeEndpointRef(
+ tensor_id=right_tensor_id,
+ index_id=f"{right_tensor_id}_left",
+ ),
+ )
+ )
+ for row_index in range(2):
+ for column_index in range(3):
+ top_tensor_id = f"tensor_{row_index}_{column_index}"
+ bottom_tensor_id = f"tensor_{row_index + 1}_{column_index}"
+ edge_name = f"v_{row_index}_{column_index}"
+ edges.append(
+ EdgeSpec(
+ id=f"edge_{edge_name}",
+ name=edge_name,
+ left=EdgeEndpointRef(
+ tensor_id=top_tensor_id,
+ index_id=f"{top_tensor_id}_down",
+ ),
+ right=EdgeEndpointRef(
+ tensor_id=bottom_tensor_id,
+ index_id=f"{bottom_tensor_id}_up",
+ ),
+ )
+ )
+ return NetworkSpec(
+ id="network_grid_export",
+ name="grid-export",
+ tensors=tensors,
+ edges=edges,
+ )
+
+
+def _build_vertical_three_tensor_spec() -> NetworkSpec:
+ spec = build_three_tensor_spec()
+ spec.tensors[0].position = CanvasPosition(x=240.0, y=80.0)
+ spec.tensors[1].position = CanvasPosition(x=240.0, y=240.0)
+ spec.tensors[2].position = CanvasPosition(x=240.0, y=400.0)
+ return spec
+
+
+def _build_vertical_three_tensor_named_hint_spec() -> NetworkSpec:
+ spec = _build_vertical_three_tensor_spec()
+ spec.tensors[0].indices[0].name = "up"
+ return spec
+
+
+def _build_diagonal_three_tensor_spec() -> NetworkSpec:
+ spec = build_three_tensor_spec()
+ spec.tensors[0].position = CanvasPosition(x=80.0, y=80.0)
+ spec.tensors[1].position = CanvasPosition(x=240.0, y=240.0)
+ spec.tensors[2].position = CanvasPosition(x=400.0, y=400.0)
+ return spec
+
+
+def _build_rotated_grid_export_spec() -> NetworkSpec:
+ spec = _build_grid_export_spec()
+ center = CanvasPosition(x=240.0, y=240.0)
+ column_step = CanvasPosition(x=100.0, y=100.0)
+ row_step = CanvasPosition(x=-100.0, y=100.0)
+ for tensor in spec.tensors:
+ _, row_text, column_text = tensor.id.split("_")
+ row_index = int(row_text)
+ column_index = int(column_text)
+ tensor.position = CanvasPosition(
+ x=center.x
+ + (column_index - 1) * column_step.x
+ + (row_index - 1) * row_step.x,
+ y=center.y
+ + (column_index - 1) * column_step.y
+ + (row_index - 1) * row_step.y,
+ )
+ return spec
+
+
+def _build_vertical_mpo_export_spec() -> NetworkSpec:
+ spec = build_template_spec(
+ "mpo",
+ TemplateParameters(
+ graph_size=4,
+ bond_dimension=3,
+ physical_dimension=2,
+ boundary_condition="open",
+ j=1.0,
+ h=1.0,
+ ),
+ )
+ for tensor_index, tensor in enumerate(spec.tensors):
+ tensor.position = CanvasPosition(x=240.0, y=80.0 + tensor_index * 160.0)
+ return spec
+
+
+def _build_generic_export_spec() -> NetworkSpec:
+ return NetworkSpec(
+ id="network_generic_export",
+ name="generic-export",
+ tensors=[
+ TensorSpec(
+ id="tensor_center",
+ name="Center",
+ position=CanvasPosition(x=220.0, y=200.0),
+ indices=[
+ IndexSpec(id="tensor_center_free", name="free", dimension=2),
+ IndexSpec(id="tensor_center_right", name="r", dimension=3),
+ IndexSpec(id="tensor_center_down", name="d", dimension=5),
+ ],
+ ),
+ TensorSpec(
+ id="tensor_right",
+ name="Right",
+ position=CanvasPosition(x=360.0, y=180.0),
+ indices=[
+ IndexSpec(id="tensor_right_left", name="r", dimension=3),
+ ],
+ ),
+ TensorSpec(
+ id="tensor_down",
+ name="Down",
+ position=CanvasPosition(x=260.0, y=340.0),
+ indices=[
+ IndexSpec(id="tensor_down_up", name="d", dimension=5),
+ ],
+ ),
+ ],
+ edges=[
+ EdgeSpec(
+ id="edge_center_right",
+ name="r",
+ left=EdgeEndpointRef(
+ tensor_id="tensor_center", index_id="tensor_center_right"
+ ),
+ right=EdgeEndpointRef(
+ tensor_id="tensor_right", index_id="tensor_right_left"
+ ),
+ ),
+ EdgeSpec(
+ id="edge_center_down",
+ name="d",
+ left=EdgeEndpointRef(
+ tensor_id="tensor_center", index_id="tensor_center_down"
+ ),
+ right=EdgeEndpointRef(
+ tensor_id="tensor_down", index_id="tensor_down_up"
+ ),
+ ),
+ ],
+ )
+
+
+def _dot(left: CanvasPosition, right: CanvasPosition) -> float:
+ return left.x * right.x + left.y * right.y
+
+
+def _normalize(vector: CanvasPosition) -> CanvasPosition:
+ magnitude = hypot(vector.x, vector.y)
+ assert magnitude > 1e-9
+ return CanvasPosition(x=vector.x / magnitude, y=vector.y / magnitude)
+
+
def _svg_text_content(svg: str) -> list[str]:
root = ET.fromstring(svg)
text_nodes = root.findall(".//{http://www.w3.org/2000/svg}text")
@@ -166,6 +483,146 @@ def test_academic_svg_and_tikz_exports_use_tensor_circles_and_dangling_ports() -
assert r"\draw[tne open index]" in tikz
+def test_export_geometry_prefers_perpendicular_free_index_directions_for_linear_chain() -> (
+ None
+):
+ spec = build_three_tensor_spec()
+ renderer = _SvgRenderer(spec, SvgRenderOptions())
+
+ left_tensor = spec.tensors[0]
+ left_index = left_tensor.indices[0]
+ direction = renderer._index_direction(left_tensor, left_index)
+ source = renderer.connection_point(left_tensor, left_index)
+ target = renderer.open_index_endpoint(left_tensor, left_index)
+
+ assert abs(direction.x) < 0.25
+ assert abs(direction.y) > 0.9
+ assert hypot(target.x - source.x, target.y - source.y) == pytest.approx(
+ 2.0 * renderer.tensor_radius(left_tensor)
+ )
+
+
+def test_export_geometry_respects_vertical_linear_chain_orientation() -> None:
+ spec = _build_vertical_three_tensor_spec()
+ renderer = _SvgRenderer(spec, SvgRenderOptions())
+
+ first_tensor = spec.tensors[0]
+ free_index = first_tensor.indices[0]
+ direction = renderer._index_direction(first_tensor, free_index)
+
+ assert abs(direction.x) > 0.9
+ assert abs(direction.y) < 0.25
+
+
+def test_export_geometry_prefers_linear_component_orientation_over_named_hints() -> (
+ None
+):
+ spec = _build_vertical_three_tensor_named_hint_spec()
+ renderer = _SvgRenderer(spec, SvgRenderOptions())
+
+ first_tensor = spec.tensors[0]
+ free_index = first_tensor.indices[0]
+ direction = renderer._index_direction(first_tensor, free_index)
+
+ assert abs(direction.x) > 0.9
+ assert abs(direction.y) < 0.25
+
+
+def test_export_geometry_respects_diagonal_linear_chain_orientation() -> None:
+ spec = _build_diagonal_three_tensor_spec()
+ renderer = _SvgRenderer(spec, SvgRenderOptions())
+
+ first_tensor = spec.tensors[0]
+ free_index = first_tensor.indices[0]
+ direction = renderer._index_direction(first_tensor, free_index)
+ chain_axis = _normalize(CanvasPosition(x=1.0, y=1.0))
+ diagonal_perpendicular = _normalize(CanvasPosition(x=-1.0, y=1.0))
+
+ assert abs(_dot(direction, chain_axis)) < 0.25
+ assert abs(_dot(direction, diagonal_perpendicular)) > 0.9
+
+
+def test_export_geometry_prefers_vertical_mpo_component_orientation_over_index_names() -> (
+ None
+):
+ spec = _build_vertical_mpo_export_spec()
+ renderer = _SvgRenderer(spec, SvgRenderOptions())
+ first_tensor = spec.tensors[0]
+ bra_index = next(index for index in first_tensor.indices if index.name == "bra")
+ ket_index = next(index for index in first_tensor.indices if index.name == "ket")
+ bra_direction = renderer._index_direction(first_tensor, bra_index)
+ ket_direction = renderer._index_direction(first_tensor, ket_index)
+
+ assert abs(bra_direction.x) > 0.9
+ assert abs(ket_direction.x) > 0.9
+ assert abs(bra_direction.y) < 0.25
+ assert abs(ket_direction.y) < 0.25
+ assert _dot(bra_direction, ket_direction) < -0.85
+
+
+def test_export_geometry_points_cycle_free_indices_outward() -> None:
+ spec = _build_cycle_spec()
+ renderer = _SvgRenderer(spec, SvgRenderOptions())
+ cycle_center = CanvasPosition(x=200.0, y=200.0)
+
+ for tensor in spec.tensors:
+ free_index = tensor.indices[0]
+ direction = renderer._index_direction(tensor, free_index)
+ radial = _normalize(
+ CanvasPosition(
+ x=tensor.position.x - cycle_center.x,
+ y=tensor.position.y - cycle_center.y,
+ )
+ )
+ assert _dot(direction, radial) > 0.85
+
+
+def test_export_geometry_points_grid_boundary_free_indices_outward() -> None:
+ spec = _build_grid_export_spec()
+ renderer = _SvgRenderer(spec, SvgRenderOptions())
+ expectations = {
+ "tensor_0_1": CanvasPosition(x=0.0, y=-1.0),
+ "tensor_1_0": CanvasPosition(x=-1.0, y=0.0),
+ "tensor_1_2": CanvasPosition(x=1.0, y=0.0),
+ "tensor_2_1": CanvasPosition(x=0.0, y=1.0),
+ }
+
+ for tensor_id, expected_direction in expectations.items():
+ tensor = next(tensor for tensor in spec.tensors if tensor.id == tensor_id)
+ free_index = tensor.indices[0]
+ direction = renderer._index_direction(tensor, free_index)
+ assert _dot(direction, expected_direction) > 0.85
+
+
+def test_export_geometry_points_rotated_grid_boundary_free_indices_outward() -> None:
+ spec = _build_rotated_grid_export_spec()
+ renderer = _SvgRenderer(spec, SvgRenderOptions())
+ expectations = {
+ "tensor_0_1": _normalize(CanvasPosition(x=1.0, y=-1.0)),
+ "tensor_1_0": _normalize(CanvasPosition(x=-1.0, y=-1.0)),
+ "tensor_1_2": _normalize(CanvasPosition(x=1.0, y=1.0)),
+ "tensor_2_1": _normalize(CanvasPosition(x=-1.0, y=1.0)),
+ }
+
+ for tensor_id, expected_direction in expectations.items():
+ tensor = next(tensor for tensor in spec.tensors if tensor.id == tensor_id)
+ free_index = tensor.indices[0]
+ direction = renderer._index_direction(tensor, free_index)
+ assert _dot(direction, expected_direction) > 0.85
+
+
+def test_export_geometry_generic_free_indices_point_away_from_local_neighbors() -> None:
+ spec = _build_generic_export_spec()
+ renderer = _SvgRenderer(spec, SvgRenderOptions())
+ center_tensor = spec.tensors[0]
+ free_index = center_tensor.indices[0]
+
+ direction = renderer._index_direction(center_tensor, free_index)
+ away_from_neighbors = _normalize(CanvasPosition(x=-180.0, y=-120.0))
+
+ assert _dot(direction, away_from_neighbors) > 0.75
+
+
def test_academic_svg_tikz_and_dot_preserve_entity_colors_and_parallel_edges() -> None:
spec = _build_colored_parallel_edge_spec()
@@ -283,12 +740,20 @@ def test_academic_parallel_edges_curve_far_enough_to_separate_three_bonds() -> N
def test_academic_edges_reach_tensor_centers_in_svg_and_tikz() -> None:
spec = _assign_demo_index_offsets()
- edge_render_infos = _SvgRenderer(spec, SvgRenderOptions())._edge_render_infos()
+ renderer = _SvgRenderer(spec, SvgRenderOptions())
+ edge_render_infos = renderer._edge_render_infos()
+ bounds = renderer._compute_bounds(edge_render_infos)
tikz = render_spec_tikz(spec)
assert edge_render_infos[0].source == spec.tensors[0].position
assert edge_render_infos[0].target == spec.tensors[1].position
- assert "(150, 116) -- (390, 116)" in tikz
+ expected_segment = (
+ f"({_number(edge_render_infos[0].source.x - bounds.x1)}, "
+ f"{_number(bounds.y2 - edge_render_infos[0].source.y)}) -- "
+ f"({_number(edge_render_infos[0].target.x - bounds.x1)}, "
+ f"{_number(bounds.y2 - edge_render_infos[0].target.y)})"
+ )
+ assert expected_segment in tikz
def test_academic_svg_renderer_can_hide_tensor_index_and_bond_labels() -> None:
@@ -335,6 +800,20 @@ def test_render_spec_svg_writes_output_path(tmp_path: Path) -> None:
assert output_path.read_text(encoding="utf-8") == svg
+def test_render_spec_svg_omits_solid_background_when_transparent() -> None:
+ pytest.importorskip("matplotlib")
+
+ svg = render_spec_svg(
+ build_sample_spec(),
+ options=SvgRenderOptions(
+ background="#abcdef",
+ transparent_background=True,
+ ),
+ )
+
+ assert "#abcdef" not in svg
+
+
def test_render_spec_svg_reuses_edge_geometry_within_one_render(
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -361,6 +840,46 @@ def counting_edge_render_infos(self: Any) -> list[Any]:
assert edge_render_info_call_count == 1
+def test_render_spec_svg_reuses_component_axis_geometry_within_one_render(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ import tensor_network_editor.rendering as rendering_module
+
+ pytest.importorskip("matplotlib")
+ spec = build_template_spec(
+ "mps",
+ TemplateParameters(
+ graph_size=12,
+ bond_dimension=3,
+ physical_dimension=2,
+ boundary_condition="open",
+ initial_state="zeros",
+ ),
+ )
+ component_primary_axis_call_count = 0
+ original_component_primary_axis = (
+ rendering_module._SvgRenderer._component_primary_axis
+ )
+
+ def counting_component_primary_axis(
+ self: Any,
+ component_tensors: list[TensorSpec],
+ ) -> CanvasPosition:
+ nonlocal component_primary_axis_call_count
+ component_primary_axis_call_count += 1
+ return original_component_primary_axis(self, component_tensors)
+
+ monkeypatch.setattr(
+ rendering_module._SvgRenderer,
+ "_component_primary_axis",
+ counting_component_primary_axis,
+ )
+
+ render_spec_svg(spec)
+
+ assert component_primary_axis_call_count == 1
+
+
def test_render_spec_svg_keeps_labels_as_svg_text_elements() -> None:
pytest.importorskip("matplotlib")
@@ -549,6 +1068,110 @@ def reject_matplotlib_modules() -> tuple[object, object, object, object]:
rendering_module.render_spec_pdf(build_sample_spec())
+def test_load_matplotlib_modules_memoizes_imports(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ import tensor_network_editor.rendering as rendering_module
+
+ if hasattr(rendering_module._load_matplotlib_modules, "cache_clear"):
+ rendering_module._load_matplotlib_modules.cache_clear()
+ import_call_counts: dict[str, int] = {}
+ original_import_module = rendering_module.import_module
+
+ def counting_import_module(name: str) -> Any:
+ import_call_counts[name] = import_call_counts.get(name, 0) + 1
+ return original_import_module(name)
+
+ monkeypatch.setattr(
+ rendering_module,
+ "import_module",
+ counting_import_module,
+ )
+
+ first_modules = rendering_module._load_matplotlib_modules()
+ second_modules = rendering_module._load_matplotlib_modules()
+
+ assert second_modules == first_modules
+ assert import_call_counts == {
+ "matplotlib": 1,
+ "matplotlib.pyplot": 1,
+ "matplotlib.patches": 1,
+ "matplotlib.path": 1,
+ }
+
+
+def test_validate_positive_render_scale_normalizes_and_rejects_invalid_values() -> None:
+ import tensor_network_editor.rendering as rendering_module
+
+ assert rendering_module._validate_positive_render_scale(
+ 2,
+ description="PNG render scale",
+ ) == pytest.approx(2.0)
+ assert rendering_module._validate_positive_render_scale(
+ 1.5,
+ description="TikZ render scale",
+ ) == pytest.approx(1.5)
+
+ for invalid_scale in (True, 0, -1, float("inf"), float("nan"), "2"):
+ with pytest.raises(
+ ValueError,
+ match="PNG render scale must be a positive finite number.",
+ ):
+ rendering_module._validate_positive_render_scale(
+ invalid_scale,
+ description="PNG render scale",
+ )
+
+
+def test_render_spec_output_validates_renders_and_writes_output(
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ import tensor_network_editor.rendering as rendering_module
+
+ spec = build_sample_spec()
+ validated_spec = build_three_tensor_spec()
+ output_path = tmp_path / "network.svg"
+ calls: dict[str, Any] = {}
+
+ def fake_validate(received_spec: NetworkSpec) -> NetworkSpec:
+ calls["validate"] = received_spec
+ return validated_spec
+
+ def fake_render(
+ received_spec: NetworkSpec,
+ received_options: SvgRenderOptions,
+ ) -> str:
+ calls["render"] = (received_spec, received_options)
+ return ""
+
+ def fake_write(
+ path: Path,
+ content: str,
+ *,
+ description: str,
+ ) -> None:
+ calls["write"] = (path, content, description)
+
+ monkeypatch.setattr(rendering_module, "ensure_valid_spec", fake_validate)
+ options = SvgRenderOptions(show_tensor_labels=False)
+
+ rendered = rendering_module._render_spec_output(
+ spec,
+ format_name="svg",
+ options=options,
+ output_path=output_path,
+ description="SVG network rendering",
+ render=fake_render,
+ writer=fake_write,
+ )
+
+ assert rendered == ""
+ assert calls["validate"] is spec
+ assert calls["render"] == (validated_spec, options)
+ assert calls["write"] == (output_path, "", "SVG network rendering")
+
+
def test_render_spec_png_returns_png_bytes_and_writes_output_path(
tmp_path: Path,
) -> None:
@@ -563,6 +1186,19 @@ def test_render_spec_png_returns_png_bytes_and_writes_output_path(
assert output_path.read_bytes() == png_bytes
+def test_render_spec_png_uses_alpha_channel_when_transparent() -> None:
+ pytest.importorskip("matplotlib")
+ from tensor_network_editor.rendering import render_spec_png
+
+ png_bytes = render_spec_png(
+ build_sample_spec(),
+ options=SvgRenderOptions(transparent_background=True),
+ )
+
+ assert png_bytes[12:16] == b"IHDR"
+ assert png_bytes[25] == 6
+
+
def test_render_spec_pdf_returns_pdf_bytes_and_writes_output_path(
tmp_path: Path,
) -> None:
diff --git a/tests/test_scripts.py b/tests/test_scripts.py
index 703dc40..3848faa 100644
--- a/tests/test_scripts.py
+++ b/tests/test_scripts.py
@@ -68,6 +68,9 @@ def seed_generated_artifacts(root: Path) -> None:
root / ".coverage",
root / ".coverage.unit",
root / "coverage.xml",
+ root / "session.log",
+ root / "session.log.1",
+ root / "session.log.7",
]
for file_path in files_to_create:
file_path.write_text("temporary", encoding="utf-8")
@@ -91,6 +94,9 @@ def assert_cleanup_removed_artifacts(root: Path) -> None:
root / ".coverage",
root / ".coverage.unit",
root / "coverage.xml",
+ root / "session.log",
+ root / "session.log.1",
+ root / "session.log.7",
]
for path in removed_paths:
assert not path.exists()
diff --git a/tests/test_session.py b/tests/test_session.py
index 0a5ed5c..53dc03c 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -3,6 +3,7 @@
import logging
import signal
import threading
+from base64 import b64encode
from collections.abc import Iterator
from importlib import import_module
from pathlib import Path
@@ -15,6 +16,7 @@
from tensor_network_editor.app._protocol import JsonDict
from tensor_network_editor.app.session import (
EditorSession,
+ _PywebviewExportApi,
build_blank_network_spec,
wait_for_editor_result,
)
@@ -628,6 +630,643 @@ class FakeThread:
assert "http://127.0.0.1:43210" in captured
+def test_launch_editor_session_pywebview_requires_main_thread(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ from tensor_network_editor.app import session as session_module
+
+ class FakeThread:
+ name = "worker"
+
+ monkeypatch.setattr(
+ session_module.threading, "current_thread", lambda: FakeThread()
+ )
+
+ with pytest.raises(RuntimeError, match="pywebview mode must be launched"):
+ session_module.launch_editor_session(ui_mode="pywebview")
+
+
+def test_launch_editor_session_pywebview_missing_dependency_raises_clear_error(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ from tensor_network_editor.app import session as session_module
+
+ class FakeEditorServer:
+ def __init__(self, *args: object, **kwargs: object) -> None:
+ del args, kwargs
+ self.base_url = "http://127.0.0.1:43210"
+
+ def start(self) -> None:
+ return None
+
+ def stop(self) -> None:
+ return None
+
+ class FakeMainThread:
+ name = "MainThread"
+
+ main_thread = FakeMainThread()
+ monkeypatch.setattr(
+ "tensor_network_editor.app.server.EditorServer",
+ FakeEditorServer,
+ )
+ monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread)
+ monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread)
+ monkeypatch.setattr(
+ session_module,
+ "_import_pywebview",
+ lambda: (_ for _ in ()).throw(ModuleNotFoundError("No module named 'webview'")),
+ )
+
+ with pytest.raises(RuntimeError, match="tensor-network-editor\\[desktop\\]"):
+ session_module.launch_editor_session(ui_mode="pywebview")
+
+
+def test_launch_editor_session_pywebview_closes_window_after_completion(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ from tensor_network_editor.app import session as session_module
+
+ completed_result = EditorResult(
+ spec=build_blank_network_spec(),
+ engine=EngineName.EINSUM_NUMPY,
+ confirmed=True,
+ )
+
+ class FakeEventHook:
+ def __init__(self) -> None:
+ self._callbacks: list[object] = []
+
+ def __iadd__(self, callback: object) -> FakeEventHook:
+ self._callbacks.append(callback)
+ return self
+
+ def fire(self) -> None:
+ for callback in list(self._callbacks):
+ cast(Any, callback)()
+
+ class FakeWindowEvents:
+ def __init__(self) -> None:
+ self.before_show = FakeEventHook()
+ self.closed = FakeEventHook()
+
+ class FakeWindow:
+ def __init__(self) -> None:
+ self.events = FakeWindowEvents()
+ self.destroy_calls = 0
+
+ def destroy(self) -> None:
+ self.destroy_calls += 1
+
+ class FakePywebview:
+ def __init__(self) -> None:
+ self.created_urls: list[str] = []
+ self.created_maximized: list[bool] = []
+ self.created_js_apis: list[object] = []
+ self.window = FakeWindow()
+ self.start_calls = 0
+
+ def create_window(
+ self,
+ title: str,
+ url: str,
+ *,
+ maximized: bool = False,
+ js_api: object | None = None,
+ ) -> FakeWindow:
+ assert title == "Tensor Network Editor"
+ self.created_urls.append(url)
+ self.created_maximized.append(maximized)
+ self.created_js_apis.append(js_api)
+ return self.window
+
+ def start(self, callback: object, window: FakeWindow) -> None:
+ self.start_calls += 1
+ self.window.events.before_show.fire()
+ cast(Any, callback)(window)
+
+ class FakeEditorServer:
+ def __init__(self, *args: object, **kwargs: object) -> None:
+ del args, kwargs
+ self.base_url = "http://127.0.0.1:43210"
+
+ def start(self) -> None:
+ return None
+
+ def stop(self) -> None:
+ return None
+
+ class FakeMainThread:
+ name = "MainThread"
+
+ main_thread = FakeMainThread()
+ pywebview = FakePywebview()
+ monkeypatch.setattr(
+ "tensor_network_editor.app.server.EditorServer",
+ FakeEditorServer,
+ )
+ monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread)
+ monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread)
+ monkeypatch.setattr(session_module, "_import_pywebview", lambda: pywebview)
+ monkeypatch.setattr(
+ session_module,
+ "wait_for_editor_result",
+ lambda _session: completed_result,
+ )
+
+ result = session_module.launch_editor_session(ui_mode="pywebview")
+
+ assert result is completed_result
+ assert pywebview.created_urls == ["http://127.0.0.1:43210"]
+ assert pywebview.created_maximized == [True]
+ assert len(pywebview.created_js_apis) == 1
+ assert isinstance(pywebview.created_js_apis[0], _PywebviewExportApi)
+ assert pywebview.start_calls == 1
+ assert pywebview.window.destroy_calls == 1
+
+
+def test_launch_editor_session_pywebview_applies_native_icon_before_show(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ from tensor_network_editor.app import session as session_module
+
+ completed_result = EditorResult(
+ spec=build_blank_network_spec(),
+ engine=EngineName.EINSUM_NUMPY,
+ confirmed=True,
+ )
+
+ class FakeEventHook:
+ def __init__(self) -> None:
+ self._callbacks: list[object] = []
+
+ def __iadd__(self, callback: object) -> FakeEventHook:
+ self._callbacks.append(callback)
+ return self
+
+ def fire(self) -> None:
+ for callback in list(self._callbacks):
+ cast(Any, callback)()
+
+ class FakeWindowEvents:
+ def __init__(self) -> None:
+ self.before_show = FakeEventHook()
+ self.closed = FakeEventHook()
+
+ class FakeNativeWindow:
+ Icon = None
+
+ class FakeWindow:
+ def __init__(self) -> None:
+ self.events = FakeWindowEvents()
+ self.native = FakeNativeWindow()
+
+ def destroy(self) -> None:
+ return None
+
+ class FakePywebview:
+ def __init__(self) -> None:
+ self.window = FakeWindow()
+
+ def create_window(
+ self,
+ title: str,
+ url: str,
+ *,
+ maximized: bool = False,
+ js_api: object | None = None,
+ ) -> FakeWindow:
+ del title, url, maximized, js_api
+ return self.window
+
+ def start(self, callback: object, window: FakeWindow) -> None:
+ self.window.events.before_show.fire()
+ cast(Any, callback)(window)
+
+ class FakeEditorServer:
+ def __init__(self, *args: object, **kwargs: object) -> None:
+ del args, kwargs
+ self.base_url = "http://127.0.0.1:43210"
+
+ def start(self) -> None:
+ return None
+
+ def stop(self) -> None:
+ return None
+
+ class FakeMainThread:
+ name = "MainThread"
+
+ applied_windows: list[object] = []
+ main_thread = FakeMainThread()
+ monkeypatch.setattr(
+ "tensor_network_editor.app.server.EditorServer",
+ FakeEditorServer,
+ )
+ monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread)
+ monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread)
+ monkeypatch.setattr(session_module, "_import_pywebview", lambda: FakePywebview())
+ monkeypatch.setattr(
+ session_module,
+ "wait_for_editor_result",
+ lambda _session: completed_result,
+ )
+ monkeypatch.setattr(
+ session_module,
+ "_apply_pywebview_native_window_icon",
+ lambda window: applied_windows.append(window),
+ )
+
+ result = session_module.launch_editor_session(ui_mode="pywebview")
+
+ assert result is completed_result
+ assert len(applied_windows) == 1
+ assert isinstance(applied_windows[0], FakeWindow)
+
+
+def test_launch_editor_session_pywebview_applies_native_icon_without_before_show(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ from tensor_network_editor.app import session as session_module
+
+ completed_result = EditorResult(
+ spec=build_blank_network_spec(),
+ engine=EngineName.EINSUM_NUMPY,
+ confirmed=True,
+ )
+
+ class FakeEventHook:
+ def __init__(self) -> None:
+ self._callbacks: list[object] = []
+
+ def __iadd__(self, callback: object) -> FakeEventHook:
+ self._callbacks.append(callback)
+ return self
+
+ def fire(self) -> None:
+ for callback in list(self._callbacks):
+ cast(Any, callback)()
+
+ class FakeWindowEvents:
+ def __init__(self) -> None:
+ self.closed = FakeEventHook()
+
+ class FakeWindow:
+ def __init__(self) -> None:
+ self.events = FakeWindowEvents()
+
+ def destroy(self) -> None:
+ return None
+
+ class FakePywebview:
+ def __init__(self) -> None:
+ self.window = FakeWindow()
+
+ def create_window(
+ self,
+ title: str,
+ url: str,
+ *,
+ maximized: bool = False,
+ js_api: object | None = None,
+ ) -> FakeWindow:
+ del title, url, maximized, js_api
+ return self.window
+
+ def start(self, callback: object, window: FakeWindow) -> None:
+ cast(Any, callback)(window)
+
+ class FakeEditorServer:
+ def __init__(self, *args: object, **kwargs: object) -> None:
+ del args, kwargs
+ self.base_url = "http://127.0.0.1:43210"
+
+ def start(self) -> None:
+ return None
+
+ def stop(self) -> None:
+ return None
+
+ class FakeMainThread:
+ name = "MainThread"
+
+ applied_windows: list[object] = []
+ main_thread = FakeMainThread()
+ monkeypatch.setattr(
+ "tensor_network_editor.app.server.EditorServer",
+ FakeEditorServer,
+ )
+ monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread)
+ monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread)
+ monkeypatch.setattr(session_module, "_import_pywebview", lambda: FakePywebview())
+ monkeypatch.setattr(
+ session_module,
+ "wait_for_editor_result",
+ lambda _session: completed_result,
+ )
+ monkeypatch.setattr(
+ session_module,
+ "_apply_pywebview_native_window_icon",
+ lambda window: applied_windows.append(window),
+ )
+
+ result = session_module.launch_editor_session(ui_mode="pywebview")
+
+ assert result is completed_result
+ assert len(applied_windows) == 1
+ assert isinstance(applied_windows[0], FakeWindow)
+
+
+def test_launch_editor_session_pywebview_tolerates_missing_closed_event(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ from tensor_network_editor.app import session as session_module
+
+ completed_result = EditorResult(
+ spec=build_blank_network_spec(),
+ engine=EngineName.EINSUM_NUMPY,
+ confirmed=True,
+ )
+
+ class FakeEventHook:
+ def __init__(self) -> None:
+ self._callbacks: list[object] = []
+
+ def __iadd__(self, callback: object) -> FakeEventHook:
+ self._callbacks.append(callback)
+ return self
+
+ def fire(self) -> None:
+ for callback in list(self._callbacks):
+ cast(Any, callback)()
+
+ class FakeWindowEvents:
+ def __init__(self) -> None:
+ self.before_show = FakeEventHook()
+
+ class FakeWindow:
+ def __init__(self) -> None:
+ self.events = FakeWindowEvents()
+ self.destroy_calls = 0
+
+ def destroy(self) -> None:
+ self.destroy_calls += 1
+
+ class FakePywebview:
+ def __init__(self) -> None:
+ self.window = FakeWindow()
+
+ def create_window(
+ self,
+ title: str,
+ url: str,
+ *,
+ maximized: bool = False,
+ js_api: object | None = None,
+ ) -> FakeWindow:
+ del title, url, maximized, js_api
+ return self.window
+
+ def start(self, callback: object, window: FakeWindow) -> None:
+ self.window.events.before_show.fire()
+ cast(Any, callback)(window)
+
+ class FakeEditorServer:
+ def __init__(self, *args: object, **kwargs: object) -> None:
+ del args, kwargs
+ self.base_url = "http://127.0.0.1:43210"
+
+ def start(self) -> None:
+ return None
+
+ def stop(self) -> None:
+ return None
+
+ class FakeMainThread:
+ name = "MainThread"
+
+ applied_windows: list[object] = []
+ main_thread = FakeMainThread()
+ pywebview = FakePywebview()
+ monkeypatch.setattr(
+ "tensor_network_editor.app.server.EditorServer",
+ FakeEditorServer,
+ )
+ monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread)
+ monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread)
+ monkeypatch.setattr(session_module, "_import_pywebview", lambda: pywebview)
+ monkeypatch.setattr(
+ session_module,
+ "wait_for_editor_result",
+ lambda _session: completed_result,
+ )
+ monkeypatch.setattr(
+ session_module,
+ "_apply_pywebview_native_window_icon",
+ lambda window: applied_windows.append(window),
+ )
+
+ result = session_module.launch_editor_session(ui_mode="pywebview")
+
+ assert result is completed_result
+ assert applied_windows == [pywebview.window]
+ assert pywebview.window.destroy_calls == 1
+
+
+def test_launch_editor_session_pywebview_window_close_cancels_session(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ from tensor_network_editor.app import session as session_module
+
+ class FakeEventHook:
+ def __init__(self) -> None:
+ self._callbacks: list[object] = []
+
+ def __iadd__(self, callback: object) -> FakeEventHook:
+ self._callbacks.append(callback)
+ return self
+
+ def fire(self) -> None:
+ for callback in list(self._callbacks):
+ cast(Any, callback)()
+
+ class FakeWindowEvents:
+ def __init__(self) -> None:
+ self.before_show = FakeEventHook()
+ self.closed = FakeEventHook()
+
+ class FakeWindow:
+ def __init__(self) -> None:
+ self.events = FakeWindowEvents()
+
+ def destroy(self) -> None:
+ return None
+
+ class FakePywebview:
+ def __init__(self) -> None:
+ self.window = FakeWindow()
+
+ def create_window(
+ self,
+ title: str,
+ url: str,
+ *,
+ maximized: bool = False,
+ js_api: object | None = None,
+ ) -> FakeWindow:
+ del title, url, maximized, js_api
+ return self.window
+
+ def start(self, callback: object, window: FakeWindow) -> None:
+ del callback, window
+ self.window.events.closed.fire()
+
+ class FakeEditorServer:
+ def __init__(self, *args: object, **kwargs: object) -> None:
+ del args, kwargs
+ self.base_url = "http://127.0.0.1:43210"
+
+ def start(self) -> None:
+ return None
+
+ def stop(self) -> None:
+ return None
+
+ class FakeMainThread:
+ name = "MainThread"
+
+ main_thread = FakeMainThread()
+ monkeypatch.setattr(
+ "tensor_network_editor.app.server.EditorServer",
+ FakeEditorServer,
+ )
+ monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread)
+ monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread)
+ monkeypatch.setattr(
+ session_module,
+ "_import_pywebview",
+ lambda: FakePywebview(),
+ )
+
+ result = session_module.launch_editor_session(ui_mode="pywebview")
+
+ assert result is None
+
+
+def test_pywebview_export_api_writes_text_file_to_selected_path(
+ tmp_path: Path,
+) -> None:
+ output_path = tmp_path / "demo.json"
+
+ class FakePywebview:
+ SAVE_DIALOG = object()
+
+ class FakeWindow:
+ def __init__(self) -> None:
+ self.dialog_calls: list[dict[str, object]] = []
+
+ def create_file_dialog(
+ self,
+ dialog_type: object,
+ *,
+ save_filename: str,
+ file_types: tuple[str, ...],
+ ) -> tuple[str]:
+ self.dialog_calls.append(
+ {
+ "dialog_type": dialog_type,
+ "save_filename": save_filename,
+ "file_types": file_types,
+ }
+ )
+ return (str(output_path),)
+
+ api = _PywebviewExportApi(FakePywebview())
+ window = FakeWindow()
+ api.bind_window(window)
+
+ saved = api.save_text_file(
+ "demo.json",
+ '{\n "ok": true\n}\n',
+ "application/json;charset=utf-8",
+ )
+
+ assert saved is True
+ assert output_path.read_text(encoding="utf-8") == '{\n "ok": true\n}\n'
+ assert window.dialog_calls == [
+ {
+ "dialog_type": FakePywebview.SAVE_DIALOG,
+ "save_filename": "demo.json",
+ "file_types": ("JSON (*.json)",),
+ }
+ ]
+
+
+def test_pywebview_export_api_returns_false_when_save_dialog_is_cancelled(
+ tmp_path: Path,
+) -> None:
+ output_path = tmp_path / "demo.json"
+
+ class FakePywebview:
+ SAVE_DIALOG = object()
+
+ class FakeWindow:
+ def create_file_dialog(
+ self,
+ dialog_type: object,
+ *,
+ save_filename: str,
+ file_types: tuple[str, ...],
+ ) -> tuple[str, ...]:
+ del dialog_type, save_filename, file_types
+ return ()
+
+ api = _PywebviewExportApi(FakePywebview())
+ api.bind_window(FakeWindow())
+
+ saved = api.save_text_file(
+ "demo.json",
+ '{"ok": true}',
+ "application/json;charset=utf-8",
+ )
+
+ assert saved is False
+ assert output_path.exists() is False
+
+
+def test_pywebview_export_api_writes_binary_file_to_selected_path(
+ tmp_path: Path,
+) -> None:
+ output_path = tmp_path / "demo.pdf"
+ binary_payload = b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n"
+
+ class FakePywebview:
+ SAVE_DIALOG = object()
+
+ class FakeWindow:
+ def create_file_dialog(
+ self,
+ dialog_type: object,
+ *,
+ save_filename: str,
+ file_types: tuple[str, ...],
+ ) -> tuple[str]:
+ del dialog_type, save_filename, file_types
+ return (str(output_path),)
+
+ api = _PywebviewExportApi(FakePywebview())
+ api.bind_window(FakeWindow())
+
+ saved = api.save_binary_file(
+ "demo.pdf",
+ b64encode(binary_payload).decode("ascii"),
+ "application/pdf",
+ )
+
+ assert saved is True
+ assert output_path.read_bytes() == binary_payload
+
+
def test_open_editor_passes_template_catalog_path(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
diff --git a/tests/test_template_catalog_internal.py b/tests/test_template_catalog_internal.py
index 1967cb3..e1e31b5 100644
--- a/tests/test_template_catalog_internal.py
+++ b/tests/test_template_catalog_internal.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import importlib
+
import pytest
from tensor_network_editor.internal.models._model_tensor_data import TensorDataMode
@@ -8,6 +10,8 @@
build_template,
)
from tensor_network_editor.internal.templates._template_catalog import (
+ _reset_template_registry_for_tests,
+ get_template_builder,
get_template_definition,
list_template_names,
serialize_template_definitions,
@@ -108,6 +112,66 @@ def test_template_builders_internal_dispatches_to_specific_builder() -> None:
assert len(spec.tensors) == 5
+def test_template_builder_facade_reexports_family_modules() -> None:
+ try:
+ linear_module = importlib.import_module(
+ "tensor_network_editor.internal.templates._template_builders_linear"
+ )
+ grid_module = importlib.import_module(
+ "tensor_network_editor.internal.templates._template_builders_grid"
+ )
+ tree_module = importlib.import_module(
+ "tensor_network_editor.internal.templates._template_builders_tree"
+ )
+ except ModuleNotFoundError as exc:
+ pytest.fail(f"Expected split template-builder modules to exist: {exc}")
+
+ _reset_template_registry_for_tests()
+
+ assert _build_linear_chain_template is linear_module._build_linear_chain_template
+ assert get_template_builder("mps").__module__ == linear_module.__name__
+ assert get_template_builder("peps_2x2").__module__ == grid_module.__name__
+ assert get_template_builder("mera").__module__ == tree_module.__name__
+
+
+def test_template_builder_common_module_exposes_shared_primitives() -> None:
+ try:
+ common_module = importlib.import_module(
+ "tensor_network_editor.internal.templates._template_builders_common"
+ )
+ except ModuleNotFoundError as exc:
+ pytest.fail(f"Expected shared template-builder primitives module: {exc}")
+
+ left_tensor = common_module._make_tensor(
+ "tensor_left",
+ "Left",
+ 10.0,
+ 20.0,
+ [("right", 3, (58.0, 0.0))],
+ )
+ right_tensor = common_module._make_tensor(
+ "tensor_right",
+ "Right",
+ 40.0,
+ 20.0,
+ [("left", 3, (-58.0, 0.0))],
+ )
+ edge = common_module._make_edge(
+ "edge_left_right",
+ left_tensor,
+ "right",
+ right_tensor,
+ "left",
+ )
+
+ assert left_tensor.indices[0].id == "tensor_left_right"
+ assert right_tensor.indices[0].id == "tensor_right_left"
+ assert edge.left.tensor_id == "tensor_left"
+ assert edge.left.index_id == "tensor_left_right"
+ assert edge.right.tensor_id == "tensor_right"
+ assert edge.right.index_id == "tensor_right_left"
+
+
def test_linear_chain_template_helper_reuses_catalog_metadata() -> None:
spec = _build_linear_chain_template(
"mpo",