import os
import sys
from io import StringIO
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.transforms import Bbox
# plt.rcParams["axes.axisbelow"] = False
from copy import copy
try:
from ..units import nice_array, nice_scale_prefix
except:
pass
try:
from fastkde import fastKDE
fastKDE_installed = True
except ImportError as e:
print("fastKDE missing - plotScreenImage will use SciPy")
fastKDE_installed = False
try:
from scipy import stats
SciPy_installed = True
except:
SciPy_installed = False
CMAP0 = copy(plt.get_cmap("viridis"))
CMAP0.set_under("white")
CMAP1 = copy(plt.get_cmap("plasma"))
# beamobject = rbf.beam()
[docs]
def density_plot(
particle_group,
key="x",
bins=None,
filename=None,
**kwargs,
):
"""
1D density plot. Also see: marginal_plot
Example:
density_plot(P, 'x', bins=100)
"""
if not bins:
n = len(particle_group)
bins = int(n / 100)
# Scale to nice units and get the factor, unit prefix
x, f1, p1 = nice_array(getattr(particle_group, key))
if key != "charge":
w = abs(particle_group.charge)
else:
w = np.ones(len(getattr(particle_group, key)))
u1 = "" # particle_group.units(key).unitSymbol
ux = p1 + u1
labelx = f"{key} ({ux})"
fig, ax = plt.subplots(**kwargs)
hist, bin_edges = np.histogram(x, bins=bins, weights=w)
hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
hist_width = np.diff(bin_edges)
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
ax.bar(hist_x, hist_y, hist_width, color="grey")
# Special label for C/s = A
if u1 == "s":
_, hist_prefix = nice_scale_prefix(hist_f / f1)
ax.set_ylabel(f"{hist_prefix}A")
else:
ax.set_ylabel(f"{hist_prefix}C/{ux}")
ax.set_xlabel(labelx)
if isinstance(filename, str):
plt.savefig(filename)
[docs]
def slice_plot(
particle_group,
xkey="t",
ykey="slice_current",
xlim=None,
nice=True,
include_legend=True,
subtract_mean=True,
bins=None,
filename=None,
**kwargs,
):
"""
slice plot. Also see: marginal_plot
Example:
slice plot(P, 'slice_current', bins=100)
"""
P = particle_group
fig, all_axis = plt.subplots(**kwargs)
ax_plot = [all_axis]
if not bins:
n = len(particle_group)
bins = int(n / 100)
P.slice.slices = bins
X = getattr(P.slice, "slice_" + xkey)
if subtract_mean:
X = X - np.mean(X)
if isinstance(ykey, str):
ykey = [ykey]
if not isinstance(ykey, (list, tuple)):
ykey = [ykey]
if len(ykey) == 1:
include_legend = False
# Only get the data we need
if xlim:
good = np.logical_and(X >= xlim[0], X <= xlim[1])
X = X[good]
else:
xlim = X.min(), X.max()
good = slice(None, None, None) # everything
# X axis scaling
units_x = "s" # str(P.units(xkey))
if nice:
X, factor_x, prefix_x = nice_array(X)
units_x = prefix_x + units_x
else:
factor_x = 1
# set all but the layout
for ax in ax_plot:
ax.set_xlim(xlim[0] / factor_x, xlim[1] / factor_x)
ax.set_xlabel(f"{xkey} ({units_x})")
# Draw for Y1 and Y2
linestyles = ["solid", "dashed"]
ii = -1 # counter for colors
for ix, keys in enumerate([ykey]):
if not keys:
continue
ax = ax_plot[ix]
linestyle = linestyles[ix]
# Check that units are compatible
ulist = [getattr(P.slice, key).units for key in keys] # [I.units(key) for key in keys]
if len(ulist) > 1:
for u2 in ulist[1:]:
assert ulist[0] == u2, f"Incompatible units: {ulist[0]} and {u2}"
# String representation
unit = str(ulist[0])
# Data
data = [np.array(getattr(P.slice, key)[good]) for key in keys]
if nice:
factor, prefix = nice_scale_prefix(np.ptp(data))
unit = prefix + unit
else:
factor = 1
# Make a line and point
for key, dat in zip(keys, data):
#
ii += 1
color = "C" + str(ii)
ax.plot(
X,
dat / factor,
label=f"{key} ({unit})",
color=color,
linestyle=linestyle,
)
ax.set_ylabel(", ".join(keys) + f" ({unit})")
# Collect legend
if include_legend:
lines = []
labels = []
for ax in ax_plot:
a, b = ax.get_legend_handles_labels()
lines += a
labels += b
ax_plot[0].legend(lines, labels, loc="best")
if isinstance(filename, str):
plt.savefig(filename)
[docs]
def marginal_plot(
particle_group,
key1="t",
key2="p",
bins=None,
units=["", ""],
scale=[1, 1],
subtract_mean=[False, False],
cmap=None,
limits=None,
filename=None,
**kwargs,
):
"""
Density plot and projections
Example:
marginal_plot(P, 't', 'energy', bins=200)
"""
if not bins:
n = len(particle_group)
bins = int(np.sqrt(n / 2))
cmap = CMAP0 if cmap is None else cmap
if not isinstance(subtract_mean, (list, tuple)):
subtract_mean = [subtract_mean, subtract_mean]
if not isinstance(scale, (list, tuple)):
scale = [scale, scale]
# Scale to nice units and get the factor, unit prefix
x, f1, p1 = nice_array(
scale[0]
* (getattr(particle_group, key1) - subtract_mean[0] * np.mean(getattr(particle_group, key1)))
)
y, f2, p2 = nice_array(
scale[1]
* (getattr(particle_group, key2) - subtract_mean[1] * np.mean(getattr(particle_group, key2)))
)
x = x / scale[0]
y = y / scale[1]
w = np.full(len(x), 1) #
charge = getattr(particle_group, "charge")
u1, u2 = [getattr(particle_group, k).units for k in [key1, key2]]
ux = p1 + u1
uy = p2 + u2
labelx = f"{key1} ({ux})"
labely = f"{key2} ({uy})"
fig = plt.figure(**kwargs)
gs = GridSpec(4, 4)
ax_joint = fig.add_subplot(gs[1:4, 0:3])
ax_marg_x = fig.add_subplot(gs[0, 0:3])
ax_marg_y = fig.add_subplot(gs[1:4, 3])
# ax_info = fig.add_subplot(gs[0, 3:4])
# ax_info.table(cellText=['a'])
# Proper weighting
ax_joint.hexbin(
x, y, C=w, reduce_C_function=np.sum, gridsize=bins, cmap=cmap, vmin=1e-20
)
if limits is not None:
ax_joint.axis(limits)
# Manual histogramming version
# H, xedges, yedges = np.histogram2d(x, y, weights=w, bins=bins)
# extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
# ax_joint.imshow(H.T, cmap=cmap, vmin=1e-16, origin='lower', extent=extent, aspect='auto')
# Top histogram
# Old method:
# dx = x.ptp()/bins
# ax_marg_x.hist(x, weights=w/dx/f1, bins=bins, color='gray')
hist, bin_edges = np.histogram(x, bins=bins, weights=w)
hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
hist_width = np.diff(bin_edges)
# Special label for C/s = A
if u1 == "s" and abs(np.sum(charge).val) > 0:
hist_y, hist_f, hist_prefix = nice_array(
-np.sum(charge).val * hist / hist_width / len(charge)
)
ax_marg_x.bar(hist_x, hist_y, hist_width, color="gray")
_, hist_prefix = nice_scale_prefix(hist_f / f1)
# print(np.sum(charge).val, hist_f, f1)
ax_marg_x.set_ylabel(f"{hist_prefix}A")
else:
if abs(np.sum(charge).val) > 0:
hist_y, hist_f, hist_prefix = nice_array(
-np.sum(charge).val * hist / hist_width / len(charge)
)
ax_marg_x.bar(hist_x, hist_y, hist_width, color="gray")
ax_marg_x.set_ylabel(f"{hist_prefix}C/{uy}")
else:
hist_y, hist_f, hist_prefix = nice_array(hist)
ax_marg_x.bar(hist_x, hist_y, hist_width, color="gray")
ax_marg_x.set_ylabel(f"{hist_prefix}Counts/{uy}")
if not limits is None:
ax_marg_x.set_xlim(limits[0:2])
# Side histogram
# Old method:
# dy = y.ptp()/bins
# ax_marg_y.hist(y, orientation="horizontal", weights=w/dy, bins=bins, color='gray')
hist, bin_edges = np.histogram(y, bins=bins, weights=w)
hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
hist_width = np.diff(bin_edges)
if u1 == "s" and abs(np.sum(charge).val) > 0:
hist_y, hist_f, hist_prefix = nice_array(
-np.sum(charge).val * hist / hist_width / len(charge)
)
ax_marg_y.barh(hist_x, hist_y, hist_width, color="gray")
ax_marg_y.set_xlabel(f"{hist_prefix}C/{uy}")
else:
if abs(np.sum(charge).val) > 0:
hist_y, hist_f, hist_prefix = nice_array(
-np.sum(charge).val * hist / hist_width / len(charge)
)
ax_marg_y.barh(hist_x, hist_y, hist_width, color="gray")
ax_marg_y.set_xlabel(f"{hist_prefix}C/{uy}")
else:
hist_y, hist_f, hist_prefix = nice_array(hist)
ax_marg_y.barh(hist_x, hist_y, hist_width, color="gray")
ax_marg_y.set_xlabel(f"{hist_prefix}Counts/{uy}")
if not limits is None:
ax_marg_y.set_ylim(limits[2:])
# Turn off tick labels on marginals
plt.setp(ax_marg_x.get_xticklabels(), visible=False)
plt.setp(ax_marg_y.get_yticklabels(), visible=False)
# Set labels on joint
ax_joint.set_xlabel(labelx)
ax_joint.set_ylabel(labely)
if isinstance(filename, str):
plt.savefig(filename)
[docs]
def plot(self, keys=None, bins=None, type="density", **kwargs):
if keys is not None and (
(isinstance(keys, (list, tuple)) and len(keys) == 1) or isinstance(keys, str)
):
if isinstance(keys, (list, tuple)):
ykey = keys[0]
if type == "slice" or "slice_" in ykey:
return slice_plot(self, ykey=ykey, bins=bins, **kwargs)
elif type == "density":
return density_plot(self, key=ykey, bins=bins, **kwargs)
else:
xkey, ykey = keys
return marginal_plot(self, key1=xkey, key2=ykey, bins=bins, **kwargs)
[docs]
def plotScreenImage(
beam,
keys=["x", "y"],
scale=[1, 1],
iscale=1,
colormap=plt.cm.jet,
size=None,
grid=False,
marginals=False,
limits=None,
screen=False,
use_scipy=False,
subtract_mean=[False, False],
title="",
filename=None,
**kwargs,
):
# Do the self-consistent density estimate
key1, key2 = keys
if not isinstance(subtract_mean, (list, tuple)):
subtract_mean = [subtract_mean, subtract_mean]
if not isinstance(scale, (list, tuple)):
scale = [scale, scale]
if not isinstance(size, (list, tuple)):
size = [size, size]
x, f1, p1 = nice_array(
scale[0] * (getattr(beam, key1) - subtract_mean[0] * np.mean(getattr(beam, key1)))
)
y, f2, p2 = nice_array(
scale[1] * (getattr(beam, key2) - subtract_mean[1] * np.mean(getattr(beam, key2)))
)
u1, u2 = [getattr(beam, k).units for k in keys]
ux = p1 + u1
uy = p2 + u2
labelx = f"{key1} ({ux})"
labely = f"{key2} ({uy})"
if fastKDE_installed and not use_scipy:
if "subtract_mean" in kwargs:
kwargs.pop("subtract_mean")
myPDF, axes = fastKDE.pdf(x, y, use_xarray=False, **kwargs)
v1, v2 = axes
elif SciPy_installed:
xmin = x.min()
xmax = x.max()
ymin = y.min()
ymax = y.max()
v1, v2 = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([v1.ravel(), v2.ravel()])
values = np.vstack([x, y])
kernel = stats.gaussian_kde(values)
myPDF = np.reshape(kernel(positions).T, v1.shape)
else:
raise Exception("fastKDE or SciPy required")
# normalise the PDF to 1
myPDF = myPDF / myPDF.max() * iscale
# Initialise the plot objects
# start with a square Figure
# Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
# the size of the marginal axes and the main axes in both directions.
# Also adjust the subplot parameters for a square plot.
if marginals:
fig = plt.figure(figsize=(12.41, 12.41))
gs = fig.add_gridspec(
2,
2,
width_ratios=(8, 2),
height_ratios=(2, 8),
left=0.1,
right=0.9,
bottom=0.1,
top=0.95,
wspace=0.05,
hspace=0.05,
)
ax = fig.add_subplot(gs[1, 0])
ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)
else:
fig = plt.figure(figsize=(10, 10))
fig.subplots_adjust(top=0.95)
ax = fig.add_subplot()
# Define ticks
# Major ticks every 5, minor ticks every 1
if size[0] is None:
use_size = False
if not screen:
xmin, xmax = [min(v1), max(v1)]
ymin, ymax = [min(v2), max(v2)]
size = [xmax - xmin, ymax - ymin]
else:
xmin, xmax = -15, 15
ymin, ymax = -15, 15
size = [15, 15]
minvalx = xmin
maxvalx = xmax
meanvalx = (xmin + xmax) / 2.0 if not subtract_mean[0] else 0
minvaly = ymin
maxvaly = ymax
meanvaly = (ymin + ymax) / 2.0 if not subtract_mean[1] else 0
else:
use_size = True
maxvalx = size[0] / f1
minvalx = -maxvalx
meanvalx = (max(v1) + min(v1)) / 2.0 if not subtract_mean[0] else 0
maxvaly = size[1] / f2
minvaly = -maxvaly
meanvaly = (max(v2) + min(v2)) / 2.0 if not subtract_mean[1] else 0
size[0] = size[0] / f1
size[1] = size[1] / f2
# print(meanvaly, minvaly, maxvaly)
# major_ticksx = meanvalx + np.arange(
# minvalx, maxvalx + (maxvalx - minvalx) / 100, (maxvalx - minvalx) / 4
# )
# minor_ticksx = meanvalx + np.arange(
# minvalx, maxvalx + (maxvalx - minvalx) / 100, (maxvalx - minvalx) / 40
# )
# ax.set_xticks(major_ticksx)
# ax.set_xticks(minor_ticksx, minor=True)
# major_ticksy = meanvaly + np.arange(
# minvaly, maxvaly + (maxvaly - minvaly) / 100, (maxvaly - minvaly) / 4
# )
# minor_ticksy = meanvaly + np.arange(
# minvaly, maxvaly + (maxvaly - minvaly) / 100, (maxvaly - minvaly) / 40
# )
# # print(minvaly, maxvaly, meanvaly, major_ticksy)
# ax.set_yticks(major_ticksy)
# ax.set_yticks(minor_ticksy, minor=True)
if marginals:
hist, bin_edges = myPDF.sum(axis=0)[:-1], v1
hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
hist_width = np.diff(bin_edges)
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
ax_histx.bar(hist_x, hist_y, hist_width, color=colormap(hist_y / max(hist_y)))
hist, bin_edges = myPDF.sum(axis=1)[:-1], v2
hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
hist_width = np.diff(bin_edges)
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
ax_histy.barh(hist_x, hist_y, hist_width, color=colormap(hist_y / max(hist_y)))
# Make a circle for the edges of the screen
if screen:
draw_circle = plt.Circle(
(meanvalx, meanvaly),
size + [0.05, 0.05],
fill=True,
ec="w",
fc=colormap(0),
zorder=-1,
)
ax.add_artist(draw_circle)
if screen:
ax.set_facecolor("k")
else:
ax.set_facecolor(colormap(0))
# Make a circle to clip the PDF
if screen:
circ = plt.Circle((meanvalx, meanvaly), max(size), facecolor="none")
else:
circ = plt.Circle((meanvalx, meanvaly), 3 * max(size), facecolor="none")
# ax.add_patch(circ) # Plot the outline
# Plot the PDF
if grid:
# Add a grid
ax.grid(which="minor", color="w", alpha=0.3, clip_path=circ)
ax.grid(which="major", color="w", alpha=0.55, clip_path=circ)
# Set the image limits to slightly larger than the screen size
if limits:
if isinstance(limits, (int, float)):
limits = (-limits, limits)
if np.array(limits).shape == (2, 2):
ax.set_xlim(limits[0])
ax.set_ylim(limits[1])
bbox = plt.Rectangle(
(min(limits[0]), min(limits[1])),
max(limits[0]) - min(limits[0]),
max(limits[1]) - min(limits[1]),
facecolor="none",
edgecolor="none",
)
elif np.array(limits).shape == (2,):
ax.set_xlim(limits)
ax.set_ylim(limits)
# make a bounding box for the limits
bbox = plt.Rectangle(
(min(limits), min(limits)),
max(limits) - min(limits),
max(limits) - min(limits),
facecolor="none",
edgecolor="none",
)
elif screen or use_size:
ax.set_xlim([meanvalx - (size[0] + 0.5), meanvalx + (size[0] + 0.5)])
ax.set_ylim([meanvaly - (size[1] + 0.5), meanvaly + (size[1] + 0.5)])
bbox = plt.Rectangle(
(-(size[0] + 0.5), -(size[1] + 0.5)),
size[0] + 0,
size[1] + 0,
facecolor="none",
edgecolor="none",
)
else:
ax.set_xlim([min(v1), max(v1)])
ax.set_ylim([min(v2), max(v2)])
bbox = plt.Polygon(
[
(min(v1), min(v2)),
(min(v1), max(v2)),
(max(v1), max(v2)),
(max(v1), min(v2)),
],
facecolor="none",
edgecolor="none",
)
# ax.add_artist(bbox)
mesh = ax.pcolormesh(
v1, v2, myPDF, cmap=colormap, zorder=1, shading="auto"
) # , clip_path=bbox)
if screen:
mesh.set_clip_path(circ)
if marginals:
plt.setp(ax_histx.get_xticklabels(), visible=False)
plt.setp(ax_histy.get_yticklabels(), visible=False)
# ax_histy.set_ylim([-(size + 0.5), (size + 0.5)])
ax.set_xlabel(labelx)
ax.set_ylabel(labely)
# Extract the screen name
file, ext = os.path.splitext(os.path.basename(beam.filename))
# Set the screen name as the title
if title == "":
plt.suptitle(file)
else:
plt.suptitle(title)
# Show the final image
# plt.draw()
if isinstance(filename, str):
plt.savefig(filename)