Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/fodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from schema.DriftElement import DriftElement
from schema.QuadrupoleElement import QuadrupoleElement

from schema.Line import Line
from schema.BeamLine import BeamLine


def main():
Expand Down Expand Up @@ -42,7 +42,7 @@ def main():
length=0.5,
)
# Create line with all elements
line = Line(
line = BeamLine(
line=[
drift1,
quad1,
Expand All @@ -63,7 +63,7 @@ def main():
with open(yaml_file, "r") as file:
yaml_data = yaml.safe_load(file)
# Parse YAML data
loaded_line = Line(**yaml_data)
loaded_line = BeamLine(**yaml_data)
# Validate loaded data
assert line == loaded_line
# Serialize to JSON
Expand All @@ -78,7 +78,7 @@ def main():
with open(json_file, "r") as file:
json_data = json.loads(file.read())
# Parse JSON data
loaded_line = Line(**json_data)
loaded_line = BeamLine(**json_data)
# Validate loaded data
assert line == loaded_line

Expand Down
10 changes: 5 additions & 5 deletions schema/Line.py → schema/BeamLine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from schema.QuadrupoleElement import QuadrupoleElement


class Line(BaseModel):
class BeamLine(BaseModel):
"""A line of elements and/or other lines"""

# Validate every time a new value is assigned to an attribute,
# not only when an instance of Line is created
# not only when an instance of BeamLine is created
model_config = ConfigDict(validate_assignment=True)

kind: Literal["Line"] = "Line"
kind: Literal["BeamLine"] = "BeamLine"

line: List[
Annotated[
Expand All @@ -23,7 +23,7 @@ class Line(BaseModel):
ThickElement,
DriftElement,
QuadrupoleElement,
"Line",
"BeamLine",
],
Field(discriminator="kind"),
]
Expand Down Expand Up @@ -82,4 +82,4 @@ def model_dump(self, *args, **kwargs):


# Avoid circular import issues
Line.model_rebuild()
BeamLine.model_rebuild()
30 changes: 15 additions & 15 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from schema.DriftElement import DriftElement
from schema.QuadrupoleElement import QuadrupoleElement

from schema.Line import Line
from schema.BeamLine import BeamLine


def test_BaseElement():
Expand Down Expand Up @@ -96,23 +96,23 @@ def test_QuadrupoleElement():
assert element.MagneticMultipoleP.Bn2 == element_magnetic_multipole_Bn2
assert element.MagneticMultipoleP.Bs2 == element_magnetic_multipole_Bs2
assert element.MagneticMultipoleP.tilt2 == element_magnetic_multipole_tilt2
# Serialize the Line object to YAML
# Serialize the BeamLine object to YAML
yaml_data = yaml.dump(element.model_dump(), default_flow_style=False)
print(f"\n{yaml_data}")


def test_Line():
def test_BeamLine():
# Create first line with one base element
element1 = BaseElement(name="element1")
line1 = Line(line=[element1])
line1 = BeamLine(line=[element1])
assert line1.line == [element1]
# Extend first line with one thick element
element2 = ThickElement(name="element2", length=2.0)
line1.line.extend([element2])
assert line1.line == [element1, element2]
# Create second line with one drift element
element3 = DriftElement(name="element3", length=3.0)
line2 = Line(line=[element3])
line2 = BeamLine(line=[element3])
# Extend first line with second line
line1.line.extend(line2.line)
assert line1.line == [element1, element2, element3]
Expand All @@ -124,8 +124,8 @@ def test_yaml():
# Create one thick element
element2 = ThickElement(name="element2", length=2.0)
# Create line with both elements
line = Line(line=[element1, element2])
# Serialize the Line object to YAML
line = BeamLine(line=[element1, element2])
# Serialize the BeamLine object to YAML
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
print(f"\n{yaml_data}")
# Write the YAML data to a test file
Expand All @@ -135,11 +135,11 @@ def test_yaml():
# Read the YAML data from the test file
with open(test_file, "r") as file:
yaml_data = yaml.safe_load(file)
# Parse the YAML data back into a Line object
loaded_line = Line(**yaml_data)
# Parse the YAML data back into a BeamLine object
loaded_line = BeamLine(**yaml_data)
# Remove the test file
os.remove(test_file)
# Validate loaded Line object
# Validate loaded BeamLine object
assert line == loaded_line


Expand All @@ -149,8 +149,8 @@ def test_json():
# Create one thick element
element2 = ThickElement(name="element2", length=2.0)
# Create line with both elements
line = Line(line=[element1, element2])
# Serialize the Line object to JSON
line = BeamLine(line=[element1, element2])
# Serialize the BeamLine object to JSON
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
print(f"\n{json_data}")
# Write the JSON data to a test file
Expand All @@ -160,9 +160,9 @@ def test_json():
# Read the JSON data from the test file
with open(test_file, "r") as file:
json_data = json.loads(file.read())
# Parse the JSON data back into a Line object
loaded_line = Line(**json_data)
# Parse the JSON data back into a BeamLine object
loaded_line = BeamLine(**json_data)
# Remove the test file
os.remove(test_file)
# Validate loaded Line object
# Validate loaded BeamLine object
assert line == loaded_line