diff --git a/examples/fodo.py b/examples/fodo.py index 3ffdbd2..189dfab 100644 --- a/examples/fodo.py +++ b/examples/fodo.py @@ -9,7 +9,7 @@ from pals_schema.MagneticMultipoleParameters import MagneticMultipoleParameters from pals_schema.DriftElement import DriftElement from pals_schema.QuadrupoleElement import QuadrupoleElement -from pals_schema.Line import Line +from pals_schema.BeamLine import BeamLine def main(): @@ -40,7 +40,7 @@ def main(): length=0.5, ) # Create line with all elements - line = Line( + line = BeamLine( line=[ drift1, quad1, @@ -61,7 +61,7 @@ def main(): with open(yaml_file, "r") as file: yaml_data = yaml.safe_load(file) # Parse YAML data - loaded_line = Line(**yaml_data) + loaded_line = BeamLine(**yaml_data) # Validate loaded data assert line == loaded_line # Serialize to JSON @@ -76,7 +76,7 @@ def main(): with open(json_file, "r") as file: json_data = json.loads(file.read()) # Parse JSON data - loaded_line = Line(**json_data) + loaded_line = BeamLine(**json_data) # Validate loaded data assert line == loaded_line diff --git a/src/pals_schema/Line.py b/src/pals_schema/BeamLine.py similarity index 94% rename from src/pals_schema/Line.py rename to src/pals_schema/BeamLine.py index da07457..e32b125 100644 --- a/src/pals_schema/Line.py +++ b/src/pals_schema/BeamLine.py @@ -7,14 +7,14 @@ from pals_schema.QuadrupoleElement import QuadrupoleElement -class Line(BaseModel): +class BeamLine(BaseModel): """A line of elements and/or other lines""" # Validate every time a new value is assigned to an attribute, - # not only when an instance of Line is created + # not only when an instance of BeamLine is created model_config = ConfigDict(validate_assignment=True) - kind: Literal["Line"] = "Line" + kind: Literal["BeamLine"] = "BeamLine" line: List[ Annotated[ @@ -23,7 +23,7 @@ class Line(BaseModel): ThickElement, DriftElement, QuadrupoleElement, - "Line", + "BeamLine", ], Field(discriminator="kind"), ] @@ -82,4 +82,4 @@ def model_dump(self, *args, **kwargs): # Avoid circular import issues -Line.model_rebuild() +BeamLine.model_rebuild() diff --git a/tests/test_schema.py b/tests/test_schema.py index d7d0560..935078f 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -13,7 +13,7 @@ from pals_schema.ThickElement import ThickElement from pals_schema.DriftElement import DriftElement from pals_schema.QuadrupoleElement import QuadrupoleElement -from pals_schema.Line import Line +from pals_schema.BeamLine import BeamLine def test_BaseElement(): @@ -98,15 +98,15 @@ def test_QuadrupoleElement(): assert element.MagneticMultipoleP.Bn2 == element_magnetic_multipole_Bn2 assert element.MagneticMultipoleP.Bs2 == element_magnetic_multipole_Bs2 assert element.MagneticMultipoleP.tilt2 == element_magnetic_multipole_tilt2 - # Serialize the Line object to YAML + # Serialize the BeamLine object to YAML yaml_data = yaml.dump(element.model_dump(), default_flow_style=False) print(f"\n{yaml_data}") -def test_Line(): +def test_BeamLine(): # Create first line with one base element element1 = BaseElement(name="element1") - line1 = Line(line=[element1]) + line1 = BeamLine(line=[element1]) assert line1.line == [element1] # Extend first line with one thick element element2 = ThickElement(name="element2", length=2.0) @@ -114,7 +114,7 @@ def test_Line(): assert line1.line == [element1, element2] # Create second line with one drift element element3 = DriftElement(name="element3", length=3.0) - line2 = Line(line=[element3]) + line2 = BeamLine(line=[element3]) # Extend first line with second line line1.line.extend(line2.line) assert line1.line == [element1, element2, element3] @@ -126,8 +126,8 @@ def test_yaml(): # Create one thick element element2 = ThickElement(name="element2", length=2.0) # Create line with both elements - line = Line(line=[element1, element2]) - # Serialize the Line object to YAML + line = BeamLine(line=[element1, element2]) + # Serialize the BeamLine object to YAML yaml_data = yaml.dump(line.model_dump(), default_flow_style=False) print(f"\n{yaml_data}") # Write the YAML data to a test file @@ -137,11 +137,11 @@ def test_yaml(): # Read the YAML data from the test file with open(test_file, "r") as file: yaml_data = yaml.safe_load(file) - # Parse the YAML data back into a Line object - loaded_line = Line(**yaml_data) + # Parse the YAML data back into a BeamLine object + loaded_line = BeamLine(**yaml_data) # Remove the test file os.remove(test_file) - # Validate loaded Line object + # Validate loaded BeamLine object assert line == loaded_line @@ -151,8 +151,8 @@ def test_json(): # Create one thick element element2 = ThickElement(name="element2", length=2.0) # Create line with both elements - line = Line(line=[element1, element2]) - # Serialize the Line object to JSON + line = BeamLine(line=[element1, element2]) + # Serialize the BeamLine object to JSON json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2) print(f"\n{json_data}") # Write the JSON data to a test file @@ -162,9 +162,9 @@ def test_json(): # Read the JSON data from the test file with open(test_file, "r") as file: json_data = json.loads(file.read()) - # Parse the JSON data back into a Line object - loaded_line = Line(**json_data) + # Parse the JSON data back into a BeamLine object + loaded_line = BeamLine(**json_data) # Remove the test file os.remove(test_file) - # Validate loaded Line object + # Validate loaded BeamLine object assert line == loaded_line