"""
Simframe Objects Module
Various objects and functions to handle simulation lattices, commands, and elements.
Classes:
- :class:`~SimulationFramework.Framework_objects.runSetup`: Defines simulation run settings, allowing
for single runs, element scans or jitter/error studies.
- :class:`~SimulationFramework.Framework_objects.frameworkObject`: Base class for generic objects in SimFrame,
including lattice elements and simulation code commands.
- :class:`~SimulationFramework.Framework_objects.frameworkElement`: Base class for generic
lattice elements in SimFrame, including lattice elements and simulation code commands.
- :class:`~SimulationFramework.Framework_objects.csrdrift`: Drift element including CSR effects.
- :class:`~SimulationFramework.Framework_objects.lscdrift`: Drift element including LSC effects.
- :class:`~SimulationFramework.Framework_objects.edrift`: Basic drift element.
- :class:`~SimulationFramework.Framework_objects.frameworkLattice`: Base class for simulation lattices,
consisting of a line of :class:`~SimulationFramework.Framework_objects.frameworkObject` s.
- :class:`~SimulationFramework.Framework_objects.frameworkCounter`: Used for counting elements of the same
type in ASTRA and CSRTrack
- :class:`~SimulationFramework.Framework_objects.frameworkGroup`: Used for grouping together
:class:`~SimulationFramework.Framework_objects.frameworkObject` s and controlling them all simultaneously.
- :class:`~SimulationFramework.Framework_objects.element_group`: Subclass of
:class:`~SimulationFramework.Framework_objects.frameworkGroup` for grouping elements.
# TODO is this ever used?
- :class:`~SimulationFramework.Framework_objects.r56_group`: Subclass of
:class:`~SimulationFramework.Framework_objects.frameworkGroup` for grouping elements with an R56.
# TODO is this ever used?
- :class:`~SimulationFramework.Framework_objects.chicane`: Subclass of\
:class:`~SimulationFramework.Framework_objects.frameworkGroup` for a 4-dipole bunch compressor chicane.
- :class:`~SimulationFramework.Framework_objects.getGrids`: Used for determining the appropriate number
of space charge grids given a number of particles.
"""
import os
import subprocess
from warnings import warn
from copy import deepcopy
import yaml
from .Modules.merge_two_dicts import merge_two_dicts
from .Modules.MathParser import MathParser
from .Framework_Settings import FrameworkSettings
from .FrameworkHelperFunctions import chunks, expand_substitution, checkValue, chop, dot
from .FrameworkHelperFunctions import _rotation_matrix
from .Modules.Fields import field
from .Codes import Executables as exes
try:
import numpy as np
except ImportError:
np = None
from pydantic import (
BaseModel,
field_validator,
PositiveInt,
SerializeAsAny,
computed_field,
ConfigDict,
Field,
)
from typing import (
Dict,
List,
Any,
Tuple,
)
if os.name == "nt":
# from .Modules.symmlinks import has_symlink_privilege
def has_symlink_privilege():
return False
else:
[docs]
def has_symlink_privilege():
return True
with open(
os.path.dirname(os.path.abspath(__file__)) + "/Codes/type_conversion_rules.yaml",
"r",
) as infile:
type_conversion_rules = yaml.safe_load(infile)
type_conversion_rules_Elegant = type_conversion_rules["elegant"]
type_conversion_rules_Names = type_conversion_rules["name"]
with open(
os.path.dirname(os.path.abspath(__file__)) + "/Codes/Elegant/commands_Elegant.yaml",
"r",
) as infile:
commandkeywords = yaml.safe_load(infile)
with open(
os.path.dirname(os.path.abspath(__file__)) + "/Elements/elementkeywords.yaml", "r"
) as infile:
elementkeywords = yaml.safe_load(infile)
with open(
os.path.dirname(os.path.abspath(__file__))
+ "/Codes/Elegant/keyword_conversion_rules_elegant.yaml",
"r",
) as infile:
keyword_conversion_rules_elegant = yaml.safe_load(infile)
with open(
os.path.dirname(os.path.abspath(__file__)) + "/Codes/Elegant/elements_Elegant.yaml",
"r",
) as infile:
elements_Elegant = yaml.safe_load(infile)
with open(
os.path.dirname(os.path.abspath(__file__))
+ "/Codes/Ocelot/keyword_conversion_rules_ocelot.yaml",
"r",
) as infile:
keyword_conversion_rules_ocelot = yaml.safe_load(infile)
with open(
os.path.dirname(os.path.abspath(__file__)) + "/Codes/Ocelot/elements_Ocelot.yaml",
"r",
) as infile:
elements_Ocelot = yaml.safe_load(infile)
[docs]
class runSetup(object):
"""
Class defining settings for simulations that include multiple runs
such as error studies or parameter scans.
"""
def __init__(self):
# define the number of runs and the random number seed
self.nruns = 1
self.seed = 0
# init errorElement and elementScan settings as None
self.elementErrors = None
self.elementScan = None
[docs]
def setNRuns(self, nruns: int | float) -> None:
"""
Sets the number of simulation runs to a new value.
Parameters
-----------
nruns : int or float
The number of runs to set. If a float is passed, it will be converted to an integer.
Raises
------
TypeError
If `nruns` is not an integer or float.
"""
# enforce integer argument type
if isinstance(nruns, (int, float)):
self.nruns = int(nruns)
else:
raise TypeError(
"Argument nruns passed to runSetup instance must be an integer"
)
[docs]
def setSeedValue(self, seed: int | float) -> None:
"""
Sets the random number seed to a new value for all lattice objects
Parameters
-----------
seed : int or float
The random number seed to set. If a float is passed, it will be converted to an integer.
Raises
------
TypeError
If `seed` is not an integer or float.
"""
# enforce integer argument type
if isinstance(seed, (int, float)):
self.seed = int(seed)
else:
raise TypeError("Argument seed passed to runSetup must be an integer")
[docs]
def loadElementErrors(self, file: str | dict) -> None:
"""
Load error definitions from a file or dictionary and assign them to the elementErrors attribute.
This method can handle both a YAML file and a dictionary containing error definitions.
Parameters
-----------
file: str or dict
- str: Path to a YAML file containing error definitions.
- dict: A dictionary containing error definitions.
"""
# load error definitions from markup file
error_setup = None
if isinstance(file, str) and (".yaml" in file):
with open(file, "r") as inputfile:
error_setup = dict(yaml.safe_load(inputfile))
# define errors from dictionary
elif isinstance(file, dict):
error_setup = file
else:
warn("error_setup must be a str or dict")
if error_setup is not None and "elements" in list(error_setup.keys()):
# assign the element error definitions
self.elementErrors = error_setup["elements"]
self.elementScan = None
# set the number of runs and random number seed, if available
if "nruns" in error_setup:
self.setNRuns(error_setup["nruns"])
if "seed" in error_setup:
self.setSeedValue(error_setup["seed"])
[docs]
def setElementScan(
self,
name: str,
item: str,
scanrange: list | tuple | np.ndarray,
multiplicative: bool = False,
) -> None:
"""
Define a parameter scan for a single parameter of a given machine element
Parameters
-----------
name : str
Name of the machine element to be scanned.
item : str
Name of the item (parameter) to be scanned within the machine element.
scanrange : list or tuple or np.ndarray
A list or tuple containing two floats, representing the minimum and maximum values of the scan range.
multiplicative : bool, optional
If True, the scan will be multiplicative; otherwise, it will be additive. Default is False.
"""
if not (isinstance(name, str) and isinstance(item, str)):
raise TypeError(
"Machine element name and item (parameter) must be defined as strings"
)
if (
isinstance(scanrange, (list, tuple, np.ndarray))
and (len(scanrange) == 2)
and all([isinstance(x, (float, int)) for x in scanrange])
):
minval, maxval = scanrange
else:
raise TypeError("Scan range (min. and max.) must be defined as floats")
if not isinstance(multiplicative, bool):
raise ValueError(
"Argument multiplicative passed to runSetup.setElementScan must be a boolean"
)
# if no type errors were raised, build an assign a dictionary
self.elementScan = {
"name": name,
"item": item,
"min": minval,
"max": maxval,
"multiplicative": multiplicative,
}
self.elementErrors = None
[docs]
class frameworkObject(BaseModel):
"""
Class defining a framework object, which is the base class for all elements
in a simulation lattice. It provides methods to add properties, validate parameters,
and handle various simulation-specific functionalities.
"""
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
validate_assignment=True,
populate_by_name=True,
)
objectname: str = Field(alias="name")
"""Name of the object, used as a unique identifier in the simulation."""
objecttype: str = Field(alias="type")
"""Type of the object, which determines its behavior and properties in the simulation."""
objectdefaults: Dict = {}
"""Default values for the object's properties, used when no specific value is provided."""
allowedkeywords: List | Dict = {}
"""List of allowed keywords for the object, which defines what properties can be set."""
global_parameters: Dict = {}
"""Global parameters to be cascaded through all objects."""
def model_post_init(self, __context):
extra_fields = {
k: v for k, v in self.model_dump().items()
if k not in self.__annotations__
}
for k, v in extra_fields.items():
setattr(self, k, v)
if self.objecttype in commandkeywords:
self.allowedkeywords = commandkeywords[self.objecttype]
elif self.objecttype in elementkeywords:
self.allowedkeywords = merge_two_dicts(
elementkeywords[self.objecttype]["keywords"],
elementkeywords["common"]["keywords"],
)
if "framework_keywords" in elementkeywords[self.objecttype]:
self.allowedkeywords = merge_two_dicts(
self.allowedkeywords,
elementkeywords[self.objecttype]["framework_keywords"],
)
else:
raise NameError(f"Unknown type = {self.objecttype}")
self.allowedkeywords = [x.lower() for x in self.allowedkeywords]
# for key, value in list(kwargs.items()):
# self.add_property(key, value)
@field_validator("objectname", mode="before")
@classmethod
def validate_objectname(cls, value: str) -> str:
"""Validate the objectname to ensure it is a string."""
if not isinstance(value, str):
raise ValueError("objectname must be a string.")
return value
@field_validator("objecttype", mode="before")
@classmethod
def validate_objecttype(cls, value: str) -> str:
"""Validate the objecttype to ensure it is a string."""
if not isinstance(value, str):
raise ValueError("objecttype must be a string.")
return value
# def __setattr__(self, name, value):
# # Let Pydantic set known fields normally
# if name in frameworkObject.model_fields:
# return super().__setattr__(name, value)
# object.__setattr__(self, name, value)
[docs]
def change_Parameter(self, key: str, value: Any) -> None:
"""
Change a parameter of the object by setting an attribute.
Parameters
----------
key: str
The name of the parameter to change.
value: Any
The new value to set for the parameter.
"""
setattr(self, key, value)
[docs]
def add_property(self, key: str, value: Any) -> None:
"""
Add a property to the object by setting an attribute if the key is allowed.
Parameters
----------
key: str
The name of the property to add.
value: Any
The value to set for the property.
"""
key = key.lower()
if key in self.allowedkeywords:
try:
setattr(self, key, value)
except Exception as e:
warn(f"add_property error: ({self.objecttype} [{key}]: {e}")
[docs]
def add_properties(self, **keyvalues: dict) -> None:
"""
Add multiple properties to the object by setting attributes for each key-value pair.
Parameters
----------
**keyvalues: dict
A dictionary of key-value pairs where keys are property names
and values are the corresponding values to set.
"""
for key, value in keyvalues.items():
key = key.lower()
if key in self.allowedkeywords:
try:
setattr(self, key, value)
except Exception as e:
warn(f"add_properties error: ({self.objecttype} [{key}]: {e}")
[docs]
def add_default(self, key: str, value: Any) -> None:
"""
Add a default value for a property of the object, updating `objectdefaults`.
Parameters
----------
key: str
The name of the property to set a default value for.
value: Any
The name of the property to set a default value for and the value to set.
"""
self.objectdefaults[key] = value
@property
def parameters(self) -> list:
"""
Returns a list of all parameters (keys) of the object.
Returns
-------
list
A list of keys representing the parameters of the object.
"""
return list(self.keys())
@property
def objectproperties(self):
"""
Returns a dictionary of the object's properties, excluding disallowed keywords.
Returns
-------
frameworkObject
The object itself, allowing for method chaining.
"""
cls = self.__class__
return {key: getattr(self, key) for key in cls.model_fields}
# def __getitem__(self, key):
# lkey = key.lower()
# defaults = self.objectdefaults
# if lkey in defaults:
# try:
# return getattr(self, lkey)
# except Exception:
# return defaults[lkey]
# else:
# try:
# return getattr(self, lkey)
# except Exception:
# try:
# return getattr(self, key)
# except Exception:
# return None
def __repr__(self):
string = ""
for k in self.model_fields_set:
if k in self.allowedkeywords:
string += f"{k} = {getattr(self, k)}" + "\n"
return string
[docs]
class frameworkElement(frameworkObject):
"""
Class defining a framework element, which is a specific type of framework object
that represents a physical component in a simulation lattice. It extends the frameworkObject
class with additional properties and methods specific to elements, such as position, rotation,
and field definitions.
"""
length: float = 0.0
"""Length of the element in the simulation, typically in meters."""
centre: Tuple[float, float, float] = (0.0, 0.0, 0.0)
"""Centre of the element in the simulation [x,y,z]."""
position_errors: Tuple[float, float, float] = (0.0, 0.0, 0.0)
"""Position errors of the element in the simulation [x,y,z]."""
rotation_errors: Tuple[float, float, float] = (0.0, 0.0, 0.0)
"""Rotation errors of the element in the simulation [x,y,z]."""
global_rotation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
"""Global rotation of the element in the simulation [x,y,z]."""
rotation: SerializeAsAny[Tuple[float, float, float] | float] = (0.0, 0.0, 0.0)
"""Local rotation of the element in the simulation [x,y,z]."""
starting_rotation: SerializeAsAny[float | Tuple[float, float, float]] = (0.0, 0.0, 0.0)
"""Initial rotation of the element, used for specific simulation setups."""
conversion_rules_elegant: Dict = {}
"""Conversion rules for keywords when exporting to Elegant format."""
conversion_rules_ocelot: Dict = {}
"""Conversion rules for keywords when exporting to Ocelot format."""
starting_offset: Tuple[float, float, float] = (0.0, 0.0, 0.0)
"""Initial offset of the element, used for positioning in the simulation."""
subelement: bool = False
"""Flag indicating whether the element is a sub-element of a larger structure."""
field_definition: SerializeAsAny[field | str | None] = None
"""Field definition for the element, can be a field object or a string representing a file."""
wakefield_definition: SerializeAsAny[field | str | None] = None
"""Wakefield definition for the element, can be a field object or a string representing a file."""
PV: SerializeAsAny[str | None] = None
"""EPICS PV root for the element"""
def model_post_init(self, __context):
self.conversion_rules_elegant = keyword_conversion_rules_elegant["general"]
self.conversion_rules_ocelot = keyword_conversion_rules_ocelot["general"]
if self.objecttype in keyword_conversion_rules_elegant:
self.conversion_rules_elegant = merge_two_dicts(
keyword_conversion_rules_elegant[self.objecttype],
keyword_conversion_rules_elegant["general"],
)
if self.objecttype in keyword_conversion_rules_ocelot:
self.conversion_rules_ocelot = merge_two_dicts(
keyword_conversion_rules_ocelot[self.objecttype],
keyword_conversion_rules_ocelot["general"],
)
super().model_post_init(__context)
# def __setattr__(self, name, value):
# # Let Pydantic set known fields normally
# if name in frameworkElement.model_fields:
# return super().__setattr__(name, value)
# object.__setattr__(self, name, value)
def __mul__(self, other):
return [self.objectproperties for x in range(other)]
def __rmul__(self, other):
return [self.objectproperties for x in range(other)]
def __neg__(self):
return self
def __repr__(self):
disallowed = [
"allowedkeywords",
"conversion_rules_elegant",
"conversion_rules_ocelot",
"objectdefaults",
"global_parameters",
"objectname",
"subelement",
]
return repr(
{k: getattr(self, k) for k in self.model_fields_set if k not in disallowed}
)
@property
def propertiesDict(self) -> dict:
"""
Returns a dictionary of the object's properties, excluding disallowed keywords.
Returns
-------
dict
A dictionary containing the object's properties.
"""
disallowed = [
"allowedkeywords",
"conversion_rules_elegant",
"conversion_rules_ocelot",
"objectdefaults",
"global_parameters",
"subelement",
]
return {
k: getattr(self, k) for k in self.model_fields_set if k not in disallowed
}
@property
def k1(self) -> None:
return None
@property
def k2(self) -> None:
return None
@property
def k3(self) -> None:
return None
@property
def x(self) -> float:
"""
Returns the x-coordinate of the element's centre.
Returns
-------
float
The x-coordinate of the element's centre.
"""
return list(self.centre)[0]
@x.setter
def x(self, x: float) -> None:
"""
Sets the x-coordinate of the element's centre.
Parameters
----------
x: float
The x-coordinate of the element's centre.
"""
centre = list(self.centre)
self.centre = (x, centre[1], centre[2])
@property
def y(self) -> float:
"""
Returns the y-coordinate of the element's centre.
Returns
-------
float
The y-coordinate of the element's centre.
"""
return list(self.centre)[1]
@y.setter
def y(self, y) -> None:
"""
Sets the y-coordinate of the element's centre.
Parameters
----------
y: float
The y-coordinate of the element's centre.
"""
centre = list(self.centre)
self.centre = (centre[0], y, centre[2])
@property
def z(self):
"""
Returns the z-coordinate of the element's centre.
Returns
-------
float
The z-coordinate of the element's centre.
"""
return list(self.centre)[2]
@z.setter
def z(self, z) -> None:
"""
Sets the z-coordinate of the element's centre.
Parameters
----------
z: float
The z-coordinate of the element's centre.
"""
centre = list(self.centre)
self.centre = (centre[0], centre[1], z)
@property
def dx(self) -> float:
"""
Returns the x-offset of the element.
Returns
-------
float
The x-offset of the element.
"""
return self.position_errors[0]
@dx.setter
def dx(self, x: float) -> None:
"""
Sets the x-offset of the element.
Parameters
----------
x: float
The x-coordinate of the element.
"""
poserr = list(self.position_errors)
self.position_errors = (x, poserr[1], poserr[2])
@property
def dy(self):
"""
Returns the y-offset of the element.
Returns
-------
float
The y-offset of the element.
"""
return self.position_errors[1]
@dy.setter
def dy(self, y: float) -> None:
"""
Sets the y-offset of the element.
Parameters
----------
y: float
The y-offset of the element.
"""
poserr = list(self.position_errors)
self.position_errors = (poserr[0], y, poserr[2])
@property
def dz(self):
"""
Returns the z-offset of the element.
Returns
-------
float
The z-offset of the element.
"""
return self.position_errors[2]
@dz.setter
def dz(self, z: float) -> None:
"""
Sets the z-offset of the element.
Parameters
----------
z: float
The z-offset of the element.
"""
poserr = list(self.position_errors)
self.position_errors = (poserr[0], poserr[1], z)
@property
def x_rot(self) -> float:
"""
Returns the global x-rotation of the element.
Returns
-------
float
The global x-rotation of the element.
"""
return self.global_rotation[0]
@property
def y_rot(self) -> float:
"""
Returns the global y-rotation of the element.
Returns
-------
float
The global y-rotation of the element.
"""
return self.global_rotation[1]
@property
def z_rot(self) -> float:
"""
Returns the global z-rotation of the element.
Returns
-------
float
The global z-rotation of the element.
"""
return self.global_rotation[2]
@x_rot.setter
def x_rot(self, x: float) -> None:
"""
Sets the global x-rotation of the element.
Parameters
----------
x: float
The global x-rotation of the element.
"""
roterr = list(self.global_rotation)
self.global_rotation = (x, roterr[1], roterr[2])
@y_rot.setter
def y_rot(self, y: float) -> None:
"""
Sets the global y-rotation of the element.
Parameters
----------
y: float
The global y-rotation of the element.
"""
roterr = list(self.global_rotation)
self.global_rotation = (roterr[0], y, roterr[2])
@z_rot.setter
def z_rot(self, z: float) -> None:
"""
Sets the global z-rotation of the element.
Parameters
----------
z: float
The global z-rotation of the element.
"""
roterr = list(self.global_rotation)
self.global_rotation = (roterr[0], roterr[1], z)
@property
def dx_rot(self) -> float:
"""
Returns the local x-rotation of the element.
Returns
-------
float
The local x-rotation of the element.
"""
return self.rotation_errors[0]
@dx_rot.setter
def dx_rot(self, x: float) -> None:
"""
Sets the x-rotation error of the element.
Parameters
----------
x: float
The x-rotation error of the element.
"""
roterr = list(self.rotation_errors)
self.rotation_errors = (x, roterr[1], roterr[2])
@property
def dy_rot(self):
"""
Returns the local y-rotation of the element.
Returns
-------
float
The local y-rotation of the element.
"""
return self.rotation_errors[1]
@dy_rot.setter
def dy_rot(self, y: float) -> None:
"""
Sets the y-rotation error of the element.
Parameters
----------
y: float
The y-rotation error of the element.
"""
roterr = list(self.rotation_errors)
self.rotation_errors = (roterr[0], y, roterr[2])
@property
def dz_rot(self):
"""
Returns the local z-rotation of the element.
Returns
-------
float
The local z-rotation of the element.
"""
return self.rotation_errors[2]
@dz_rot.setter
def dz_rot(self, z: float) -> None:
"""
Sets the z-rotation error of the element.
Parameters
----------
z: float
The z-rotation error of the element.
"""
roterr = list(self.rotation_errors)
self.rotation_errors = (roterr[0], roterr[1], z)
@property
def tilt(self):
"""
Returns the local z-rotation of the element.
Returns
-------
float
The local z-rotation of the element.
"""
return self.dz_rot
@property
def get_field_amplitude(self) -> float | None:
"""
Returns the field amplitude of the element, scaled by `field_scale` if it exists.
Returns
-------
float or None
The field amplitude of the element, which is either scaled by `field_scale`
or directly taken from `field_amplitude`.
Returns None if `field_amplitude` is not defined
"""
if hasattr(self, "field_amplitude"):
if hasattr(self, "field_scale") and isinstance(self.field_scale, (int, float)):
return float(self.field_scale) * float(
expand_substitution(self, self.field_amplitude)
)
else:
return float(expand_substitution(self, self.field_amplitude))
return None
[docs]
def get_field_reference_position(self) -> list:
"""
Returns the position of the field reference point based on the `field_reference_position` attribute.
Returns
-------
list
The position of the field reference point, which can be 'start', 'middle', or 'end'.
If `field_reference_position` is not set, it defaults to the start position.
Raises
------
ValueError
If `field_reference_position` is set to an invalid value that is not 'start', 'middle', or 'end'.
"""
if (
hasattr(self, "field_reference_position")
and self.field_reference_position is not None
):
if self.field_reference_position.lower() == "start":
return self.start
elif self.field_reference_position.lower() == "middle":
return self.middle
elif self.field_reference_position.lower() == "end":
return self.end
else:
raise ValueError(
"field_reference_position should be (start/middle/end) not",
self.field_reference_position,
)
else:
return self.start
@property
def theta(self) -> float:
"""
Returns the global rotation angle of the element in radians.
Returns
-------
float
The global rotation angle of the element, which is derived from `global_rotation`.
If `global_rotation` is not set, it defaults to 0 radians.
"""
if hasattr(self, "global_rotation") and self.global_rotation is not None:
rotation = (
self.global_rotation[0]
if len(self.global_rotation) == 3
else self.global_rotation
)
else:
rotation = 0
rotation += self.starting_rotation[0]
return rotation
@property
def rotation_matrix(self) -> np.ndarray:
"""
Returns the rotation matrix for the element based on its global rotation angle :attr:`theta`.
Returns
-------
np.ndarray
The rotation matrix corresponding to the global rotation angle of the element.
"""
return _rotation_matrix(self.theta)
[docs]
def rotated_position(
self, pos: tuple = (0, 0, 0), offset: list | tuple = None, theta: float = None
) -> int | float | complex | list:
"""
Returns the position of the element after applying a rotation and an offset.
Parameters
----------
pos: tuple, optional
A tuple representing the position to be rotated. Default is (0, 0, 0).
offset: list, optional
A list representing the offset to be applied to the position. If not provided,
it defaults to the element's starting offset or [0, 0, 0] if not set.
theta: float, optional
The rotation angle in radians to be applied to the position. If not provided,
it defaults to the element's global rotation angle.
Returns
-------
int | float | complex | list
The rotated position of the element, adjusted for the specified offset and rotation angle.
If `offset` is not provided, it uses the element's `starting_offset` or defaults to [0, 0, 0].
"""
if offset is None:
if not hasattr(self, "starting_offset") or self.starting_offset is None:
offset = [0, 0, 0]
else:
offset = self.starting_offset
if theta is None:
return chop(
np.dot(np.array(pos) - np.array(offset), self.rotation_matrix), 1e-6
)
else:
return chop(
np.dot(np.array(pos) - np.array(offset), _rotation_matrix(theta)), 1e-6
)
@property
def start(self) -> list:
"""
Returns the starting position of the element, which is calculated based on its center and length,
see :attr:`position_start`.
Returns
-------
list
The starting position of the element, which is the position at the beginning of the element's length.
"""
return self.position_start
@property
def position_start(self) -> list:
"""
Returns the starting position of the element, which is calculated based on its center and length.
Returns
-------
list
The starting position of the element, which is the position at the beginning of the element's length.
"""
middle = np.array(self.centre)
start = middle - self.rotated_position(
(0, 0, self.length / 2.0),
offset=self.starting_offset,
theta=self.x_rot,
)
return list(start)
@property
def middle(self) -> list | tuple:
"""
Returns the middle position of the element, which is the center of the element's length.
Returns
-------
list
The middle position of the element, which is the center point of the element's length.
"""
return self.centre
@property
def end(self) -> list:
"""
Returns the end position of the element, which is calculated based on its starting position and length.
Returns
-------
list
The end position of the element, which is the position at the end of the element's length.
"""
return self.position_end
@property
def position_end(self) -> list | np.ndarray:
"""
Returns the end position of the element, which is calculated based on its starting position and length.
Returns
-------
list
The end position of the element, which is the position at the end of the element's length.
"""
start = np.array(self.position_start)
end = start + self.rotated_position(
(0, 0, self.length),
offset=self.starting_offset,
theta=self.x_rot,
)
return end
[docs]
def relative_position_from_centre(self, vec: tuple = (0, 0, 0)) -> list:
"""
Returns the position relative to the centre of the element,
taking into account the element's rotation and offset.
Parameters
----------
vec: tuple, optional
A tuple representing the vector to be added to the centre position.
Returns
-------
list
The position relative to the centre of the element, adjusted for rotation and offset.
"""
middle = np.array(self.centre)
return list(
middle
+ self.rotated_position(
vec,
offset=self.starting_offset,
theta=self.x_rot,
)
)
[docs]
def relative_position_from_start(self, vec: tuple = (0, 0, 0)) -> list:
"""
Returns the position relative to the start of the element,
Parameters
----------
vec: tuple, optional
A tuple representing the vector to be added to the start position.
Returns
-------
list
The position relative to the start of the element, adjusted for rotation and offset.
"""
start = np.array(self.position_start)
return list(
start
+ self.rotated_position(
vec,
offset=self.starting_offset,
theta=self.x_rot,
)
)
[docs]
def update_field_definition(self) -> None:
"""
Updates the field definitions to allow for the relative sub-directory location
"""
if (
hasattr(self, "field_definition")
and self.field_definition is not None
and isinstance(self.field_definition, str)
):
field_kwargs = {
"filename": expand_substitution(self, self.field_definition),
"field_type": self.field_type,
}
if self.objecttype == "cavity":
field_kwargs.update(
{
"frequency": self.frequency,
"cavity_type": self.Structure_Type,
"n_cells": self.n_cells,
}
)
self.field_definition = field(**field_kwargs)
if (
hasattr(self, "wakefield_definition")
and self.wakefield_definition is not None
and isinstance(self.wakefield_definition, str)
):
self.wakefield_definition = field(
filename=expand_substitution(self, self.wakefield_definition),
field_type=self.field_type,
frequency=self.frequency,
cavity_type=self.Structure_Type,
n_cells=self.n_cells,
)
def _write_ASTRA_dictionary(self, d: dict, n: int | None = 1) -> str:
"""
Generates a string representation of the object's properties in the ASTRA format.
Parameters
----------
d: dict
A dictionary containing the properties of the object to be formatted.
n: int, optional
An optional integer to specify the index for ASTRA objects. Default is 1.
Returns
-------
str
A formatted string representing the object's properties in ASTRA format.
"""
output = ""
for k, v in list(d.items()):
if checkValue(self, v) is not None:
if "type" in v and v["type"] == "list":
for i, l in enumerate(checkValue(self, v)):
if n is not None:
param_string = (
k
+ "("
+ str(i + 1)
+ ","
+ str(n)
+ ") = "
+ str(l)
+ ", "
)
else:
param_string = k + " = " + str(l) + "\n"
if len((output + param_string).splitlines()[-1]) > 70:
output += "\n"
output += param_string
elif "type" in v and v["type"] == "array":
if n is not None:
param_string = k + "(" + str(n) + ") = ("
else:
param_string = k + " = ("
for i, l in enumerate(checkValue(self, v)):
param_string += str(l) + ", "
if len((output + param_string).splitlines()[-1]) > 70:
output += "\n"
output += param_string[:-2] + "),\n"
elif "type" in v and v["type"] == "not_zero":
if abs(checkValue(self, v)) > 0:
if n is not None:
param_string = (
k
+ "("
+ str(n)
+ ") = "
+ str(checkValue(self, v))
+ ", "
)
else:
param_string = k + " = " + str(checkValue(self, v)) + ",\n"
if len((output + param_string).splitlines()[-1]) > 70:
output += "\n"
output += param_string
else:
if n is not None:
param_string = (
k + "(" + str(n) + ") = " + str(checkValue(self, v)) + ", "
)
else:
param_string = k + " = " + str(checkValue(self, v)) + ",\n"
if len((output + param_string).splitlines()[-1]) > 70:
output += "\n"
output += param_string
return output[:-2]
def write_ASTRA(self, n, **kwargs) -> str:
"""
Generates a string representation of the object's properties in the ASTRA format.
Parameters
----------
n: int
An integer representing the index for ASTRA objects,
typically used for multiple instances of the same element.
**kwargs: dict
Additional keyword arguments that can be used to pass extra parameters.
Returns
-------
str
A formatted string representing the object's properties in ASTRA format.
"""
return self._write_ASTRA(n, **kwargs)
[docs]
def generate_field_file_name(self, param: field, code: str) -> str | None:
"""
Generates a field file name based on the provided frameworkElement and tracking code.
Parameters
----------
param: field
The :class:`~SimulationFramework.Modules.Fields.field` object for which the field file is being generated.
code: str
The tracking code for which the field file is being generated (e.g., 'elegant', 'ocelot').
Returns
-------
str | None
The name of the field file if it exists, otherwise None.
"""
if hasattr(param, "filename"):
basename = (
os.path.basename(param.filename).replace('"', "").replace("'", "")
)
efield_basename = os.path.abspath(
self.global_parameters["master_subdir"].replace("\\", "/")
+ "/"
+ basename.replace("\\", "/")
)
return os.path.basename(
param.write_field_file(code=code, location=efield_basename)
)
else:
warn(
f"param does not have a filename: {param}, it must be a `field` object"
)
return None
def _write_Elegant(self) -> str:
"""
Generates a string representation of the object's properties in the Elegant format.
Returns
-------
str
A formatted string representing the object's properties in Elegant format.
"""
wholestring = ""
etype = self._convertType_Elegant(self.objecttype)
string = self.objectname + ": " + etype
# setattr(self, "k1", self.k1 if self.k1 is not None else 0)
# setattr(self, "k2", self.k2 if self.k2 is not None else 0)
# setattr(self, "k3", self.k3 if self.k3 is not None else 0)
cls = self.__class__
for key in {**cls.model_fields, **cls.model_computed_fields}:
if (
not key == "name"
and not key == "type"
and not key == "commandtype"
and self._convertKeyword_Elegant(key) in elements_Elegant[etype]
):
value = (
getattr(self, key)
if hasattr(self, key) and getattr(self, key) is not None
else None
)
if value is not None:
key = self._convertKeyword_Elegant(key)
value = 1 if value is True else value
value = 0 if value is False else value
tmpstring = ", " + key + " = " + str(value)
if len(string + tmpstring) > 76:
wholestring += string + ",&\n"
string = ""
string += tmpstring[2::]
else:
string += tmpstring
wholestring += string + ";\n"
return wholestring
[docs]
def write_Elegant(self) -> str:
"""
Generates a string representation of the object's properties in the Elegant format,
see :func:`_write_Elegant`.
Returns
-------
str
A formatted string representing the object's properties in Elegant format.
"""
if not self.subelement:
return self._write_Elegant()
def _convertType_Elegant(self, etype: str) -> str:
"""
Converts the element type to the corresponding Elegant type using predefined rules.
Parameters
----------
etype: str
The type of the element to be converted.
Returns
-------
str
The converted type of the element, or the original type if no conversion rule exists.
"""
return (
type_conversion_rules_Elegant[etype]
if etype in type_conversion_rules_Elegant
else etype
)
def _convertKeyword_Elegant(self, keyword: str) -> str:
"""
Converts a keyword to its corresponding Elegant keyword using predefined rules.
Parameters
----------
keyword: str:
The keyword to be converted.
Returns
-------
str
The converted keyword for Elegant, or the original keyword if no conversion rule exists.
"""
return (
self.conversion_rules_elegant[keyword]
if keyword in self.conversion_rules_elegant
else keyword
)
def _write_Ocelot(self) -> object:
"""
Generates an Ocelot object based on the element's properties and type.
Returns
-------
object
An Ocelot object representing the element, initialized with its properties.
"""
from ocelot.cpbd.elements import Marker, Aperture
from .Codes.Ocelot import ocelot_conversion
type_conversion_rules_Ocelot = ocelot_conversion.ocelot_conversion_rules
obj = type_conversion_rules_Ocelot[self.objecttype](eid=self.objectname)
# setattr(self, "k1", self.k1 if self.k1 is not None else 0)
# setattr(self, "k2", self.k2 if self.k2 is not None else 0)
# setattr(self, "k3", self.k3 if self.k3 is not None else 0)
for key, value in self.objectproperties.items():
if (key not in ["name", "type", "commandtype"]) and (
not type(obj) in [Aperture, Marker]
):
value = (
getattr(self, key)
if hasattr(self, key) and getattr(self, key) is not None
else value
)
setattr(obj, self._convertKeyword_Ocelot(key), value)
return obj
[docs]
def write_Ocelot(self) -> object:
"""
Generates an Ocelot object based on the element's properties and type,
see :func:`_write_Ocelot`.
Returns
-------
object
An Ocelot object representing the element, initialized with its properties.
"""
if not self.subelement:
return self._write_Ocelot()
def _convertType_Ocelot(self, etype: str) -> object:
"""
Converts the element type to the corresponding Ocelot type using predefined rules.
Parameters
----------
etype: str
The type of the element to be converted.
Returns
-------
object
The Ocelot element, or the original type if no conversion rule exists.
"""
from .Codes.Ocelot import ocelot_conversion
type_conversion_rules_Ocelot = ocelot_conversion.ocelot_conversion_rules
return (
type_conversion_rules_Ocelot[etype]
if etype in type_conversion_rules_Ocelot
else etype
)
def _convertKeyword_Ocelot(self, keyword: str) -> str:
"""
Converts a keyword to its corresponding Ocelot keyword using predefined rules.
Parameters
----------
keyword: str
The keyword to be converted.
Returns
-------
str
The converted keyword for Ocelot, or the original keyword if no conversion rule exists.
"""
return (
self.conversion_rules_ocelot[keyword]
if keyword in self.conversion_rules_ocelot
else keyword
)
def _write_ASTRA(self, n: int = 0, **kwargs: dict) -> str:
pass
[docs]
def write_ASTRA(self, n: int = 0, **kwargs: dict) -> str:
return self._write_ASTRA(n, **kwargs)
def _write_CSRTrack(self, n: int = 0, **kwargs: dict) -> str:
pass
[docs]
def write_CSRTrack(self, n: int = 0, **kwargs: dict) -> str:
return self._write_CSRTrack(n, **kwargs)
[docs]
def write_GPT(self, Brho: float, ccs: str = "wcs", *args, **kwargs) -> str:
return self._write_GPT(Brho, ccs, *args, **kwargs)
def _write_GPT(self, Brho: float, ccs: str = "wcs", *args, **kwargs) -> str:
return ""
[docs]
def gpt_coordinates(self, position: list, rotation: float) -> str:
"""
Get the GPT coordinates for a given position and rotation
Parameters
----------
position: list
The lattice position.
rotation: float
The element rotation
Returns
-------
str
A GPT-formatted position string.
"""
x, y, z = chop(position, 1e-6)
psi, phi, theta = rotation
output = ""
for c in [-x, y, z]:
output += str(c) + ", "
output += "cos(" + str(theta) + "), 0, -sin(" + str(theta) + "), 0, 1 ,0"
return output
[docs]
def gpt_ccs(self, ccs: str) -> str:
"""
Get the GPT coordinate system for the element.
Parameters
----------
ccs: str
The GPT coordinate system.
Returns
-------
str
The GPT coordinate system
"""
return ccs
[docs]
def array_names_string(self) -> str:
"""
Get the array names for a given element (i.e. the parameters in the field file)
Returns
-------
str
A formatted string containing the array names for the element.
"""
array_names = (
self.default_array_names if self.array_names is None else self.array_names
)
return ", ".join(['"' + name + '"' for name in array_names])
[docs]
class csrdrift(frameworkElement):
"""
Class defining a drift including CSR effects.
"""
lsc_interpolate: int = 1
"""Flag to allow for interpolation of computed longitudinal space charge wake.
See `Elegant manual LSC drift`_"""
csr_enable: bool = True
"""Enable CSR drift calculations"""
lsc_enable: bool = True
"""Enable LSC drift calculations"""
use_stupakov: int = 1
"""Use Stupakov formula; see `Elegant manual LSC drift`_"""
csrdz: float = 0.01
"""Step size for CSR calculations"""
lsc_bins: int = 20
"""Number of bins for LSC calculations"""
lsc_high_frequency_cutoff_start: float = -1
"""Spatial frequency at which smoothing filter begins. If not positive, no frequency filter smoothing is done.
See `Elegant manual LSC drift`_
"""
lsc_high_frequency_cutoff_end: float = -1
"""Spatial frequency at which smoothing filter is 0. See `Elegant manual LSC drift`_"""
lsc_low_frequency_cutoff_start: float = -1
"""Highest spatial frequency at which low-frequency cutoff filter is zero. See `Elegant manual LSC drift`_"""
lsc_low_frequency_cutoff_end: float = -1
"""Lowest spatial frequency at which low-frequency cutoff filter is 1. See `Elegant manual LSC drift`_"""
def _write_Elegant(self) -> str:
"""
Writes the csrdrift element string for ELEGANT.
Returns
-------
str
String representation of the element for ELEGANT
"""
wholestring = ""
etype = self._convertType_Elegant(self.objecttype)
string = self.objectname + ": " + etype
for key, value in self.objectproperties.items():
if (
not key == "name"
and not key == "type"
and not key == "commandtype"
and self._convertKeyword_Elegant(key) in elements_Elegant[etype]
):
value = (
getattr(self, key)
if hasattr(self, key) and getattr(self, key) is not None
else value
)
key = self._convertKeyword_Elegant(key)
value = 1 if value is True else value
value = 0 if value is False else value
tmpstring = ", " + key + " = " + str(value)
if len(string + tmpstring) > 76:
wholestring += string + ",&\n"
string = ""
string += tmpstring[2::]
else:
string += tmpstring
wholestring += string + ";\n"
return wholestring
[docs]
class lscdrift(csrdrift):
"""
Class defining a drift including LSC effects.
"""
[docs]
class edrift(csrdrift):
"""
Class defining a drift.
"""
[docs]
class frameworkLattice(BaseModel):
"""
Class defining a framework lattice object, which contains all elements and groups
of elements in a simulation lattice. It also contains methods to manipulate and
retrieve information about the elements and groups, as well as methods to run
simulations and process results.
See :ref:`creating-the-lattice-elements`
"""
model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
validate_assignment=True,
)
name: str
"""Name of the lattice, used as a prefix for output files and commands."""
objectname: str | None = ""
"""Name of the lattice, used as a prefix for output files and commands."""
objecttype: str | None = ""
"""Type of the lattice, used as a prefix for output files and commands."""
file_block: Dict
"""File block containing input and output settings for the lattice."""
elementObjects: Dict
"""Dictionary of element objects, where keys are element names and values are element instances."""
groupObjects: Dict
"""Dictionary of group objects, where keys are group names and values are group instances."""
runSettings: runSetup
"""Run settings for the lattice, including number of runs and random seed."""
settings: FrameworkSettings
"""Instance of :class:`~SimulationFramework.Framework_Settings.FrameworkSettings`"""
executables: exes.Executables
"""Executable commands for running simulations, defined in the Executables class.
See :class:`~SimulationFramework.Framework.Codes.Executables.Executables` for more details."""
global_parameters: Dict
"""Global parameters for the lattice, including master subdirectory and other configuration settings."""
allow_negative_drifts: bool = False
"""If True, allows negative drifts in the lattice."""
_csr_enable: bool = True
"""Flag to enable CSR drifts in the lattice."""
csrDrifts: bool = True
"""Flag to enable CSR drifts in the lattice."""
lscDrifts: bool = True
"""Flag to enable LSC drifts in the lattice."""
lsc_bins: int = 20
"""Number of bins for LSC drifts."""
lsc_high_frequency_cutoff_start: float = -1
"""Spatial frequency at which smoothing filter begins. If not positive, no frequency filter smoothing is done.
See `Elegant manual LSC drift`_
.. _Elegant manual LSC drift: https://ops.aps.anl.gov/manuals/elegant_latest/elegantsu168.html#x179-18000010.58"""
lsc_high_frequency_cutoff_end: float = -1
"""Spatial frequency at which smoothing filter is 0. See `Elegant manual LSC drift`_"""
lsc_low_frequency_cutoff_start: float = -1
"""Highest spatial frequency at which low-frequency cutoff filter is zero. See `Elegant manual LSC drift`_"""
lsc_low_frequency_cutoff_end: float = -1
"""Lowest spatial frequency at which low-frequency cutoff filter is 1. See `Elegant manual LSC drift`_"""
sample_interval: int = 1
"""Sample interval for downsampling particles, in units of 2**(3*sample_interval)"""
globalSettings: Dict = {"charge": None}
"""Global settings for the lattice, including charge and other parameters."""
groupSettings: Dict = {}
"""Group settings for the lattice, including group-specific parameters."""
allElements: List = []
"""List of all element names in the lattice."""
initial_twiss: Dict = {}
"""Initial Twiss parameters for the lattice, used for tracking and analysis."""
def model_post_init(self, __context):
# super().model_post_init(__context)
for key, value in list(self.elementObjects.items()):
setattr(self, key, value)
self.allElements = list(self.elementObjects.keys())
self.objectname = self.name
# define settings for simulations with multiple runs
self.updateRunSettings(self.runSettings)
if not isinstance(self.file_block, dict):
raise ValueError("file_block must be a dictionary.")
if "groups" in self.file_block:
if self.file_block["groups"] is not None:
self.groupSettings = self.file_block["groups"]
if "input" in self.file_block:
if "sample_interval" in self.file_block["input"]:
self.sample_interval = self.file_block["input"]["sample_interval"]
self.globalSettings = self.settings["global"]
# @field_validator("file_block", mode="before")
# @classmethod
# def validate_file_block(cls, value: Dict) -> Dict:
# """
# Validate the file_block dictionary to ensure it has the required structure.
# This method checks if the file_block is a dictionary and contains the necessary keys.
#
# Raises
# ------
# ValueError
# If the file_block is not a dictionary or does not contain the required keys.
# """
# if not isinstance(value, dict):
# raise ValueError("file_block must be a dictionary.")
# if "groups" in value:
# if value["groups"] is not None:
# cls.groupSettings = value["groups"]
# if "input" in value:
# if "sample_interval" in value["input"]:
# cls.sample_interval = value["input"]["sample_interval"]
# return value
#
# @field_validator("settings", mode="before")
# @classmethod
# def validate_settings(cls, value: Dict) -> Dict:
# """
# Validate the settings dictionary to ensure it has the required structure.
# This method checks if the settings is a dictionary and contains the necessary keys.
#
# Raises
# ------
# ValueError
# If the settings is not a dictionary or does not contain the required keys.
#
# """
# if not isinstance(value, dict):
# raise ValueError("settings must be a dictionary.")
# if "global" in value:
# if value["global"] is not None:
# cls.globalSettings = value["global"]
# return value
def __setattr__(self, name, value):
# Let Pydantic set known fields normally
if name in frameworkLattice.model_fields:
return super().__setattr__(name, value)
object.__setattr__(self, name, value)
[docs]
def insert_element(self, index: int, element) -> None:
"""
Insert an element at a specific index in the elements dictionary.
Parameters
----------
index: int
The index at which to insert the element.
element: :class:`SimulationFramework.Framework_objects.frameworkElement`
The element to insert into the elements dictionary.
"""
for i, _ in enumerate(range(len(self.elements))):
k, v = self.elements.popitem(False)
self.elements[element.objectname if i == index else k] = element
@property
def csr_enable(self) -> bool:
"""
Property to get or set the CSR enable flag.
"""
return self._csr_enable
@csr_enable.setter
def csr_enable(self, csr) -> None:
self.csrDrifts = csr
self._csr_enable = csr
[docs]
def get_prefix(self) -> str:
"""
Get the prefix from the input file block.
Returns
-------
str
The prefix string used in the input file block.
"""
if "input" not in self.file_block:
self.file_block["input"] = {}
if "prefix" not in self.file_block["input"]:
self.file_block["input"]["prefix"] = ""
return self.file_block["input"]["prefix"]
[docs]
def set_prefix(self, prefix: str) -> None:
"""
Set the prefix for the input file block.
Parameters
----------
prefix: str
The prefix string used in the input file block.
"""
if not hasattr(self, "file_block") or self.file_block is None:
self.file_block = {}
if "input" not in self.file_block or self.file_block["input"] is None:
self.file_block["input"] = {}
self.file_block["input"]["prefix"] = prefix
@computed_field
@property
def prefix(self) -> str:
return self.get_prefix()
@prefix.setter
def prefix(self, prefix: str) -> None:
self.set_prefix(prefix)
[docs]
def update_groups(self) -> None:
"""
Update the group objects in the lattice with their settings.
"""
for g in list(self.groupSettings.keys()):
if g in self.groupObjects:
setattr(self, g, self.groupObjects[g])
if self.groupSettings[g] is not None:
self.groupObjects[g].update(**self.groupSettings[g])
[docs]
def getElement(self, element: str, param: str = None) -> dict | frameworkElement:
"""
Get an element or group object by its name and optionally a specific parameter.
This method checks if the element exists in the allElements dictionary or in the groupObjects dictionary.
If the element exists, it returns the element object or the specified parameter of the element.
Parameters
----------
element: str
param: str, optional
The parameter to retrieve from the element object. If None, returns the entire element object.
Returns
-------
dict | :class:`~SimulationFramework.Framework_objects.frameworkElement`
The element object or the specified parameter of the element.
"""
if element in self.elements:
if param is not None:
return getattr(self.elementObjects[element], param.lower())
else:
return self.elementObjects[element]
elif element in list(self.groupObjects.keys()):
if param is not None:
return getattr(self.groupObjects[element], param.lower())
else:
return self.groupObjects[element]
else:
warn(f"WARNING: Element {element} does not exist")
return {}
[docs]
def getElementType(
self,
typ: list | tuple | str,
param: list | tuple | str = None,
) -> list | tuple | zip:
"""
Get all elements of a specific type or types from the lattice.
Parameters
----------
typ: list, tuple, or str
The type or types of elements to retrieve.
If a list or tuple is provided, it retrieves elements of all specified types.
param: list, tuple, or str, optional
The specific parameter to retrieve from each element.
Returns
-------
list | tuple | zip
A list or tuple of elements of the specified type(s), or a zip object if multiple parameters are specified.
If `param` is provided, it returns the specified parameter for each element.
"""
if isinstance(typ, (list, tuple)):
return [self.getElementType(t, param=param) for t in typ]
if isinstance(param, (list, tuple)):
return zip(*[self.getElementType(typ, param=p) for p in param])
return [
self.elements[element] if param is None else self.elements[element][param]
for element in list(self.elements.keys())
if self.elements[element].objecttype.lower() == typ.lower()
]
[docs]
def setElementType(
self, typ: list | tuple | str, setting: str, values: list | tuple | Any
) -> None:
"""
Set a specific setting for all elements of a specific type or types in the lattice.
Parameters
----------
typ: list, tuple, or str
The type or types of elements to set the setting for.
setting: str
The setting to be updated for the elements. This can be a single setting or a list of settings.
values: list, tuple, or Any
The values to set for the specified setting.
Raises
------
ValueError
If the number of elements of the specified type does not match the number of values provided.
"""
elems = self.getElementType(typ)
if len(elems) == len(values):
for e, v in zip(elems, values):
e[setting] = v
else:
raise ValueError
@property
def quadrupoles(self) -> list:
"""
Property to get all quadrupole elements in the lattice.
Returns
-------
list
A list of quadrupole elements in the lattice.
"""
return self.getElementType("quadrupole")
@property
def cavities(self) -> list:
"""
Property to get all cavity elements in the lattice.
Returns
-------
list
A list of cavity elements in the lattice.
"""
return self.getElementType("cavity")
@property
def solenoids(self) -> list:
"""
Property to get all solenoid elements in the lattice.
Returns
-------
list
A list of solenoid elements in the lattice.
"""
return self.getElementType("solenoid")
@property
def dipoles(self) -> list:
"""
Property to get all dipole elements in the lattice.
Returns
-------
list
A list of dipole elements in the lattice.
"""
return self.getElementType("dipole")
@property
def kickers(self) -> list:
"""
Property to get all kicker elements in the lattice.
Returns
-------
list
A list of kicker elements in the lattice.
"""
return self.getElementType("kicker")
@property
def dipoles_and_kickers(self) -> list:
"""
Property to get all dipole and kicker elements in the lattice.
Returns
-------
list
A list of dipole and kicker elements in the lattice.
"""
return sorted(
self.getElementType("dipole") + self.getElementType("kicker"),
key=lambda x: x.position_end[2],
)
@property
def wakefields(self) -> list:
"""
Property to get all wakefield elements in the lattice.
Returns
-------
list
A list of wakefield elements in the lattice.
"""
return self.getElementType("wakefield")
@property
def wakefields_and_cavity_wakefields(self) -> list:
"""
Property to get all wakefield and cavity wakefield elements in the lattice.
Returns
-------
list
A list of wakefield and cavity wakefield elements in the lattice.
"""
cavities = [
cav
for cav in self.getElementType("cavity")
if (
isinstance(cav.longitudinal_wakefield, field)
or cav.longitudinal_wakefield != ""
)
or (
isinstance(cav.transverse_wakefield, field)
or cav.transverse_wakefield != ""
)
or (
isinstance(cav.wakefield_definition, field)
or cav.wakefield_definition != ""
)
]
wakes = self.getElementType("wakefield")
return cavities + wakes
@property
def screens(self) -> list:
"""
Property to get all screen elements in the lattice.
Returns
-------
list
A list of screen elements in the lattice.
"""
return self.getElementType("screen")
@property
def screens_and_bpms(self) -> list:
"""
Property to get all screen and BPM elements in the lattice.
Returns
-------
list
A list of screen and BPM elements in the lattice.
"""
return sorted(
self.getElementType("screen")
+ self.getElementType("beam_position_monitor"),
key=lambda x: x.position_start[2],
)
@property
def screens_and_markers_and_bpms(self) -> list:
"""
Property to get all screen and BPM and marker elements in the lattice.
Returns
-------
list
A list of screen and BPM and marker elements in the lattice.
"""
return sorted(
self.getElementType("screen")
+ self.getElementType("marker")
+ self.getElementType("beam_position_monitor"),
key=lambda x: x.position_start[2],
)
@property
def apertures(self) -> list:
"""
Property to get all aperture and collimator elements in the lattice.
Returns
-------
list
A list of aperture and collimator elements in the lattice.
"""
return sorted(
self.getElementType("aperture") + self.getElementType("collimator"),
key=lambda x: x.position_start[2],
)
@property
def lines(self) -> list:
"""
Property to get all lines in the lattice.
Returns
-------
list
A list of lines in the lattice.
"""
return list(self.lineObjects.keys())
@property
def start(self) -> frameworkElement:
"""
Property to get the starting element of the lattice.
This method checks if the file block contains a "start_element" key or a "zstart" key.
If "start_element" is present, it returns the corresponding element.
If "zstart" is present, it iterates through the elementObjects to find the element
with the matching start position. If no match is found, it returns the first element in the elementObjects.
Returns
-------
frameworkElement
The starting element of the lattice.
"""
if "start_element" in self.file_block["output"]:
return self.file_block["output"]["start_element"]
elif "zstart" in self.file_block["output"]:
for e in list(self.elementObjects.keys()):
if (
np.isclose(self.elementObjects[e].position_start[2],
self.file_block["output"]["zstart"], atol=1e-2)
):
return e
return self.elementObjects[list(self.elementObjects.keys())[0]]
else:
return self.elementObjects[list(self.elementObjects.keys())[0]]
@property
def startObject(self) -> frameworkElement:
"""
Property to get the starting element of the lattice.
See :func:`start` for more details.
Returns
-------
frameworkElement
The starting element of the lattice.
"""
return self.elementObjects[self.start]
@property
def end(self) -> frameworkElement:
"""
Property to get the ending element of the lattice.
This method checks if the file block contains an "end_element" key or a "zstop" key.
If "end_element" is present, it returns the corresponding element.
If "zstop" is present, it iterates through the elementObjects to find the element
with the matching end position. If no match is found, it returns the last element in the elementObjects.
Returns
-------
frameworkElement
The final element of the lattice.
"""
if "end_element" in self.file_block["output"]:
return self.file_block["output"]["end_element"]
elif "zstop" in self.file_block["output"]:
endelems = []
for e in list(self.elementObjects.keys()):
if (
np.isclose(self.elementObjects[e].position_end[2],
self.file_block["output"]["zstop"], atol=1e-2)
):
endelems.append(e)
elif (
self.elementObjects[e].position_end[2]
> self.file_block["output"]["zstop"]
and len(endelems) == 0
):
endelems.append(e)
return endelems[-1]
else:
return self.elementObjects[self.elements[0]]
@property
def endObject(self) -> frameworkElement:
"""
Property to get the final element of the lattice.
See :func:`end` for more details.
Returns
-------
frameworkElement
The final element of the lattice.
"""
return self.elementObjects[self.end]
@property
def elements(self) -> dict:
"""
Property to get a dictionary of elements in the lattice.
Returns
-------
dict
A dictionary where keys are element names and values are the corresponding element objects.
"""
index_start = self.allElements.index(self.start)
index_end = self.allElements.index(self.end)
f = dict(
[
[e, self.elementObjects[e]]
for e in self.allElements[index_start : index_end + 1]
]
)
return f
[docs]
def run(self) -> None:
"""
Run the code with input 'filename'
This method constructs the command to run the simulation using the specified executable
and the name of the lattice. It redirects the output to a log file in the master subdirectory.
Raises
------
FileNotFoundError
If the executable for the specified code is not found in the executables dictionary.
"""
command = self.executables[self.code] + [self.name]
with open(
os.path.relpath(
self.global_parameters["master_subdir"] + "/" + self.name + ".log",
".",
),
"w",
) as f:
subprocess.call(
command, stdout=f, cwd=self.global_parameters["master_subdir"]
)
[docs]
def getInitialTwiss(self) -> dict:
"""
Get the initial Twiss parameters from the file block
This method checks if the file block contains an "input" key with a "twiss" subkey.
If the "twiss" subkey exists and contains values, it retrieves the alpha, beta, and normalized emittance
parameters for both horizontal and vertical planes.
Returns
-------
dict
A dictionary containing the initial Twiss parameters for horizontal and vertical planes.
If the parameters are not found, it returns False for each parameter.
"""
if (
"input" in self.file_block
and "twiss" in self.file_block["input"]
and self.file_block["input"]["twiss"]
):
alpha_x = (
self.file_block["input"]["twiss"]["alpha_x"]
if "alpha_x" in self.file_block["input"]["twiss"]
else False
)
alpha_y = (
self.file_block["input"]["twiss"]["alpha_y"]
if "alpha_y" in self.file_block["input"]["twiss"]
else False
)
beta_x = (
self.file_block["input"]["twiss"]["beta_x"]
if "beta_x" in self.file_block["input"]["twiss"]
else False
)
beta_y = (
self.file_block["input"]["twiss"]["beta_y"]
if "beta_y" in self.file_block["input"]["twiss"]
else False
)
nemit_x = (
self.file_block["input"]["twiss"]["nemit_x"]
if "nemit_x" in self.file_block["input"]["twiss"]
else False
)
nemit_y = (
self.file_block["input"]["twiss"]["nemit_y"]
if "nemit_y" in self.file_block["input"]["twiss"]
else False
)
return {
"horizontal": {
"alpha": alpha_x,
"beta": beta_x,
"nEmit": nemit_x,
},
"vertical": {
"alpha": alpha_y,
"beta": beta_y,
"nEmit": nemit_y,
},
}
else:
return {
"horizontal": {
"alpha": False,
"beta": False,
"nEmit": False,
},
"vertical": {
"alpha": False,
"beta": False,
"nEmit": False,
},
}
[docs]
def preProcess(self) -> None:
"""
Pre-process the lattice before running the simulation.
This method initializes the initial Twiss parameters by calling the `getInitialTwiss` method.
Returns
-------
None
"""
self.initial_twiss = self.getInitialTwiss()
[docs]
def postProcess(self):
pass
def __repr__(self):
return self.elements
def __str__(self):
str = self.name + " = ("
for e in self.elements:
if len((str + e).splitlines()[-1]) > 60:
str += "&\n"
str += e + ", "
return str + ")"
[docs]
def createDrifts(
self, drift_elements: tuple = ("screen", "beam_position_monitor")
) -> dict:
"""
Insert drifts into a sequence of 'elements'.
This method creates drifts for elements that are not subelements and have a length greater than zero.
It calculates the start and end positions of each element and creates drift elements accordingly.
Parameters
----------
drift_elements: tuple, optional
A tuple of element types for which drifts should be created.
Default is ("screen", "beam_position_monitor").
Returns
-------
dict
A dictionary containing the new drift elements created for the lattice.
The keys are the names of the new drift elements, and the values are the corresponding drift objects.
"""
positions = []
originalelements = dict()
elementno = 0
newelements = dict()
for name in list(self.elements.keys()):
if not self.elements[name].subelement:
originalelements[name] = self.elements[name]
pos = np.array(self.elementObjects[name].position_start)
# If element is a cavity, we need to offset the cavity by the coupling cell length
# to make it consistent with ASTRA
if originalelements[name].objecttype == "cavity" and hasattr(
originalelements[name], "coupling_cell_length"
):
pos += originalelements[name].coupling_cell_length
# print('Adding coupling_cell_length of ', originalelements[name].coupling_cell_length,'to the start position')
positions.append(pos)
positions.append(self.elementObjects[name].position_end)
positions = positions[1:]
positions.append(positions[-1])
driftdata = list(
zip(iter(list(originalelements.items())), list(chunks(positions, 2)))
)
lscbins = self.lsc_bins if self.lscDrifts is True else 0
csr = 1 if self.csrDrifts is True else 0
lsc = 1 if self.lscDrifts is True else 0
drifttype = lscdrift if self.csrDrifts or self.lscDrifts else edrift
if drifttype == lscdrift:
objtype = "lscdrift"
else:
objtype = "edrift"
for e, d in driftdata:
if (e[1].objecttype in drift_elements) and round(e[1].length / 2, 6) > 0:
name = e[0] + "-drift-01"
driftparams = {
"length": round(e[1].length / 2, 6),
"csr_enable": csr,
"lsc_enable": lsc,
"use_stupakov": 1,
"csrdz": 0.01,
"lsc_bins": lscbins,
"lsc_high_frequency_cutoff_start": self.lsc_high_frequency_cutoff_start,
"lsc_high_frequency_cutoff_end": self.lsc_high_frequency_cutoff_end,
"lsc_low_frequency_cutoff_start": self.lsc_low_frequency_cutoff_start,
"lsc_low_frequency_cutoff_end": self.lsc_low_frequency_cutoff_end,
}
newdrift = drifttype(
objectname=name,
objecttype=objtype,
global_parameters=self.global_parameters,
**driftparams,
)
newelements[name] = newdrift
new_bpm_screen = deepcopy(e[1])
new_bpm_screen.length = 0
newelements[e[0]] = new_bpm_screen
name = e[0] + "-drift-02"
newdrift = drifttype(
objectname=name,
objecttype=objtype,
global_parameters=self.global_parameters,
**driftparams,
)
newelements[name] = newdrift
else:
# print('NOT Drift Element', e[1]["objecttype"], round(e[1]["length"] / 2, 6))
newelements[e[0]] = e[1]
if e[1].objecttype == "dipole":
drifttype = (
csrdrift
if self.csrDrifts
else lscdrift if self.lscDrifts else edrift
)
if len(d) > 1:
x1, y1, z1 = d[0]
x2, y2, z2 = d[1]
try:
length = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2 + (z2 - z1) ** 2)
vector = dot((d[1] - d[0]), [0, 0, -1])
except Exception as exc:
warn(f"Element with error = {e[0]}")
warn(d)
raise exc
if self.allow_negative_drifts or (
round(length, 6) > 0 and vector < 1e-6
):
elementno += 1
name = self.objectname + "_DRIFT_" + str(elementno).zfill(2)
middle = [(a + b) / 2.0 for a, b in zip(d[0], d[1])]
newdrift = drifttype(
objectname=name,
objecttype=objtype,
global_parameters=self.global_parameters,
**{
"length": round(length, 6),
# "position_start": list(d[0]),
# "position_end": list(d[1]),
"centre": middle,
"csr_enable": csr,
"lsc_enable": lsc,
"use_stupakov": 1,
"csrdz": 0.01,
"lsc_bins": lscbins,
"lsc_high_frequency_cutoff_start": self.lsc_high_frequency_cutoff_start,
"lsc_high_frequency_cutoff_end": self.lsc_high_frequency_cutoff_end,
"lsc_low_frequency_cutoff_start": self.lsc_low_frequency_cutoff_start,
"lsc_low_frequency_cutoff_end": self.lsc_low_frequency_cutoff_end,
},
)
newelements[name] = newdrift
elif length < 0 or vector > 1e-6:
raise Exception(
"Lattice has negative drifts!",
self.allow_negative_drifts,
e[0],
e[1],
length,
)
return newelements
[docs]
def getSValues(
self, as_dict: bool = False, at_entrance: bool = False
) -> list | dict:
"""
Get the S values for the elements in the lattice.
This method calculates the cumulative length of the elements in the lattice,
starting from the entrance or the first element, depending on the `at_entrance` parameter.
It returns a list or dict of S values, which represent the positions of the elements along the lattice.
Parameters
----------
as_dict: bool, optional
If True, returns a dictionary with element names as keys and their S values as values.
at_entrance: bool, optional
If True, calculates S values starting from the entrance of the lattice.
If False, calculates S values starting from the first element.
Returns
-------
list | dict
A list or dictionary of S values for the elements in the lattice.
If `as_dict` is True, returns a dictionary with element names as keys and their S values as values.
If `as_dict` is False, returns a list of S values.
"""
elems = self.createDrifts()
s = [0]
for e in list(elems.values()):
s.append(s[-1] + e.length)
s = s[:-1] if at_entrance else s[1:]
if as_dict:
return dict(zip([e.objectname for e in elems.values()], s))
return list(s)
[docs]
def getZValues(self, drifts: bool = True, as_dict: bool = False) -> list | dict:
"""
Get the Z values for the elements in the lattice.
This method calculates the cumulative length of the elements in the lattice,
starting from the entrance or the first element, depending on the `at_entrance` parameter.
It returns a list or dict of S values, which represent the positions of the elements along the lattice.
Parameters
----------
drifts: bool, optional
If True, includes drift elements in the calculation.
If False, only considers the main elements in the lattice.
as_dict: bool, optional
If True, returns a dictionary with element names as keys and their Z values as values.
Returns
-------
list | dict
A list or dictionary of Z values for the elements in the lattice.
If `as_dict` is True, returns a dictionary with element names as keys and their Z values as values.
If `as_dict` is False, returns a list of Z values.
"""
if drifts:
elems = self.createDrifts()
else:
elems = self.elements
if as_dict:
return {e.objectname: [e.start[2], e.end[2]] for e in elems.values()}
return [[e.start[2], e.end[2]] for e in elems.values()]
[docs]
def getNames(self, drifts: bool = True) -> list:
"""
Get the names of the elements in the lattice.
Parameters
----------
drifts: bool, optional
If True, includes drift elements in the list of names.
Returns
-------
list
A list of names of the elements in the lattice.
If `drifts` is True, includes drift elements; otherwise, only includes main elements.
"""
if drifts:
elems = self.createDrifts()
else:
elems = self.elements
return [e.objectname for e in list(elems.values())]
[docs]
def getElems(self, drifts: bool = True, as_dict: bool = False) -> list | dict:
"""
Get the elements in the lattice.
Parameters
----------
drifts: bool, optional
If True, includes drift elements in the list of elements.
as_dict: bool, optional
If True, returns a dictionary with element names as keys and their corresponding element objects as values.
Returns
-------
list | dict
A list or dictionary of elements in the lattice.
"""
if drifts:
elems = self.createDrifts()
else:
elems = self.elements
if as_dict:
return {e.objectname: e for e in list(elems.values())}
return [e for e in list(elems.values())]
[docs]
def getSNames(self) -> list:
"""
Get the names and S values of the elements in the lattice.
Returns
-------
list
A list of tuples, where each tuple contains the name of an element and its corresponding S value.
"""
s = self.getSValues()
names = self.getNames()
return list(zip(names, s))
[docs]
def getSNamesElems(self) -> tuple:
"""
Get the names, elements, and S values of the elements in the lattice.
Returns
-------
tuple
A tuple containing three elements:
- A list of names of the elements.
- A list of element objects.
- A list of S values corresponding to the elements.
"""
s = self.getSValues()
names = self.getNames()
elems = self.getElems()
return names, elems, s
[docs]
def getZNamesElems(self) -> tuple:
"""
Get the names, elements, and Z values of the elements in the lattice.
Returns
-------
tuple
A tuple containing three elements:
- A list of names of the elements.
- A list of element objects.
- A list of Z values corresponding to the elements.
"""
z = self.getZValues()
names = self.getNames()
elems = self.getElems()
return names, elems, z
[docs]
def findS(self, elem) -> list:
"""
Find the S values for a specific element in the lattice.
Parameters
----------
elem: str
The name of the element to find in the lattice.
Returns
-------
list
A list of tuples, where each tuple contains the name of the element and its corresponding S value.
If the element does not exist in the lattice, returns an empty list.
"""
if elem in self.allElements:
sNames = self.getSNames()
return [a for a in sNames if a[0] == elem]
return []
[docs]
def updateRunSettings(self, runSettings: runSetup) -> None:
"""
Update the run settings for the lattice.
Parameters
----------
runSettings: runSetup
An instance of runSetup containing the new run settings.
Raises
------
TypeError
If the `runSettings` argument is not an instance of `runSetup`.
"""
if isinstance(runSettings, runSetup):
self.runSettings = runSettings
else:
raise TypeError(
"runSettings argument passed to frameworkLattice.updateRunSettings is not a runSetup instance"
)
[docs]
class frameworkCommand(frameworkObject):
"""
Class defining a framework command, which is used to generate commands used in setup files
for various simulation codes.
"""
def model_post_init(self, __context):
if self.objecttype not in commandkeywords:
raise NameError("Command '%s' does not exist" % self.objecttype)
super().model_post_init(__context)
[docs]
def write_Elegant(self) -> str:
"""
Writes the command string for ELEGANT.
Returns
-------
str
String representation of the command for ELEGANT
"""
string = "&" + self.objecttype + "\n"
for key in commandkeywords[self.objecttype]:
if (
key.lower() in self.allowedkeywords
and not key == "objectname"
and not key == "objecttype"
and hasattr(self, key)
):
string += "\t" + key + " = " + str(getattr(self, key.lower())) + "\n"
string += "&end\n"
return string
[docs]
def write_MAD8(self) -> str:
"""
Writes the command string for MAD8.
# TODO deprecated?
Returns
-------
str
String representation of the command for MAD8
"""
string = self.objecttype
# print(self.objecttype, self.objectproperties)
for key in commandkeywords[self.objecttype]:
if (
key.lower() in self.objectproperties
and not key == "name"
and not key == "type"
and not self.objectproperties[key.lower()] is None
):
e = "," + key + "=" + str(self.objectproperties[key.lower()])
if len((string + e).splitlines()[-1]) > 79:
string += ",&\n"
string += e
string += ";\n"
return string
[docs]
class frameworkGroup(object):
"""
Class defining a framework group, which is used to group together elements to perform coordinated
actions on them.
"""
def __init__(self, name, elementObjects, type, elements, **kwargs):
super(frameworkGroup, self).__init__()
self.objectname = name
self.type = type
self.elements = elements
self.allElementObjects = elementObjects.elementObjects
self.allGroupObjects = elementObjects.groupObjects
[docs]
def update(self, **kwargs):
pass
[docs]
def get_Parameter(self, p: str) -> Any:
"""
Get a specific parameter associated with the group, i.e. bunch compressor angle
Parameters
----------
p: str
A parameter associated with the group
Returns
-------
Any
The parameter, if defined.
"""
try:
isinstance(type(getattr(self, p)), p)
return getattr(self, p)
except Exception:
if self.elements[0] in self.allGroupObjects:
return getattr(self.allGroupObjects[self.elements[0]], p)
return getattr(self.allElementObjects[self.elements[0]], p)
[docs]
def change_Parameter(self, p: Any, v: Any) -> None:
"""
Set a parameter on all elements in the group.
Parameters
----------
p: str
The parameter to be set
v: Any
The value to be set.
"""
try:
getattr(self, p)
setattr(self, p, v)
if p == "angle":
self.set_angle(v)
# print ('Changing group ', self.objectname, ' ', p, ' = ', v, ' result = ', self.get_Parameter(p))
except Exception:
for e in self.elements:
setattr(self.allElementObjects[e], p, v)
# print ('Changing group elements ', self.objectname, ' ', p, ' = ', v, ' result = ', self.allElementObjects[self.elements[0]].objectname, self.get_Parameter(p))
# def __getattr__(self, p):
# return self.get_Parameter(p)
def __repr__(self):
return str([self.allElementObjects[e].objectname for e in self.elements])
def __str__(self):
return str([self.allElementObjects[e].objectname for e in self.elements])
def __getitem__(self, key):
return self.get_Parameter(key)
def __setitem__(self, key, value):
return self.change_Parameter(key, value)
[docs]
class element_group(frameworkGroup):
"""
Class defining a group of elements, which is used to group together elements to perform coordinated
actions on them.
"""
def __init__(self, name, elementObjects, type, elements, **kwargs):
super().__init__(name, elementObjects, type, elements, **kwargs)
def __str__(self):
return str([self.allElementObjects[e] for e in self.elements])
[docs]
class r56_group(frameworkGroup):
"""
Class defining a group of elements with a total R56.
"""
def __init__(self, name, elementObjects, type, elements, ratios, keys, **kwargs):
super().__init__(name, elementObjects, type, elements, **kwargs)
self.ratios = ratios
self.keys = keys
self._r56 = None
def __str__(self):
return str({e: k for e, k in zip(self.elements, self.keys)})
[docs]
def get_Parameter(self, p: str) -> Any:
"""
Get a parameter associated with the group.
Parameters
----------
p: str
The parameter to be retrieved.
Returns
-------
Any
The parameter.
"""
if str(p) == "r56":
return self.r56
else:
return super().get_Parameter(p)
@property
def r56(self) -> float:
"""
Get the R56 of the group of elements
Returns
-------
float
The R56 pararmeter
"""
return self._r56
@r56.setter
def r56(self, r56: float) -> None:
"""
Set the R56 of the group of elements
Parameters
----------
r56: float
The R56 to be set
"""
# print('Changing r56!', self._r56)
self._r56 = r56
data = {"r56": self._r56}
parser = MathParser(data)
values = [parser.parse(e) for e in self.ratios]
# print('\t', list(zip(self.elements, self.keys, values)))
for e, k, v in zip(self.elements, self.keys, values):
self.updateElements(e, k, v)
[docs]
def updateElements(self, element: str | list | tuple, key: str, value: Any) -> None:
"""
Update one or more elements in the group.
Parameters
----------
element: str, list or tuple
The element(s) to be updated
key: str
The parameter in the element or group of elements to be changed
value: Any
The value to which the parameter should be set
"""
# print('R56 : updateElements', element, key, value)
if isinstance(element, (list, tuple)):
[self.updateElements(e, key, value) for e in self.elements]
else:
if element in self.allElementObjects:
# print('R56 : updateElements : element', element, key, value)
self.allElementObjects[element].change_Parameter(key, value)
if element in self.allGroupObjects:
# print('R56 : updateElements : group', element, key, value)
self.allGroupObjects[element].change_Parameter(key, value)
[docs]
class chicane(frameworkGroup):
"""
Class defining a 4-dipole chicane.
"""
def __init__(self, name, elementObjects, type, elements, **kwargs):
super(chicane, self).__init__(name, elementObjects, type, elements, **kwargs)
self.ratios = (1, -1, -1, 1)
[docs]
def update(self, **kwargs) -> None:
"""
Update the bending angle and/or dipole width and/or dipole gap of all magnets in the chicane.
Parameters
----------
**kwargs: Dict
Dictionary containing parameters to be updated -- must be in ["dipoleangle", "width", "gap"]
"""
if "dipoleangle" in kwargs:
self.set_angle(kwargs["dipoleangle"])
if "width" in kwargs:
self.change_Parameter("width", kwargs["width"])
if "gap" in kwargs:
self.change_Parameter("gap", kwargs["gap"])
return None
@property
def angle(self) -> float:
"""
Bending angle of the chicane
Returns
-------
float
The bending angle
"""
obj = [self.allElementObjects[e] for e in self.elements]
return float(obj[0].angle)
@angle.setter
def angle(self, theta: float) -> None:
"""
Set the bending angle of the chicane; see :func:`~SimulationFramework.Framework_objects.chicane.set_angle`.
Parameters
-----------
theta: float
Chicane bending angle
"""
self.set_angle(theta)
# def set_angle2(self, a):
# indices = list(sorted([list(self.allElementObjects).index(e) for e in self.elements]))
# dipole_objs = [self.allElementObjects[e] for e in self.elements]
# obj = [self.allElementObjects[list(self.allElementObjects)[e]] for e in range(indices[0],indices[-1]+1)]
# starting_angle = obj[0].theta
# dipole_number = 0
# for i in range(len(obj)):
# start = obj[i].position_start
# x1 = np.transpose([start])
# obj[i].global_rotation[2] = starting_angle
# if obj[i] in dipole_objs:
# start_angle = obj[i].angle
# obj[i].angle = a*self.ratios[dipole_number]
# if abs(obj[i].angle) > 0:
# scale = (np.tan(obj[i].angle/2.0) / obj[i].angle) / (np.tan(start_angle/2.0) / start_angle)
# else:
# scale = 1
# obj[i].length = obj[i].length / scale
# dipole_number += 1
# elem_angle = obj[i].angle
# else:
# elem_angle = obj[i].angle if obj[i].angle is not None else 0
# if not obj[i] in dipole_objs:
# obj[i].centre = list(obj[i].middle)
# xstart, ystart, zstart = obj[i].position_end
# if i < len(obj)-1:
# xend, yend, zend = obj[i+1].position_start
# angle = starting_angle + elem_angle
# # print('angle = ', angle, starting_angle, obj[i+1].objectname)
# length = float((zend - zstart))
# endx = chop(float(xstart - np.tan(angle)*(length/2.0)))
# obj[i+1].centre[0] = endx
# obj[i+1].global_rotation[2] = angle
# starting_angle += elem_angle
[docs]
def set_angle(self, a: float) -> None:
"""
Set the chicane bending angle, including updating the inter-dipole drift lengths.
Parameters
----------
a: float
The angle to be set
"""
indices = list(
sorted([list(self.allElementObjects).index(e) for e in self.elements])
)
dipole_objs = [self.allElementObjects[e] for e in self.elements]
obj = [
self.allElementObjects[list(self.allElementObjects)[e]]
for e in range(indices[0], indices[-1] + 1)
]
dipole_number = 0
ref_pos = None
ref_angle = None
for e, i in enumerate(range(len(obj))):
if dipole_number > 0:
# print('before',obj[i])
adj = obj[i].centre[2] - ref_pos[2]
# print(' adj', adj)
# print(' ref_angle', ref_angle)
obj[i].centre = [
ref_pos[0] + np.tan(-1.0 * ref_angle) * adj,
0,
obj[i].centre[2],
]
obj[i].x_rot = ref_angle
# print('after',obj[i])
if obj[i] in dipole_objs:
# print('DIPOLE before',obj[i])
ref_pos = obj[i].middle
obj[i].angle = a * self.ratios[dipole_number]
ref_angle = obj[i].x_rot + obj[i].angle
dipole_number += 1
# print('DIPOLE after',obj[i])
# print('\n\n\n')
def __str__(self):
return str(
[
[
self.allElementObjects[e].objectname,
self.allElementObjects[e].angle,
self.allElementObjects[e].global_rotation[2],
self.allElementObjects[e].position_start,
self.allElementObjects[e].position_end,
]
for e in self.elements
]
)
[docs]
class s_chicane(chicane):
"""
Class defining an s-type chicane; in this case the bending ratios for
:func:`~SimulationFramework.Framework_objects.chicane.set_angle` are different.
"""
def __init__(self, name, elementObjects, type, elements, **kwargs):
super(s_chicane, self).__init__(name, elementObjects, type, elements, **kwargs)
self.ratios = (-1, 2, -2, 1)
[docs]
class frameworkCounter(dict):
"""
Class defining a counter object, used for numbering elements of the same type in ASTRA and CSRTrack
"""
def __init__(self, sub={}):
super(frameworkCounter, self).__init__()
self.sub = sub
[docs]
def counter(self, typ: str) -> int:
"""
Increment count of elements of a given type in the lattice.
Parameters
----------
typ: str
Element type
Returns
-------
int
The updated number of elements of a given type defined so far
"""
typ = self.sub[typ] if typ in self.sub else typ
if typ not in self:
return 1
return self[typ] + 1
[docs]
def value(self, typ: str) -> int:
"""
Number of elements of a given type in the lattice.
Parameters
----------
typ: str
Element type
Returns
-------
int
The number of elements of a given type defined so far
"""
typ = self.sub[typ] if typ in self.sub else typ
if typ not in self:
return 1
return self[typ]
[docs]
def add(self, typ: str, n: PositiveInt = 1) -> int:
"""
Add to count of elements of a given type in the lattice.
Parameters
----------
typ: str
Element type
n: PositiveInt, optional
Add more than one element at a time
Returns
-------
int
The number of elements of a given type defined so far
"""
typ = self.sub[typ] if typ in self.sub else typ
if typ not in self:
self[typ] = n
else:
self[typ] += n
return self[typ]
[docs]
def subtract(self, typ: str) -> int:
"""
Reduce count of elements of a given type in the lattice.
Parameters
----------
typ: str
Element type
Returns
-------
int
The updated number of elements of a given type defined so far
"""
typ = self.sub[typ] if typ in self.sub else typ
if typ not in self:
self[typ] = 0
else:
self[typ] = self[typ] - 1 if self[typ] > 0 else 0
return self[typ]
[docs]
class getGrids(object):
"""
Class defining the appropriate number of space charge bins given the number of particles,
defined as the closest power of 8 to the cube root of the number of particles.
"""
def __init__(self):
self.powersof8 = np.asarray([2**j for j in range(1, 20)])
[docs]
def getGridSizes(self, x: PositiveInt) -> int:
"""
Calculate the 3D space charge grid size given the number of particles, minimum of 4
Parameters
----------
x: PositiveInt
Number of particles
Returns
-------
int
The number of space charge grids
"""
self.x = abs(x)
self.cuberoot = int(round(self.x ** (1.0 / 3)))
return max([4, self.find_nearest(self.powersof8, self.cuberoot)])
[docs]
def find_nearest(self, array: np.ndarray | list, value: int) -> int:
"""
Get the nearest value in an array to the value provided; in this case the array should be a list of
powers of 8.
Parameters
----------
array: np.ndarray or list
Array of values to be checked
value: Value to be found in the array
Returns
-------
int
The closest value in `array` to `value`
"""
self.array = array
self.value = value
self.idx = (np.abs(self.array - self.value)).argmin()
return self.array[self.idx]