Source code for SimulationFramework.Codes.CSRTrack.CSRTrack

"""
Simframe CSRTrack Module

Various objects and functions to handle CSRTrack lattices and commands. See `CSRTrack manual`_ for more details.

    .. _CSRTrack manual: https://www.desy.de/xfel-beam/csrtrack/files/CSRtrack_User_Guide_(actual).pdf

Classes:
    - :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrackLattice`: The CSRTrack lattice object, used for
    converting the :class:`~SimulationFramework.Framework_elements.frameworkObject` s defined in the
    :class:`~SimulationFramework.Framework_elements.frameworkLattice` into a string representation of
    the lattice suitable for a CSRTrack input file.

    - :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrack_element`: Class for defining the a
    CSRTrack instance of a :class:`~SimulationFramework.Framework_objects.frameworkElement`.

    - :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrack_forces`: Class for defining the CSR
    calculation type.

    - :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrack_track_step`: Class for defining the
    tracking step.

    - :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrack_particles`: Class for defining the
     particle distribution and format.

    - :class:`~SimulationFramework.Codes.CSRTrack.CSRTrack.csrtrack_monitor`: Class for defining monitors.
"""

import os
import yaml
from ...Framework_objects import (
    frameworkLattice,
    frameworkElement,
    frameworkCounter,
    elementkeywords,
)
from ...Framework_elements import screen
from ...FrameworkHelperFunctions import saveFile, expand_substitution
from ...Modules import Beams as rbf
from typing import Dict, List

with open(
    os.path.dirname(os.path.abspath(__file__)) + "/csrtrack_defaults.yaml", "r"
) as infile:
    csrtrack_defaults = yaml.safe_load(infile)


[docs] class csrtrackLattice(frameworkLattice): """ Class for defining the CSRTrack lattice object, used for converting the :class:`~SimulationFramework.Framework_elements.frameworkObject`s defined in the :class:`~SimulationFramework.Framework_elements.frameworkLattice` into a string representation of the lattice suitable for a CSRTrack input file. """ code: str = "csrtrack" """String indicating the lattice object type""" particle_definition: str = "" """String representing the initial particle distribution""" CSRTrackelementObjects: Dict = {} """Dictionary representing all CSRTrack object namelists""" def model_post_init(self, __context): super().model_post_init(__context) self.set_particles_filename()
[docs] def endScreen(self, **kwargs) -> screen: """ Create a final screen object for dumping the particle output after tracking. Returns ------- :class:`~SimulationFramework.Elements.screen.screen` """ return screen( objectname="end", objecttype="screen", centre=self.endObject.centre, global_rotation=self.endObject.global_rotation, global_parameters=self.global_parameters, **kwargs, )
[docs] def set_particles_filename(self) -> None: """ Set up the `CSRTrackelementObjects namelist for the initial particle distribution, based on the `particle_definition` and the `global_parameters` of the lattice. """ self.CSRTrackelementObjects["particles"] = csrtrack_particles( particle_definition=self.particle_definition, global_parameters=self.global_parameters, format="astra", ) if self.particle_definition == "initial_distribution": self.CSRTrackelementObjects["particles"].particle_definition = "laser.astra" self.CSRTrackelementObjects["particles"].add_default( "array", "#file{name=laser.astra}" ) else: self.CSRTrackelementObjects["particles"].particle_definition = ( self.elementObjects[self.start].objectname ) self.CSRTrackelementObjects["particles"].array = ( "#file{name=" + self.elementObjects[self.start].objectname + ".astra" + "}" )
@property def dipoles_screens_and_bpms(self) -> List: """ Get a list of the dipoles, screens and BPMs sorted by their position in the lattice Returns ------- List A sorted list of :class:`~SimulationFramework.Framework_objects.frameworkElement` """ return sorted( self.getElementType("dipole") + self.getElementType("screen") + self.getElementType("beam_position_monitor"), key=lambda x: x.position_end[2], )
[docs] def setCSRMode(self) -> None: """ Set up the `forces` key in `CSRTrackelementObjects based on the `csr_mode` defined in the settings file for this lattice section. `csr_mode` can be either ["csr_g_to_p" (2D) or "projected" (1D)] """ if "csr" in self.file_block and "csr_mode" in self.file_block["csr"]: if self.file_block["csr"]["csr_mode"] == "3D": self.CSRTrackelementObjects["forces"] = csrtrack_forces( type="csr_g_to_p" ) elif self.file_block["csr"]["csr_mode"] == "1D": self.CSRTrackelementObjects["forces"] = csrtrack_forces( type="projected" ) else: self.CSRTrackelementObjects["forces"] = csrtrack_forces()
[docs] def writeElements(self) -> str: """ Write the lattice elements defined in this object into a CSRTrack-compatible format; see :attr:`~SimulationFramework.Framework_objects.frameworkLattice.elementObjects`. The appropriate headers required for ASTRA are written at the top of the file, see the `_write_CSRTrack` function in :class:`~SimulationFramework.Codes.CSRTrack.csrtrack_element`. Returns ------- str The lattice represented as a string compatible with CSRTrack """ fulltext = "io_path{logfile = log.txt}\nlattice{\n" counter = frameworkCounter(sub={"beam_position_monitor": "screen"}) for e in self.dipoles_screens_and_bpms: # if not e.type == 'dipole': # self.CSRTrackelementObjects[e.name] = csrtrack_online_monitor(filename=e.name+'.fmt2', monitor_type='phase', marker='screen'+str(counter.counter(e.type)), particle='all') fulltext += e.write_CSRTrack(counter.counter(e.objecttype)) counter.add(e.objecttype) fulltext += self.endScreen().write_CSRTrack( counter.counter(self.endScreen().objecttype) ) fulltext += "}\n" self.set_particles_filename() self.setCSRMode() self.CSRTrackelementObjects["track_step"] = csrtrack_track_step() self.CSRTrackelementObjects["tracker"] = csrtrack_tracker( end_time_marker="screen" + str(counter.counter(self.endScreen().objecttype)) + "a" ) self.CSRTrackelementObjects["monitor"] = csrtrack_monitor( name=self.end + ".fmt2", global_parameters=self.global_parameters ) for c in self.CSRTrackelementObjects: fulltext += self.CSRTrackelementObjects[c].write_CSRTrack(n=0) return fulltext
[docs] def write(self) -> str: """ Writes the CSRTrack input file from :func:`~SimulationFramework.Codes.CSRTrack.csrtrackLattice.writeElements` to <master_subdir>/csrtrk.in. """ code_file = self.global_parameters["master_subdir"] + "/csrtrk.in" saveFile(code_file, self.writeElements())
[docs] def preProcess(self) -> None: """ Convert the beam file from the previous lattice section into CSRTrack format and set the number of particles based on the input distribution, see :func:`~SimulationFramework.Codes.CSRTrack.csrtrack_particles.hdf5_to_astra`. """ super().preProcess() prefix = self.get_prefix() self.CSRTrackelementObjects["particles"].hdf5_to_astra(prefix)
[docs] def postProcess(self) -> None: """ Convert the beam file from the CSRTrack output into HDF5 format, see :func:`~SimulationFramework.Codes.CSRTrack.csrtrack_monitor.csrtrack_to_hdf5`. """ super().postProcess() self.CSRTrackelementObjects["monitor"].csrtrack_to_hdf5()
[docs] class csrtrack_element(frameworkElement): """ Base class for CSRTrack elements, including namelists for the lattice file. """ header: str = "" """Header for CSRtrack file types""" def model_post_init(self, __context): if self.objectname in csrtrack_defaults: for k, v in list(csrtrack_defaults[self.objectname].items()): setattr(self, k, v)
[docs] def CSRTrack_str(self, s: bool) -> str: """ Convert a boolean into a string for CSRTrack. Parameters ---------- s: bool Boolean to convert Returns ------- str 'yes' for `True`, 'no' for `False`, or the original string if otherwise """ if s is True: return "yes" elif s is False: return "no" else: return str(s)
def _write_CSRTrack(self, n: int = 0, **kwargs) -> str: """ Create the string for the header object in CSRTrack format. Returns ------- str CSRTrack-compatible string for this element. """ output = str(self.header) + "{\n" for k in elementkeywords[self.objecttype]["keywords"]: k = k.lower() if getattr(self, k) is not None: output += k + "=" + self.CSRTrack_str(getattr(self, k)) + "\n" elif k in self.objectdefaults: output += k + "=" + self.CSRTrack_str(self.objectdefaults[k]) + "\n" output += "}\n" return output
# class csrtrack_online_monitor(csrtrack_element): # # def __init__(self, marker="", **kwargs): # super(csrtrack_online_monitor, self).__init__( # "online_monitor", "csrtrack_online_monitor", **kwargs # ) # self.header = "online_monitor" # self.end_time_marker = marker + "b"
[docs] class csrtrack_forces(csrtrack_element): """ Class for CSRTrack forces. """ header: str = "forces" """Header for CSRtrack element""" objectname: str = "forces" """Name of object""" objecttype: str = "csrtrack_forces" """Type of object"""
[docs] class csrtrack_track_step(csrtrack_element): """ Class for defining CSRTrack the tracking step. """ header: str = "track_step" """Header for CSRtrack element""" objectname: str = "track_step" """Name of object""" objecttype: str = "csrtrack_track_step" """Type of object"""
[docs] class csrtrack_tracker(csrtrack_element): """ Class for defining the CSRTrack tracker. """ header: str = "tracker" """Header for CSRtrack element""" objectname: str = "tracker" """Name of object""" objecttype: str = "csrtrack_tracker" """Type of object""" end_time_marker: str = "" """Name of end marker""" end_time_shift_c0: str | float = 0.0 """Time shift for end"""
[docs] class csrtrack_monitor(csrtrack_element): """ Class for defining CSRTrack monitors. """ header: str = "monitor" """Header for CSRtrack element""" objectname: str = "monitor" """Name of object""" objecttype: str = "csrtrack_monitor" """Type of object""" name: str = "" """File name for monitor"""
[docs] def csrtrack_to_hdf5(self) -> None: """ Convert the particle distribution from a CSRTrack monitor into HDF5 format, and write it to `master_subdir`. """ csrtrackbeamfilename = self.name astrabeamfilename = csrtrackbeamfilename.replace(".fmt2", ".astra") rbf.astra.convert_csrtrackfile_to_astrafile( self.global_parameters["beam"], self.global_parameters["master_subdir"] + "/" + csrtrackbeamfilename, self.global_parameters["master_subdir"] + "/" + astrabeamfilename, ) rbf.astra.read_astra_beam_file( self.global_parameters["beam"], self.global_parameters["master_subdir"] + "/" + astrabeamfilename, normaliseZ=False, ) HDF5filename = self.name.replace(".fmt2", ".hdf5") rbf.hdf5.write_HDF5_beam_file( self.global_parameters["beam"], self.global_parameters["master_subdir"] + "/" + HDF5filename, sourcefilename=csrtrackbeamfilename, )
[docs] class csrtrack_particles(csrtrack_element): """ Class for defining CSRTrack particles. """ header: str = "particles" """Header for CSRtrack element""" objectname: str = "particles" """Name of object""" objecttype: str = "csrtrack_particles" """Type of object""" particle_definition: str = "laser.astra" """Particle definition file""" array: str = "#file{name=laser.astra}" """File name array"""
[docs] def hdf5_to_astra(self, prefix: str = "") -> None: """ Convert HDF5 particle distribution to ASTRA format, suitable for inputting to CSRTrack. Parameters ---------- prefix: str Prefix for filename """ HDF5filename = prefix + self.particle_definition.replace(".astra", "") + ".hdf5" if os.path.isfile(expand_substitution(self, HDF5filename)): filepath = expand_substitution(self, HDF5filename) else: filepath = self.global_parameters["master_subdir"] + "/" + HDF5filename rbf.hdf5.read_HDF5_beam_file( self.global_parameters["beam"], filepath, ) astrabeamfilename = self.particle_definition + ".astra" rbf.astra.write_astra_beam_file( self.global_parameters["beam"], self.global_parameters["master_subdir"] + "/" + astrabeamfilename, normaliseZ=False, )