Skip to content

Commit 0bc4878

Browse files
authored
Fix BeamLine name attribute (pals-project#16)
### Overview With the implementation in the main branch, if we add a `name` attribute to the `line` object created in the FODO example, the attribute is ignored and the output is ```yaml kind: Line line: - drift1: kind: Drift length: 0.25 - quad1: MagneticMultipoleP: Bn1: 1.0 kind: Quadrupole length: 1.0 - drift2: kind: Drift length: 0.5 - quad2: MagneticMultipoleP: Bn1: -1.0 kind: Quadrupole length: 1.0 - drift3: kind: Drift length: 0.5 ``` This is not consistent with the example in https://github.com/campa-consortium/pals/blob/main/examples/fodo.pals.yaml. With the implementation in this branch, if we add a `name` attribute to the `line` object created in the FODO example, the attribute is stored and the output is ```yaml fodo_cell: kind: BeamLine line: - drift1: kind: Drift length: 0.25 - quad1: MagneticMultipoleP: Bn1: 1.0 kind: Quadrupole length: 1.0 - drift2: kind: Drift length: 0.5 - quad2: MagneticMultipoleP: Bn1: -1.0 kind: Quadrupole length: 1.0 - drift3: kind: Drift length: 0.5 ``` which is consistent with the example in https://github.com/campa-consortium/pals/blob/main/examples/fodo.pals.yaml. This is achieved primarily by deriving the `Line` (`BeamLine` after pals-project#15) class from the base element class `BaseElement` instead of deriving it from Pydantic's base model class `BaseModel`. I also merged the custom field validation of the `line` field with the custom model validation, since we needed one anyways to fix the deserialization. I don't see why they should be separate. I wonder if the right way to do this is using `@model_serializer` instead of `@model_validator`... ### To do - [x] Serialization (fixed with [c391e43](pals-project@c391e43), very few lines changed) - [x] Deserialization (fixed with the custom model validator, many more lines changed) - [x] Cleaning
1 parent f7b9dbb commit 0bc4878

File tree

4 files changed

+58
-44
lines changed

4 files changed

+58
-44
lines changed

examples/fodo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,14 @@ def main():
4141
)
4242
# Create line with all elements
4343
line = BeamLine(
44+
name="fodo_cell",
4445
line=[
4546
drift1,
4647
quad1,
4748
drift2,
4849
quad2,
4950
drift3,
50-
]
51+
],
5152
)
5253
# Serialize to YAML
5354
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)

src/pals_schema/BaseElement.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,8 @@ def model_dump(self, *args, **kwargs):
2121
name = elem_dict.pop("name", None)
2222
if name is None:
2323
raise ValueError("Element missing 'name' attribute")
24-
data = [{name: elem_dict}]
24+
# Return a dict {name: properties} rather than a single-item list
25+
# This makes the serialized form a plain dict so it can be passed to
26+
# constructors using keyword expansion (e.g., Model(**data))
27+
data = {name: elem_dict}
2528
return data

src/pals_schema/BeamLine.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pydantic import BaseModel, ConfigDict, Field, field_validator
1+
from pydantic import ConfigDict, Field, model_validator
22
from typing import Annotated, List, Literal, Union
33

44
from pals_schema.BaseElement import BaseElement
@@ -7,7 +7,7 @@
77
from pals_schema.QuadrupoleElement import QuadrupoleElement
88

99

10-
class BeamLine(BaseModel):
10+
class BeamLine(BaseElement):
1111
"""A line of elements and/or other lines"""
1212

1313
# Validate every time a new value is assigned to an attribute,
@@ -29,55 +29,65 @@ class BeamLine(BaseModel):
2929
]
3030
]
3131

32-
@field_validator("line", mode="before")
32+
@model_validator(mode="before")
3333
@classmethod
34-
def parse_list_of_dicts(cls, value):
35-
"""This method inserts the key of the one-key dictionary into
36-
the name attribute of the elements"""
37-
if not isinstance(value, list):
38-
raise TypeError("line must be a list")
39-
40-
if value and isinstance(value[0], BaseModel):
41-
# Already a list of models; nothing to do
42-
return value
43-
44-
# we expect a list of dicts or strings
45-
elements = []
46-
for item_dict in value:
47-
# an element is either a reference string to another element or a dict
48-
if isinstance(item_dict, str):
34+
def unpack_yaml_structure(cls, data):
35+
# Handle the top-level one-key dict: unpack the line's name
36+
if isinstance(data, dict) and len(data) == 1:
37+
name, value = list(data.items())[0]
38+
if not isinstance(value, dict):
39+
raise TypeError(
40+
f"Value for line key {name!r} must be a dict, but we got {value!r}"
41+
)
42+
value["name"] = name
43+
data = value
44+
# Handle the 'line' field: unpack each element's name
45+
if "line" not in data:
46+
raise ValueError("'line' field is missing")
47+
if not isinstance(data["line"], list):
48+
raise TypeError("'line' must be a list")
49+
new_line = []
50+
# Loop over all elements in the line
51+
for item in data["line"]:
52+
# An element can be a string that refers to another element
53+
if isinstance(item, str):
4954
raise RuntimeError("Reference/alias elements not yet implemented")
50-
51-
elif isinstance(item_dict, dict):
52-
if not (isinstance(item_dict, dict) and len(item_dict) == 1):
55+
# An element can be a dict
56+
elif isinstance(item, dict):
57+
if not (len(item) == 1):
5358
raise ValueError(
54-
f"Each line element must be a dict with exactly one key, the name of the element, but we got: {item_dict!r}"
59+
f"Each element must be a dict with exactly one key (the element's name), but we got {item!r}"
5560
)
56-
[(name, fields)] = item_dict.items()
57-
61+
name, fields = list(item.items())[0]
5862
if not isinstance(fields, dict):
59-
raise ValueError(
60-
f"Value for element key '{name}' must be a dict (got {fields!r})"
63+
raise TypeError(
64+
f"Value for element key {name!r} must be a dict (the element's properties), but we got {fields!r}"
6165
)
62-
63-
# Insert the name into the fields dict
64-
fields["name"] = name
65-
elements.append(fields)
66-
return elements
66+
fields["name"] = name
67+
new_line.append(fields)
68+
# An element can be an instance of an existing model
69+
elif isinstance(item, BaseElement):
70+
# Nothing to do, keep the element as is
71+
new_line.append(item)
72+
else:
73+
raise TypeError(
74+
f"Value for element key {name!r} must be a reference string or a dict, but we got {item!r}"
75+
)
76+
data["line"] = new_line
77+
return data
6778

6879
def model_dump(self, *args, **kwargs):
6980
"""This makes sure the element name property is moved out and up to a one-key dictionary"""
70-
# Use default dump for non-line fields
81+
# Use base element dump first and return a dict {key: value}, where 'key'
82+
# is the name of the line and 'value' is a dict with all other properties
7183
data = super().model_dump(*args, **kwargs)
72-
73-
# Reformat 'line' field as list of single-key dicts
84+
# Reformat 'line' field as list of element dicts
7485
new_line = []
7586
for elem in self.line:
76-
# Use custom dump for each line element
77-
elem_dict = elem.model_dump(**kwargs)[0]
87+
# Use custom dump for each line element, which now returns a dict
88+
elem_dict = elem.model_dump(**kwargs)
7889
new_line.append(elem_dict)
79-
80-
data["line"] = new_line
90+
data[self.name]["line"] = new_line
8191
return data
8292

8393

tests/test_schema.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,15 @@ def test_QuadrupoleElement():
106106
def test_BeamLine():
107107
# Create first line with one base element
108108
element1 = BaseElement(name="element1")
109-
line1 = BeamLine(line=[element1])
109+
line1 = BeamLine(name="line1", line=[element1])
110110
assert line1.line == [element1]
111111
# Extend first line with one thick element
112112
element2 = ThickElement(name="element2", length=2.0)
113113
line1.line.extend([element2])
114114
assert line1.line == [element1, element2]
115115
# Create second line with one drift element
116116
element3 = DriftElement(name="element3", length=3.0)
117-
line2 = BeamLine(line=[element3])
117+
line2 = BeamLine(name="line2", line=[element3])
118118
# Extend first line with second line
119119
line1.line.extend(line2.line)
120120
assert line1.line == [element1, element2, element3]
@@ -126,7 +126,7 @@ def test_yaml():
126126
# Create one thick element
127127
element2 = ThickElement(name="element2", length=2.0)
128128
# Create line with both elements
129-
line = BeamLine(line=[element1, element2])
129+
line = BeamLine(name="line", line=[element1, element2])
130130
# Serialize the BeamLine object to YAML
131131
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
132132
print(f"\n{yaml_data}")
@@ -151,7 +151,7 @@ def test_json():
151151
# Create one thick element
152152
element2 = ThickElement(name="element2", length=2.0)
153153
# Create line with both elements
154-
line = BeamLine(line=[element1, element2])
154+
line = BeamLine(name="line", line=[element1, element2])
155155
# Serialize the BeamLine object to JSON
156156
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
157157
print(f"\n{json_data}")

0 commit comments

Comments
 (0)