This is the multi-page printable view of this section. Click here to print.

Return to the regular view of this page.

ART API

User-facing package, contains mostly visualisation and analysis tools.

The API may be updated without warning

API stuff TODO

1 - ModuleAnalysis

Adds various analysis methods.

Contains the following functions: - GetETransmission: Calculates the energy transmission from RayListIn to RayListOut in percent. - GetResultSummary: Calculate and return FocalSpotSize-standard-deviation and Duration-standard-deviation. - GetPulseProfile: Retrieve the pulse profile from the given RayList and Detector. - GetNumericalAperture: Returns the numerical aperture associated with the supplied ray-bundle. - GetAiryRadius: Returns the radius of the Airy disk. - GetPlaneWaveFocus: Calculates the approximate polychromatic focal spot of a set of rays. - GetDiffractionFocus: Calculates the approximate polychromatic focal spot of a set of rays. - GetClosestSphere: Calculates the closest sphere to the surface of a mirror. - GetAsphericity: Calculates the maximum distance of the mirror surface to the closest sphere. - _best_fit_sphere: Calculates the best sphere to fit a set of points.

Adds the following methods: - OpticalChain: - getETransmission: Calculates the energy transmission from the input to the output of the OpticalChain. - getResultsSummary: Calculate and return FocalSpotSize-standard-deviation and Duration-standard-deviation. - getPulseProfile: Retrieve the pulse profile from the output of the OpticalChain. - getPlaneWaveFocus: Calculates the approximate polychromatic focal spot of the output of the OpticalChain. - getDiffractionFocus: Calculates the approximate polychromatic focal spot of the output of the OpticalChain. - Mirror: - getClosestSphere: Calculates the closest sphere to the surface of a mirror. - getAsphericity: Calculates the maximum distance of the mirror surface to the closest sphere.

Created in July 2024

@author: André Kalouguine + Stefan Haessler + Anthony Guillaume

  1"""
  2Adds various analysis methods.
  3
  4Contains the following functions:
  5    - GetETransmission: Calculates the energy transmission from RayListIn to RayListOut in percent.
  6    - GetResultSummary: Calculate and return FocalSpotSize-standard-deviation and Duration-standard-deviation.
  7    - GetPulseProfile: Retrieve the pulse profile from the given RayList and Detector.
  8    - GetNumericalAperture: Returns the numerical aperture associated with the supplied ray-bundle.
  9    - GetAiryRadius: Returns the radius of the Airy disk.
 10    - GetPlaneWaveFocus: Calculates the approximate polychromatic focal spot of a set of rays.
 11    - GetDiffractionFocus: Calculates the approximate polychromatic focal spot of a set of rays.
 12    - GetClosestSphere: Calculates the closest sphere to the surface of a mirror.
 13    - GetAsphericity: Calculates the maximum distance of the mirror surface to the closest sphere.
 14    - _best_fit_sphere: Calculates the best sphere to fit a set of points.
 15
 16Adds the following methods:
 17    - OpticalChain:
 18        - getETransmission: Calculates the energy transmission from the input to the output of the OpticalChain.
 19        - getResultsSummary: Calculate and return FocalSpotSize-standard-deviation and Duration-standard-deviation.
 20        - getPulseProfile: Retrieve the pulse profile from the output of the OpticalChain.
 21        - getPlaneWaveFocus: Calculates the approximate polychromatic focal spot of the output of the OpticalChain.
 22        - getDiffractionFocus: Calculates the approximate polychromatic focal spot of the output of the OpticalChain.
 23    - Mirror:
 24        - getClosestSphere: Calculates the closest sphere to the surface of a mirror.
 25        - getAsphericity: Calculates the maximum distance of the mirror surface to the closest sphere.
 26
 27Created in July 2024
 28
 29@author: André Kalouguine + Stefan Haessler + Anthony Guillaume
 30"""
 31# %% Module imports
 32import ARTcore.ModuleGeometry as mgeo
 33import ARTcore.ModuleOpticalRay as mray
 34import ARTcore.ModuleProcessing as mp
 35import ARTcore.ModuleOpticalChain as moc
 36import ARTcore.ModuleMirror as mmirror
 37import ART.ModulePlottingMethods as mpm
 38import matplotlib.pyplot as plt
 39from mpl_toolkits.axes_grid1 import make_axes_locatable
 40import numpy as np
 41import math
 42
 43LightSpeed = 299792458000
 44
 45# %% Analysis methods
 46def GetETransmission(RayListIn, RayListOut) -> float:
 47    """
 48    Calculates the energy transmission from RayListIn to RayListOut in percent by summing up the
 49    intensity-property of the individual rays.
 50
 51    Parameters
 52    ----------
 53        RayListIn : list(Ray)
 54            List of incoming rays.
 55        
 56        RayListOut : list(Ray)
 57            List of outgoing rays.
 58
 59    Returns
 60    -------
 61        ETransmission : float
 62    """
 63    ETransmission = 100 * sum(Ray.intensity for Ray in RayListOut) / sum(Ray.intensity for Ray in RayListIn)
 64    return ETransmission
 65
 66def _getETransmission(OpticalChain, IndexIn=0, IndexOut=-1) -> float:
 67    """
 68    Calculates the energy transmission from the input to the output of the OpticalChain in percent.
 69
 70    Parameters
 71    ----------
 72        OpticalChain : OpticalChain
 73            An object of the ModuleOpticalChain.OpticalChain-class.
 74
 75        IndexIn : int, optional
 76            Index of the input RayList in the OpticalChain, defaults to 0.
 77
 78        IndexOut : int, optional
 79            Index of the output RayList in the OpticalChain, defaults to -1.
 80
 81    Returns
 82    -------
 83        ETransmission : float
 84    """
 85    Rays = OpticalChain.get_output_rays()
 86    if IndexIn == 0:
 87        RayListIn = OpticalChain.get_input_rays()
 88    else:
 89        RayListIn = Rays[IndexIn]
 90    RayListOut = Rays[IndexOut]
 91    ETransmission = GetETransmission(RayListIn, RayListOut)
 92    return ETransmission
 93
 94moc.OpticalChain.getETransmission = _getETransmission
 95
 96def GetResultSummary(Detector, RayListAnalysed, verbose=False):
 97    """
 98    Calculate and return FocalSpotSize-standard-deviation and Duration-standard-deviation
 99    for the given Detector and RayList.
100    If verbose, then also print a summary of the results for the given Detector.
101
102    Parameters
103    ----------
104        Detector : Detector
105            An object of the ModuleDetector.Detector-class.
106
107        RayListAnalysed : list(Ray)
108            List of objects of the ModuleOpticalRay.Ray-class.
109
110        verbose : bool
111            Whether to print a result summary.
112
113    Returns
114    -------
115        FocalSpotSizeSD : float
116
117        DurationSD : float
118    """
119    DetectorPointList2DCentre = Detector.get_2D_points(RayListAnalysed)
120    FocalSpotSizeSD = mp.StandardDeviation(DetectorPointList2DCentre)
121    DelayList = Detector.get_Delays(RayListAnalysed)
122    DurationSD = mp.StandardDeviation(DelayList)
123
124    if verbose:
125        FocalSpotSize = mgeo.DiameterPointList(DetectorPointList2DCentre)
126        summarystring = (
127            "At the detector distance of "
128            + "{:.3f}".format(Detector.get_distance())
129            + " mm we get:\n"
130            + "Spatial std : "
131            + "{:.3f}".format(FocalSpotSizeSD * 1e3)
132            + " \u03BCm and min-max: "
133            + "{:.3f}".format(FocalSpotSize * 1e3)
134            + " \u03BCm\n"
135            + "Temporal std : "
136            + "{:.3e}".format(DurationSD)
137            + " fs and min-max : "
138            + "{:.3e}".format(max(DelayList) - min(DelayList))
139            + " fs"
140        )
141
142        print(summarystring)
143
144    return FocalSpotSizeSD, DurationSD
145
146def _getResultsSummary(OpticalChain, Detector = "Focus", verbose=False):
147    """
148    Calculate and return FocalSpotSize-standard-deviation and Duration-standard-deviation
149    for the given Detector and RayList.
150    If verbose, then also print a summary of the results for the given Detector.
151
152    Parameters
153    ----------
154        OpticalChain : OpticalChain
155            An object of the ModuleOpticalChain.OpticalChain-class.
156
157        Detector : Detector or str, optional
158            An object of the ModuleDetector.Detector-class or "Focus" to use the focus detector, defaults to "Focus".
159
160        verbose : bool
161            Whether to print a result summary.
162
163    Returns
164    -------
165        FocalSpotSizeSD : float
166
167        DurationSD : float
168    """
169    if isinstance(Detector, str):
170        Detector = OpticalChain.detectors[Detector]
171    Index = Detector.index
172    RayListAnalysed = OpticalChain.get_output_rays()[Index]
173    FocalSpotSizeSD, DurationSD = GetResultSummary(Detector, RayListAnalysed, verbose)
174    return FocalSpotSizeSD, DurationSD
175
176moc.OpticalChain.getResultsSummary = _getResultsSummary
177
178def GetPulseProfile(Detector, RayList, Nbins=100):
179    """
180    Retrieve the pulse profile from the given RayList and Detector.
181    The pulse profile is calculated by binning the delays of the rays in the RayList
182    at the Detector. The intensity of the rays is also taken into account.
183
184    Parameters
185    ----------
186        Detector : Detector
187            An object of the ModuleDetector.Detector-class.
188
189        RayList : list(Ray)
190            List of objects of the ModuleOpticalRay.Ray-class.
191
192        Nbins : int
193            Number of bins to use for the histogram.
194
195    Returns
196    -------
197        time : numpy.ndarray
198            The bins of the histogram.
199
200        intensity : numpy.ndarray
201            The histogram values.
202    """
203    DelayList = Detector.get_Delays(RayList)
204    IntensityList = [Ray.intensity for Ray in RayList]
205    intensity, time_edges = np.histogram(DelayList, bins=Nbins, weights=IntensityList)
206    time = 0.5 * (time_edges[1:] + time_edges[:-1])
207    return time, intensity
208
209def _getPulseProfile(OpticalChain, Detector = "Focus", Nbins=100):
210    """
211    Retrieve the pulse profile from the output of the OpticalChain.
212    The pulse profile is calculated by binning the delays of the rays in the RayList
213    at the Detector. The intensity of the rays is also taken into account.
214
215    Parameters
216    ----------
217        OpticalChain : OpticalChain
218            An object of the ModuleOpticalChain.OpticalChain-class.
219
220        Detector : Detector or str, optional
221            An object of the ModuleDetector.Detector-class or "Focus" to use the focus detector, defaults to "Focus".
222
223        Nbins : int
224            Number of bins to use for the histogram.
225
226    Returns
227    -------
228        time : numpy.ndarray
229            The bins of the histogram.
230
231        intensity : numpy.ndarray
232            The histogram values.
233    """
234    if isinstance(Detector, str):
235        Detector = OpticalChain.detectors[Detector]
236    Index = Detector.index
237    RayList = OpticalChain.get_output_rays()[Index]
238    time, intensity = GetPulseProfile(Detector, RayList, Nbins)
239    return time, intensity
240
241
242def GetNumericalAperture(RayList: list[mray.Ray], RefractiveIndex: float = 1) -> float:
243    r"""
244    Returns the numerical aperture associated with the supplied ray-bundle 'Raylist'.
245    This is $n\sin\theta$, where $\theta$ is the maximum angle between any of the rays and the central ray,
246    and $n$ is the refractive index of the propagation medium.
247
248    Parameters
249    ----------
250        RayList : list of Ray-object
251            The ray-bundle of which to determine the NA.
252
253        RefractiveIndex : float, optional
254            Refractive index of the propagation medium, defaults to =1.
255
256    Returns
257    -------
258        NA : float
259    """
260    CentralRay = mp.FindCentralRay(RayList)
261    if CentralRay is None:
262        CentralVector = np.array([0, 0, 0])
263        for k in RayList:
264            CentralVector = CentralVector + k.vector
265        CentralVector = CentralVector / len(RayList)
266    else:
267        CentralVector = CentralRay.vector
268    ListAngleAperture = []
269    for k in RayList:
270        ListAngleAperture.append(mgeo.AngleBetweenTwoVectors(CentralVector, k.vector))
271
272    return np.sin(np.amax(ListAngleAperture)) * RefractiveIndex
273
274
275def GetAiryRadius(Wavelength: float, NumericalAperture: float) -> float:
276    r"""
277    Returns the radius of the Airy disk: $r = 1.22 \frac\{\lambda\}\{NA\}$,
278    i.e. the diffraction-limited radius of the focal spot corresponding to a given
279    numerical aperture $NA$ and a light wavelength $\lambda$.
280
281    For very small $NA<10^\{-3\}$, diffraction effects becomes negligible and the Airy Radius becomes meaningless,
282    so in that case, a radius of 0 is returned.
283
284    Parameters
285    ----------
286        Wavelength : float
287            Light wavelength in mm.
288
289        NumericalAperture : float
290
291    Returns
292    -------
293        AiryRadius : float
294    """
295    if NumericalAperture > 1e-3 and Wavelength is not None:
296        return 1.22 * 0.5 * Wavelength / NumericalAperture
297    else:
298        return 0  # for very small numerical apertures, diffraction effects becomes negligible and the Airy Radius becomes meaningless
299
300
301def GetPlaneWaveFocus(OpticalChain, Detector = "Focus", size=None, Nrays=1000, resolution=100):
302    """
303    This function calculates the approximate polychromatic focal spot of a set of rays.
304    To do so, it positions itself in the detector plane and samples a square area.
305    If the size of the area is not given, it will use twice the Airy radius of the system with the largest wavbelength.
306    Otherwise it uses the size.
307    The resolution of the sampling is given by the resolution parameter. So it samples on a grid resolution x resolution.
308
309    To calculate the intensity, it takes Nrays out of the raylist (to subsample if needed).
310    It assimilates every ray to a plane wave, so it calculates a k-vector for each ray (taking into account the wavelength of the ray).
311    It calculates the phase of each ray from the delay and from the intersection position with the detector.
312    The delay from the non-central position is simply sin(alpha)*distance/c, where alpha is the angle between the ray and the normal to the detector
313    and distance is the distance between the intersection point and the central point of the detector. 
314    It then calculates the intensity at each point of the grid by summing the intensity of each plane wave.
315
316    As long as there are not too many rays, this method is faster than doing an FFT.
317    It's also naturally polychromatic.
318    """
319    if isinstance(Detector, str):
320        Detector = OpticalChain.detectors[Detector]
321    Index = Detector.index
322    RayList = OpticalChain.get_output_rays()[Index]
323    if size is None:
324        Wavelengths = [Ray.wavelength for Ray in RayList]
325        Wavelength = max(Wavelengths)
326        NumericalAperture = GetNumericalAperture(RayList)
327        size = 3 * GetAiryRadius(Wavelength, NumericalAperture)
328    X = np.linspace(-size / 2, size / 2, resolution)
329    Y = np.linspace(-size / 2, size / 2, resolution)
330    X, Y = np.meshgrid(X, Y)
331    # We pick Nrays. if there are less than Nrays, we take all of them.
332    PickedRaysGlobal = np.random.choice(RayList, min(Nrays, len(RayList)), replace=False)
333    # We calculate the k-vector of each ray, taking into account the wavelength of the ray.
334    # The units should be in mm^-1
335    PickedRays = [Ray.to_basis(*Detector.basis) for Ray in PickedRaysGlobal]
336    # The rays are now in the reference plane of the detector whose normal is [0,0,1]
337    wavelengths = np.array([Ray.wavelength for Ray in PickedRays])
338    vectors = np.array([Ray.vector for Ray in PickedRays])
339    frequencies = np.array([LightSpeed / Ray.wavelength for Ray in PickedRays]) # mm/s / mm = 1/s
340    k_vectors = 2*np.pi * vectors / wavelengths[:, np.newaxis]
341    angles = np.arccos(np.clip(np.dot(vectors, np.array([0, 0, 1])), -1.0, 1.0))
342    # We calculate the intersection of the rays with the detector plane
343    Intersections = mgeo.IntersectionRayListZPlane(PickedRays)[0]
344    distances = (Intersections - mgeo.Origin[:2]).norm
345    # We calculate the phase of each ray
346    PathDelays = np.array(Detector.get_Delays(PickedRaysGlobal))
347    PositionDelays = np.sum(np.array(Intersections._add_dimension()-mgeo.Origin)*vectors, axis=1)  / LightSpeed * 1e15
348    Delays = (PathDelays+PositionDelays) /  1e15
349    # We have the delays in fs, we can now calculate the phase, taking into account the wavelength of each ray
350    Phases = np.array([2 * np.pi * frequencies[i] * Delays[i] for i in range(len(Delays))])
351    # We also need the intensity of each ray
352    RayIntensities = np.array([Ray.intensity for Ray in PickedRays])
353    # We can now calculate the intensity at each point of the grid
354    Intensity = np.zeros((resolution, resolution))
355    for i in range(len(PickedRays)):
356        Intensity += RayIntensities[i] * np.cos(k_vectors[i][0] * X + k_vectors[i][1] * Y + Phases[i])
357
358    return X, Y, Intensity
359
360moc.getPlaneWaveFocus = GetPlaneWaveFocus
361
362def GetDiffractionFocus(OpticalChain, Detector = "Focus", size=None, Nrays=1000, resolution=100):
363    """
364    This function calculates the approximate polychromatic focal spot of a set of rays.
365    To do so, it positions itself in the detector plane and samples a square area.
366    If the size of the area is not given, it will use twice the Airy radius of the system with the largest wavbelength.
367    Otherwise it uses the size.
368    The resolution of the sampling is given by the resolution parameter. So it samples on a grid resolution x resolution.
369
370    To calculate the intensity, it takes Nrays out of the raylist (to subsample if needed).
371    It assimilates every ray to a plane wave, so it calculates a k-vector for each ray (taking into account the wavelength of the ray).
372    It considers that all the rays are intersecting the detector in the middle and doesn't take into account their phase.
373
374    So it returns the best case scenario for a diffraction limited focus with that numerical aperture
375    """
376    if isinstance(Detector, str):
377        Detector = OpticalChain.detectors[Detector]
378    Index = Detector.index
379    RayList = OpticalChain.get_output_rays()[Index]
380    if size is None:
381        Wavelengths = [Ray.wavelength for Ray in RayList]
382        Wavelength = max(Wavelengths)
383        NumericalAperture = GetNumericalAperture(RayList)
384        size = 3 * GetAiryRadius(Wavelength, NumericalAperture)
385    X = np.linspace(-size / 2, size / 2, resolution)
386    Y = np.linspace(-size / 2, size / 2, resolution)
387    X, Y = np.meshgrid(X, Y)
388    # We pick Nrays. if there are less than Nrays, we take all of them.
389    PickedRaysGlobal = np.random.choice(RayList, min(Nrays, len(RayList)), replace=False)
390    # We calculate the k-vector of each ray, taking into account the wavelength of the ray.
391    # The units should be in mm^-1
392    PickedRays = [Ray.to_basis(*Detector.basis) for Ray in PickedRaysGlobal]
393    # The rays are now in the reference plane of the detector whose normal is [0,0,1]
394    wavelengths = np.array([Ray.wavelength for Ray in PickedRays])
395    vectors = np.array([Ray.vector for Ray in PickedRays])
396    frequencies = np.array([LightSpeed / Ray.wavelength for Ray in PickedRays]) # mm/s / mm = 1/s
397    k_vectors = 2*np.pi * vectors / wavelengths[:, np.newaxis]
398    angles = np.arccos(np.clip(np.dot(vectors, np.array([0, 0, 1])), -1.0, 1.0))
399    # We calculate the intersection of the rays with the detector plane
400    Intersections = mgeo.IntersectionRayListZPlane(PickedRays)[0]
401    distances = (Intersections - mgeo.Origin[:2]).norm
402    # We also need the intensity of each ray
403    RayIntensities = np.array([Ray.intensity for Ray in PickedRays])
404    # We can now calculate the intensity at each point of the grid
405    Intensity = np.zeros((resolution, resolution))
406    for i in range(len(PickedRays)):
407        Intensity += RayIntensities[i] * np.cos(k_vectors[i][0] * X + k_vectors[i][1] * Y)
408
409    return X, Y, Intensity
410
411moc.getDiffractionFocus = GetDiffractionFocus
412
413# %% Asphericity analysis
414# This code is to do asphericity analysis of various surfaces
415# It calculates the closes sphere to the surface of an optical element
416# Then there are two functions. One that simply gives an asphericity value
417# and another one that actually plots the distance the the closest sphere 
418# in much the same way as we plot the MirrorProjection
419
420def BestFitSphere(X,Y,Z):
421    """
422    This function calculates the best sphere to fit a set of points.
423    It uses the least square method to find the center and the radius of the sphere.
424    Cite https://jekel.me/2015/Least-Squares-Sphere-Fit/
425    """
426    A = np.zeros((len(X),4))
427    A[:,0] = X*2
428    A[:,1] = Y*2
429    A[:,2] = Z*2
430    A[:,3] = 1
431
432    #   Assemble the f matrix
433    f = np.zeros((len(X),1))
434    f[:,0] = X**2 + Y**2 + Z**2
435    C, residules, rank, singval = np.linalg.lstsq(A,f)
436
437    #   solve for the radius
438    t = (C[0]*C[0])+(C[1]*C[1])+(C[2]*C[2])+C[3]
439    radius = math.sqrt(t)
440
441    return mgeo.Point(C[:3].flatten()),radius
442
443
444def GetClosestSphere(Mirror, Npoints=1000):
445    """
446    This function calculates the closest sphere to the surface of a mirror.
447    It does so by sampling the surface of the mirror at Npoints points.
448    It then calculates the closest sphere to these points.
449    It returns the radius of the sphere and the center of the sphere.
450    """
451    Points = mpm.sample_support(Mirror.support, Npoints=1000)
452    Points += Mirror.r0[:2]
453    Z = Mirror._zfunc(Points)
454    Points = mgeo.PointArray([Points[:, 0], Points[:, 1], Z]).T
455    spX, spY, spZ = Points[:, 0], Points[:, 1], Points[:, 2]
456    Center, Radius = BestFitSphere(spX, spY, spZ)
457    return Center, Radius
458
459def GetAsphericity(Mirror, Npoints=1000):
460    """
461    This function calculates the maximum distance of the mirror surface to the closest sphere. 
462    """
463    center, radius = GetClosestSphere(Mirror, Npoints)
464    Points = mpm.sample_support(Mirror.support, Npoints=1000)
465    Points += Mirror.r0[:2]
466    Z = Mirror._zfunc(Points)
467    Points = mgeo.PointArray([Points[:, 0], Points[:, 1], Z]).T
468    Points_centered = Points - center
469    Distance = np.linalg.norm(Points_centered, axis=1) - radius
470    Distance*=1e3 # To convert to µm
471    return np.ptp(Distance)
472
473mmirror.Mirror.getClosestSphere = GetClosestSphere
474mmirror.Mirror.getAsphericity = GetAsphericity
LightSpeed = 299792458000
def GetETransmission(RayListIn, RayListOut) -> float:
47def GetETransmission(RayListIn, RayListOut) -> float:
48    """
49    Calculates the energy transmission from RayListIn to RayListOut in percent by summing up the
50    intensity-property of the individual rays.
51
52    Parameters
53    ----------
54        RayListIn : list(Ray)
55            List of incoming rays.
56        
57        RayListOut : list(Ray)
58            List of outgoing rays.
59
60    Returns
61    -------
62        ETransmission : float
63    """
64    ETransmission = 100 * sum(Ray.intensity for Ray in RayListOut) / sum(Ray.intensity for Ray in RayListIn)
65    return ETransmission

Calculates the energy transmission from RayListIn to RayListOut in percent by summing up the intensity-property of the individual rays.

Parameters

RayListIn : list(Ray)
    List of incoming rays.

RayListOut : list(Ray)
    List of outgoing rays.

Returns

ETransmission : float
def GetResultSummary(Detector, RayListAnalysed, verbose=False):
 97def GetResultSummary(Detector, RayListAnalysed, verbose=False):
 98    """
 99    Calculate and return FocalSpotSize-standard-deviation and Duration-standard-deviation
100    for the given Detector and RayList.
101    If verbose, then also print a summary of the results for the given Detector.
102
103    Parameters
104    ----------
105        Detector : Detector
106            An object of the ModuleDetector.Detector-class.
107
108        RayListAnalysed : list(Ray)
109            List of objects of the ModuleOpticalRay.Ray-class.
110
111        verbose : bool
112            Whether to print a result summary.
113
114    Returns
115    -------
116        FocalSpotSizeSD : float
117
118        DurationSD : float
119    """
120    DetectorPointList2DCentre = Detector.get_2D_points(RayListAnalysed)
121    FocalSpotSizeSD = mp.StandardDeviation(DetectorPointList2DCentre)
122    DelayList = Detector.get_Delays(RayListAnalysed)
123    DurationSD = mp.StandardDeviation(DelayList)
124
125    if verbose:
126        FocalSpotSize = mgeo.DiameterPointList(DetectorPointList2DCentre)
127        summarystring = (
128            "At the detector distance of "
129            + "{:.3f}".format(Detector.get_distance())
130            + " mm we get:\n"
131            + "Spatial std : "
132            + "{:.3f}".format(FocalSpotSizeSD * 1e3)
133            + " \u03BCm and min-max: "
134            + "{:.3f}".format(FocalSpotSize * 1e3)
135            + " \u03BCm\n"
136            + "Temporal std : "
137            + "{:.3e}".format(DurationSD)
138            + " fs and min-max : "
139            + "{:.3e}".format(max(DelayList) - min(DelayList))
140            + " fs"
141        )
142
143        print(summarystring)
144
145    return FocalSpotSizeSD, DurationSD

Calculate and return FocalSpotSize-standard-deviation and Duration-standard-deviation for the given Detector and RayList. If verbose, then also print a summary of the results for the given Detector.

Parameters

Detector : Detector
    An object of the ModuleDetector.Detector-class.

RayListAnalysed : list(Ray)
    List of objects of the ModuleOpticalRay.Ray-class.

verbose : bool
    Whether to print a result summary.

Returns

FocalSpotSizeSD : float

DurationSD : float
def GetPulseProfile(Detector, RayList, Nbins=100):
179def GetPulseProfile(Detector, RayList, Nbins=100):
180    """
181    Retrieve the pulse profile from the given RayList and Detector.
182    The pulse profile is calculated by binning the delays of the rays in the RayList
183    at the Detector. The intensity of the rays is also taken into account.
184
185    Parameters
186    ----------
187        Detector : Detector
188            An object of the ModuleDetector.Detector-class.
189
190        RayList : list(Ray)
191            List of objects of the ModuleOpticalRay.Ray-class.
192
193        Nbins : int
194            Number of bins to use for the histogram.
195
196    Returns
197    -------
198        time : numpy.ndarray
199            The bins of the histogram.
200
201        intensity : numpy.ndarray
202            The histogram values.
203    """
204    DelayList = Detector.get_Delays(RayList)
205    IntensityList = [Ray.intensity for Ray in RayList]
206    intensity, time_edges = np.histogram(DelayList, bins=Nbins, weights=IntensityList)
207    time = 0.5 * (time_edges[1:] + time_edges[:-1])
208    return time, intensity

Retrieve the pulse profile from the given RayList and Detector. The pulse profile is calculated by binning the delays of the rays in the RayList at the Detector. The intensity of the rays is also taken into account.

Parameters

Detector : Detector
    An object of the ModuleDetector.Detector-class.

RayList : list(Ray)
    List of objects of the ModuleOpticalRay.Ray-class.

Nbins : int
    Number of bins to use for the histogram.

Returns

time : numpy.ndarray
    The bins of the histogram.

intensity : numpy.ndarray
    The histogram values.
def GetNumericalAperture( RayList: list[ARTcore.ModuleOpticalRay.Ray], RefractiveIndex: float = 1) -> float:
243def GetNumericalAperture(RayList: list[mray.Ray], RefractiveIndex: float = 1) -> float:
244    r"""
245    Returns the numerical aperture associated with the supplied ray-bundle 'Raylist'.
246    This is $n\sin\theta$, where $\theta$ is the maximum angle between any of the rays and the central ray,
247    and $n$ is the refractive index of the propagation medium.
248
249    Parameters
250    ----------
251        RayList : list of Ray-object
252            The ray-bundle of which to determine the NA.
253
254        RefractiveIndex : float, optional
255            Refractive index of the propagation medium, defaults to =1.
256
257    Returns
258    -------
259        NA : float
260    """
261    CentralRay = mp.FindCentralRay(RayList)
262    if CentralRay is None:
263        CentralVector = np.array([0, 0, 0])
264        for k in RayList:
265            CentralVector = CentralVector + k.vector
266        CentralVector = CentralVector / len(RayList)
267    else:
268        CentralVector = CentralRay.vector
269    ListAngleAperture = []
270    for k in RayList:
271        ListAngleAperture.append(mgeo.AngleBetweenTwoVectors(CentralVector, k.vector))
272
273    return np.sin(np.amax(ListAngleAperture)) * RefractiveIndex

Returns the numerical aperture associated with the supplied ray-bundle 'Raylist'. This is $n\sin\theta$, where $\theta$ is the maximum angle between any of the rays and the central ray, and $n$ is the refractive index of the propagation medium.

Parameters

RayList : list of Ray-object
    The ray-bundle of which to determine the NA.

RefractiveIndex : float, optional
    Refractive index of the propagation medium, defaults to =1.

Returns

NA : float
def GetAiryRadius(Wavelength: float, NumericalAperture: float) -> float:
276def GetAiryRadius(Wavelength: float, NumericalAperture: float) -> float:
277    r"""
278    Returns the radius of the Airy disk: $r = 1.22 \frac\{\lambda\}\{NA\}$,
279    i.e. the diffraction-limited radius of the focal spot corresponding to a given
280    numerical aperture $NA$ and a light wavelength $\lambda$.
281
282    For very small $NA<10^\{-3\}$, diffraction effects becomes negligible and the Airy Radius becomes meaningless,
283    so in that case, a radius of 0 is returned.
284
285    Parameters
286    ----------
287        Wavelength : float
288            Light wavelength in mm.
289
290        NumericalAperture : float
291
292    Returns
293    -------
294        AiryRadius : float
295    """
296    if NumericalAperture > 1e-3 and Wavelength is not None:
297        return 1.22 * 0.5 * Wavelength / NumericalAperture
298    else:
299        return 0  # for very small numerical apertures, diffraction effects becomes negligible and the Airy Radius becomes meaningless

Returns the radius of the Airy disk: $r = 1.22 \frac{\lambda}{NA}$, i.e. the diffraction-limited radius of the focal spot corresponding to a given numerical aperture $NA$ and a light wavelength $\lambda$.

For very small $NA<10^{-3}$, diffraction effects becomes negligible and the Airy Radius becomes meaningless, so in that case, a radius of 0 is returned.

Parameters

Wavelength : float
    Light wavelength in mm.

NumericalAperture : float

Returns

AiryRadius : float
def GetPlaneWaveFocus( OpticalChain, Detector='Focus', size=None, Nrays=1000, resolution=100):
302def GetPlaneWaveFocus(OpticalChain, Detector = "Focus", size=None, Nrays=1000, resolution=100):
303    """
304    This function calculates the approximate polychromatic focal spot of a set of rays.
305    To do so, it positions itself in the detector plane and samples a square area.
306    If the size of the area is not given, it will use twice the Airy radius of the system with the largest wavbelength.
307    Otherwise it uses the size.
308    The resolution of the sampling is given by the resolution parameter. So it samples on a grid resolution x resolution.
309
310    To calculate the intensity, it takes Nrays out of the raylist (to subsample if needed).
311    It assimilates every ray to a plane wave, so it calculates a k-vector for each ray (taking into account the wavelength of the ray).
312    It calculates the phase of each ray from the delay and from the intersection position with the detector.
313    The delay from the non-central position is simply sin(alpha)*distance/c, where alpha is the angle between the ray and the normal to the detector
314    and distance is the distance between the intersection point and the central point of the detector. 
315    It then calculates the intensity at each point of the grid by summing the intensity of each plane wave.
316
317    As long as there are not too many rays, this method is faster than doing an FFT.
318    It's also naturally polychromatic.
319    """
320    if isinstance(Detector, str):
321        Detector = OpticalChain.detectors[Detector]
322    Index = Detector.index
323    RayList = OpticalChain.get_output_rays()[Index]
324    if size is None:
325        Wavelengths = [Ray.wavelength for Ray in RayList]
326        Wavelength = max(Wavelengths)
327        NumericalAperture = GetNumericalAperture(RayList)
328        size = 3 * GetAiryRadius(Wavelength, NumericalAperture)
329    X = np.linspace(-size / 2, size / 2, resolution)
330    Y = np.linspace(-size / 2, size / 2, resolution)
331    X, Y = np.meshgrid(X, Y)
332    # We pick Nrays. if there are less than Nrays, we take all of them.
333    PickedRaysGlobal = np.random.choice(RayList, min(Nrays, len(RayList)), replace=False)
334    # We calculate the k-vector of each ray, taking into account the wavelength of the ray.
335    # The units should be in mm^-1
336    PickedRays = [Ray.to_basis(*Detector.basis) for Ray in PickedRaysGlobal]
337    # The rays are now in the reference plane of the detector whose normal is [0,0,1]
338    wavelengths = np.array([Ray.wavelength for Ray in PickedRays])
339    vectors = np.array([Ray.vector for Ray in PickedRays])
340    frequencies = np.array([LightSpeed / Ray.wavelength for Ray in PickedRays]) # mm/s / mm = 1/s
341    k_vectors = 2*np.pi * vectors / wavelengths[:, np.newaxis]
342    angles = np.arccos(np.clip(np.dot(vectors, np.array([0, 0, 1])), -1.0, 1.0))
343    # We calculate the intersection of the rays with the detector plane
344    Intersections = mgeo.IntersectionRayListZPlane(PickedRays)[0]
345    distances = (Intersections - mgeo.Origin[:2]).norm
346    # We calculate the phase of each ray
347    PathDelays = np.array(Detector.get_Delays(PickedRaysGlobal))
348    PositionDelays = np.sum(np.array(Intersections._add_dimension()-mgeo.Origin)*vectors, axis=1)  / LightSpeed * 1e15
349    Delays = (PathDelays+PositionDelays) /  1e15
350    # We have the delays in fs, we can now calculate the phase, taking into account the wavelength of each ray
351    Phases = np.array([2 * np.pi * frequencies[i] * Delays[i] for i in range(len(Delays))])
352    # We also need the intensity of each ray
353    RayIntensities = np.array([Ray.intensity for Ray in PickedRays])
354    # We can now calculate the intensity at each point of the grid
355    Intensity = np.zeros((resolution, resolution))
356    for i in range(len(PickedRays)):
357        Intensity += RayIntensities[i] * np.cos(k_vectors[i][0] * X + k_vectors[i][1] * Y + Phases[i])
358
359    return X, Y, Intensity

This function calculates the approximate polychromatic focal spot of a set of rays. To do so, it positions itself in the detector plane and samples a square area. If the size of the area is not given, it will use twice the Airy radius of the system with the largest wavbelength. Otherwise it uses the size. The resolution of the sampling is given by the resolution parameter. So it samples on a grid resolution x resolution.

To calculate the intensity, it takes Nrays out of the raylist (to subsample if needed). It assimilates every ray to a plane wave, so it calculates a k-vector for each ray (taking into account the wavelength of the ray). It calculates the phase of each ray from the delay and from the intersection position with the detector. The delay from the non-central position is simply sin(alpha)*distance/c, where alpha is the angle between the ray and the normal to the detector and distance is the distance between the intersection point and the central point of the detector. It then calculates the intensity at each point of the grid by summing the intensity of each plane wave.

As long as there are not too many rays, this method is faster than doing an FFT. It's also naturally polychromatic.

def GetDiffractionFocus( OpticalChain, Detector='Focus', size=None, Nrays=1000, resolution=100):
363def GetDiffractionFocus(OpticalChain, Detector = "Focus", size=None, Nrays=1000, resolution=100):
364    """
365    This function calculates the approximate polychromatic focal spot of a set of rays.
366    To do so, it positions itself in the detector plane and samples a square area.
367    If the size of the area is not given, it will use twice the Airy radius of the system with the largest wavbelength.
368    Otherwise it uses the size.
369    The resolution of the sampling is given by the resolution parameter. So it samples on a grid resolution x resolution.
370
371    To calculate the intensity, it takes Nrays out of the raylist (to subsample if needed).
372    It assimilates every ray to a plane wave, so it calculates a k-vector for each ray (taking into account the wavelength of the ray).
373    It considers that all the rays are intersecting the detector in the middle and doesn't take into account their phase.
374
375    So it returns the best case scenario for a diffraction limited focus with that numerical aperture
376    """
377    if isinstance(Detector, str):
378        Detector = OpticalChain.detectors[Detector]
379    Index = Detector.index
380    RayList = OpticalChain.get_output_rays()[Index]
381    if size is None:
382        Wavelengths = [Ray.wavelength for Ray in RayList]
383        Wavelength = max(Wavelengths)
384        NumericalAperture = GetNumericalAperture(RayList)
385        size = 3 * GetAiryRadius(Wavelength, NumericalAperture)
386    X = np.linspace(-size / 2, size / 2, resolution)
387    Y = np.linspace(-size / 2, size / 2, resolution)
388    X, Y = np.meshgrid(X, Y)
389    # We pick Nrays. if there are less than Nrays, we take all of them.
390    PickedRaysGlobal = np.random.choice(RayList, min(Nrays, len(RayList)), replace=False)
391    # We calculate the k-vector of each ray, taking into account the wavelength of the ray.
392    # The units should be in mm^-1
393    PickedRays = [Ray.to_basis(*Detector.basis) for Ray in PickedRaysGlobal]
394    # The rays are now in the reference plane of the detector whose normal is [0,0,1]
395    wavelengths = np.array([Ray.wavelength for Ray in PickedRays])
396    vectors = np.array([Ray.vector for Ray in PickedRays])
397    frequencies = np.array([LightSpeed / Ray.wavelength for Ray in PickedRays]) # mm/s / mm = 1/s
398    k_vectors = 2*np.pi * vectors / wavelengths[:, np.newaxis]
399    angles = np.arccos(np.clip(np.dot(vectors, np.array([0, 0, 1])), -1.0, 1.0))
400    # We calculate the intersection of the rays with the detector plane
401    Intersections = mgeo.IntersectionRayListZPlane(PickedRays)[0]
402    distances = (Intersections - mgeo.Origin[:2]).norm
403    # We also need the intensity of each ray
404    RayIntensities = np.array([Ray.intensity for Ray in PickedRays])
405    # We can now calculate the intensity at each point of the grid
406    Intensity = np.zeros((resolution, resolution))
407    for i in range(len(PickedRays)):
408        Intensity += RayIntensities[i] * np.cos(k_vectors[i][0] * X + k_vectors[i][1] * Y)
409
410    return X, Y, Intensity

This function calculates the approximate polychromatic focal spot of a set of rays. To do so, it positions itself in the detector plane and samples a square area. If the size of the area is not given, it will use twice the Airy radius of the system with the largest wavbelength. Otherwise it uses the size. The resolution of the sampling is given by the resolution parameter. So it samples on a grid resolution x resolution.

To calculate the intensity, it takes Nrays out of the raylist (to subsample if needed). It assimilates every ray to a plane wave, so it calculates a k-vector for each ray (taking into account the wavelength of the ray). It considers that all the rays are intersecting the detector in the middle and doesn't take into account their phase.

So it returns the best case scenario for a diffraction limited focus with that numerical aperture

def BestFitSphere(X, Y, Z):
421def BestFitSphere(X,Y,Z):
422    """
423    This function calculates the best sphere to fit a set of points.
424    It uses the least square method to find the center and the radius of the sphere.
425    Cite https://jekel.me/2015/Least-Squares-Sphere-Fit/
426    """
427    A = np.zeros((len(X),4))
428    A[:,0] = X*2
429    A[:,1] = Y*2
430    A[:,2] = Z*2
431    A[:,3] = 1
432
433    #   Assemble the f matrix
434    f = np.zeros((len(X),1))
435    f[:,0] = X**2 + Y**2 + Z**2
436    C, residules, rank, singval = np.linalg.lstsq(A,f)
437
438    #   solve for the radius
439    t = (C[0]*C[0])+(C[1]*C[1])+(C[2]*C[2])+C[3]
440    radius = math.sqrt(t)
441
442    return mgeo.Point(C[:3].flatten()),radius

This function calculates the best sphere to fit a set of points. It uses the least square method to find the center and the radius of the sphere. Cite https://jekel.me/2015/Least-Squares-Sphere-Fit/

def GetClosestSphere(Mirror, Npoints=1000):
445def GetClosestSphere(Mirror, Npoints=1000):
446    """
447    This function calculates the closest sphere to the surface of a mirror.
448    It does so by sampling the surface of the mirror at Npoints points.
449    It then calculates the closest sphere to these points.
450    It returns the radius of the sphere and the center of the sphere.
451    """
452    Points = mpm.sample_support(Mirror.support, Npoints=1000)
453    Points += Mirror.r0[:2]
454    Z = Mirror._zfunc(Points)
455    Points = mgeo.PointArray([Points[:, 0], Points[:, 1], Z]).T
456    spX, spY, spZ = Points[:, 0], Points[:, 1], Points[:, 2]
457    Center, Radius = BestFitSphere(spX, spY, spZ)
458    return Center, Radius

This function calculates the closest sphere to the surface of a mirror. It does so by sampling the surface of the mirror at Npoints points. It then calculates the closest sphere to these points. It returns the radius of the sphere and the center of the sphere.

def GetAsphericity(Mirror, Npoints=1000):
460def GetAsphericity(Mirror, Npoints=1000):
461    """
462    This function calculates the maximum distance of the mirror surface to the closest sphere. 
463    """
464    center, radius = GetClosestSphere(Mirror, Npoints)
465    Points = mpm.sample_support(Mirror.support, Npoints=1000)
466    Points += Mirror.r0[:2]
467    Z = Mirror._zfunc(Points)
468    Points = mgeo.PointArray([Points[:, 0], Points[:, 1], Z]).T
469    Points_centered = Points - center
470    Distance = np.linalg.norm(Points_centered, axis=1) - radius
471    Distance*=1e3 # To convert to µm
472    return np.ptp(Distance)

This function calculates the maximum distance of the mirror surface to the closest sphere.

2 - ModuleAnalysisAndPlots

Provides functions for analysis of the output ray-bundles calculated through ray-tracing, and also for the ART's standard visualization options.

Implements the following functions: - DrawSpotDiagram: Produces an interactive figure with the spot diagram on the selected Detector. - DrawDelaySpots: Produces a an interactive figure with a spot diagram resulting from the RayListAnalysed - _drawDelayGraph: Auxiliary function for DelayGraph. - DrawMirrorProjection: Produce a plot of the ray impact points on the optical element with index 'ReflectionNumber'. - DrawSetup: Renders an image of the Optical setup and the traced rays. - DrawAsphericity: Displays a map of the asphericity of the mirror. - DrawCaustics: Displays the caustics of the rays on the detector.

It also implements the following methods: -OpticalChain: - render: Method to render the optical setup. - drawSpotDiagram: Method to draw the spot diagram on the selected detector. - drawDelaySpots: Method to draw the delay spots on the detector. - drawMirrorProjection: Method to draw the mirror projection. - drawCaustics: Method to draw the caustics on the detector. -Mirror: - drawAsphericity: Method to draw the asphericity of the mirror.

Created in Apr 2020

@author: Anthony Guillaume + Stefan Haessler + Andre Kalouguine

  1"""
  2Provides functions for analysis of the output ray-bundles calculated through ray-tracing, and also for the ART's standard visualization options.
  3
  4Implements the following functions:
  5    - DrawSpotDiagram: Produces an interactive figure with the spot diagram on the selected Detector.
  6    - DrawDelaySpots: Produces a an interactive figure with a spot diagram resulting from the RayListAnalysed
  7    - _drawDelayGraph: Auxiliary function for DelayGraph.
  8    - DrawMirrorProjection: Produce a plot of the ray impact points on the optical element with index 'ReflectionNumber'.
  9    - DrawSetup: Renders an image of the Optical setup and the traced rays.
 10    - DrawAsphericity: Displays a map of the asphericity of the mirror.
 11    - DrawCaustics: Displays the caustics of the rays on the detector.
 12
 13It also implements the following methods:
 14    -OpticalChain:
 15        - render: Method to render the optical setup.
 16        - drawSpotDiagram: Method to draw the spot diagram on the selected detector.
 17        - drawDelaySpots: Method to draw the delay spots on the detector.
 18        - drawMirrorProjection: Method to draw the mirror projection.
 19        - drawCaustics: Method to draw the caustics on the detector.
 20    -Mirror:
 21        - drawAsphericity: Method to draw the asphericity of the mirror.
 22
 23
 24Created in Apr 2020
 25
 26@author: Anthony Guillaume + Stefan Haessler + Andre Kalouguine
 27"""
 28# %% Module imports
 29import numpy as np
 30import matplotlib.pyplot as plt
 31from mpl_toolkits.mplot3d import Axes3D
 32import pyvista as pv
 33import pyvistaqt as pvqt
 34import colorcet as cc
 35from colorsys import rgb_to_hls, hls_to_rgb
 36import logging
 37
 38logger = logging.getLogger(__name__)
 39
 40
 41import ARTcore.ModuleProcessing as mp
 42import ARTcore.ModuleGeometry as mgeo
 43import ARTcore.ModuleDetector as mdet
 44import ARTcore.ModuleMirror as mmirror
 45import ART.ModulePlottingMethods as mpm
 46import ART.ModulePlottingUtilities as mpu
 47from ART.ModulePlottingUtilities import Observable
 48
 49import ARTcore.ModuleOpticalChain as moc
 50import ART.ModuleAnalysis as man
 51import itertools
 52from copy import copy
 53
 54
 55# %% Spot diagram on detector
 56def DrawSpotDiagram(OpticalChain, 
 57                Detector = "Focus",
 58                DrawAiryAndFourier=False, 
 59                DrawFocalContour=False,
 60                DrawFocal=False,
 61                ColorCoded=None,
 62                Observer = None) -> plt.Figure:
 63    """
 64    Produce an interactive figure with the spot diagram on the selected Detector.
 65    The detector distance can be shifted with the left-right cursor keys. Doing so will actually move the detector.
 66    If DrawAiryAndFourier is True, a circle with the Airy-spot-size will be shown.
 67    If DrawFocalContour is True, the focal contour calculated from some of the rays will be shown.
 68    If DrawFocal is True, a heatmap calculated from some of the rays will be shown. 
 69    The 'spots' can optionally be color-coded by specifying ColorCoded, which can be one of ["Intensity","Incidence","Delay"].
 70
 71    Parameters
 72    ----------
 73        RayListAnalysed : list(Ray)
 74            List of objects of the ModuleOpticalRay.Ray-class.
 75
 76        Detector : Detector or str, optional
 77            An object of the ModuleDetector.Detector-class or the name of the detector. The default is "Focus".
 78
 79        DrawAiryAndFourier : bool, optional
 80            Whether to draw a circle with the Airy-spot-size. The default is False.
 81        
 82        DrawFocalContour : bool, optional
 83            Whether to draw the focal contour. The default is False.
 84
 85        DrawFocal : bool, optional
 86            Whether to draw the focal heatmap. The default is False.
 87
 88        ColorCoded : str, optional
 89            Color-code the spots according to one of ["Intensity","Incidence","Delay"]. The default is None.
 90
 91        Observer : Observer, optional
 92            An observer object. If none, then we just create a copy of the detector and move it when pressing left-right. 
 93            However, if an observer is specified, then we will change the value of the observer and it will issue 
 94            the required callbacks to update several plots at the same time.
 95
 96    Returns
 97    -------
 98        fig : matlplotlib-figure-handle.
 99            Shows the interactive figure.
100    """
101    if isinstance(Detector, str):
102        Detector = OpticalChain.detectors[Detector]
103    Index = Detector.index
104    movingDetector = copy(Detector) # We will move this detector when pressing left-right
105    if Observer is None:
106        detectorPosition = Observable(movingDetector.distance) # We will observe the distance of this detector
107    else:
108        detectorPosition = Observer
109        movingDetector.distance = detectorPosition.value
110
111    detectorPosition.register_calculation(lambda x: movingDetector.set_distance(x))
112
113    RayListAnalysed = OpticalChain.get_output_rays()[Index]
114
115    NumericalAperture = man.GetNumericalAperture(RayListAnalysed, 1)  # NA determined from final ray bundle
116    MaxWavelength = np.max([i.wavelength for i in RayListAnalysed])
117    if DrawAiryAndFourier:
118        AiryRadius = man.GetAiryRadius(MaxWavelength, NumericalAperture) * 1e3  # in µm
119    else:
120        AiryRadius = 0
121    
122    if DrawFocalContour or DrawFocal:
123        X,Y,Z = man.GetDiffractionFocus(OpticalChain, movingDetector, Index)
124        Z/=np.max(Z) 
125
126    DectectorPoint2D_Xcoord, DectectorPoint2D_Ycoord, FocalSpotSize, SpotSizeSD = mpu._getDetectorPoints(
127        RayListAnalysed, movingDetector
128    )
129
130    match ColorCoded:
131        case "Intensity":
132            IntensityList = [k.intensity for k in RayListAnalysed]
133            z = np.asarray(IntensityList)
134            zlabel = "Intensity (arb.u.)"
135            title = "Intensity + Spot Diagram\n press left/right to move detector position"
136            addLine = ""
137        case "Incidence":
138            IncidenceList = [np.rad2deg(k.incidence) for k in RayListAnalysed]  # degree
139            z = np.asarray(IncidenceList)
140            zlabel = "Incidence angle (deg)"
141            title = "Ray Incidence + Spot Diagram\n press left/right to move detector position"
142            addLine = ""
143        case "Delay":
144            DelayList = movingDetector.get_Delays(RayListAnalysed)
145            DurationSD = mp.StandardDeviation(DelayList)
146            z = np.asarray(DelayList)
147            zlabel = "Delay (fs)"
148            title = "Delay + Spot Diagram\n press left/right to move detector position"
149            addLine = "\n" + "{:.2f}".format(DurationSD) + " fs SD"
150        case _:
151            z = "red"
152            title = "Spot Diagram\n press left/right to move detector position"
153            addLine = ""
154
155    distStep = min(50, max(0.0005, round(FocalSpotSize / 8 / np.arcsin(NumericalAperture) * 10000) / 10000))  # in mm
156
157    plt.ion()
158    fig, ax = plt.subplots()
159    if DrawFocal:
160        focal = ax.pcolormesh(X*1e3,Y*1e3,Z)
161    if DrawFocalContour:
162        levels = [1/np.e**2, 0.5]
163        contour = ax.contourf(X*1e3, Y*1e3, Z, levels=levels, cmap='gray')
164
165    if DrawAiryAndFourier:
166        theta = np.linspace(0, 2 * np.pi, 100)
167        x = AiryRadius * np.cos(theta)
168        y = AiryRadius * np.sin(theta)  #
169        ax.plot(x, y, c="black")
170        
171
172    foo = ax.scatter(
173        DectectorPoint2D_Xcoord,
174        DectectorPoint2D_Ycoord,
175        c=z,
176        s=15,
177        label="{:.3f}".format(detectorPosition.value) + " mm\n" + "{:.1f}".format(SpotSizeSD * 1e3) + " \u03BCm SD" + addLine,
178    )
179
180    axisLim = 1.1 * max(AiryRadius, 0.5 * FocalSpotSize * 1000)
181    ax.set_xlim(-axisLim, axisLim)
182    ax.set_ylim(-axisLim, axisLim)
183
184    if ColorCoded == "Intensity" or ColorCoded == "Incidence" or ColorCoded == "Delay":
185        cbar = fig.colorbar(foo)
186        cbar.set_label(zlabel)
187
188    ax.legend(loc="upper right")
189    ax.set_xlabel("X (µm)")
190    ax.set_ylabel("Y (µm)")
191    ax.set_title(title)
192    # ax.margins(x=0)
193
194
195    def update_plot(new_value):
196        nonlocal movingDetector, ColorCoded, zlabel, cbar, detectorPosition, foo, distStep, focal, contour, levels, Index, RayListAnalysed
197
198        newDectectorPoint2D_Xcoord, newDectectorPoint2D_Ycoord, newFocalSpotSize, newSpotSizeSD = mpu._getDetectorPoints(
199            RayListAnalysed, movingDetector
200        )
201
202        if DrawFocal:
203            focal.set_array(Z)
204        if DrawFocalContour:
205            levels = [1/np.e**2, 0.5]
206            for coll in contour.collections:
207                coll.remove()  # Remove old contour lines
208            contour = ax.contourf(X * 1e3, Y * 1e3, Z, levels=levels, cmap='gray')
209        
210        xy = foo.get_offsets()
211        xy[:, 0] = newDectectorPoint2D_Xcoord
212        xy[:, 1] = newDectectorPoint2D_Ycoord
213        foo.set_offsets(xy)
214
215
216        if ColorCoded == "Delay":
217            newDelayList = np.asarray(movingDetector.get_Delays(RayListAnalysed))
218            newDurationSD = mp.StandardDeviation(newDelayList)
219            newaddLine = "\n" + "{:.2f}".format(newDurationSD) + " fs SD"
220            foo.set_array(newDelayList)
221            foo.set_clim(min(newDelayList), max(newDelayList))
222            cbar.update_normal(foo)
223        else:
224            newaddLine = ""
225
226        foo.set_label(
227            "{:.3f}".format(detectorPosition.value) + " mm\n" + "{:.1f}".format(newSpotSizeSD * 1e3) + " \u03BCm SD" + newaddLine
228        )
229        ax.legend(loc="upper right")
230
231        axisLim = 1.1 * max(AiryRadius, 0.5 * newFocalSpotSize * 1000)
232        ax.set_xlim(-axisLim, axisLim)
233        ax.set_ylim(-axisLim, axisLim)
234
235        distStep = min(
236            50, max(0.0005, round(newFocalSpotSize / 8 / np.arcsin(NumericalAperture) * 10000) / 10000)
237        )  # in mm
238
239        fig.canvas.draw_idle()
240
241
242    def press(event):
243        nonlocal detectorPosition, distStep
244        if event.key == "right":
245            detectorPosition.value += distStep
246        elif event.key == "left":
247            if detectorPosition.value > 1.5 * distStep:
248                detectorPosition.value -= distStep
249            else:
250                detectorPosition.value = 0.5 * distStep
251        else:
252            return None
253
254    fig.canvas.mpl_connect("key_press_event", press)
255
256    plt.show()
257
258    detectorPosition.register(update_plot)
259
260
261    return fig, detectorPosition
262
263moc.OpticalChain.drawSpotDiagram = DrawSpotDiagram
264
265# %% Delay graph on detector (3D spot diagram)
266def DrawDelaySpots(OpticalChain, 
267                DeltaFT: tuple[int, float],
268                Detector = "Focus",
269                DrawAiryAndFourier=False, 
270                ColorCoded=None,
271                Observer = None
272                ) -> plt.Figure:
273    """
274    Produce a an interactive figure with a spot diagram resulting from the RayListAnalysed
275    hitting the Detector, with the ray-delays shown in the 3rd dimension.
276    The detector distance can be shifted with the left-right cursor keys.
277    If DrawAiryAndFourier is True, a cylinder is shown whose diameter is the Airy-spot-size and
278    whose height is the Fourier-limited pulse duration 'given by 'DeltaFT'.
279    
280    The 'spots' can optionally be color-coded by specifying ColorCoded as ["Intensity","Incidence"].
281
282    Parameters
283    ----------
284        RayListAnalysed : list(Ray)
285            List of objects of the ModuleOpticalRay.Ray-class.
286
287        Detector : Detector
288            An object of the ModuleDetector.Detector-class.
289
290        DeltaFT : (int, float)
291            The Fourier-limited pulse duration. Just used as a reference to compare the temporal spread
292            induced by the ray-delays.
293
294        DrawAiryAndFourier : bool, optional
295            Whether to draw a cylinder showing the Airy-spot-size and Fourier-limited-duration.
296            The default is False.
297
298        ColorCoded : str, optional
299            Color-code the spots according to one of ["Intensity","Incidence"].
300            The default is None.
301
302    Returns
303    -------
304        fig : matlplotlib-figure-handle.
305            Shows the interactive figure.
306    """
307    if isinstance(Detector, str):
308        Det = OpticalChain.detectors[Detector]
309    else:
310        Det = Detector
311    Index = Det.index
312    Detector = copy(Det)
313    if Observer is None:
314        detectorPosition = Observable(Detector.distance)
315    else:
316        detectorPosition = Observer
317        Detector.distance = detectorPosition.value
318    
319    RayListAnalysed = OpticalChain.get_output_rays()[Index]
320    fig, NumericalAperture, AiryRadius, FocalSpotSize = _drawDelayGraph(
321        RayListAnalysed, Detector, detectorPosition.value, DeltaFT, DrawAiryAndFourier, ColorCoded
322    )
323
324    distStep = min(50, max(0.0005, round(FocalSpotSize / 8 / np.arcsin(NumericalAperture) * 10000) / 10000))  # in mm
325
326    movingDetector = copy(Detector)
327
328    def update_plot(new_value):
329        nonlocal movingDetector, ColorCoded, detectorPosition, distStep, fig
330        ax = fig.axes[0]
331        cam = [ax.azim, ax.elev, ax._dist]
332        fig, sameNumericalAperture, sameAiryRadius, newFocalSpotSize = _drawDelayGraph(
333            RayListAnalysed, movingDetector, detectorPosition.value, DeltaFT, DrawAiryAndFourier, ColorCoded, fig
334        )
335        ax = fig.axes[0]
336        ax.azim, ax.elev, ax._dist = cam
337        distStep = min(
338            50, max(0.0005, round(newFocalSpotSize / 8 / np.arcsin(NumericalAperture) * 10000) / 10000)
339        )
340        return fig
341
342    def press(event):
343        nonlocal detectorPosition, distStep, movingDetector, fig
344        if event.key == "right":
345            detectorPosition.value += distStep
346        elif event.key == "left":
347            if detectorPosition.value > 1.5 * distStep:
348                detectorPosition.value -= distStep
349            else:
350                detectorPosition.value = 0.5 * distStep
351
352    fig.canvas.mpl_connect("key_press_event", press)
353    detectorPosition.register(update_plot)
354    detectorPosition.register_calculation(lambda x: movingDetector.set_distance(x))
355
356    return fig, Observable
357
358
359def _drawDelayGraph(RayListAnalysed, Detector, Distance, DeltaFT, DrawAiryAndFourier=False, ColorCoded=None, fig=None):
360    """
361    Draws the 3D-delay-spot-diagram for a fixed detector. See more doc in the function below.
362    """
363    NumericalAperture = man.GetNumericalAperture(RayListAnalysed, 1)  # NA determined from final ray bundle
364    Wavelength = RayListAnalysed[0].wavelength
365    AiryRadius = man.GetAiryRadius(Wavelength, NumericalAperture) * 1e3  # in µm
366
367    DectectorPoint2D_Xcoord, DectectorPoint2D_Ycoord, FocalSpotSize, SpotSizeSD = mpu._getDetectorPoints(
368        RayListAnalysed, Detector
369    )
370
371    DelayList = Detector.get_Delays(RayListAnalysed)
372    DurationSD = mp.StandardDeviation(DelayList)
373
374    if ColorCoded == "Intensity":
375        IntensityList = [k.intensity for k in RayListAnalysed]
376    elif ColorCoded == "Incidence":
377        IncidenceList = [np.rad2deg(k.incidence) for k in RayListAnalysed]  # in degree
378
379    plt.ion()
380    if fig is None:
381        fig = plt.figure()
382    else:
383        fig.clear(keep_observers=True)
384
385    ax = Axes3D(fig)
386    fig.add_axes(ax)
387    ax.set_xlabel("X (µm)")
388    ax.set_ylabel("Y (µm)")
389    ax.set_zlabel("Delay (fs)")
390
391    labelLine = (
392        "{:.3f}".format(Distance)
393        + " mm\n"
394        + "{:.1f}".format(SpotSizeSD * 1e3)
395        + " \u03BCm SD\n"
396        + "{:.2f}".format(DurationSD)
397        + " fs SD"
398    )
399    if ColorCoded == "Intensity":
400        ax.scatter(DectectorPoint2D_Xcoord, DectectorPoint2D_Ycoord, DelayList, s=4, c=IntensityList, label=labelLine)
401        # plt.title('Delay + Intensity graph\n press left/right to move detector position')
402        fig.suptitle("Delay + Intensity graph\n press left/right to move detector position")
403
404    elif ColorCoded == "Incidence":
405        ax.scatter(DectectorPoint2D_Xcoord, DectectorPoint2D_Ycoord, DelayList, s=4, c=IncidenceList, label=labelLine)
406        # plt.title('Delay + Incidence graph\n press left/right to move detector position')
407        fig.suptitle("Delay + Incidence graph\n press left/right to move detector position")
408    else:
409        ax.scatter(DectectorPoint2D_Xcoord, DectectorPoint2D_Ycoord, DelayList, s=4, c=DelayList, label=labelLine)
410        # plt.title('Delay graph\n press left/right to move detector position')
411        fig.suptitle("Delay graph\n press left/right to move detector position")
412
413    ax.legend(loc="upper right")
414
415    if DrawAiryAndFourier:
416        x = np.linspace(-AiryRadius, AiryRadius, 40)
417        z = np.linspace(np.mean(DelayList) - DeltaFT * 0.5, np.mean(DelayList) + DeltaFT * 0.5, 40)
418        x, z = np.meshgrid(x, z)
419        y = np.sqrt(AiryRadius**2 - x**2)
420        ax.plot_wireframe(x, y, z, color="grey", alpha=0.1)
421        ax.plot_wireframe(x, -y, z, color="grey", alpha=0.1)
422
423    axisLim = 1.1 * max(AiryRadius, 0.5 * FocalSpotSize * 1000)
424    ax.set_xlim(-axisLim, axisLim)
425    ax.set_ylim(-axisLim, axisLim)
426    #ax.set_zlim(-axisLim/3*10, axisLim/3*10) #same scaling as spatial axes 
427    ax.set_zlim(-DeltaFT, DeltaFT)
428    
429    plt.show()
430    fig.canvas.draw()
431
432    return fig, NumericalAperture, AiryRadius, FocalSpotSize
433
434moc.OpticalChain.drawDelaySpots = DrawDelaySpots
435
436# %% Mirror projection
437def DrawMirrorProjection(OpticalChain, ReflectionNumber: int, ColorCoded=None, Detector="") -> plt.Figure:
438    """
439    Produce a plot of the ray impact points on the optical element with index 'ReflectionNumber'.
440    The points can be color-coded according ["Incidence","Intensity","Delay"], where the ray delay is
441    measured at the Detector.
442
443    Parameters
444    ----------
445        OpticalChain : OpticalChain
446           List of objects of the ModuleOpticalOpticalChain.OpticalChain-class.
447
448        ReflectionNumber : int
449            Index specifying the optical element on which you want to see the impact points.
450
451        Detector : Detector, optional
452            Object of the ModuleDetector.Detector-class. Only necessary to project delays. The default is None.
453
454        ColorCoded : str, optional
455            Specifies which ray property to color-code: ["Incidence","Intensity","Delay"]. The default is None.
456
457    Returns
458    -------
459        fig : matlplotlib-figure-handle.
460            Shows the figure.
461    """
462    from mpl_toolkits.axes_grid1 import make_axes_locatable
463    if isinstance(Detector, str):
464        if Detector == "":
465            Detector = None
466        else:
467            Detector = OpticalChain.detectors[Detector]
468
469    Position = OpticalChain[ReflectionNumber].position
470    q = OpticalChain[ReflectionNumber].orientation
471    # n = OpticalChain.optical_elements[ReflectionNumber].normal
472    # m = OpticalChain.optical_elements[ReflectionNumber].majoraxis
473
474    RayListAnalysed = OpticalChain.get_output_rays()[ReflectionNumber]
475    # transform rays into the mirror-support reference frame
476    # (same as mirror frame but without the shift by mirror-centre)
477    r0 = OpticalChain[ReflectionNumber].r0
478    RayList = [r.to_basis(*OpticalChain[ReflectionNumber].basis) for r in RayListAnalysed]
479
480    x = np.asarray([k.point[0] for k in RayList]) - r0[0]
481    y = np.asarray([k.point[1] for k in RayList]) - r0[1]
482    if ColorCoded == "Intensity":
483        IntensityList = [k.intensity for k in RayListAnalysed]
484        z = np.asarray(IntensityList)
485        zlabel = "Intensity (arb.u.)"
486        title = "Ray intensity projected on mirror              "
487    elif ColorCoded == "Incidence":
488        IncidenceList = [np.rad2deg(k.incidence) for k in RayListAnalysed]  # in degree
489        z = np.asarray(IncidenceList)
490        zlabel = "Incidence angle (deg)"
491        title = "Ray incidence projected on mirror              "
492    elif ColorCoded == "Delay":
493        if Detector is not None:
494            z = np.asarray(Detector.get_Delays(RayListAnalysed))
495            zlabel = "Delay (fs)"
496            title = "Ray delay at detector projected on mirror              "
497        else:
498            raise ValueError("If you want to project ray delays, you must specify a detector.")
499    else:
500        z = "red"
501        title = "Ray impact points projected on mirror"
502
503    plt.ion()
504    fig = plt.figure()
505    ax = OpticalChain.optical_elements[ReflectionNumber].support._ContourSupport(fig)
506    p = plt.scatter(x, y, c=z, s=15)
507    if ColorCoded == "Delay" or ColorCoded == "Incidence" or ColorCoded == "Intensity":
508        divider = make_axes_locatable(ax)
509        cax = divider.append_axes("right", size="5%", pad=0.05)
510        cbar = fig.colorbar(p, cax=cax)
511        cbar.set_label(zlabel)
512    ax.set_xlabel("x (mm)")
513    ax.set_ylabel("y (mm)")
514    plt.title(title, loc="right")
515    plt.tight_layout()
516
517    bbox = ax.get_position()
518    bbox.set_points(bbox.get_points() - np.array([[0.01, 0], [0.01, 0]]))
519    ax.set_position(bbox)
520    plt.show()
521
522    return fig
523
524
525moc.OpticalChain.drawMirrorProjection = DrawMirrorProjection
526
527# %% Setup rendering
528def DrawSetup(OpticalChain, 
529                   EndDistance=None, 
530                   maxRays=300, 
531                   OEpoints=2000, 
532                   draw_mesh=False, 
533                   cycle_ray_colors = False,
534                   impact_points = False,
535                   DrawDetectors=True,
536                   DetectedRays = False,
537                   Observers = dict()):
538    """
539    Renders an image of the Optical setup and the traced rays.
540
541    Parameters
542    ----------
543        OpticalChain : OpticalChain
544            List of objects of the ModuleOpticalOpticalChain.OpticalChain-class.
545
546        EndDistance : float, optional
547            The rays of the last ray bundle are drawn with a length given by EndDistance (in mm). If not specified,
548            this distance is set to that between the source point and the 1st optical element.
549
550        maxRays: int
551            The maximum number of rays to render. Rendering all the traced rays is a insufferable resource hog
552            and not required for a nice image. Default is 150.
553
554        OEpoints : int
555            How many little spheres to draw to represent the optical elements.  Default is 2000.
556
557    Returns
558    -------
559        fig : Pyvista-figure-handle.
560            Shows the figure.
561    """
562
563    RayListHistory = [OpticalChain.source_rays] + OpticalChain.get_output_rays()
564
565    if EndDistance is None:
566        EndDistance = np.linalg.norm(OpticalChain.source_rays[0].point - OpticalChain.optical_elements[0].position)
567
568    print("...rendering image of optical chain...", end="", flush=True)
569    fig = pvqt.BackgroundPlotter(window_size=(1500, 500), notebook=False) # Opening a window
570    fig.set_background('white')
571    
572    if cycle_ray_colors:
573        colors = mpu.generate_distinct_colors(len(OpticalChain)+1)
574    else:
575        colors = [[0.7, 0, 0]]*(len(OpticalChain)+1) # Default color: dark red
576
577    # Optics display
578    # For each optic we will send the figure to the function _RenderOpticalElement and it will add the optic to the figure
579    for i,OE in enumerate(OpticalChain.optical_elements):
580        color = pv.Color(colors[i+1])
581        rgb = color.float_rgb
582        h, l, s = rgb_to_hls(*rgb)
583        s = max(0, min(1, s * 0.3))  # Decrease saturation
584        l = max(0, min(1, l + 0.1))  # Increase lightness
585        new_rgb = hls_to_rgb(h, l, s)
586        darkened_color = pv.Color(new_rgb)
587        mpm._RenderOpticalElement(fig, OE, OEpoints, draw_mesh, darkened_color, index=i)
588    ray_meshes = mpm._RenderRays(RayListHistory, EndDistance, maxRays)
589    for i,ray in enumerate(ray_meshes):
590        color = pv.Color(colors[i])
591        fig.add_mesh(ray, color=color, name=f"RayBundle_{i}")
592    if impact_points:
593        for i,rays in enumerate(RayListHistory):
594            points = np.array([list(r.point) for r in rays], dtype=np.float32)
595            points = pv.PolyData(points)
596            color = pv.Color(colors[i-1])
597            fig.add_mesh(points, color=color, point_size=5, name=f"RayImpactPoints_{i}")
598    
599    detector_copies = {key: copy(OpticalChain.detectors[key]) for key in OpticalChain.detectors.keys()}
600    detector_meshes_list = []
601    detectedpoint_meshes = dict()
602    
603    if OpticalChain.detectors is not None and DrawDetectors:
604        # Detector display
605        for key in OpticalChain.detectors.keys():
606            det = detector_copies[key]
607            index = OpticalChain.detectors[key].index
608            if key in Observers:
609                det.distance = Observers[key].value
610                #Observers[key].register_calculation(lambda x: det.set_distance(x))
611            mpm._RenderDetector(fig, det, name = key, detector_meshes = detector_meshes_list)
612            if DetectedRays:
613                RayListAnalysed = OpticalChain.get_output_rays()[index]
614                points = det.get_3D_points(RayListAnalysed)
615                points = pv.PolyData(points)
616                detectedpoint_meshes[key] = points
617                fig.add_mesh(points, color='purple', point_size=5, name=f"DetectedRays_{key}")
618    detector_meshes = dict(zip(OpticalChain.detectors.keys(), detector_meshes_list))
619    
620    # Now we define a function that will move on the plot the detector with name "detname" when it's called
621    def move_detector(detname, new_value):
622        nonlocal fig, detector_meshes, detectedpoint_meshes, DetectedRays, detectedpoint_meshes, detector_copies, OpticalChain
623        det = detector_copies[detname]
624        index = OpticalChain.detectors[detname].index
625        det_mesh = detector_meshes[detname]
626        translation = det.normal * (det.distance - new_value)
627        det_mesh.translate(translation, inplace=True)
628        det.distance = new_value
629        if DetectedRays:
630            points_mesh = detectedpoint_meshes[detname]
631            points_mesh.points = det.get_3D_points(OpticalChain.get_output_rays()[index])
632        fig.show()
633    
634    # Now we register the function to the observers
635    for key in OpticalChain.detectors.keys():
636        if key in Observers:
637            Observers[key].register(lambda x: move_detector(key, x))
638
639    #pv.save_meshio('optics.inp', pointcloud)  
640    print(
641        "\r\033[K", end="", flush=True
642    )  # move to beginning of the line with \r and then delete the whole line with \033[K
643    fig.show()
644    return fig
645
646moc.OpticalChain.render = DrawSetup
647
648# %% Asphericity
649
650def DrawAsphericity(Mirror, Npoints=1000):
651    """
652    This function displays a map of the asphericity of the mirror.
653    It's a scatter plot of the points of the mirror surface, with the color representing the distance to the closest sphere.
654    The closest sphere is calculated by the function get_closest_sphere, so least square method.
655
656    Parameters
657    ----------
658    Mirror : Mirror
659        The mirror to analyse.
660
661    Npoints : int, optional
662        The number of points to sample on the mirror surface. The default is 1000.
663    
664    Returns
665    -------
666    fig : Figure
667        The figure of the plot.
668    """
669    plt.ion()
670    fig = plt.figure()
671    ax = Mirror.support._ContourSupport(fig)
672    center, radius = man.GetClosestSphere(Mirror, Npoints)
673    Points = mpm.sample_support(Mirror.support, Npoints=1000)
674    Points += Mirror.r0[:2]
675    Z = Mirror._zfunc(Points)
676    Points = mgeo.PointArray([Points[:, 0], Points[:, 1], Z]).T
677    X, Y = Points[:, 0] - Mirror.r0[0], Points[:, 1] - Mirror.r0[1]
678    Points_centered = Points - center
679    Distance = np.linalg.norm(Points_centered, axis=1) - radius
680    Distance*=1e3 # To convert to µm
681    p = plt.scatter(X, Y, c=Distance, s=15)
682    divider = man.make_axes_locatable(ax)
683    cax = divider.append_axes("right", size="5%", pad=0.05)
684    cbar = fig.colorbar(p, cax=cax)
685    cbar.set_label("Distance to closest sphere (µm)")
686    ax.set_xlabel("x (mm)")
687    ax.set_ylabel("y (mm)")
688    plt.title("Asphericity map", loc="right")
689    plt.tight_layout()
690
691    bbox = ax.get_position()
692    bbox.set_points(bbox.get_points() - np.array([[0.01, 0], [0.01, 0]]))
693    ax.set_position(bbox)
694    plt.show()
695    return fig
696
697mmirror.Mirror.drawAsphericity = DrawAsphericity
698# %% Caustics
699
700def DrawCaustics(OpticalChain, Range=1, Detector="Focus" , Npoints=1000, Nrays=1000):
701    """
702    This function displays the caustics of the rays on the detector.
703    To do so, it calculates the intersections of the rays with the detector over a 
704    range determined by the parameter Range, and then plots the standard deviation of the
705    positions in the x and y directions.
706
707    Parameters
708    ----------
709    OpticalChain : OpticalChain
710        The optical chain to analyse.
711
712    DetectorName : str
713        The name of the detector on which the caustics are calculated.
714    
715    Range : float
716        The range of the detector over which to calculate the caustics.
717
718    Npoints : int, optional
719        The number of points to sample on the detector. The default is 1000.
720    
721    Returns
722    -------
723    fig : Figure
724        The figure of the plot.
725    """
726    distances = np.linspace(-Range, Range, Npoints)
727    if isinstance(Detector, str):
728        Det = OpticalChain.detectors[Detector]
729        Index = Det.index
730    Rays = OpticalChain.get_output_rays()[Index]
731    Nrays = min(Nrays, len(Rays))
732    Rays = np.random.choice(Rays, Nrays, replace=False)
733    LocalRayList = [r.to_basis(*Det.basis) for r in Rays]
734    Points = mgeo.IntersectionRayListZPlane(LocalRayList, distances)
735    x_std = []
736    y_std = []
737    for i in range(len(distances)):
738        x_std.append(mp.StandardDeviation(Points[i][:,0]))
739        y_std.append(mp.StandardDeviation(Points[i][:,1]))
740    plt.ion()
741    fig, ax = plt.subplots()
742    ax.plot(distances, x_std, label="x std")
743    ax.plot(distances, y_std, label="y std")
744    ax.set_xlabel("Detector distance (mm)")
745    ax.set_ylabel("Standard deviation (mm)")
746    ax.legend()
747    plt.title("Caustics")
748    plt.show()
749    return fig
750
751moc.OpticalChain.drawCaustics = DrawCaustics
logger = <Logger ART.ModuleAnalysisAndPlots (WARNING)>
def DrawSpotDiagram( OpticalChain, Detector='Focus', DrawAiryAndFourier=False, DrawFocalContour=False, DrawFocal=False, ColorCoded=None, Observer=None) -> matplotlib.figure.Figure:
 57def DrawSpotDiagram(OpticalChain, 
 58                Detector = "Focus",
 59                DrawAiryAndFourier=False, 
 60                DrawFocalContour=False,
 61                DrawFocal=False,
 62                ColorCoded=None,
 63                Observer = None) -> plt.Figure:
 64    """
 65    Produce an interactive figure with the spot diagram on the selected Detector.
 66    The detector distance can be shifted with the left-right cursor keys. Doing so will actually move the detector.
 67    If DrawAiryAndFourier is True, a circle with the Airy-spot-size will be shown.
 68    If DrawFocalContour is True, the focal contour calculated from some of the rays will be shown.
 69    If DrawFocal is True, a heatmap calculated from some of the rays will be shown. 
 70    The 'spots' can optionally be color-coded by specifying ColorCoded, which can be one of ["Intensity","Incidence","Delay"].
 71
 72    Parameters
 73    ----------
 74        RayListAnalysed : list(Ray)
 75            List of objects of the ModuleOpticalRay.Ray-class.
 76
 77        Detector : Detector or str, optional
 78            An object of the ModuleDetector.Detector-class or the name of the detector. The default is "Focus".
 79
 80        DrawAiryAndFourier : bool, optional
 81            Whether to draw a circle with the Airy-spot-size. The default is False.
 82        
 83        DrawFocalContour : bool, optional
 84            Whether to draw the focal contour. The default is False.
 85
 86        DrawFocal : bool, optional
 87            Whether to draw the focal heatmap. The default is False.
 88
 89        ColorCoded : str, optional
 90            Color-code the spots according to one of ["Intensity","Incidence","Delay"]. The default is None.
 91
 92        Observer : Observer, optional
 93            An observer object. If none, then we just create a copy of the detector and move it when pressing left-right. 
 94            However, if an observer is specified, then we will change the value of the observer and it will issue 
 95            the required callbacks to update several plots at the same time.
 96
 97    Returns
 98    -------
 99        fig : matlplotlib-figure-handle.
100            Shows the interactive figure.
101    """
102    if isinstance(Detector, str):
103        Detector = OpticalChain.detectors[Detector]
104    Index = Detector.index
105    movingDetector = copy(Detector) # We will move this detector when pressing left-right
106    if Observer is None:
107        detectorPosition = Observable(movingDetector.distance) # We will observe the distance of this detector
108    else:
109        detectorPosition = Observer
110        movingDetector.distance = detectorPosition.value
111
112    detectorPosition.register_calculation(lambda x: movingDetector.set_distance(x))
113
114    RayListAnalysed = OpticalChain.get_output_rays()[Index]
115
116    NumericalAperture = man.GetNumericalAperture(RayListAnalysed, 1)  # NA determined from final ray bundle
117    MaxWavelength = np.max([i.wavelength for i in RayListAnalysed])
118    if DrawAiryAndFourier:
119        AiryRadius = man.GetAiryRadius(MaxWavelength, NumericalAperture) * 1e3  # in µm
120    else:
121        AiryRadius = 0
122    
123    if DrawFocalContour or DrawFocal:
124        X,Y,Z = man.GetDiffractionFocus(OpticalChain, movingDetector, Index)
125        Z/=np.max(Z) 
126
127    DectectorPoint2D_Xcoord, DectectorPoint2D_Ycoord, FocalSpotSize, SpotSizeSD = mpu._getDetectorPoints(
128        RayListAnalysed, movingDetector
129    )
130
131    match ColorCoded:
132        case "Intensity":
133            IntensityList = [k.intensity for k in RayListAnalysed]
134            z = np.asarray(IntensityList)
135            zlabel = "Intensity (arb.u.)"
136            title = "Intensity + Spot Diagram\n press left/right to move detector position"
137            addLine = ""
138        case "Incidence":
139            IncidenceList = [np.rad2deg(k.incidence) for k in RayListAnalysed]  # degree
140            z = np.asarray(IncidenceList)
141            zlabel = "Incidence angle (deg)"
142            title = "Ray Incidence + Spot Diagram\n press left/right to move detector position"
143            addLine = ""
144        case "Delay":
145            DelayList = movingDetector.get_Delays(RayListAnalysed)
146            DurationSD = mp.StandardDeviation(DelayList)
147            z = np.asarray(DelayList)
148            zlabel = "Delay (fs)"
149            title = "Delay + Spot Diagram\n press left/right to move detector position"
150            addLine = "\n" + "{:.2f}".format(DurationSD) + " fs SD"
151        case _:
152            z = "red"
153            title = "Spot Diagram\n press left/right to move detector position"
154            addLine = ""
155
156    distStep = min(50, max(0.0005, round(FocalSpotSize / 8 / np.arcsin(NumericalAperture) * 10000) / 10000))  # in mm
157
158    plt.ion()
159    fig, ax = plt.subplots()
160    if DrawFocal:
161        focal = ax.pcolormesh(X*1e3,Y*1e3,Z)
162    if DrawFocalContour:
163        levels = [1/np.e**2, 0.5]
164        contour = ax.contourf(X*1e3, Y*1e3, Z, levels=levels, cmap='gray')
165
166    if DrawAiryAndFourier:
167        theta = np.linspace(0, 2 * np.pi, 100)
168        x = AiryRadius * np.cos(theta)
169        y = AiryRadius * np.sin(theta)  #
170        ax.plot(x, y, c="black")
171        
172
173    foo = ax.scatter(
174        DectectorPoint2D_Xcoord,
175        DectectorPoint2D_Ycoord,
176        c=z,
177        s=15,
178        label="{:.3f}".format(detectorPosition.value) + " mm\n" + "{:.1f}".format(SpotSizeSD * 1e3) + " \u03BCm SD" + addLine,
179    )
180
181    axisLim = 1.1 * max(AiryRadius, 0.5 * FocalSpotSize * 1000)
182    ax.set_xlim(-axisLim, axisLim)
183    ax.set_ylim(-axisLim, axisLim)
184
185    if ColorCoded == "Intensity" or ColorCoded == "Incidence" or ColorCoded == "Delay":
186        cbar = fig.colorbar(foo)
187        cbar.set_label(zlabel)
188
189    ax.legend(loc="upper right")
190    ax.set_xlabel("X (µm)")
191    ax.set_ylabel("Y (µm)")
192    ax.set_title(title)
193    # ax.margins(x=0)
194
195
196    def update_plot(new_value):
197        nonlocal movingDetector, ColorCoded, zlabel, cbar, detectorPosition, foo, distStep, focal, contour, levels, Index, RayListAnalysed
198
199        newDectectorPoint2D_Xcoord, newDectectorPoint2D_Ycoord, newFocalSpotSize, newSpotSizeSD = mpu._getDetectorPoints(
200            RayListAnalysed, movingDetector
201        )
202
203        if DrawFocal:
204            focal.set_array(Z)
205        if DrawFocalContour:
206            levels = [1/np.e**2, 0.5]
207            for coll in contour.collections:
208                coll.remove()  # Remove old contour lines
209            contour = ax.contourf(X * 1e3, Y * 1e3, Z, levels=levels, cmap='gray')
210        
211        xy = foo.get_offsets()
212        xy[:, 0] = newDectectorPoint2D_Xcoord
213        xy[:, 1] = newDectectorPoint2D_Ycoord
214        foo.set_offsets(xy)
215
216
217        if ColorCoded == "Delay":
218            newDelayList = np.asarray(movingDetector.get_Delays(RayListAnalysed))
219            newDurationSD = mp.StandardDeviation(newDelayList)
220            newaddLine = "\n" + "{:.2f}".format(newDurationSD) + " fs SD"
221            foo.set_array(newDelayList)
222            foo.set_clim(min(newDelayList), max(newDelayList))
223            cbar.update_normal(foo)
224        else:
225            newaddLine = ""
226
227        foo.set_label(
228            "{:.3f}".format(detectorPosition.value) + " mm\n" + "{:.1f}".format(newSpotSizeSD * 1e3) + " \u03BCm SD" + newaddLine
229        )
230        ax.legend(loc="upper right")
231
232        axisLim = 1.1 * max(AiryRadius, 0.5 * newFocalSpotSize * 1000)
233        ax.set_xlim(-axisLim, axisLim)
234        ax.set_ylim(-axisLim, axisLim)
235
236        distStep = min(
237            50, max(0.0005, round(newFocalSpotSize / 8 / np.arcsin(NumericalAperture) * 10000) / 10000)
238        )  # in mm
239
240        fig.canvas.draw_idle()
241
242
243    def press(event):
244        nonlocal detectorPosition, distStep
245        if event.key == "right":
246            detectorPosition.value += distStep
247        elif event.key == "left":
248            if detectorPosition.value > 1.5 * distStep:
249                detectorPosition.value -= distStep
250            else:
251                detectorPosition.value = 0.5 * distStep
252        else:
253            return None
254
255    fig.canvas.mpl_connect("key_press_event", press)
256
257    plt.show()
258
259    detectorPosition.register(update_plot)
260
261
262    return fig, detectorPosition

Produce an interactive figure with the spot diagram on the selected Detector. The detector distance can be shifted with the left-right cursor keys. Doing so will actually move the detector. If DrawAiryAndFourier is True, a circle with the Airy-spot-size will be shown. If DrawFocalContour is True, the focal contour calculated from some of the rays will be shown. If DrawFocal is True, a heatmap calculated from some of the rays will be shown. The 'spots' can optionally be color-coded by specifying ColorCoded, which can be one of ["Intensity","Incidence","Delay"].

Parameters

RayListAnalysed : list(Ray)
    List of objects of the ModuleOpticalRay.Ray-class.

Detector : Detector or str, optional
    An object of the ModuleDetector.Detector-class or the name of the detector. The default is "Focus".

DrawAiryAndFourier : bool, optional
    Whether to draw a circle with the Airy-spot-size. The default is False.

DrawFocalContour : bool, optional
    Whether to draw the focal contour. The default is False.

DrawFocal : bool, optional
    Whether to draw the focal heatmap. The default is False.

ColorCoded : str, optional
    Color-code the spots according to one of ["Intensity","Incidence","Delay"]. The default is None.

Observer : Observer, optional
    An observer object. If none, then we just create a copy of the detector and move it when pressing left-right. 
    However, if an observer is specified, then we will change the value of the observer and it will issue 
    the required callbacks to update several plots at the same time.

Returns

fig : matlplotlib-figure-handle.
    Shows the interactive figure.
def DrawDelaySpots( OpticalChain, DeltaFT: tuple[int, float], Detector='Focus', DrawAiryAndFourier=False, ColorCoded=None, Observer=None) -> matplotlib.figure.Figure:
267def DrawDelaySpots(OpticalChain, 
268                DeltaFT: tuple[int, float],
269                Detector = "Focus",
270                DrawAiryAndFourier=False, 
271                ColorCoded=None,
272                Observer = None
273                ) -> plt.Figure:
274    """
275    Produce a an interactive figure with a spot diagram resulting from the RayListAnalysed
276    hitting the Detector, with the ray-delays shown in the 3rd dimension.
277    The detector distance can be shifted with the left-right cursor keys.
278    If DrawAiryAndFourier is True, a cylinder is shown whose diameter is the Airy-spot-size and
279    whose height is the Fourier-limited pulse duration 'given by 'DeltaFT'.
280    
281    The 'spots' can optionally be color-coded by specifying ColorCoded as ["Intensity","Incidence"].
282
283    Parameters
284    ----------
285        RayListAnalysed : list(Ray)
286            List of objects of the ModuleOpticalRay.Ray-class.
287
288        Detector : Detector
289            An object of the ModuleDetector.Detector-class.
290
291        DeltaFT : (int, float)
292            The Fourier-limited pulse duration. Just used as a reference to compare the temporal spread
293            induced by the ray-delays.
294
295        DrawAiryAndFourier : bool, optional
296            Whether to draw a cylinder showing the Airy-spot-size and Fourier-limited-duration.
297            The default is False.
298
299        ColorCoded : str, optional
300            Color-code the spots according to one of ["Intensity","Incidence"].
301            The default is None.
302
303    Returns
304    -------
305        fig : matlplotlib-figure-handle.
306            Shows the interactive figure.
307    """
308    if isinstance(Detector, str):
309        Det = OpticalChain.detectors[Detector]
310    else:
311        Det = Detector
312    Index = Det.index
313    Detector = copy(Det)
314    if Observer is None:
315        detectorPosition = Observable(Detector.distance)
316    else:
317        detectorPosition = Observer
318        Detector.distance = detectorPosition.value
319    
320    RayListAnalysed = OpticalChain.get_output_rays()[Index]
321    fig, NumericalAperture, AiryRadius, FocalSpotSize = _drawDelayGraph(
322        RayListAnalysed, Detector, detectorPosition.value, DeltaFT, DrawAiryAndFourier, ColorCoded
323    )
324
325    distStep = min(50, max(0.0005, round(FocalSpotSize / 8 / np.arcsin(NumericalAperture) * 10000) / 10000))  # in mm
326
327    movingDetector = copy(Detector)
328
329    def update_plot(new_value):
330        nonlocal movingDetector, ColorCoded, detectorPosition, distStep, fig
331        ax = fig.axes[0]
332        cam = [ax.azim, ax.elev, ax._dist]
333        fig, sameNumericalAperture, sameAiryRadius, newFocalSpotSize = _drawDelayGraph(
334            RayListAnalysed, movingDetector, detectorPosition.value, DeltaFT, DrawAiryAndFourier, ColorCoded, fig
335        )
336        ax = fig.axes[0]
337        ax.azim, ax.elev, ax._dist = cam
338        distStep = min(
339            50, max(0.0005, round(newFocalSpotSize / 8 / np.arcsin(NumericalAperture) * 10000) / 10000)
340        )
341        return fig
342
343    def press(event):
344        nonlocal detectorPosition, distStep, movingDetector, fig
345        if event.key == "right":
346            detectorPosition.value += distStep
347        elif event.key == "left":
348            if detectorPosition.value > 1.5 * distStep:
349                detectorPosition.value -= distStep
350            else:
351                detectorPosition.value = 0.5 * distStep
352
353    fig.canvas.mpl_connect("key_press_event", press)
354    detectorPosition.register(update_plot)
355    detectorPosition.register_calculation(lambda x: movingDetector.set_distance(x))
356
357    return fig, Observable

Produce a an interactive figure with a spot diagram resulting from the RayListAnalysed hitting the Detector, with the ray-delays shown in the 3rd dimension. The detector distance can be shifted with the left-right cursor keys. If DrawAiryAndFourier is True, a cylinder is shown whose diameter is the Airy-spot-size and whose height is the Fourier-limited pulse duration 'given by 'DeltaFT'.

The 'spots' can optionally be color-coded by specifying ColorCoded as ["Intensity","Incidence"].

Parameters

RayListAnalysed : list(Ray)
    List of objects of the ModuleOpticalRay.Ray-class.

Detector : Detector
    An object of the ModuleDetector.Detector-class.

DeltaFT : (int, float)
    The Fourier-limited pulse duration. Just used as a reference to compare the temporal spread
    induced by the ray-delays.

DrawAiryAndFourier : bool, optional
    Whether to draw a cylinder showing the Airy-spot-size and Fourier-limited-duration.
    The default is False.

ColorCoded : str, optional
    Color-code the spots according to one of ["Intensity","Incidence"].
    The default is None.

Returns

fig : matlplotlib-figure-handle.
    Shows the interactive figure.
def DrawMirrorProjection( OpticalChain, ReflectionNumber: int, ColorCoded=None, Detector='') -> matplotlib.figure.Figure:
438def DrawMirrorProjection(OpticalChain, ReflectionNumber: int, ColorCoded=None, Detector="") -> plt.Figure:
439    """
440    Produce a plot of the ray impact points on the optical element with index 'ReflectionNumber'.
441    The points can be color-coded according ["Incidence","Intensity","Delay"], where the ray delay is
442    measured at the Detector.
443
444    Parameters
445    ----------
446        OpticalChain : OpticalChain
447           List of objects of the ModuleOpticalOpticalChain.OpticalChain-class.
448
449        ReflectionNumber : int
450            Index specifying the optical element on which you want to see the impact points.
451
452        Detector : Detector, optional
453            Object of the ModuleDetector.Detector-class. Only necessary to project delays. The default is None.
454
455        ColorCoded : str, optional
456            Specifies which ray property to color-code: ["Incidence","Intensity","Delay"]. The default is None.
457
458    Returns
459    -------
460        fig : matlplotlib-figure-handle.
461            Shows the figure.
462    """
463    from mpl_toolkits.axes_grid1 import make_axes_locatable
464    if isinstance(Detector, str):
465        if Detector == "":
466            Detector = None
467        else:
468            Detector = OpticalChain.detectors[Detector]
469
470    Position = OpticalChain[ReflectionNumber].position
471    q = OpticalChain[ReflectionNumber].orientation
472    # n = OpticalChain.optical_elements[ReflectionNumber].normal
473    # m = OpticalChain.optical_elements[ReflectionNumber].majoraxis
474
475    RayListAnalysed = OpticalChain.get_output_rays()[ReflectionNumber]
476    # transform rays into the mirror-support reference frame
477    # (same as mirror frame but without the shift by mirror-centre)
478    r0 = OpticalChain[ReflectionNumber].r0
479    RayList = [r.to_basis(*OpticalChain[ReflectionNumber].basis) for r in RayListAnalysed]
480
481    x = np.asarray([k.point[0] for k in RayList]) - r0[0]
482    y = np.asarray([k.point[1] for k in RayList]) - r0[1]
483    if ColorCoded == "Intensity":
484        IntensityList = [k.intensity for k in RayListAnalysed]
485        z = np.asarray(IntensityList)
486        zlabel = "Intensity (arb.u.)"
487        title = "Ray intensity projected on mirror              "
488    elif ColorCoded == "Incidence":
489        IncidenceList = [np.rad2deg(k.incidence) for k in RayListAnalysed]  # in degree
490        z = np.asarray(IncidenceList)
491        zlabel = "Incidence angle (deg)"
492        title = "Ray incidence projected on mirror              "
493    elif ColorCoded == "Delay":
494        if Detector is not None:
495            z = np.asarray(Detector.get_Delays(RayListAnalysed))
496            zlabel = "Delay (fs)"
497            title = "Ray delay at detector projected on mirror              "
498        else:
499            raise ValueError("If you want to project ray delays, you must specify a detector.")
500    else:
501        z = "red"
502        title = "Ray impact points projected on mirror"
503
504    plt.ion()
505    fig = plt.figure()
506    ax = OpticalChain.optical_elements[ReflectionNumber].support._ContourSupport(fig)
507    p = plt.scatter(x, y, c=z, s=15)
508    if ColorCoded == "Delay" or ColorCoded == "Incidence" or ColorCoded == "Intensity":
509        divider = make_axes_locatable(ax)
510        cax = divider.append_axes("right", size="5%", pad=0.05)
511        cbar = fig.colorbar(p, cax=cax)
512        cbar.set_label(zlabel)
513    ax.set_xlabel("x (mm)")
514    ax.set_ylabel("y (mm)")
515    plt.title(title, loc="right")
516    plt.tight_layout()
517
518    bbox = ax.get_position()
519    bbox.set_points(bbox.get_points() - np.array([[0.01, 0], [0.01, 0]]))
520    ax.set_position(bbox)
521    plt.show()
522
523    return fig

Produce a plot of the ray impact points on the optical element with index 'ReflectionNumber'. The points can be color-coded according ["Incidence","Intensity","Delay"], where the ray delay is measured at the Detector.

Parameters

OpticalChain : OpticalChain
   List of objects of the ModuleOpticalOpticalChain.OpticalChain-class.

ReflectionNumber : int
    Index specifying the optical element on which you want to see the impact points.

Detector : Detector, optional
    Object of the ModuleDetector.Detector-class. Only necessary to project delays. The default is None.

ColorCoded : str, optional
    Specifies which ray property to color-code: ["Incidence","Intensity","Delay"]. The default is None.

Returns

fig : matlplotlib-figure-handle.
    Shows the figure.
def DrawSetup( OpticalChain, EndDistance=None, maxRays=300, OEpoints=2000, draw_mesh=False, cycle_ray_colors=False, impact_points=False, DrawDetectors=True, DetectedRays=False, Observers={}):
529def DrawSetup(OpticalChain, 
530                   EndDistance=None, 
531                   maxRays=300, 
532                   OEpoints=2000, 
533                   draw_mesh=False, 
534                   cycle_ray_colors = False,
535                   impact_points = False,
536                   DrawDetectors=True,
537                   DetectedRays = False,
538                   Observers = dict()):
539    """
540    Renders an image of the Optical setup and the traced rays.
541
542    Parameters
543    ----------
544        OpticalChain : OpticalChain
545            List of objects of the ModuleOpticalOpticalChain.OpticalChain-class.
546
547        EndDistance : float, optional
548            The rays of the last ray bundle are drawn with a length given by EndDistance (in mm). If not specified,
549            this distance is set to that between the source point and the 1st optical element.
550
551        maxRays: int
552            The maximum number of rays to render. Rendering all the traced rays is a insufferable resource hog
553            and not required for a nice image. Default is 150.
554
555        OEpoints : int
556            How many little spheres to draw to represent the optical elements.  Default is 2000.
557
558    Returns
559    -------
560        fig : Pyvista-figure-handle.
561            Shows the figure.
562    """
563
564    RayListHistory = [OpticalChain.source_rays] + OpticalChain.get_output_rays()
565
566    if EndDistance is None:
567        EndDistance = np.linalg.norm(OpticalChain.source_rays[0].point - OpticalChain.optical_elements[0].position)
568
569    print("...rendering image of optical chain...", end="", flush=True)
570    fig = pvqt.BackgroundPlotter(window_size=(1500, 500), notebook=False) # Opening a window
571    fig.set_background('white')
572    
573    if cycle_ray_colors:
574        colors = mpu.generate_distinct_colors(len(OpticalChain)+1)
575    else:
576        colors = [[0.7, 0, 0]]*(len(OpticalChain)+1) # Default color: dark red
577
578    # Optics display
579    # For each optic we will send the figure to the function _RenderOpticalElement and it will add the optic to the figure
580    for i,OE in enumerate(OpticalChain.optical_elements):
581        color = pv.Color(colors[i+1])
582        rgb = color.float_rgb
583        h, l, s = rgb_to_hls(*rgb)
584        s = max(0, min(1, s * 0.3))  # Decrease saturation
585        l = max(0, min(1, l + 0.1))  # Increase lightness
586        new_rgb = hls_to_rgb(h, l, s)
587        darkened_color = pv.Color(new_rgb)
588        mpm._RenderOpticalElement(fig, OE, OEpoints, draw_mesh, darkened_color, index=i)
589    ray_meshes = mpm._RenderRays(RayListHistory, EndDistance, maxRays)
590    for i,ray in enumerate(ray_meshes):
591        color = pv.Color(colors[i])
592        fig.add_mesh(ray, color=color, name=f"RayBundle_{i}")
593    if impact_points:
594        for i,rays in enumerate(RayListHistory):
595            points = np.array([list(r.point) for r in rays], dtype=np.float32)
596            points = pv.PolyData(points)
597            color = pv.Color(colors[i-1])
598            fig.add_mesh(points, color=color, point_size=5, name=f"RayImpactPoints_{i}")
599    
600    detector_copies = {key: copy(OpticalChain.detectors[key]) for key in OpticalChain.detectors.keys()}
601    detector_meshes_list = []
602    detectedpoint_meshes = dict()
603    
604    if OpticalChain.detectors is not None and DrawDetectors:
605        # Detector display
606        for key in OpticalChain.detectors.keys():
607            det = detector_copies[key]
608            index = OpticalChain.detectors[key].index
609            if key in Observers:
610                det.distance = Observers[key].value
611                #Observers[key].register_calculation(lambda x: det.set_distance(x))
612            mpm._RenderDetector(fig, det, name = key, detector_meshes = detector_meshes_list)
613            if DetectedRays:
614                RayListAnalysed = OpticalChain.get_output_rays()[index]
615                points = det.get_3D_points(RayListAnalysed)
616                points = pv.PolyData(points)
617                detectedpoint_meshes[key] = points
618                fig.add_mesh(points, color='purple', point_size=5, name=f"DetectedRays_{key}")
619    detector_meshes = dict(zip(OpticalChain.detectors.keys(), detector_meshes_list))
620    
621    # Now we define a function that will move on the plot the detector with name "detname" when it's called
622    def move_detector(detname, new_value):
623        nonlocal fig, detector_meshes, detectedpoint_meshes, DetectedRays, detectedpoint_meshes, detector_copies, OpticalChain
624        det = detector_copies[detname]
625        index = OpticalChain.detectors[detname].index
626        det_mesh = detector_meshes[detname]
627        translation = det.normal * (det.distance - new_value)
628        det_mesh.translate(translation, inplace=True)
629        det.distance = new_value
630        if DetectedRays:
631            points_mesh = detectedpoint_meshes[detname]
632            points_mesh.points = det.get_3D_points(OpticalChain.get_output_rays()[index])
633        fig.show()
634    
635    # Now we register the function to the observers
636    for key in OpticalChain.detectors.keys():
637        if key in Observers:
638            Observers[key].register(lambda x: move_detector(key, x))
639
640    #pv.save_meshio('optics.inp', pointcloud)  
641    print(
642        "\r\033[K", end="", flush=True
643    )  # move to beginning of the line with \r and then delete the whole line with \033[K
644    fig.show()
645    return fig

Renders an image of the Optical setup and the traced rays.

Parameters

OpticalChain : OpticalChain
    List of objects of the ModuleOpticalOpticalChain.OpticalChain-class.

EndDistance : float, optional
    The rays of the last ray bundle are drawn with a length given by EndDistance (in mm). If not specified,
    this distance is set to that between the source point and the 1st optical element.

maxRays: int
    The maximum number of rays to render. Rendering all the traced rays is a insufferable resource hog
    and not required for a nice image. Default is 150.

OEpoints : int
    How many little spheres to draw to represent the optical elements.  Default is 2000.

Returns

fig : Pyvista-figure-handle.
    Shows the figure.
def DrawAsphericity(Mirror, Npoints=1000):
651def DrawAsphericity(Mirror, Npoints=1000):
652    """
653    This function displays a map of the asphericity of the mirror.
654    It's a scatter plot of the points of the mirror surface, with the color representing the distance to the closest sphere.
655    The closest sphere is calculated by the function get_closest_sphere, so least square method.
656
657    Parameters
658    ----------
659    Mirror : Mirror
660        The mirror to analyse.
661
662    Npoints : int, optional
663        The number of points to sample on the mirror surface. The default is 1000.
664    
665    Returns
666    -------
667    fig : Figure
668        The figure of the plot.
669    """
670    plt.ion()
671    fig = plt.figure()
672    ax = Mirror.support._ContourSupport(fig)
673    center, radius = man.GetClosestSphere(Mirror, Npoints)
674    Points = mpm.sample_support(Mirror.support, Npoints=1000)
675    Points += Mirror.r0[:2]
676    Z = Mirror._zfunc(Points)
677    Points = mgeo.PointArray([Points[:, 0], Points[:, 1], Z]).T
678    X, Y = Points[:, 0] - Mirror.r0[0], Points[:, 1] - Mirror.r0[1]
679    Points_centered = Points - center
680    Distance = np.linalg.norm(Points_centered, axis=1) - radius
681    Distance*=1e3 # To convert to µm
682    p = plt.scatter(X, Y, c=Distance, s=15)
683    divider = man.make_axes_locatable(ax)
684    cax = divider.append_axes("right", size="5%", pad=0.05)
685    cbar = fig.colorbar(p, cax=cax)
686    cbar.set_label("Distance to closest sphere (µm)")
687    ax.set_xlabel("x (mm)")
688    ax.set_ylabel("y (mm)")
689    plt.title("Asphericity map", loc="right")
690    plt.tight_layout()
691
692    bbox = ax.get_position()
693    bbox.set_points(bbox.get_points() - np.array([[0.01, 0], [0.01, 0]]))
694    ax.set_position(bbox)
695    plt.show()
696    return fig

This function displays a map of the asphericity of the mirror. It's a scatter plot of the points of the mirror surface, with the color representing the distance to the closest sphere. The closest sphere is calculated by the function get_closest_sphere, so least square method.

Parameters

Mirror : Mirror The mirror to analyse.

Npoints : int, optional The number of points to sample on the mirror surface. The default is 1000.

Returns

fig : Figure The figure of the plot.

def DrawCaustics(OpticalChain, Range=1, Detector='Focus', Npoints=1000, Nrays=1000):
701def DrawCaustics(OpticalChain, Range=1, Detector="Focus" , Npoints=1000, Nrays=1000):
702    """
703    This function displays the caustics of the rays on the detector.
704    To do so, it calculates the intersections of the rays with the detector over a 
705    range determined by the parameter Range, and then plots the standard deviation of the
706    positions in the x and y directions.
707
708    Parameters
709    ----------
710    OpticalChain : OpticalChain
711        The optical chain to analyse.
712
713    DetectorName : str
714        The name of the detector on which the caustics are calculated.
715    
716    Range : float
717        The range of the detector over which to calculate the caustics.
718
719    Npoints : int, optional
720        The number of points to sample on the detector. The default is 1000.
721    
722    Returns
723    -------
724    fig : Figure
725        The figure of the plot.
726    """
727    distances = np.linspace(-Range, Range, Npoints)
728    if isinstance(Detector, str):
729        Det = OpticalChain.detectors[Detector]
730        Index = Det.index
731    Rays = OpticalChain.get_output_rays()[Index]
732    Nrays = min(Nrays, len(Rays))
733    Rays = np.random.choice(Rays, Nrays, replace=False)
734    LocalRayList = [r.to_basis(*Det.basis) for r in Rays]
735    Points = mgeo.IntersectionRayListZPlane(LocalRayList, distances)
736    x_std = []
737    y_std = []
738    for i in range(len(distances)):
739        x_std.append(mp.StandardDeviation(Points[i][:,0]))
740        y_std.append(mp.StandardDeviation(Points[i][:,1]))
741    plt.ion()
742    fig, ax = plt.subplots()
743    ax.plot(distances, x_std, label="x std")
744    ax.plot(distances, y_std, label="y std")
745    ax.set_xlabel("Detector distance (mm)")
746    ax.set_ylabel("Standard deviation (mm)")
747    ax.legend()
748    plt.title("Caustics")
749    plt.show()
750    return fig

This function displays the caustics of the rays on the detector. To do so, it calculates the intersections of the rays with the detector over a range determined by the parameter Range, and then plots the standard deviation of the positions in the x and y directions.

Parameters

OpticalChain : OpticalChain The optical chain to analyse.

DetectorName : str The name of the detector on which the caustics are calculated.

Range : float The range of the detector over which to calculate the caustics.

Npoints : int, optional The number of points to sample on the detector. The default is 1000.

Returns

fig : Figure The figure of the plot.

3 - ModulePlottingMethods

Adds plotting and visualisation methods to various classes from ARTcore. Most of the code is simply moved here from the class definitions previously.

Created in July 2024

@author: André Kalouguine + Stefan Haessler + Anthony Guillaume

  1"""
  2Adds plotting and visualisation methods to various classes from ARTcore.
  3Most of the code is simply moved here from the class definitions previously.
  4
  5Created in July 2024
  6
  7@author: André Kalouguine + Stefan Haessler + Anthony Guillaume
  8"""
  9# %%
 10import numpy as np
 11import matplotlib.pyplot as plt
 12from mpl_toolkits.mplot3d import Axes3D
 13import matplotlib.patches as patches
 14import pyvista as pv
 15import pyvistaqt as pvqt
 16import colorcet as cc
 17
 18import ARTcore.ModuleSupport as msup
 19import ARTcore.ModuleMirror as mmir
 20import ARTcore.ModuleGeometry as mgeo
 21import ARTcore.ModuleMask as mmask
 22from ARTcore.ModuleGeometry import Point, Vector, Origin
 23import itertools
 24
 25# %% Adding support drawing for mirror projection
 26
 27def SupportRound_ContourSupport(self, Figure):
 28    """Draw support contour in MirrorProjection plots."""
 29    axe = Figure.add_subplot(111, aspect="equal")
 30    axe.add_patch(patches.Circle((0, 0), self.radius, alpha=0.08))
 31    return axe
 32
 33msup.SupportRound._ContourSupport = SupportRound_ContourSupport
 34
 35def SupportRoundHole_ContourSupport(self, Figure):
 36    """Draws support contour in MirrorProjection plots."""
 37    axe = Figure.add_subplot(111, aspect="equal")
 38    axe.add_patch(patches.Circle((0, 0), self.radius, alpha=0.08))
 39    axe.add_patch(patches.Circle((self.centerholeX, self.centerholeY), self.radiushole, color="white", alpha=1))
 40    return axe
 41
 42msup.SupportRoundHole._ContourSupport = SupportRoundHole_ContourSupport
 43
 44def SupportRectangle_ContourSupport(self, Figure):
 45    """Draws support contour in MirrorProjection plots."""
 46    axe = Figure.add_subplot(111, aspect="equal")
 47    axe.add_patch(patches.Rectangle((-self.dimX * 0.5, -self.dimY * 0.5), self.dimX, self.dimY, alpha=0.08))
 48    return axe
 49
 50msup.SupportRectangle._ContourSupport = SupportRectangle_ContourSupport
 51
 52def SupportRectangleHole_ContourSupport(self, Figure):
 53    """Draws support contour in MirrorProjection plots."""
 54    axe = Figure.add_subplot(111, aspect="equal")
 55    axe.add_patch(patches.Rectangle((-self.dimX * 0.5, -self.dimY * 0.5), self.dimX, self.dimY, alpha=0.08))
 56    axe.add_patch(patches.Circle((self.centerholeX, self.centerholeY), self.radiushole, color="white", alpha=1))
 57    return axe
 58
 59msup.SupportRectangleHole._ContourSupport = SupportRectangleHole_ContourSupport
 60
 61def SupportRectangleRectHole_ContourSupport(self, Figure):
 62    """Draws support contour in MirrorProjection plots."""
 63    axe = Figure.add_subplot(111, aspect="equal")
 64    axe.add_patch(patches.Rectangle((-self.dimX * 0.5, -self.dimY * 0.5), self.dimX, self.dimY, alpha=0.08))
 65    axe.add_patch(
 66        patches.Rectangle(
 67            (-self.holeX * 0.5 + self.centerholeX, -self.holeY * 0.5 + self.centerholeY),
 68            self.holeX,
 69            self.holeY,
 70            color="white",
 71            alpha=1,
 72        )
 73    )
 74    return axe
 75
 76msup.SupportRectangleRectHole._ContourSupport = SupportRectangleRectHole_ContourSupport
 77
 78# %% Support sampling and meshing 
 79
 80def sample_support(Support, Npoints):
 81    """
 82    This function samples a regular grid on the support and filters out the points that are outside of the support.
 83    For that it uses the _estimate_size method of the support.
 84    It then generates a regular grid over that area and filters out the points that are outside of the support.
 85    If less than half of the points are inside the support, it will increase the number of points by a factor of 2 and try again.
 86    """
 87    if hasattr(Support, "size"):
 88        size = Support.size
 89    else:
 90        size_x = Support._estimate_size()
 91        size = np.array([2*size_x, 2*size_x])
 92    enough = False
 93    while not enough:
 94        X = np.linspace(-size[0] * 0.5, size[0] * 0.5, int(np.sqrt(Npoints)))
 95        Y = np.linspace(-size[1] * 0.5, size[1] * 0.5, int(np.sqrt(Npoints)))
 96        X, Y = np.meshgrid(X, Y)
 97        Points = np.array([X.flatten(), Y.flatten()]).T
 98        Points = [p for p in Points if p in Support]
 99        if len(Points) > Npoints * 0.5:
100            enough = True
101        else:
102            Npoints *= 2
103    return mgeo.PointArray(Points)
104
105def sample_z_values(Support, Npoints):
106    Npoints_init = Npoints
107    if hasattr(Support, "size"):
108        size = Support.size
109    else:
110        size_x = Support._estimate_size()
111        size = np.array([2*size_x, 2*size_x])
112    enough = False
113    while not enough:
114        X = np.linspace(-size[0] * 0.5, size[0] * 0.5, int(np.sqrt(Npoints)))
115        Y = np.linspace(-size[1] * 0.5, size[1] * 0.5, int(np.sqrt(Npoints)))
116        X, Y = np.meshgrid(X, Y)
117        Points = np.array([X.flatten(), Y.flatten()]).T
118        Points = [p for p in Points if p in Support]
119        if len(Points) > Npoints_init * 0.5:
120            enough = True
121        else:
122            Npoints *= 2
123    Points = np.array([X.flatten(), Y.flatten()]).T
124    mask = np.array([p in Support for p in Points]).reshape(X.shape)
125    return X,Y,mask
126
127
128# %% Mesh/pointcloud generation for different mirror/mask types
129
130# For the mirrors that are already implemented in ARTcore
131# we can use the method _zfunc to render the mirror surface. 
132# For that we sample the support and then use the _zfunc method to get the z-values of the surface.
133# We then use the pyvista function add_mesh to render the surface.
134
135def _RenderMirror(Mirror, Npoints=10000):
136    """
137    This function renders a mirror in 3D.
138    It samples the support of the mirror and uses the _zfunc method to get the z-values of the surface.
139    It draws a small sphere at each point of the support and connects them to form a surface.
140    """
141    Points = sample_support(Mirror.support, Npoints)
142    Points += Mirror.r0[:2]
143    Z = Mirror._zfunc(Points)
144    Points = mgeo.PointArray([Points[:, 0], Points[:, 1], Z]).T
145    Points = Points.from_basis(*Mirror.basis)
146    mesh = pv.PolyData(Points)
147    return mesh
148
149
150def _RenderMirrorSurface(Mirror, Npoints=10000):
151    """
152    This function renders a mirror in 3D.
153    It samples the support of the mirror and uses the _zfunc method to get the z-values of the surface.
154    It draws a small sphere at each point of the support and connects them to form a surface.
155    """
156    X,Y,mask = sample_z_values(Mirror.support, Npoints)
157    shape = X.shape
158    X += Mirror.r0[0]
159    Y += Mirror.r0[1]
160    Z = Mirror._zfunc(mgeo.PointArray([X.flatten(), Y.flatten()]).T).reshape(X.shape)
161    #Z += Mirror.r0[2]
162    Z[~mask] = np.nan
163    Points = mgeo.PointArray([X.flatten(), Y.flatten(), Z.flatten()]).T
164    Points = Points.from_basis(*Mirror.basis)
165    X,Y,Z = Points.T
166    X = X.reshape(shape)
167    Y = Y.reshape(shape)
168    Z = Z.reshape(shape)
169    mesh = pv.StructuredGrid(X, Y, Z)
170    return mesh
171
172mmir.MirrorPlane._Render = _RenderMirror
173mmir.MirrorParabolic._Render = _RenderMirror
174mmir.MirrorCylindrical._Render = _RenderMirror
175mmir.MirrorEllipsoidal._Render = _RenderMirror
176mmir.MirrorSpherical._Render = _RenderMirror
177mmir.MirrorToroidal._Render = _RenderMirror
178
179mmask.Mask._Render = _RenderMirror
180
181mmir.MirrorPlane._Render = _RenderMirrorSurface
182mmir.MirrorParabolic._Render = _RenderMirrorSurface
183mmir.MirrorCylindrical._Render = _RenderMirrorSurface
184mmir.MirrorEllipsoidal._Render = _RenderMirrorSurface
185mmir.MirrorSpherical._Render = _RenderMirrorSurface
186mmir.MirrorToroidal._Render = _RenderMirrorSurface
187
188mmask.Mask._Render = _RenderMirrorSurface
189
190# %% Optical element rendering
191def _RenderOpticalElement(fig, OE, OEpoints, draw_mesh = False, color="blue", index=0):
192    """
193    This function renders the optical elements in the 3D plot.
194    It receives a pyvista figure handle, and draws the optical element on it.
195    """
196    mesh = OE._Render(OEpoints)
197    fig.add_mesh(mesh, color = color, name = f"{OE.description} {index}")    
198
199def _RenderDetector(fig, Detector, size = 40, name = "Focus", detector_meshes = None):
200    """
201    Unfinished
202    """
203    Points = [np.array([-size/2,size/2,0]),np.array([size/2,size/2,0]),np.array([size/2,-size/2,0]),np.array([-size/2,-size/2,0])]
204    Points = mgeo.PointArray(Points)
205    Points = Points.from_basis(*Detector.basis)
206    Rect = pv.Rectangle(Points[:3])
207    if detector_meshes is not None:
208        detector_meshes += [Rect]
209    fig.add_mesh(Rect, color="green", name=f"Detector {name}")
210
211# %% Support rendering using the sdf method
212
213def _RenderSupport(Support, Npoints=10000):
214    """
215    This function renders a support in 3D.
216    """
217    X,Y,mask = sample_z_values(Support, Npoints)
218    shape = X.shape
219    Z = np.zeros_like(X)
220    Z[~mask] = np.nan
221    mesh = pv.StructuredGrid(X, Y, Z)
222    return mesh
223    
224msup.SupportRound._Render = _RenderSupport
225msup.SupportRoundHole._Render = _RenderSupport
226msup.SupportRectangle._Render = _RenderSupport
227msup.SupportRectangleHole._Render = _RenderSupport
228msup.SupportRectangleRectHole._Render = _RenderSupport
229
230
231# %% Standalone optical element rendering
232
233def _RenderMirror(Mirror, Npoints=1000, draw_support = False, draw_points = False, draw_vectors = True , recenter_support = True):
234    """
235    This function renders a mirror in 3D.
236    """
237    mesh = Mirror._Render(Npoints)
238    p = pv.Plotter()
239    p.add_mesh(mesh)
240    if draw_support:
241        support = Mirror.support._Render()
242        if recenter_support:
243            support.translate(Mirror.r0,inplace=True)
244        p.add_mesh(support, color="gray", opacity=0.5)
245    
246    if draw_vectors:
247        # We draw the important vectors of the optical element
248        # For that, if we have a "vectors" attribute, we use that
249        #  (a dictionary with the vector names as keys and the colors as values)
250        # Otherwise we use the default vectors: "support_normal", "majoraxis"
251        if hasattr(Mirror, "vectors"):
252            vectors = Mirror.vectors
253        else:
254            vectors = {"support_normal_ref": "red", "majoraxis_ref": "blue"}
255        for vector, color in vectors.items():
256            if hasattr(Mirror, vector):
257                p.add_arrows(Mirror.r0, 10*getattr(Mirror, vector), color=color)
258    
259    if draw_points:
260        # We draw the important points of the optical element
261        # For that, if we have a "points" attribute, we use that
262        #  (a dictionary with the point names as keys and the colors as values)
263        # Otherwise we use the default points: "centre_ref"
264        if hasattr(Mirror, "points"):
265            points = Mirror.points
266        else:
267            points = {"centre_ref": "red"}
268        for point, color in points.items():
269            if hasattr(Mirror, point):
270                p.add_mesh(pv.Sphere(radius=1, center=getattr(Mirror, point)), color=color)
271    else:
272        p.add_mesh(pv.Sphere(radius=1, center=Mirror.r0), color="red")
273    p.show()
274    return p
275
276mmir.MirrorPlane.render = _RenderMirror
277mmir.MirrorParabolic.render = _RenderMirror
278mmir.MirrorParabolic.vectors = {
279    "support_normal": "red",
280    "majoraxis": "blue",
281    "towards_focusing_ref": "green"
282}
283mmir.MirrorParabolic.points = {
284    "centre_ref": "red",
285    "focus_ref": "green"
286}
287mmir.MirrorCylindrical.render = _RenderMirror
288mmir.MirrorEllipsoidal.render = _RenderMirror
289mmir.MirrorEllipsoidal.vectors = {
290    "support_normal_ref": "red",
291    "majoraxis_ref": "blue",
292    "towards_image_ref": "green",
293    "towards_object_ref": "green",
294    "centre_normal_ref": "purple"}
295
296mmir.MirrorSpherical.render = _RenderMirror
297mmir.MirrorToroidal.render = _RenderMirror
298
299mmask.Mask.render = _RenderMirror
300
301# %% Ray rendering
302
303def _RenderRays(RayListHistory, EndDistance, maxRays=150):
304    """
305    Generates a list of Pyvista-meshes representing the rays in RayListHistory.
306    This can then be employed to render the rays in a Pyvista-figure.
307
308    Parameters
309    ----------
310        RayListHistory : list(list(Ray))
311            A list of lists of objects of the ModuleOpticalRay.Ray-class.
312
313        EndDistance : float
314            The rays of the last ray bundle are drawn with a length given by EndDistance (in mm).
315
316        maxRays : int, optional
317            The maximum number of rays to render. Rendering all the traced rays is a insufferable resource hog
318            and not required for a nice image. Default is 150.
319    
320    Returns
321    -------
322        meshes : list(pvPolyData)
323            List of Pyvista PolyData objects representing the rays
324    """
325    meshes = []
326    # Ray display
327    for k in range(len(RayListHistory)):
328        x = []
329        y = []
330        z = []
331        if k != len(RayListHistory) - 1:
332            knums = list(
333                map(lambda x: x.number, RayListHistory[k])
334            )  # make a list of all ray numbers that are still in the game
335            if len(RayListHistory[k + 1]) > maxRays:
336                rays_to_render = np.random.choice(RayListHistory[k + 1], maxRays, replace=False)
337            else:
338                rays_to_render = RayListHistory[k + 1]
339
340            for j in rays_to_render:
341                indx = knums.index(j.number)
342                i = RayListHistory[k][indx]
343                Point1 = i.point
344                Point2 = j.point
345                x += [Point1[0], Point2[0]]
346                y += [Point1[1], Point2[1]]
347                z += [Point1[2], Point2[2]]
348
349        else:
350            if len(RayListHistory[k]) > maxRays:
351                rays_to_render = np.random.choice(RayListHistory[k], maxRays, replace=False)
352            else:
353                rays_to_render = RayListHistory[k]
354
355            for j in rays_to_render:
356                Point = j.point
357                Vector = j.vector
358                x += [Point[0], Point[0] + Vector[0] * EndDistance]
359                y += [Point[1], Point[1] + Vector[1] * EndDistance]
360                z += [Point[2], Point[2] + Vector[2] * EndDistance]
361        points = np.column_stack((x, y, z))
362        meshes += [pv.line_segments_from_points(points)]
363    return meshes
def SupportRound_ContourSupport(self, Figure):
28def SupportRound_ContourSupport(self, Figure):
29    """Draw support contour in MirrorProjection plots."""
30    axe = Figure.add_subplot(111, aspect="equal")
31    axe.add_patch(patches.Circle((0, 0), self.radius, alpha=0.08))
32    return axe

Draw support contour in MirrorProjection plots.

def SupportRoundHole_ContourSupport(self, Figure):
36def SupportRoundHole_ContourSupport(self, Figure):
37    """Draws support contour in MirrorProjection plots."""
38    axe = Figure.add_subplot(111, aspect="equal")
39    axe.add_patch(patches.Circle((0, 0), self.radius, alpha=0.08))
40    axe.add_patch(patches.Circle((self.centerholeX, self.centerholeY), self.radiushole, color="white", alpha=1))
41    return axe

Draws support contour in MirrorProjection plots.

def SupportRectangle_ContourSupport(self, Figure):
45def SupportRectangle_ContourSupport(self, Figure):
46    """Draws support contour in MirrorProjection plots."""
47    axe = Figure.add_subplot(111, aspect="equal")
48    axe.add_patch(patches.Rectangle((-self.dimX * 0.5, -self.dimY * 0.5), self.dimX, self.dimY, alpha=0.08))
49    return axe

Draws support contour in MirrorProjection plots.

def SupportRectangleHole_ContourSupport(self, Figure):
53def SupportRectangleHole_ContourSupport(self, Figure):
54    """Draws support contour in MirrorProjection plots."""
55    axe = Figure.add_subplot(111, aspect="equal")
56    axe.add_patch(patches.Rectangle((-self.dimX * 0.5, -self.dimY * 0.5), self.dimX, self.dimY, alpha=0.08))
57    axe.add_patch(patches.Circle((self.centerholeX, self.centerholeY), self.radiushole, color="white", alpha=1))
58    return axe

Draws support contour in MirrorProjection plots.

def SupportRectangleRectHole_ContourSupport(self, Figure):
62def SupportRectangleRectHole_ContourSupport(self, Figure):
63    """Draws support contour in MirrorProjection plots."""
64    axe = Figure.add_subplot(111, aspect="equal")
65    axe.add_patch(patches.Rectangle((-self.dimX * 0.5, -self.dimY * 0.5), self.dimX, self.dimY, alpha=0.08))
66    axe.add_patch(
67        patches.Rectangle(
68            (-self.holeX * 0.5 + self.centerholeX, -self.holeY * 0.5 + self.centerholeY),
69            self.holeX,
70            self.holeY,
71            color="white",
72            alpha=1,
73        )
74    )
75    return axe

Draws support contour in MirrorProjection plots.

def sample_support(Support, Npoints):
 81def sample_support(Support, Npoints):
 82    """
 83    This function samples a regular grid on the support and filters out the points that are outside of the support.
 84    For that it uses the _estimate_size method of the support.
 85    It then generates a regular grid over that area and filters out the points that are outside of the support.
 86    If less than half of the points are inside the support, it will increase the number of points by a factor of 2 and try again.
 87    """
 88    if hasattr(Support, "size"):
 89        size = Support.size
 90    else:
 91        size_x = Support._estimate_size()
 92        size = np.array([2*size_x, 2*size_x])
 93    enough = False
 94    while not enough:
 95        X = np.linspace(-size[0] * 0.5, size[0] * 0.5, int(np.sqrt(Npoints)))
 96        Y = np.linspace(-size[1] * 0.5, size[1] * 0.5, int(np.sqrt(Npoints)))
 97        X, Y = np.meshgrid(X, Y)
 98        Points = np.array([X.flatten(), Y.flatten()]).T
 99        Points = [p for p in Points if p in Support]
100        if len(Points) > Npoints * 0.5:
101            enough = True
102        else:
103            Npoints *= 2
104    return mgeo.PointArray(Points)

This function samples a regular grid on the support and filters out the points that are outside of the support. For that it uses the _estimate_size method of the support. It then generates a regular grid over that area and filters out the points that are outside of the support. If less than half of the points are inside the support, it will increase the number of points by a factor of 2 and try again.

def sample_z_values(Support, Npoints):
106def sample_z_values(Support, Npoints):
107    Npoints_init = Npoints
108    if hasattr(Support, "size"):
109        size = Support.size
110    else:
111        size_x = Support._estimate_size()
112        size = np.array([2*size_x, 2*size_x])
113    enough = False
114    while not enough:
115        X = np.linspace(-size[0] * 0.5, size[0] * 0.5, int(np.sqrt(Npoints)))
116        Y = np.linspace(-size[1] * 0.5, size[1] * 0.5, int(np.sqrt(Npoints)))
117        X, Y = np.meshgrid(X, Y)
118        Points = np.array([X.flatten(), Y.flatten()]).T
119        Points = [p for p in Points if p in Support]
120        if len(Points) > Npoints_init * 0.5:
121            enough = True
122        else:
123            Npoints *= 2
124    Points = np.array([X.flatten(), Y.flatten()]).T
125    mask = np.array([p in Support for p in Points]).reshape(X.shape)
126    return X,Y,mask

4 - ModulePlottingUtilities

Module containing some useful functions for generating various plots. It exists mostly to avoid cluttering the main plotting module.

Created in November 2024

@author: André Kalouguine + Stefan Haessler + Anthony Guillaume

  1"""
  2Module containing some useful functions for generating various plots.
  3It exists mostly to avoid cluttering the main plotting module.
  4
  5Created in November 2024
  6
  7@author: André Kalouguine + Stefan Haessler + Anthony Guillaume
  8"""
  9# %% Module imports
 10import weakref
 11import numpy as np
 12import colorcet as cc
 13import matplotlib.pyplot as plt
 14
 15import ARTcore.ModuleProcessing as mp
 16import ARTcore.ModuleGeometry as mgeo
 17import ARTcore.ModuleProcessing as mp
 18
 19
 20# %% Definition of observer class to make plots interactive
 21class Observable():
 22    """
 23    Observer class to make several plots interactive simultaneously.
 24    It encapsulates a single value.
 25    When it's modified, it notifies all registered observers.
 26    """
 27    def __init__(self, value):
 28        self._value = value
 29        self._observers = set()
 30        self._calculation = set()
 31    
 32    def register_calculation(self, callback):
 33        """
 34        Register a calculation to be performed when the value is modified before notifying the observers.
 35        For instance, this can be used to perform the ray tracing calculation when the value is modified.
 36        The other callbacks will be notified after the calculation is performed and will update the plots.
 37        """
 38        self._calculation.add(callback)
 39
 40    def unregister_calculation(self, callback):
 41        self._calculation.discard(callback)
 42    
 43    @property
 44    def value(self):
 45        return self._value
 46    
 47    @value.setter
 48    def value(self, new_value):
 49        self._value = new_value
 50        for callback in self._calculation:
 51            callback(new_value)
 52        self.notify(new_value)
 53    
 54    def register(self, callback):
 55        self._observers.add(callback)
 56
 57    def unregister(self, callback):
 58        self._observers.discard(callback)
 59
 60    def notify(self, event):
 61        for callback in self._observers:
 62            callback(event)
 63
 64
 65# %% Utility functions
 66def generate_distinct_colors(num_colors):
 67    """
 68    Utility function to generate a list of distinct colors for plotting.
 69
 70    Parameters
 71    ----------
 72        num_colors : int
 73            The number of colors to generate.
 74
 75    Returns
 76    -------
 77        distinct_colors : list
 78            List of distinct colors.
 79    """
 80    # Get a color palette from colorcet
 81    palette = cc.glasbey
 82
 83    # Make sure the number of colors does not exceed the palette length
 84    num_colors = min(num_colors, len(palette))
 85
 86    # Slice the palette to get the desired number of colors
 87    distinct_colors = palette[:num_colors]
 88
 89    return distinct_colors
 90
 91
 92def _getDetectorPoints(RayListAnalysed, Detector) -> tuple[np.ndarray, np.ndarray, float, float]:
 93    """
 94    Prepare the ray impact points on the detector in a format used for the plotting,
 95    and along the way also calculate the "spotsize" of this point-cloud on the detector.
 96
 97    Parameters
 98    ----------
 99        RayListAnalysed : list(Ray)
100            A list of objects of the ModuleOpticalRay.Ray-class.
101
102        Detector : Detector
103            An object of the ModuleDetector.Detector-class.
104
105    Returns
106    -------
107        DectectorPoint2D_Xcoord : np.ndarray with x-coordinates
108
109        DectectorPoint2D_Ycoord : np.ndarray with y-coordinates
110
111        FocalSpotSize : float
112
113        SpotSizeSD : float
114    """
115
116    Points2D = Detector.get_2D_points(RayListAnalysed)
117    Points2D -= np.mean(Points2D, axis=1)  # Centering the points
118    X = Points2D[0][:,0] * 1e3 # To convert to µm
119    Y = Points2D[0][:,1] * 1e3 # To convert to µm
120
121    FocalSpotSize = float(mgeo.DiameterPointArray(Points2D[0]))
122    SpotSizeSD = mp.StandardDeviation(Points2D[0])
123    return X, Y, FocalSpotSize, SpotSizeSD
124
125def show():
126    plt.show(block=False)
class Observable:
22class Observable():
23    """
24    Observer class to make several plots interactive simultaneously.
25    It encapsulates a single value.
26    When it's modified, it notifies all registered observers.
27    """
28    def __init__(self, value):
29        self._value = value
30        self._observers = set()
31        self._calculation = set()
32    
33    def register_calculation(self, callback):
34        """
35        Register a calculation to be performed when the value is modified before notifying the observers.
36        For instance, this can be used to perform the ray tracing calculation when the value is modified.
37        The other callbacks will be notified after the calculation is performed and will update the plots.
38        """
39        self._calculation.add(callback)
40
41    def unregister_calculation(self, callback):
42        self._calculation.discard(callback)
43    
44    @property
45    def value(self):
46        return self._value
47    
48    @value.setter
49    def value(self, new_value):
50        self._value = new_value
51        for callback in self._calculation:
52            callback(new_value)
53        self.notify(new_value)
54    
55    def register(self, callback):
56        self._observers.add(callback)
57
58    def unregister(self, callback):
59        self._observers.discard(callback)
60
61    def notify(self, event):
62        for callback in self._observers:
63            callback(event)

Observer class to make several plots interactive simultaneously. It encapsulates a single value. When it's modified, it notifies all registered observers.

Observable(value)
28    def __init__(self, value):
29        self._value = value
30        self._observers = set()
31        self._calculation = set()
def register_calculation(self, callback):
33    def register_calculation(self, callback):
34        """
35        Register a calculation to be performed when the value is modified before notifying the observers.
36        For instance, this can be used to perform the ray tracing calculation when the value is modified.
37        The other callbacks will be notified after the calculation is performed and will update the plots.
38        """
39        self._calculation.add(callback)

Register a calculation to be performed when the value is modified before notifying the observers. For instance, this can be used to perform the ray tracing calculation when the value is modified. The other callbacks will be notified after the calculation is performed and will update the plots.

def unregister_calculation(self, callback):
41    def unregister_calculation(self, callback):
42        self._calculation.discard(callback)
value
44    @property
45    def value(self):
46        return self._value
def register(self, callback):
55    def register(self, callback):
56        self._observers.add(callback)
def unregister(self, callback):
58    def unregister(self, callback):
59        self._observers.discard(callback)
def notify(self, event):
61    def notify(self, event):
62        for callback in self._observers:
63            callback(event)
def generate_distinct_colors(num_colors):
67def generate_distinct_colors(num_colors):
68    """
69    Utility function to generate a list of distinct colors for plotting.
70
71    Parameters
72    ----------
73        num_colors : int
74            The number of colors to generate.
75
76    Returns
77    -------
78        distinct_colors : list
79            List of distinct colors.
80    """
81    # Get a color palette from colorcet
82    palette = cc.glasbey
83
84    # Make sure the number of colors does not exceed the palette length
85    num_colors = min(num_colors, len(palette))
86
87    # Slice the palette to get the desired number of colors
88    distinct_colors = palette[:num_colors]
89
90    return distinct_colors

Utility function to generate a list of distinct colors for plotting.

Parameters

num_colors : int
    The number of colors to generate.

Returns

distinct_colors : list
    List of distinct colors.
def show():
126def show():
127    plt.show(block=False)

5 - ModuleTolerancing

Provides functions for analysing the alignment tolerance of a system.

Mostly it consists in mis-aligning the optical elements of the system and checking the impact on the focal spot and delays.

There are two parts to this module:

  • Misaligning/deteriorating the optical elements of the system
  • Analysing the impact of these misalignments on the focal spot and delays

The end goal is to have a simple function "GetTolerance" that will return the tolerance of the system to misalignments of each optical element.

Created in Nov 2024

@author: Andre Kalouguine

 1"""
 2Provides functions for analysing the alignment tolerance of a system.
 3
 4Mostly it consists in mis-aligning the optical elements of the system and checking the impact on the focal spot and delays.
 5
 6There are two parts to this module:
 7- Misaligning/deteriorating the optical elements of the system
 8- Analysing the impact of these misalignments on the focal spot and delays
 9
10The end goal is to have a simple function "GetTolerance" that will return the tolerance of the system to misalignments of each optical element.
11
12Created in Nov 2024
13
14@author: Andre Kalouguine
15"""
16import numpy as np
17import matplotlib.pyplot as plt
18from copy import copy
19from scipy.stats import linregress
20
21import ARTcore.ModuleOpticalChain as moc
22
23
24def GetTolerance(OpticalChain,
25                 time_error=1.0,
26                 Detector = "Focus",
27                 n_iter = 5,
28                 elements_to_misalign = None):
29    """
30    Returns the tolerance of the system to misalignments of each optical element.
31    It first constructs a normalisation vector:
32    For each optical element, for each degree of freedom, it iteratively calculates the required amplitude of misalignment to reach the time_error.
33    This gives a n-dimensional vector. The smaller the value, the more sensitive the system is to misalignments of this axis.
34    """
35    # Define the normalisation vector
36    normalisation_vector = np.zeros(len(OpticalChain.optical_elements) * 6)*np.nan
37    durations = np.zeros(len(OpticalChain.optical_elements) * 6)*np.nan
38    misalignments = ["rotate_roll_by", "rotate_pitch_by", "rotate_yaw_by", "shift_along_normal", "shift_along_major", "shift_along_cross"]
39    if elements_to_misalign is None:
40        elements_to_misalign = range(len(OpticalChain.optical_elements))
41    if isinstance(Detector, str):
42        Det = OpticalChain.detectors[Detector]
43    else:
44        Det = Detector
45    for i in elements_to_misalign:
46        for j in range(6):
47            # Misalign the optical element
48            amplitude = 1e-3
49            for k in range(n_iter):
50                try:
51                    misaligned_optical_chain = copy(OpticalChain)
52                    r_before = misaligned_optical_chain[i].r
53                    q_before = misaligned_optical_chain[i].q
54                    misaligned_optical_chain[i].__getattribute__(misalignments[j])(amplitude)
55                    r_after = misaligned_optical_chain[i].r
56                    q_after = misaligned_optical_chain[i].q
57                    rays = misaligned_optical_chain.get_output_rays()
58                    Det.optimise_distance(rays[Det.index], [Det.distance-100, Det.distance+100], Det._spot_size, maxiter=10, tol=1e-10)
59                except:
60                    print(f"OE {i} failed to misalign {misalignments[j]} by {amplitude}")
61                    amplitude /= 10
62                    continue
63                rays = misaligned_optical_chain.get_output_rays(force=True)
64                duration = np.std(Det.get_Delays(rays[Det.index]))
65                if len(rays[Det.index]) <= 50:
66                    amplitude /= 2
67                    continue
68                if duration > time_error:
69                    amplitude /= 3
70                elif duration < time_error / 10:
71                    amplitude *= 3
72                elif duration < time_error / 1.5:
73                    amplitude *= 1.2
74                else:
75                    break
76            if not (time_error/2 < duration < time_error*2):
77                print(f"OE {i} failed to misalign {misalignments[j]} by {amplitude}: duration = {duration}")
78                if duration > time_error:
79                    amplitude = np.nan
80                else:
81                    amplitude = 0.1
82            normalisation_vector[i*6+j] = amplitude
83            durations[i*6+j] = duration
84    return normalisation_vector, durations
def GetTolerance( OpticalChain, time_error=1.0, Detector='Focus', n_iter=5, elements_to_misalign=None):
25def GetTolerance(OpticalChain,
26                 time_error=1.0,
27                 Detector = "Focus",
28                 n_iter = 5,
29                 elements_to_misalign = None):
30    """
31    Returns the tolerance of the system to misalignments of each optical element.
32    It first constructs a normalisation vector:
33    For each optical element, for each degree of freedom, it iteratively calculates the required amplitude of misalignment to reach the time_error.
34    This gives a n-dimensional vector. The smaller the value, the more sensitive the system is to misalignments of this axis.
35    """
36    # Define the normalisation vector
37    normalisation_vector = np.zeros(len(OpticalChain.optical_elements) * 6)*np.nan
38    durations = np.zeros(len(OpticalChain.optical_elements) * 6)*np.nan
39    misalignments = ["rotate_roll_by", "rotate_pitch_by", "rotate_yaw_by", "shift_along_normal", "shift_along_major", "shift_along_cross"]
40    if elements_to_misalign is None:
41        elements_to_misalign = range(len(OpticalChain.optical_elements))
42    if isinstance(Detector, str):
43        Det = OpticalChain.detectors[Detector]
44    else:
45        Det = Detector
46    for i in elements_to_misalign:
47        for j in range(6):
48            # Misalign the optical element
49            amplitude = 1e-3
50            for k in range(n_iter):
51                try:
52                    misaligned_optical_chain = copy(OpticalChain)
53                    r_before = misaligned_optical_chain[i].r
54                    q_before = misaligned_optical_chain[i].q
55                    misaligned_optical_chain[i].__getattribute__(misalignments[j])(amplitude)
56                    r_after = misaligned_optical_chain[i].r
57                    q_after = misaligned_optical_chain[i].q
58                    rays = misaligned_optical_chain.get_output_rays()
59                    Det.optimise_distance(rays[Det.index], [Det.distance-100, Det.distance+100], Det._spot_size, maxiter=10, tol=1e-10)
60                except:
61                    print(f"OE {i} failed to misalign {misalignments[j]} by {amplitude}")
62                    amplitude /= 10
63                    continue
64                rays = misaligned_optical_chain.get_output_rays(force=True)
65                duration = np.std(Det.get_Delays(rays[Det.index]))
66                if len(rays[Det.index]) <= 50:
67                    amplitude /= 2
68                    continue
69                if duration > time_error:
70                    amplitude /= 3
71                elif duration < time_error / 10:
72                    amplitude *= 3
73                elif duration < time_error / 1.5:
74                    amplitude *= 1.2
75                else:
76                    break
77            if not (time_error/2 < duration < time_error*2):
78                print(f"OE {i} failed to misalign {misalignments[j]} by {amplitude}: duration = {duration}")
79                if duration > time_error:
80                    amplitude = np.nan
81                else:
82                    amplitude = 0.1
83            normalisation_vector[i*6+j] = amplitude
84            durations[i*6+j] = duration
85    return normalisation_vector, durations

Returns the tolerance of the system to misalignments of each optical element. It first constructs a normalisation vector: For each optical element, for each degree of freedom, it iteratively calculates the required amplitude of misalignment to reach the time_error. This gives a n-dimensional vector. The smaller the value, the more sensitive the system is to misalignments of this axis.