|
1 | | -from pydantic import ConfigDict, Field, model_validator |
2 | | -from typing import Annotated, List, Literal, Union |
| 1 | +from pydantic import model_validator |
| 2 | +from typing import List, Literal |
3 | 3 |
|
4 | | -from .BaseElement import BaseElement |
5 | | -from .ThickElement import ThickElement |
6 | | - |
7 | | -from .ACKicker import ACKicker |
8 | | -from .BeamBeam import BeamBeam |
9 | | -from .BeginningEle import BeginningEle |
10 | | -from .Converter import Converter |
11 | | -from .CrabCavity import CrabCavity |
12 | | -from .Drift import Drift |
13 | | -from .EGun import EGun |
14 | | -from .Feedback import Feedback |
15 | | -from .Fiducial import Fiducial |
16 | | -from .FloorShift import FloorShift |
17 | | -from .Foil import Foil |
18 | | -from .Fork import Fork |
19 | | -from .Girder import Girder |
20 | | -from .Instrument import Instrument |
21 | | -from .Kicker import Kicker |
22 | | -from .Marker import Marker |
23 | | -from .Mask import Mask |
24 | | -from .Match import Match |
25 | | -from .Multipole import Multipole |
26 | | -from .NullEle import NullEle |
27 | | -from .Octupole import Octupole |
28 | | -from .Patch import Patch |
29 | | -from .Quadrupole import Quadrupole |
30 | | -from .RBend import RBend |
31 | | -from .RFCavity import RFCavity |
32 | | -from .SBend import SBend |
33 | | -from .Sextupole import Sextupole |
34 | | -from .Solenoid import Solenoid |
35 | | -from .Taylor import Taylor |
36 | | -from .UnionEle import UnionEle |
37 | | -from .Wiggler import Wiggler |
| 4 | +from .all_elements import get_all_elements_as_annotation |
| 5 | +from .mixin import BaseElement |
38 | 6 |
|
39 | 7 |
|
40 | 8 | class BeamLine(BaseElement): |
41 | 9 | """A line of elements and/or other lines""" |
42 | 10 |
|
43 | | - # Validate every time a new value is assigned to an attribute, |
44 | | - # not only when an instance of BeamLine is created |
45 | | - model_config = ConfigDict(validate_assignment=True) |
46 | | - |
47 | 11 | kind: Literal["BeamLine"] = "BeamLine" |
48 | 12 |
|
49 | | - line: List[ |
50 | | - Annotated[ |
51 | | - Union[ |
52 | | - # Base classes (for testing compatibility) |
53 | | - BaseElement, |
54 | | - ThickElement, |
55 | | - # User-Facing element kinds |
56 | | - "BeamLine", |
57 | | - ACKicker, |
58 | | - BeamBeam, |
59 | | - BeginningEle, |
60 | | - Converter, |
61 | | - CrabCavity, |
62 | | - Drift, |
63 | | - EGun, |
64 | | - Feedback, |
65 | | - Fiducial, |
66 | | - FloorShift, |
67 | | - Foil, |
68 | | - Fork, |
69 | | - Girder, |
70 | | - Instrument, |
71 | | - Kicker, |
72 | | - Marker, |
73 | | - Mask, |
74 | | - Match, |
75 | | - Multipole, |
76 | | - NullEle, |
77 | | - Octupole, |
78 | | - Patch, |
79 | | - Quadrupole, |
80 | | - RBend, |
81 | | - RFCavity, |
82 | | - SBend, |
83 | | - Sextupole, |
84 | | - Solenoid, |
85 | | - Taylor, |
86 | | - UnionEle, |
87 | | - Wiggler, |
88 | | - ], |
89 | | - Field(discriminator="kind"), |
90 | | - ] |
91 | | - ] |
| 13 | + line: List[get_all_elements_as_annotation()] |
92 | 14 |
|
93 | 15 | @model_validator(mode="before") |
94 | 16 | @classmethod |
95 | | - def unpack_yaml_structure(cls, data): |
96 | | - # Handle the top-level one-key dict: unpack the line's name |
97 | | - if isinstance(data, dict) and len(data) == 1: |
98 | | - name, value = list(data.items())[0] |
99 | | - if not isinstance(value, dict): |
100 | | - raise TypeError( |
101 | | - f"Value for line key {name!r} must be a dict, but we got {value!r}" |
102 | | - ) |
103 | | - value["name"] = name |
104 | | - data = value |
105 | | - # Handle the 'line' field: unpack each element's name |
106 | | - if "line" not in data: |
107 | | - raise ValueError("'line' field is missing") |
108 | | - if not isinstance(data["line"], list): |
109 | | - raise TypeError("'line' must be a list") |
110 | | - new_line = [] |
111 | | - # Loop over all elements in the line |
112 | | - for item in data["line"]: |
113 | | - # An element can be a string that refers to another element |
114 | | - if isinstance(item, str): |
115 | | - raise RuntimeError("Reference/alias elements not yet implemented") |
116 | | - # An element can be a dict |
117 | | - elif isinstance(item, dict): |
118 | | - if not (len(item) == 1): |
119 | | - raise ValueError( |
120 | | - f"Each element must be a dict with exactly one key (the element's name), but we got {item!r}" |
121 | | - ) |
122 | | - name, fields = list(item.items())[0] |
123 | | - if not isinstance(fields, dict): |
124 | | - raise TypeError( |
125 | | - f"Value for element key {name!r} must be a dict (the element's properties), but we got {fields!r}" |
126 | | - ) |
127 | | - fields["name"] = name |
128 | | - new_line.append(fields) |
129 | | - # An element can be an instance of an existing model |
130 | | - elif isinstance(item, BaseElement): |
131 | | - # Nothing to do, keep the element as is |
132 | | - new_line.append(item) |
133 | | - else: |
134 | | - raise TypeError( |
135 | | - f"Value for element key {name!r} must be a reference string or a dict, but we got {item!r}" |
136 | | - ) |
137 | | - data["line"] = new_line |
138 | | - return data |
| 17 | + def unpack_json_structure(cls, data): |
| 18 | + """Deserialize the JSON/YAML/...-like dict for BeamLine elements""" |
| 19 | + from pals.kinds.mixin.all_element_mixin import unpack_element_list_structure |
139 | 20 |
|
140 | | - def model_dump(self, *args, **kwargs): |
141 | | - """This makes sure the element name property is moved out and up to a one-key dictionary""" |
142 | | - # Use base element dump first and return a dict {key: value}, where 'key' |
143 | | - # is the name of the line and 'value' is a dict with all other properties |
144 | | - data = super().model_dump(*args, **kwargs) |
145 | | - # Reformat 'line' field as list of element dicts |
146 | | - new_line = [] |
147 | | - for elem in self.line: |
148 | | - # Use custom dump for each line element, which now returns a dict |
149 | | - elem_dict = elem.model_dump(**kwargs) |
150 | | - new_line.append(elem_dict) |
151 | | - data[self.name]["line"] = new_line |
152 | | - return data |
| 21 | + return unpack_element_list_structure(data, "line", "line") |
153 | 22 |
|
| 23 | + def model_dump(self, *args, **kwargs): |
| 24 | + """Custom model dump for BeamLine to handle element list formatting""" |
| 25 | + from pals.kinds.mixin.all_element_mixin import dump_element_list |
154 | 26 |
|
155 | | -# Avoid circular import issues |
156 | | -BeamLine.model_rebuild() |
| 27 | + return dump_element_list(self, "line", *args, **kwargs) |
0 commit comments