Source code for SimulationFramework.Modules.plotting.plotting

import math
import matplotlib.pyplot as plt
from copy import copy
import numpy as np
from ..units import nice_array, nice_scale_prefix
from mpl_axes_aligner import align
from ..Twiss import twissParameter, twiss_defaults

# from units import nice_array, nice_scale_prefix

CMAP0 = copy(plt.get_cmap("viridis"))
CMAP0.set_under("white")
CMAP1 = copy(plt.get_cmap("plasma"))


[docs] def trans(M): return [[M[j][i] for j in range(len(M))] for i in range(len(M[0]))]
[docs] def find_nearest(array, value): idx = np.searchsorted(array, value, side="left") if idx > 0 and ( idx == len(array) or math.fabs(value - array[idx - 1]) < math.fabs(value - array[idx]) ): return idx - 1 else: return idx
[docs] def ASTRA_TW_FieldMap(fielddat, start, stop, cells, p): zpos = list(fielddat[:, 0]) startpos = zpos.index(start) stoppos = zpos.index(stop) halfcell1 = fielddat[:startpos] halfcell2 = fielddat[stoppos:] rfcell = fielddat[startpos:stoppos] n_cells = int(cells / p) cell_length = rfcell[-1, 0] - rfcell[0, 0] dat = list(halfcell1) for i in range(0, n_cells + 1, 1): dat += list(1.0 * rfcell) rfcell[:, 0] += cell_length halfcell2[:, 0] += n_cells * cell_length dat += list(halfcell2) dat = np.array(dat) return dat
[docs] def fieldmap_data(element): """ Loads the fieldmap in absolute coordinates. If a fieldmaps dict is given, these will be used instead of loading the file. """ # Position offset = element.get_field_reference_position()[2] # Scaling scale = element.get_field_amplitude if element.objecttype == "cavity": scale = scale / 1e6 # file element.update_field_definition() field = element.field_definition data = field.get_field_data(code="astra") if field.field_type == "1DElectroDynamic" and field.cavity_type == "TravellingWave": dat = ASTRA_TW_FieldMap( np.transpose([field.z.value.val, field.Ez.value.val]), field.start_cell_z, field.end_cell_z, element.n_cells, field.mode_denominator, ) else: dat = data dat[:, 0] += offset x = dat[:, 1] normalise = max(x.min(), x.max(), key=abs) dat[:, 1] *= scale / normalise return dat
[docs] class magnet_plotting_data: def __init__( self, kinetic_energy=None, ): if kinetic_energy is not None: self.z, self.kinetic_energy = trans(kinetic_energy) else: self.z, self.kinetic_energy = ([0.0], [1.0])
[docs] def half_rectangle(self, e, half_height): return np.array( [ [e.position_start[2], 0], [e.position_start[2], half_height], [e.position_end[2], half_height], [e.position_end[2], 0], ] )
[docs] def full_rectangle(self, e, half_height, width=0): return np.array( [ [e.position_start[2] - width, -half_height], [e.position_start[2] - width, half_height], [e.position_end[2] + width, half_height], [e.position_end[2] + width, -half_height], ] )
[docs] def quadrupole(self, e): if e.gradient is None: strength = np.sign(e.k1l) * 0.5 else: idx = find_nearest(self.z, e.middle[2]) ke = self.kinetic_energy[idx] strength = 1.0 / (3.3356 * ke / 1e6) * e.gradient return self.half_rectangle(e, strength), "red"
[docs] def sextupole(self, e): if e.gradient is None: strength = np.sign(e.k2l) * 0.5 else: idx = find_nearest(self.z, e.middle[2]) ke = self.kinetic_energy[idx] strength = 1.0 / (3.3356 * ke / 1e6) * e.gradient return self.half_rectangle(e, strength), "green"
[docs] def dipole(self, e): strength = np.sign(e.angle) * 0.4 # e.angle return self.half_rectangle(e, strength), "blue"
[docs] def beam_position_monitor(self, e): strength = 0.1 # e.angle return self.full_rectangle(e, strength), "purple"
[docs] def screen(self, e): strength = 0.33 # e.angle return self.full_rectangle(e, strength), "green"
[docs] def aperture(self, e): strength = 0.15 # e.angle return self.full_rectangle(e, strength, width=0.01), "black"
[docs] def wall_current_monitor(self, e): strength = 0.33 # e.angle return self.full_rectangle(e, strength), "brown"
[docs] def load_elements( lattice, bounds=None, sections="All", types=["cavity", "solenoid"], kinetic_energy=None, verbose=False, scale=1, ): fmap = {} mpd = magnet_plotting_data(kinetic_energy=kinetic_energy) for t in types: fmap[t] = {} if sections == "All": elements = [lattice[e["name"]] for e in lattice.getElementType(t)] else: elements = [] for s in sections: elements += [lattice[e["name"]] for e in lattice[s].getElementType(t)] if bounds is not None: elements = [ e for e in elements if e.position_start[2] <= bounds[1] and e.position_end[2] >= bounds[0] - 0.1 ] for e in elements: if ( (t == "cavity" or t == "solenoid") and hasattr(e, "field_definition") and e.field_definition is not None ): fmap[t][e.objectname] = fieldmap_data(e) elif hasattr(mpd, t): fmap[t][e.objectname] = getattr(mpd, t)(e) else: print("Missing drawings for", t) return fmap
[docs] def add_fieldmaps_to_axes( lattice, axes, bounds=None, sections="All", fields=["cavity", "solenoid"], include_labels=True, verbose=False, ): """ Adds fieldmaps to an axes. """ max_scale = 0 fmaps = load_elements( lattice, bounds=bounds, sections=sections, verbose=verbose, types=fields ) ax1 = axes ax1rhs = ax1.twinx() ax = [ax1, ax1rhs] ylabel = {"cavity": "$E_z$ (MV/m)", "solenoid": "$B_z$ (T)"} color = {"cavity": "green", "solenoid": "blue"} for i, section in enumerate(fields): a = ax[i] for name, data in fmaps[section].items(): label = f"{section}_{name}" c = color[section] # if section == 'cavity':# and not section == 'solenoid': if section == fields[0]: max_scale = ( max(abs(data[:, 1])) if max(abs(data[:, 1])) > max_scale else max_scale ) a.plot(*data.T, label=label, color=c) a.yaxis.label.set_color(c) a.set_ylabel(ylabel[section]) if len(fields) < 1: for a in ax: a.set_yticks([]) data = np.array([[0, 0], [100, 0]]) ax[0].plot(*data.T, color="black")
[docs] def add_magnets_to_axes( lattice, axes, bounds=None, sections="All", magnets=["quadrupole", "dipole", "sextupole", "beam_position_monitor", "screen"], include_labels=True, kinetic_energy=None, verbose=False, ): """ Adds magnets to an axes. """ max_scale = 0 fmaps = load_elements( lattice, bounds=bounds, sections=sections, verbose=verbose, types=magnets, scale=max_scale, kinetic_energy=kinetic_energy, ) ax1 = axes ax1rhs = ax1.twinx() ax = [ax1, ax1rhs] ylabel = { "dipole": r"$\theta$ (rad)", "quadrupole": "$K_n$ (T/m)", } # , "sextupole": "$K_2$ (T/$m^2$)"} axis = {"dipole": 0, "quadrupole": 1} color = { "dipole": "blue", "quadrupole": "red", "sextupole": "green", "beam_position_monitor": "purple", } for section, i in axis.items(): a = ax[i] c = color[section] a.set_ylabel(ylabel[section]) a.yaxis.label.set_color(c) for section in color.keys(): if section in fmaps: for name, (data, c) in fmaps[section].items(): a.fill(*data.T, color=c) data = np.array([[0, 0], [100, 0]]) ax[0].plot(*data.T, color="black") if bounds: ax1.set_xlim(bounds[0], bounds[1]) align.yaxes(ax[0], 0, ax[1], 0, 0.5)
[docs] def plot_fieldmaps( lattice, sections="All", include_labels=True, limits=None, figsize=(12, 4), fields=["cavity", "solenoid"], magnets=["quadrupole", "dipole", "beam_position_monitor", "screen"], **kwargs, ): """ Simple fieldmap plot """ fig, axes = plt.subplots(figsize=figsize, **kwargs) add_fieldmaps_to_axes( lattice, axes, bounds=limits, include_labels=include_labels, sections=sections, fields=fields, magnets=magnets, )
[docs] def plot( framework_object, ykeys=["sigma_x", "sigma_y"], ykeys2=["sigma_z"], xkey="z", limits=None, nice=True, include_layout=False, include_labels=True, include_legend=True, include_particles=False, fields=["cavity", "solenoid"], magnets=[ "quadrupole", "dipole", "beam_position_monitor", "screen", "wall_current_monitor", "aperture", ], grid=False, **kwargs, ): """ Plots stat output multiple keys. If a list of ykeys2 is given, these will be put on the right hand axis. This can also be given as a single key. Logical switches, all default to True: nice: a nice SI prefix and scaling will be used to make the numbers reasonably sized. include_legend: The plot will include the legend include_layout: the layout plot will be displayed at the bottom include_labels: the layout will include element labels. Copied almost verbatim from lume-impact's Impact.plot.plot_stats_with_layout """ twiss = framework_object.twiss # convenience twiss.sort() # sort before plotting! P = framework_object.beams if include_layout is not False: if "sharex" not in kwargs: kwargs["sharex"] = True fig, all_axis = plt.subplots( 3, gridspec_kw={"height_ratios": [4, 1, 1]}, subplot_kw=dict(frameon=False), **kwargs, ) plt.subplots_adjust(hspace=0.0) ax_field_layout = all_axis[1] ax_magnet_layout = all_axis[2] ax_plot = [all_axis[0]] else: fig, all_axis = plt.subplots(**kwargs) ax_plot = [all_axis] if grid: ax_plot[0].grid(visible=True, which="major", color="#666666", linestyle="-") # collect axes if isinstance(ykeys, str): ykeys = [ykeys] if ykeys2: if isinstance(ykeys2, str): ykeys2 = [ykeys2] ax_plot.append(ax_plot[0].twinx()) # No need for a legend if there is only one plot if len(ykeys) == 1 and not ykeys2: include_legend = False X = twiss.stat(xkey) if not isinstance(X, twissParameter): if xkey in list(twiss_defaults.keys()): X = twissParameter(val=X, **twiss_defaults[xkey]) else: X = twissParameter(val=X, name=xkey, unit="") # Only get the data we need if limits: good = np.logical_and(X.val >= limits[0], X.val <= limits[1]) idx = list(np.where(good is True)[0]) if len(idx) > 0: if idx[0] > 0: good[idx[0] - 1] = True if (idx[-1] + 1) < len(good): good[idx[-1] + 1] = True X = X[good] if X.val.min() > limits[0]: limits[0] = X.val.min() if X.val.max() < limits[1]: limits[1] = X.val.max() else: limits = X.val.min(), X.val.max() good = slice(None, None, None) # everything # Try particles within these bounds Pnames = [] X_particles = [] if include_particles: # try: for pname in range(len(P)): # Modified from Impact xp = np.mean(np.array(getattr(P[pname], xkey))) if xp >= limits[0] and xp <= limits[1]: Pnames.append(pname) X_particles.append(xp) X_particles = np.array(X_particles) # except: # Pnames = [] else: Pnames = [] # X axis scaling units_x = str(twiss.stat(xkey).unit) if nice: X.val, factor_x, prefix_x = nice_array(X.val) units_x = prefix_x + units_x else: factor_x = 1 # set all but the layout if include_layout is not False: ax_magnet_layout.set_xlim(limits[0] / factor_x, limits[1] / factor_x) ax_magnet_layout.set_xlabel(f"{xkey} ({units_x})") # Draw for Y1 and Y2 linestyles = ["solid", "dashed"] legend_labels = [] ii = -1 # counter for colors for ix, keys in enumerate([ykeys, ykeys2]): if not keys: continue ax = ax_plot[ix] ax.ticklabel_format(useOffset=False) linestyle = linestyles[ix] ulist = [0] for key in keys: Y = twiss.stat(key) if not isinstance(Y, twissParameter): if xkey in list(twiss_defaults.keys()): Y = twissParameter(val=Y, **twiss_defaults[key]) else: Y = twissParameter(val=Y, name=key, unit="") ulist.append(Y.unit) # Check that units are compatible # ulist = [twiss.stat(key).unit for key in keys] if len(ulist) > 1: for u2 in ulist[1:]: assert ulist[0] == u2, f"Incompatible units: {ulist[0]} and {u2}" # String Unit representation unit = str(ulist[0]) # Data data = [twiss.stat(key).val[good] for key in keys] # Labels labels = [twiss.stat(key).label for key in keys] if nice: factor, prefix = nice_scale_prefix(np.ptp(data)) unit = prefix + unit else: factor = 1 # Make a line and point for key, dat, label in zip(keys, data, labels): legend_labels.append("$" + label.replace("sigma", r"\sigma") + "$") ii += 1 color = "C" + str(ii) ax.plot( X.val, dat / factor, label=f"{label} ({unit})", color=color, linestyle=linestyle, ) # Particles if Pnames: # try: # print(Pnames, [key in P._parameters['data'] for key in Pnames]) Y_particles = np.array( [ ( np.std(getattr(P[name],key)) if key in P._parameters["data"] else getattr(P[name], key) ) for name in Pnames ] ) if not all (v is None for v in Y_particles): ax.scatter(X_particles / factor_x, Y_particles / factor, color=color) # except: # pass labels = ["$" + k.replace("sigma", r"\sigma") + "$" for k in labels] ax.set_ylabel(", ".join(labels) + f" ({unit})") # Collect legend if include_legend: lines = [] # labels = [] for ax in ax_plot: a, _ = ax.get_legend_handles_labels() lines += a # labels += b ax_plot[0].legend(lines, legend_labels, loc="best") # Layout if include_layout is not False: # Gives some space to the top plot # ax_layout.set_ylim(-1, 1.5) if xkey == "z": # ax_layout.set_axis_off() ax_field_layout.set_xlim(limits[0], limits[1]) ax_magnet_layout.set_xlim(limits[0], limits[1]) # else: # ax_layout.set_xlabel('mean_z') # limits = (0, I.stop) add_fieldmaps_to_axes( framework_object.framework, ax_field_layout, bounds=limits, include_labels=include_labels, fields=fields, ) add_magnets_to_axes( framework_object.framework, ax_magnet_layout, bounds=limits, include_labels=include_labels, magnets=magnets, kinetic_energy=list( zip(twiss.stat("z").val[good], twiss.stat("kinetic_energy").val[good]) ), ) return plt, fig, all_axis
[docs] def getattrsplit(self, attr): attrs = attr.split(".") for a in attrs: self = getattr(self, a) return self
[docs] def general_plot( framework_object, ykeys=[], ykeys2=[], xkey="z", limits=None, nice=True, include_layout=False, include_labels=True, include_legend=True, include_particles=False, fields=["cavity", "solenoid"], magnets=[ "quadrupole", "dipole", "beam_position_monitor", "screen", "wall_current_monitor", "aperture", ], grid=False, **kwargs, ): if include_layout is not False: fig, all_axis = plt.subplots(2, gridspec_kw={"height_ratios": [4, 1]}, **kwargs) ax_layout = all_axis[-1] ax_plot = [all_axis[0]] else: fig, all_axis = plt.subplots(**kwargs) ax_plot = [all_axis] if grid: ax_plot[0].grid(b=True, which="major", color="#666666", linestyle="-") # collect axes if isinstance(ykeys, str): ykeys = [ykeys] if ykeys2: if isinstance(ykeys2, str): ykeys2 = [ykeys2] ax_plot.append(ax_plot[0].twinx()) # Ensure we are using numpy arrays xdata = getattrsplit(framework_object.twiss, xkey) ydata = [getattrsplit(framework_object.twiss, y) for y in ykeys] ydata2 = [getattrsplit(framework_object.twiss, y) for y in ykeys2] # Split keys xkey = xkey.split(".")[-1] ykeys = [yk.split(".")[-1] for yk in ykeys] ykeys2 = [yk.split(".")[-1] for yk in ykeys2] # No need for a legend if there is only one plot if len(ydata) == 1 and not ydata2: include_legend = False X = xdata # Only get the data we need if limits: good = np.logical_and(X >= limits[0], X <= limits[1]) idx = list(np.where(good is True)[0]) if len(idx) > 0: if idx[0] > 0: good[idx[0] - 1] = True if (idx[-1] + 1) < len(good): good[idx[-1] + 1] = True X = X[good] if X.min() > limits[0]: limits[0] = X.min() if X.max() < limits[1]: limits[1] = X.max() else: limits = X.min(), X.max() good = slice(None, None, None) # everything # X axis scaling units_x = xdata.unit if nice: X, factor_x, prefix_x = nice_array(X.val) units_x = prefix_x + units_x else: factor_x = 1 # set all but the layout for ax in ax_plot: ax.set_xlim(limits[0] / factor_x, limits[1] / factor_x) ax.set_xlabel(f"{xkey} ({units_x})") # Draw for Y1 and Y2 linestyles = ["solid", "dashed"] ii = -1 # counter for colors for ix, (d, keys) in enumerate([[ydata, ykeys], [ydata2, ykeys2]]): if not keys: continue ax = ax_plot[ix] linestyle = linestyles[ix] # Check that units are compatible ulist = [dat.unit for dat in d] if len(ulist) > 1: for u2 in ulist[1:]: assert ulist[0] == u2, f"Incompatible units: {ulist[0]} and {u2}" # String representation unit = str(ulist[0]) # Data data = [key.val[good] for key in d] if nice: factor, prefix = nice_scale_prefix(np.ptp(data)) unit = prefix + unit else: factor = 1 # Make a line and point keys = ["$" + k.replace("sigma", r"\sigma") + "$" for k in keys] for key, dat in zip(keys, data): # ii += 1 color = "C" + str(ii) ax.plot( X, dat / factor, label=f"{key} ({unit})", color=color, linestyle=linestyle, ) ax.set_ylabel(", ".join(keys) + f" ({unit})") # Collect legend if include_legend: lines = [] labels = [] for ax in ax_plot: a, b = ax.get_legend_handles_labels() lines += a labels += b ax_plot[0].legend(lines, labels, loc="best") # Layout if include_layout is not False: # Gives some space to the top plot # ax_layout.set_ylim(-1, 1.5) if xkey == "z": # ax_layout.set_axis_off() ax_layout.set_xlim(limits[0], limits[1]) # else: # ax_layout.set_xlabel('mean_z') # limits = (0, I.stop) add_fieldmaps_to_axes( framework_object.framework, ax_layout, bounds=limits, include_labels=include_labels, fields=fields, # magnets=magnets, ) return plt, fig, all_axis