"""
Simframe CSRTrack Module
Various objects and functions to handle CSRTrack lattices and commands. See `CSRTrack manual`_ for more details.
.. _CSRTrack manual: https://www.desy.de/xfel-beam/csrtrack/files/CSRtrack_User_Guide_(actual).pdf
Classes:
- :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrackLattice`: The CSRTrack lattice object, used for
converting the :class:`~SimulationFramework.Framework_elements.frameworkObject` s defined in the
:class:`~SimulationFramework.Framework_elements.frameworkLattice` into a string representation of
the lattice suitable for a CSRTrack input file.
- :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrack_element`: Class for defining the a
CSRTrack instance of a :class:`~SimulationFramework.Framework_objects.frameworkElement`.
- :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrack_forces`: Class for defining the CSR
calculation type.
- :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrack_track_step`: Class for defining the
tracking step.
- :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrack_particles`: Class for defining the
particle distribution and format.
- :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrack_monitor`: Class for defining monitors.
"""
import os
import yaml
from ...Framework_objects import (
frameworkLattice,
frameworkElement,
frameworkCounter,
elementkeywords,
)
from ...Framework_elements import screen
from ...FrameworkHelperFunctions import saveFile, expand_substitution
from ...Modules import Beams as rbf
from typing import Dict, List
with open(
os.path.dirname(os.path.abspath(__file__)) + "/csrtrack_defaults.yaml", "r"
) as infile:
csrtrack_defaults = yaml.safe_load(infile)
[docs]
class csrtrackLattice(frameworkLattice):
"""
Class for defining the CSRTrack lattice object, used for
converting the :class:`~SimulationFramework.Framework_elements.frameworkObject`s defined in the
:class:`~SimulationFramework.Framework_elements.frameworkLattice` into a string representation of
the lattice suitable for a CSRTrack input file.
"""
code: str = "csrtrack"
"""String indicating the lattice object type"""
particle_definition: str = ""
"""String representing the initial particle distribution"""
CSRTrackelementObjects: Dict = {}
"""Dictionary representing all CSRTrack object namelists"""
def model_post_init(self, __context):
super().model_post_init(__context)
self.set_particles_filename()
[docs]
def endScreen(self, **kwargs) -> screen:
"""
Create a final screen object for dumping the particle output after tracking.
Returns
-------
:class:`~SimulationFramework.Elements.screen.screen`
"""
return screen(
objectname="end",
objecttype="screen",
centre=self.endObject.centre,
global_rotation=self.endObject.global_rotation,
global_parameters=self.global_parameters,
**kwargs,
)
[docs]
def set_particles_filename(self) -> None:
"""
Set up the `CSRTrackelementObjects namelist for the initial particle distribution,
based on the `particle_definition` and the `global_parameters` of the lattice.
"""
self.CSRTrackelementObjects["particles"] = csrtrack_particles(
particle_definition=self.particle_definition,
global_parameters=self.global_parameters,
format="astra",
)
if self.particle_definition == "initial_distribution":
self.CSRTrackelementObjects["particles"].particle_definition = "laser.astra"
self.CSRTrackelementObjects["particles"].add_default(
"array", "#file{name=laser.astra}"
)
else:
self.CSRTrackelementObjects["particles"].particle_definition = (
self.elementObjects[self.start].objectname
)
self.CSRTrackelementObjects["particles"].array = (
"#file{name="
+ self.elementObjects[self.start].objectname
+ ".astra"
+ "}"
)
@property
def dipoles_screens_and_bpms(self) -> List:
"""
Get a list of the dipoles, screens and BPMs sorted by their position in the lattice
Returns
-------
List
A sorted list of :class:`~SimulationFramework.Framework_objects.frameworkElement`
"""
return sorted(
self.getElementType("dipole")
+ self.getElementType("screen")
+ self.getElementType("beam_position_monitor"),
key=lambda x: x.position_end[2],
)
[docs]
def setCSRMode(self) -> None:
"""
Set up the `forces` key in `CSRTrackelementObjects based on the `csr_mode` defined in the settings
file for this lattice section. `csr_mode` can be either ["csr_g_to_p" (2D) or "projected" (1D)]
"""
if "csr" in self.file_block and "csr_mode" in self.file_block["csr"]:
if self.file_block["csr"]["csr_mode"] == "3D":
self.CSRTrackelementObjects["forces"] = csrtrack_forces(
type="csr_g_to_p"
)
elif self.file_block["csr"]["csr_mode"] == "1D":
self.CSRTrackelementObjects["forces"] = csrtrack_forces(
type="projected"
)
else:
self.CSRTrackelementObjects["forces"] = csrtrack_forces()
[docs]
def writeElements(self) -> str:
"""
Write the lattice elements defined in this object into a CSRTrack-compatible format; see
:attr:`~SimulationFramework.Framework_objects.frameworkLattice.elementObjects`.
The appropriate headers required for ASTRA are written at the top of the file, see the `_write_CSRTrack`
function in :class:`~SimulationFramework.Codes.CSRTrack.csrtrack_element`.
Returns
-------
str
The lattice represented as a string compatible with CSRTrack
"""
fulltext = "io_path{logfile = log.txt}\nlattice{\n"
counter = frameworkCounter(sub={"beam_position_monitor": "screen"})
for e in self.dipoles_screens_and_bpms:
# if not e.type == 'dipole':
# self.CSRTrackelementObjects[e.name] = csrtrack_online_monitor(filename=e.name+'.fmt2', monitor_type='phase', marker='screen'+str(counter.counter(e.type)), particle='all')
fulltext += e.write_CSRTrack(counter.counter(e.objecttype))
counter.add(e.objecttype)
fulltext += self.endScreen().write_CSRTrack(
counter.counter(self.endScreen().objecttype)
)
fulltext += "}\n"
self.set_particles_filename()
self.setCSRMode()
self.CSRTrackelementObjects["track_step"] = csrtrack_track_step()
self.CSRTrackelementObjects["tracker"] = csrtrack_tracker(
end_time_marker="screen"
+ str(counter.counter(self.endScreen().objecttype))
+ "a"
)
self.CSRTrackelementObjects["monitor"] = csrtrack_monitor(
name=self.end + ".fmt2", global_parameters=self.global_parameters
)
for c in self.CSRTrackelementObjects:
fulltext += self.CSRTrackelementObjects[c].write_CSRTrack(n=0)
return fulltext
[docs]
def write(self) -> str:
"""
Writes the CSRTrack input file from :func:`~SimulationFramework.Codes.CSRTrack.csrtrackLattice.writeElements`
to <master_subdir>/csrtrk.in.
"""
code_file = self.global_parameters["master_subdir"] + "/csrtrk.in"
saveFile(code_file, self.writeElements())
[docs]
def preProcess(self) -> None:
"""
Convert the beam file from the previous lattice section into CSRTrack format and set the number of
particles based on the input distribution, see
:func:`~SimulationFramework.Codes.CSRTrack.csrtrack_particles.hdf5_to_astra`.
"""
super().preProcess()
prefix = self.get_prefix()
self.CSRTrackelementObjects["particles"].hdf5_to_astra(prefix)
[docs]
def postProcess(self) -> None:
"""
Convert the beam file from the CSRTrack output into HDF5 format, see
:func:`~SimulationFramework.Codes.CSRTrack.csrtrack_monitor.csrtrack_to_hdf5`.
"""
super().postProcess()
self.CSRTrackelementObjects["monitor"].csrtrack_to_hdf5()
[docs]
class csrtrack_element(frameworkElement):
"""
Base class for CSRTrack elements, including namelists for the lattice file.
"""
header: str = ""
"""Header for CSRtrack file types"""
def model_post_init(self, __context):
if self.objectname in csrtrack_defaults:
for k, v in list(csrtrack_defaults[self.objectname].items()):
setattr(self, k, v)
[docs]
def CSRTrack_str(self, s: bool) -> str:
"""
Convert a boolean into a string for CSRTrack.
Parameters
----------
s: bool
Boolean to convert
Returns
-------
str
'yes' for `True`, 'no' for `False`, or the original string if otherwise
"""
if s is True:
return "yes"
elif s is False:
return "no"
else:
return str(s)
def _write_CSRTrack(self, n: int = 0, **kwargs) -> str:
"""
Create the string for the header object in CSRTrack format.
Returns
-------
str
CSRTrack-compatible string for this element.
"""
output = str(self.header) + "{\n"
for k in elementkeywords[self.objecttype]["keywords"]:
k = k.lower()
if getattr(self, k) is not None:
output += k + "=" + self.CSRTrack_str(getattr(self, k)) + "\n"
elif k in self.objectdefaults:
output += k + "=" + self.CSRTrack_str(self.objectdefaults[k]) + "\n"
output += "}\n"
return output
# class csrtrack_online_monitor(csrtrack_element):
#
# def __init__(self, marker="", **kwargs):
# super(csrtrack_online_monitor, self).__init__(
# "online_monitor", "csrtrack_online_monitor", **kwargs
# )
# self.header = "online_monitor"
# self.end_time_marker = marker + "b"
[docs]
class csrtrack_forces(csrtrack_element):
"""
Class for CSRTrack forces.
"""
header: str = "forces"
"""Header for CSRtrack element"""
objectname: str = "forces"
"""Name of object"""
objecttype: str = "csrtrack_forces"
"""Type of object"""
[docs]
class csrtrack_track_step(csrtrack_element):
"""
Class for defining CSRTrack the tracking step.
"""
header: str = "track_step"
"""Header for CSRtrack element"""
objectname: str = "track_step"
"""Name of object"""
objecttype: str = "csrtrack_track_step"
"""Type of object"""
[docs]
class csrtrack_tracker(csrtrack_element):
"""
Class for defining the CSRTrack tracker.
"""
header: str = "tracker"
"""Header for CSRtrack element"""
objectname: str = "tracker"
"""Name of object"""
objecttype: str = "csrtrack_tracker"
"""Type of object"""
end_time_marker: str = ""
"""Name of end marker"""
end_time_shift_c0: str | float = 0.0
"""Time shift for end"""
[docs]
class csrtrack_monitor(csrtrack_element):
"""
Class for defining CSRTrack monitors.
"""
header: str = "monitor"
"""Header for CSRtrack element"""
objectname: str = "monitor"
"""Name of object"""
objecttype: str = "csrtrack_monitor"
"""Type of object"""
name: str = ""
"""File name for monitor"""
[docs]
def csrtrack_to_hdf5(self) -> None:
"""
Convert the particle distribution from a CSRTrack monitor into HDF5 format,
and write it to `master_subdir`.
"""
csrtrackbeamfilename = self.name
astrabeamfilename = csrtrackbeamfilename.replace(".fmt2", ".astra")
rbf.astra.convert_csrtrackfile_to_astrafile(
self.global_parameters["beam"],
self.global_parameters["master_subdir"] + "/" + csrtrackbeamfilename,
self.global_parameters["master_subdir"] + "/" + astrabeamfilename,
)
rbf.astra.read_astra_beam_file(
self.global_parameters["beam"],
self.global_parameters["master_subdir"] + "/" + astrabeamfilename,
normaliseZ=False,
)
HDF5filename = self.name.replace(".fmt2", ".hdf5")
rbf.hdf5.write_HDF5_beam_file(
self.global_parameters["beam"],
self.global_parameters["master_subdir"] + "/" + HDF5filename,
sourcefilename=csrtrackbeamfilename,
)
[docs]
class csrtrack_particles(csrtrack_element):
"""
Class for defining CSRTrack particles.
"""
header: str = "particles"
"""Header for CSRtrack element"""
objectname: str = "particles"
"""Name of object"""
objecttype: str = "csrtrack_particles"
"""Type of object"""
particle_definition: str = "laser.astra"
"""Particle definition file"""
array: str = "#file{name=laser.astra}"
"""File name array"""
[docs]
def hdf5_to_astra(self, prefix: str = "") -> None:
"""
Convert HDF5 particle distribution to ASTRA format, suitable for inputting to CSRTrack.
Parameters
----------
prefix: str
Prefix for filename
"""
HDF5filename = prefix + self.particle_definition.replace(".astra", "") + ".hdf5"
if os.path.isfile(expand_substitution(self, HDF5filename)):
filepath = expand_substitution(self, HDF5filename)
else:
filepath = self.global_parameters["master_subdir"] + "/" + HDF5filename
rbf.hdf5.read_HDF5_beam_file(
self.global_parameters["beam"],
filepath,
)
astrabeamfilename = self.particle_definition + ".astra"
rbf.astra.write_astra_beam_file(
self.global_parameters["beam"],
self.global_parameters["master_subdir"] + "/" + astrabeamfilename,
normaliseZ=False,
)