Source code for SimulationFramework.Codes.Ocelot.Ocelot

"""
Simframe Ocelot Module

Various objects and functions to handle OCELOT lattices and commands. See `Ocelot github`_ for more details.

    .. _Ocelot github: https://github.com/ocelot-collab/ocelot

Classes:
    - :class:`~SimulationFramework.Codes.Ocelot.Ocelot.ocelotLattice`: The Ocelot lattice object, used for
    converting the :class:`~SimulationFramework.Framework_elements.frameworkObject` s defined in the
    :class:`~SimulationFramework.Framework_elements.frameworkLattice` into an Ocelot lattice object,
    and for tracking through it.

"""
from ...Framework_objects import frameworkLattice, getGrids
from ...Framework_elements import screen
from ...FrameworkHelperFunctions import expand_substitution
from ...Modules import Beams as rbf
from ...Modules.Fields import field
from .mbi import MBI
from ocelot.cpbd.magnetic_lattice import MagneticLattice
from ocelot.cpbd.track import track
from ocelot.cpbd.io import save_particle_array, load_particle_array
from ocelot.cpbd.navi import Navigator
from ocelot.cpbd.sc import SpaceCharge, LSC
from ocelot.cpbd.csr import CSR
from ocelot.cpbd.wake3D import Wake, WakeTable
from ocelot.cpbd.physics_proc import SaveBeam
from ocelot.cpbd.beam import ParticleArray
from ocelot.cpbd.transformations.second_order import SecondTM
from ocelot.cpbd.transformations.kick import KickTM
from ocelot.cpbd.transformations.runge_kutta import RungeKuttaTM
from ocelot.cpbd.elements import Octupole, Undulator
from copy import deepcopy
from numpy import array, mean, savez_compressed, linspace, save
import os
from yaml import safe_load

with open(
    os.path.dirname(os.path.abspath(__file__)) + "/ocelot_defaults.yaml",
    "r",
) as infile:
    oceglobal = safe_load(infile)
from typing import Dict, List


[docs] class ocelotLattice(frameworkLattice): """ Class for defining the OCELOT lattice object, used for converting the :class:`~SimulationFramework.Framework_elements.frameworkObject`s defined in the :class:`~SimulationFramework.Framework_elements.frameworkLattice` into an Ocelot lattice object, and for tracking through it. """ code: str = "ocelot" """String indicating the lattice object type""" trackBeam: bool = True """Flag to indicate whether to track the beam""" lat_obj: MagneticLattice = None """Lattice object as an Ocelot `MagneticLattice`_ .. _MagneticLattice: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/magnetic_lattice.py """ pin: ParticleArray = None """Initial particle distribution as an Ocelot `ParticleArray`_ .. _ParticleArray: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/beam.py""" pout: ParticleArray = None """Final particle distribution as an Ocelot `ParticleArray`_""" tws: List = None """List containing Ocelot `Twiss`_ objects .. _Twiss: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/beam.py """ names: List = None """Names of elements in the lattice""" grids: getGrids = None """Class for calculating the required number of space charge grids""" oceglobal: Dict = {} """Global settings for Ocelot, read in from `ocelotLattice.settings["global"]["OCELOTsettings"]` and `ocelot_defaults.yaml`""" unit_step: float = 0.01 """Step for Ocelot `PhysProc`_ objects .. _PhysProc: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/physics_proc.py """ smooth_param: float = 0.01 """Smoothing parameter""" lsc: bool = True """Flag to enable LSC calculations""" random_mesh: bool = True """Random meshing for space charge calculations""" nbin_csr: int = 10 """Number of longitudinal bins for CSR calculations""" mbin_csr: int = 5 """Number of macroparticle bins for CSR calculations""" wake_factor: float = 1.0 """Multiplication factor for wakefields""" sigmamin_csr: float = 1e-5 """Minimum size for CSR calculations""" wake_sampling: int = 1000 """Number of samples for wake calculations""" wake_filter: int = 10 """Filter parameter for wake calculations""" particle_definition: str = None """Initial particle distribution as a string""" final_screen: screen | None = None """Final screen object""" mbi_navi: MBI | None = None """Physics process for calculating microbunching gain""" mbi: Dict = {} """Dictionary containing settings for microbunching gain calculation""" def model_post_init(self, __context): super().model_post_init(__context) self.oceglobal = ( self.settings["global"]["OCELOTsettings"] if "OCELOTsettings" in list(self.settings["global"].keys()) else oceglobal ) cls = self.__class__ for f in cls.model_fields: if f in list(self.oceglobal.keys()): setattr(self, f, self.oceglobal[f]) if ( "input" in self.file_block and "particle_definition" in self.file_block["input"] ): if ( self.file_block["input"]["particle_definition"] == "initial_distribution" ): self.particle_definition = "laser" else: self.particle_definition = self.file_block["input"][ "particle_definition" ] else: self.particle_definition = self.elementObjects[self.start].objectname self.grids = getGrids()
[docs] def endScreen(self, **kwargs) -> screen: """ Create a final screen object for dumping the particle output after tracking. Returns ------- :class:`~SimulationFramework.Elements.screen.screen` """ return screen( objectname=self.endObject.objectname, objecttype="screen", centre=self.endObject.centre, # position_start=self.endObject.position_start, # position_end=self.endObject.position_start, global_rotation=self.endObject.global_rotation, global_parameters=self.global_parameters, **kwargs, )
[docs] def writeElements(self) -> None: """ Create Ocelot objects for all the elements in the lattice and set the :attr:`~SimulationFramework.Codes.Ocelot.Ocelot.ocelotLattice.lat_obj` and :attr:`~SimulationFramework.Codes.Ocelot.Ocelot.ocelotLattice.names`. """ self.final_screen = None if not self.endObject in self.screens_and_bpms: self.final_screen = self.endScreen( output_filename=self.endObject.objectname + ".npz" ) elements = self.createDrifts() mag_lat = [] for element in list(elements.values()): if not element.subelement: try: mag_lat.append(element.write_Ocelot()) except Exception as e: print("Ocelot writeElements error:", element.objectname, e) method = {"global": SecondTM, Octupole: KickTM, Undulator: RungeKuttaTM} self.lat_obj = MagneticLattice(mag_lat, method=method) self.names = [str(x) for x in array([lat.id for lat in self.lat_obj.sequence])]
[docs] def write(self) -> None: """ Create the lattice object via :func:`~SimulationFramework.Codes.Ocelot.Ocelot.ocelotLattice.writeElements` and save it as a python file to `master_subdir`. """ self.writeElements() self.lat_obj.save_as_py_file( f'{self.global_parameters["master_subdir"]}/{self.objectname}.py' )
[docs] def preProcess(self) -> None: """ Get the initial particle distribution defined in `file_block['input']['prefix']` if it exists. """ super().preProcess() prefix = self.get_prefix() prefix = prefix if self.trackBeam else prefix + self.particle_definition self.hdf5_to_npz(prefix)
[docs] def hdf5_to_npz(self, prefix: str="", write: bool=True) -> None: """ Convert the initial HDF5 particle distribution to Ocelot format and set :attr:`~SimulationFramework.Codes.Ocelot.Ocelot.ocelotLattice.pin` accordingly. Parameters ---------- prefix: str Prefix for particle file write: bool Flag to indicate whether to save the file """ HDF5filename = prefix + self.particle_definition + ".hdf5" if os.path.isfile(expand_substitution(self, HDF5filename)): filepath = expand_substitution(self, HDF5filename) else: filepath = self.global_parameters["master_subdir"] + "/" + HDF5filename rbf.hdf5.read_HDF5_beam_file( self.global_parameters["beam"], os.path.abspath(filepath), ) hdf5outname = f'{self.global_parameters["master_subdir"]}/{self.elementObjects[self.start].objectname}.hdf5' rbf.hdf5.write_HDF5_beam_file( self.global_parameters["beam"], hdf5outname, ) ocebeamfilename = hdf5outname.replace("hdf5", "ocelot.npz") self.pin = rbf.beam.write_ocelot_beam_file( self.global_parameters["beam"], ocebeamfilename, write=write )
[docs] def run(self) -> None: """ Run the code, and set :attr:`~tws` and :attr:`~pout` """ navi = self.navi_setup() pin = deepcopy(self.pin) if self.sample_interval > 1: pin = pin.thin_out(nth=self.sample_interval) self.tws, self.pout = track( self.lat_obj, pin, navi=navi, calc_tws=True, twiss_disp_correction=True, )
[docs] def postProcess(self) -> None: """ Convert the outputs from Ocelot to HDF5 format and save them to `master_subdir`. """ super().postProcess() bfname = f'{self.global_parameters["master_subdir"]}/{self.endObject.objectname}.ocelot.npz' save_particle_array(bfname, self.pout) for elem in self.screens_and_bpms + [self.endObject]: ocebeamname = f'{self.global_parameters["master_subdir"]}/{elem.objectname}.ocelot.npz' parray = load_particle_array(ocebeamname) beam = rbf.beam(ocebeamname) rbf.hdf5.write_HDF5_beam_file( beam, ocebeamname.replace("ocelot.npz", "hdf5"), centered=False, sourcefilename=ocebeamname, pos=0.0, xoffset=mean(parray.x()), yoffset=mean(parray.y()), zoffset=[parray.s], ) twsdat = {e: [] for e in self.tws[0].__dict__.keys()} for t in self.tws: for k, v in t.__dict__.items(): # Offset the s values to the start of the lattice if k == "s": v += self.startObject.position_start[2] twsdat[k].append(v) savez_compressed( f'{self.global_parameters["master_subdir"]}/{self.objectname}_twiss.npz', **twsdat, ) if self.mbi_navi is not None: save( f'{self.global_parameters["master_subdir"]}/{self.objectname}_mbi.dat', self.mbi_navi.bf, )
[docs] def navi_setup(self) -> Navigator: """ Set up the physics processes for Ocelot (i.e. space charge, CSR, wakes etc). .. _Navigator: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/navi.py Returns ------- Navigator An Ocelot `Navigator`_ object """ navi_processes = [] navi_locations_start = [] navi_locations_end = [] # settings = self.settings navi = Navigator(self.lat_obj, unit_step=self.unit_step) if self.lsc: lsc = self.physproc_lsc() navi_processes += [lsc] navi_locations_start += [self.lat_obj.sequence[0]] navi_locations_end += [self.lat_obj.sequence[-1]] space_charge_set = False csr_set = False if "charge" in list(self.file_block.keys()): if ( "space_charge_mode" in list(self.file_block["charge"].keys()) and self.file_block["charge"]["space_charge_mode"].lower() == "3d" ): gridsize = self.grids.getGridSizes( (len(self.global_parameters["beam"].x) / self.sample_interval) ) g1 = self.sc_grid if hasattr(self, "sc_grid") else gridsize grids = [g1 for _ in range(3)] sc = self.physproc_sc(grids) navi_processes += [sc] navi_locations_start += [self.lat_obj.sequence[0]] navi_locations_end += [self.lat_obj.sequence[-1]] space_charge_set = True if "csr" in list(self.file_block.keys()): csr, start, end = self.physproc_csr() for i in range(len(csr)): navi_processes += [csr[i]] navi_locations_start += [start[i]] navi_locations_end += [end[i]] if self.mbi["set_mbi"]: self.mbi_navi = MBI( lattice=self.lat_obj, lamb_range=list( linspace( float(self.mbi["min"]), float(self.mbi["max"]), int(self.mbi["nstep"]), ) ), lsc=space_charge_set, csr=csr_set, slices=self.mbi["slices"], ) # mbi1.step = self.unit_step self.mbi_navi.navi = deepcopy(navi) self.mbi_navi.lattice = deepcopy(self.lat_obj) self.mbi_navi.lsc = True navi.add_physics_proc( self.mbi_navi, self.lat_obj.sequence[0], self.lat_obj.sequence[-1] ) for name, obj in self.elements.items(): if obj.objecttype == "cavity": fieldstr = "wakefield_definition" elif obj.objecttype == "wakefield": fieldstr = "field_definition" else: fieldstr = None if fieldstr is not None: if getattr(obj, fieldstr) is not None: wake, w_ind = self.physproc_wake( name, getattr(obj, fieldstr), obj.n_cells ) navi_processes += [wake] navi_locations_start += [self.lat_obj.sequence[w_ind]] navi_locations_end += [self.lat_obj.sequence[w_ind + 1]] for w in self.screens_and_bpms: name = w.output_filename.replace(".sdds", "") loc = self.lat_obj.sequence[self.names.index(name)] subdir = self.global_parameters["master_subdir"] navi_processes += [SaveBeam(filename=f"{subdir}/{name}.ocelot.npz")] navi_locations_start += [loc] navi_locations_end += [loc] navi.add_physics_processes( navi_processes, navi_locations_start, navi_locations_end ) return navi
[docs] def physproc_lsc(self) -> LSC: """ Get an Ocelot `LSC`_ physics process .. LSC: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/sc.py Returns ------- LSC The Ocelot LSC PhysProc """ lsc = LSC() lsc.smooth_param = self.smooth_param return lsc
[docs] def physproc_sc(self, grids: List[int]) -> SpaceCharge: """ Get an Ocelot `SpaceCharge`_ physics process .. _SpaceCharge: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/sc.py Parameters ---------- grids: List[int] The space charge grid number in x,y,z Returns ------- SpaceCharge The Ocelot SpaceCharge PhysProc """ sc = SpaceCharge(step=1) sc.nmesh_xyz = grids sc.random_mesh = self.random_mesh return sc
[docs] def physproc_csr(self) -> tuple: """ Get Ocelot `CSR`_ physics processes based on the start and end positions provided in `file_block`. If these are not provided, just include CSR for the entire lattice. .. _CSR: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/csr.py Returns ------- tuple A list of CSR PhysProcs, and their start and end positions """ csrlist = [] stlist = [] enlist = [] if ("start" in list(self.file_block["csr"].keys())) and ( "end" in list(self.file_block["csr"].keys()) ): start = self.file_block["csr"]["start"] st = [start] if isinstance(start, str) else start end = self.file_block["csr"]["end"] en = [end] if isinstance(end, str) else end for i in range(len(st)): stelem = self.lat_obj.sequence[self.names.index(st[i])] enelem = self.lat_obj.sequence[self.names.index(en[i])] csr = CSR() csr.n_bin = self.nbin_csr csr.m_bin = self.mbin_csr csr.sigma_min = self.sigmamin_csr csrlist.append(csr) stlist.append(stelem) enlist.append(enelem) else: csr = CSR() csr.n_bin = self.nbin_csr csr.m_bin = self.mbin_csr csr.sigma_min = self.sigmamin_csr stlist = [self.lat_obj.sequence[0]] enlist = [self.lat_obj.sequence[-1]] return [csrlist, stlist, enlist]
[docs] def physproc_wake( self, name: str, loc: field | str, ncell: int, ) -> tuple: """ Get an Ocelot `Wake`_ physics process based on the wakefield provided. .. _Wake: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/wake.py Parameters ---------- name: str Name of lattice object associated with the wake loc: :class:`~SimulationFramework.Modules.Fields.field` or str If `field`, then write the field file to ASTRA format ncell: int Number of cells, which provides a multiplication factor for the wake Returns ------- tuple A Wake PhysProc, and its index in the lattice """ if isinstance(loc, field): loc = loc.write_field_file(code="astra") wake = Wake( step=100, w_sampling=self.wake_sampling, filter_order=self.wake_filter, ) wake.factor = ncell * self.wake_factor wake.wake_table = WakeTable(expand_substitution(self, loc)) w_ind = self.names.index(name) return [wake, w_ind]