Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/fodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pals_schema.MagneticMultipoleParameters import MagneticMultipoleParameters
from pals_schema.DriftElement import DriftElement
from pals_schema.QuadrupoleElement import QuadrupoleElement
from pals_schema.Line import Line
from pals_schema.BeamLine import BeamLine


def main():
Expand Down Expand Up @@ -40,7 +40,7 @@ def main():
length=0.5,
)
# Create line with all elements
line = Line(
line = BeamLine(
line=[
drift1,
quad1,
Expand All @@ -61,7 +61,7 @@ def main():
with open(yaml_file, "r") as file:
yaml_data = yaml.safe_load(file)
# Parse YAML data
loaded_line = Line(**yaml_data)
loaded_line = BeamLine(**yaml_data)
# Validate loaded data
assert line == loaded_line
# Serialize to JSON
Expand All @@ -76,7 +76,7 @@ def main():
with open(json_file, "r") as file:
json_data = json.loads(file.read())
# Parse JSON data
loaded_line = Line(**json_data)
loaded_line = BeamLine(**json_data)
# Validate loaded data
assert line == loaded_line

Expand Down
10 changes: 5 additions & 5 deletions src/pals_schema/Line.py → src/pals_schema/BeamLine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from pals_schema.QuadrupoleElement import QuadrupoleElement


class Line(BaseModel):
class BeamLine(BaseModel):
"""A line of elements and/or other lines"""

# Validate every time a new value is assigned to an attribute,
# not only when an instance of Line is created
# not only when an instance of BeamLine is created
model_config = ConfigDict(validate_assignment=True)

kind: Literal["Line"] = "Line"
kind: Literal["BeamLine"] = "BeamLine"

line: List[
Annotated[
Expand All @@ -23,7 +23,7 @@ class Line(BaseModel):
ThickElement,
DriftElement,
QuadrupoleElement,
"Line",
"BeamLine",
],
Field(discriminator="kind"),
]
Expand Down Expand Up @@ -82,4 +82,4 @@ def model_dump(self, *args, **kwargs):


# Avoid circular import issues
Line.model_rebuild()
BeamLine.model_rebuild()
30 changes: 15 additions & 15 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pals_schema.ThickElement import ThickElement
from pals_schema.DriftElement import DriftElement
from pals_schema.QuadrupoleElement import QuadrupoleElement
from pals_schema.Line import Line
from pals_schema.BeamLine import BeamLine


def test_BaseElement():
Expand Down Expand Up @@ -98,23 +98,23 @@ def test_QuadrupoleElement():
assert element.MagneticMultipoleP.Bn2 == element_magnetic_multipole_Bn2
assert element.MagneticMultipoleP.Bs2 == element_magnetic_multipole_Bs2
assert element.MagneticMultipoleP.tilt2 == element_magnetic_multipole_tilt2
# Serialize the Line object to YAML
# Serialize the BeamLine object to YAML
yaml_data = yaml.dump(element.model_dump(), default_flow_style=False)
print(f"\n{yaml_data}")


def test_Line():
def test_BeamLine():
# Create first line with one base element
element1 = BaseElement(name="element1")
line1 = Line(line=[element1])
line1 = BeamLine(line=[element1])
assert line1.line == [element1]
# Extend first line with one thick element
element2 = ThickElement(name="element2", length=2.0)
line1.line.extend([element2])
assert line1.line == [element1, element2]
# Create second line with one drift element
element3 = DriftElement(name="element3", length=3.0)
line2 = Line(line=[element3])
line2 = BeamLine(line=[element3])
# Extend first line with second line
line1.line.extend(line2.line)
assert line1.line == [element1, element2, element3]
Expand All @@ -126,8 +126,8 @@ def test_yaml():
# Create one thick element
element2 = ThickElement(name="element2", length=2.0)
# Create line with both elements
line = Line(line=[element1, element2])
# Serialize the Line object to YAML
line = BeamLine(line=[element1, element2])
# Serialize the BeamLine object to YAML
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
print(f"\n{yaml_data}")
# Write the YAML data to a test file
Expand All @@ -137,11 +137,11 @@ def test_yaml():
# Read the YAML data from the test file
with open(test_file, "r") as file:
yaml_data = yaml.safe_load(file)
# Parse the YAML data back into a Line object
loaded_line = Line(**yaml_data)
# Parse the YAML data back into a BeamLine object
loaded_line = BeamLine(**yaml_data)
# Remove the test file
os.remove(test_file)
# Validate loaded Line object
# Validate loaded BeamLine object
assert line == loaded_line


Expand All @@ -151,8 +151,8 @@ def test_json():
# Create one thick element
element2 = ThickElement(name="element2", length=2.0)
# Create line with both elements
line = Line(line=[element1, element2])
# Serialize the Line object to JSON
line = BeamLine(line=[element1, element2])
# Serialize the BeamLine object to JSON
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
print(f"\n{json_data}")
# Write the JSON data to a test file
Expand All @@ -162,9 +162,9 @@ def test_json():
# Read the JSON data from the test file
with open(test_file, "r") as file:
json_data = json.loads(file.read())
# Parse the JSON data back into a Line object
loaded_line = Line(**json_data)
# Parse the JSON data back into a BeamLine object
loaded_line = BeamLine(**json_data)
# Remove the test file
os.remove(test_file)
# Validate loaded Line object
# Validate loaded BeamLine object
assert line == loaded_line