Source code for multipac_testbench.threshold.threshold_set

"""Define an object to hold all thresholds of a multipactor test."""

import logging
import math
from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence
from typing import Self

import numpy as np
import pandas as pd
from multipac_testbench.instruments.instrument import Instrument
from multipac_testbench.threshold.helper import (
    sorter_index_then_way,
    threshold_df_column_header,
)
from multipac_testbench.threshold.threshold import (
    THRESHOLD_DETECTOR_T,
    THRESHOLD_FILTER_T,
    PowerExtremum,
    Threshold,
    create_power_extrema,
    create_thresholds,
)
from multipac_testbench.util.types import MULTIPAC_DETECTOR_T
from numpy.typing import NDArray


[docs] class ThresholdSet: def __init__( self, thresholds: Iterable[Threshold], power_extrema: Iterable[PowerExtremum], ) -> None: """Create object. Parameters ---------- thresholds : Multipactor thresholds detected during a :class:`.MultipactorTest`. power_extrema : Power minima/maxima delimiting the different power cycles in the :class:`.MultipactorTest`. """ self.thresholds = sorted(thresholds, key=lambda t: t.sample_index) self.extrema = sorted(power_extrema, key=lambda p: p.sample_index) self._warn_instruments_at_same_position() self._remove_thresholds_at_seesaw()
[docs] @classmethod def from_instruments( cls, multipac_detector: MULTIPAC_DETECTOR_T, detecting_instruments: Iterable[Instrument], growth_array: NDArray[np.float64], threshold_predicate: THRESHOLD_FILTER_T | None = None, threshold_reducer: THRESHOLD_DETECTOR_T | None = None, ) -> Self: """Create a ThresholdSet using the specified detection strategy. Parameters ---------- multipac_detector : Function that takes in the ``data`` of an :class:`.Instrument` and returns an array, where True means multipactor and False no multipactor. detecting_instruments : Instruments to apply ``multipac_detector`` on. growth_array : Holds ``1.0`` where power increases, ``0.0`` where it is stable, ``-1.0`` where it decreases. threshold_reducer : - not provided: thresholds are computed for each instrument independently. - "any": thresholds appear when multipactor is detected by *any* of the provided detecting instrument. - "all": thresholds appear when multipactor is detected by *all* the provided detecting instrument. threshold_predicate : Function filtering the thresholds. Applied *after* ``threshold_reducer``. """ if threshold_reducer is None: thresholds = [ threshold for instr in detecting_instruments if isinstance(instr.position, float) for threshold in create_thresholds( multipac_detector(instr.data), growth_array, detecting_instrument=instr.name, position=instr.position, threshold_predicate=threshold_predicate, color=instr.color, ) ] elif threshold_reducer in {"any", "all"}: multipactors = [ multipac_detector(instr.data) for instr in detecting_instruments if isinstance(instr.position, float) ] reducer = np.any if threshold_reducer == "any" else np.all combined = reducer(multipactors, axis=0) thresholds = create_thresholds( combined, growth_array, detecting_instrument=threshold_reducer, position=np.nan, threshold_predicate=threshold_predicate, color=(0, 0, 0), ) else: raise ValueError(f"Unknown {threshold_reducer = }") power_extrema = create_power_extrema(growth_array) return cls(thresholds, power_extrema)
[docs] @classmethod def last( cls, threshold_set: Self, threshold_predicate: THRESHOLD_FILTER_T | None = None, ) -> Self: """ Create object holding the last threshold measured by every instrument. See Also -------- :class:`AveragedThresholdSet` Parameters ---------- threshold_set : Holds all the detected thresholds. threshold_predicate : Additional predicate, *eg* to exclude thresholds measured during the first power cycles, from a specific detecting instrument, of a certain type... Returns ------- Holds only one lower and one upper :class:`.Threshold` per detecting instrument: the last one measured during the test. """ filtered_thresholds = tuple( [ t for t in threshold_set if threshold_predicate is None or threshold_predicate(t) ] ) last_thresholds_by_instr: dict[str, Threshold] = {} for t in filtered_thresholds[::-1]: if t.detecting_instrument in last_thresholds_by_instr: continue last_thresholds_by_instr[t.detecting_instrument] = t return cls(last_thresholds_by_instr.values(), threshold_set.extrema)
[docs] @classmethod def subset( cls, threshold_set: Self, threshold_predicate: THRESHOLD_FILTER_T ) -> Self: """Return object holding a subset of ``threshold_set``. ``threshold_predicate`` is used to filter on the :class:`.Threshold`. """ thresholds = [t for t in threshold_set if threshold_predicate(t)] return cls(thresholds=thresholds, power_extrema=threshold_set.extrema)
[docs] @classmethod def extreme( cls, threshold_set: Self, threshold_predicate: THRESHOLD_FILTER_T | None = None, ) -> Self: """Create object holding only the most *extreme* :class:`.Threshold`. For each half cycle: - If power increases: keep first lower and last upper threshold. - If power decreases: keep first upper and last lower threshold. - If there was still multipactor somewhere when the half power cycle ended (e.g. instrument with a lower but no upper threshold), no upper threshold is added. - If direction is undetermined: skip the cycle. Parameters ---------- threshold_set : The full set of thresholds. threshold_predicate : A function to select relevant thresholds. Returns ------- A new object containing only selected extreme thresholds. """ assert len(threshold_set.detecting_instruments()) <= 1, ( "This method currently does not handle detection from several " "instruments." ) subset = [] for key, thresholds in threshold_set._thresholds_by_half_power_cycle( threshold_predicate=threshold_predicate ).items(): if not thresholds: continue direction = key.split("(", 1)[-1].removesuffix(")").strip() if direction not in {"increasing", "decreasing"}: logging.warning(f"Skipped undetermined cycle: {key}") continue if direction == "increasing": first = min(thresholds, key=sorter_index_then_way) if first.nature == "lower" and first.way == "enter": subset.append(first) last = max(thresholds, key=sorter_index_then_way) if last.nature == "upper" and last.way == "exit": subset.append(last) continue if direction == "decreasing": first = min(thresholds, key=sorter_index_then_way) if first.nature == "upper" and first.way == "enter": subset.append(first) last = max(thresholds, key=sorter_index_then_way) if last.nature == "lower" and last.way == "exit": subset.append(last) return cls(thresholds=subset, power_extrema=threshold_set.extrema)
def __iter__(self) -> Iterator[Threshold]: """Iterate over stored :class:`.Threshold` objects. Yields ------ Threshold The stored :class:`.Threshold` objects, sorted by sample index. """ return iter(self.thresholds)
[docs] def remove_singularities(self, min_consecutive: int = 1) -> None: """Remove fugitive :class:`.Threshold`. If two :class:`.Threshold` are detected by the same :class:`.Instrument` and their :attr:`.Threshold.sample_index` are separated by ``min_consecutive - 1`` or less, both objects are removed. Parameters ---------- min_consecutive : :class:`.Threshold` objects separated by less than ``min_consecutive`` sample index are removed. The default ``min_consecutive=1`` removes multipactor spanning over a single sample index. """ by_instr: dict[str, list[Threshold]] = defaultdict(list) for t in self.thresholds: by_instr[t.detecting_instrument].append(t) cleaned_thresholds = [] for thresholds in by_instr.values(): thresholds.sort(key=lambda t: t.sample_index) to_remove: list[Threshold] = [] for i in range(len(thresholds) - 1): current = thresholds[i] next_ = thresholds[i + 1] if ( abs(current.sample_index - next_.sample_index) >= min_consecutive ): continue to_remove.append(current) to_remove.append(next_) cleaned = [t for t in thresholds if t not in to_remove] cleaned_thresholds.extend(cleaned) self.thresholds = cleaned_thresholds
[docs] def _remove_thresholds_at_seesaw(self) -> None: """Clean incorrect :class:`.Threshold` if power follows seesaw profile. Consider this: - ``i - 1``: max power of the seesaw profile, there is MP - ``i``: new seesaw (min power of the profile), no MP At ``i - 1``, we did not reach a threshold. But we have a :class:`.Threshold` corresponding to it. So we have to remove it. """ extrema = { extremum.sample_index: extremum for extremum in self.extrema } for t in self.thresholds: matching_extremum = extrema.get(t.sample_index) if matching_extremum is None: continue if matching_extremum.smooth: continue if matching_extremum.nature == "maximum" and t.way == "exit": logging.debug(f"Removed Threshold: {t}") self.thresholds.remove(t)
[docs] def _warn_instruments_at_same_position(self) -> None: """Verify bijection between detecting instruments pos and name.""" pos_to_names: dict[float, str] = {} warned_positions = set() for threshold in self.thresholds: name = threshold.detecting_instrument pos = threshold.position if pos in pos_to_names and pos_to_names[pos] != name: if pos not in warned_positions: msg = ( "Multiple instruments detected at the same position " f"{pos}:\n- {pos_to_names[pos]}\n- {name}" ) logging.warning(msg) warned_positions.add(pos) pos_to_names[pos] = name
[docs] def sample_indexes( self, *, threshold_predicate: THRESHOLD_FILTER_T | None = None ) -> list[int]: """Return sample indexes matching optional filter.""" return [ t.sample_index for t in self if threshold_predicate is None or threshold_predicate(t) ]
[docs] def apply_to(self, instrument: Instrument) -> NDArray[np.float64]: """Extract instrument data at threshold sample indexes.""" idx = self.sample_indexes() return instrument.data[idx]
[docs] def at( self, position: float, tol: float = 1e-10, return_global: bool = False ) -> list[Threshold]: """Gather all thresholds measured at ``position``. Parameters ---------- position : Where you want the thresholds. tol : Tolerance over the position. return_global : To return global multipactors, and also return all thresholds when ``position`` is ``np.nan``. ``np.nan`` position are associated with "global" instruments, such as :class:`.ForwardPower`, and with "global" multipactors, such as obtained by crossing several :class:`.Instrument` data. Returns ------- All multipactor thresholds detected at this position. """ return [ thresh for thresh in self.thresholds if math.isclose(thresh.position, position, abs_tol=tol) or return_global and (thresh.is_global or np.isnan(position)) ]
[docs] def data_at_thresholds( self, instruments: Iterable[Instrument], tol: float = 1e-10, global_instruments: bool = False, global_multipactor: bool = False, xdata_instrument: Instrument | None = None, unique_x_value: float | None = None, ) -> pd.DataFrame: """Return instrument values at threshold sample indices. We match :class:`.Threshold` and :class:`.Instrument` objects by position. Parameters ---------- instruments : Instruments to which data must be plotted. Must have ``.position`` and ``.data`` attributes. tol : Tolerance for position matching. global_instruments : If instruments not position-specific (eg :class:`.ForwardPower`) should be returned. global_multipactor : If multipactor not position-specific (eg thresholds created by merging several other multipactor arrays) should be returned. xdata_instrument : Its data is returned at every threshold. It results in a unique ``xdata`` column, without ``nan``, that can be used as a common x-data for plotting. unique_x_value : If given, this value will replace every value of the ``xdata_instrument`` column. Returns ------- Columns are named by detecting instrument + threshold nature: ``"NI9205_E4 @ upper threshold (according to NI9205_MP4l)"``. If ``xdata_instrument`` was given, also return this instrument values at every sample index (can be unique value if ``unique_x_value`` was given). Indexes are the sample indices at every threshold. """ # {column: {sample_index: instrument value}} result: dict[str, dict[int, float]] = defaultdict(dict) for threshold in self: for instrument in instruments: is_close = ( math.isclose( instrument.position, threshold.position, abs_tol=tol ) or (global_instruments and instrument.is_global) or (global_multipactor and threshold.is_global) ) if not is_close: continue label = threshold_df_column_header(instrument, threshold) idx = threshold.sample_index result[label][idx] = instrument.data[idx] if xdata_instrument is None: return pd.DataFrame({k: pd.Series(v) for k, v in result.items()}) xlabel = xdata_instrument.ylabel() result[xlabel] = { t.sample_index: xdata_instrument.data[t.sample_index] for t in self } df = pd.DataFrame({k: pd.Series(v) for k, v in result.items()}) if unique_x_value is not None: df[xlabel] = unique_x_value return df
[docs] def according_to( self, instrument: Instrument | str | THRESHOLD_DETECTOR_T ) -> list[Threshold]: """Give thresholds measured by ``instrument``.""" if isinstance(instrument, Instrument): detecting_name = instrument.name else: detecting_name = instrument thresholds: list[Threshold] = [] for x in self: if isinstance(x.detecting_instrument, Instrument): matching = x.detecting_instrument.name else: matching = x.detecting_instrument if detecting_name == matching: thresholds.append(x) return thresholds
[docs] def remove_detected_by( self, instrument: Instrument | str | THRESHOLD_DETECTOR_T ) -> None: """Remove thresholds detected by ``instrument``.""" to_remove = self.according_to(instrument) cleaned = [t for t in self if t not in to_remove] self.thresholds = cleaned logging.info( f"Removed the {len(to_remove)} thresholds detected by {instrument}" )
[docs] def get_threshold_label_color_map( self, instruments: Sequence[Instrument] ) -> dict[str, tuple[float, float, float]]: """Maps threshold dataframe column headers to corresponding colors. Assumes :attr:`.Threshold.color` is already set to the corresponding :class:`.Instrument` color. Returns ------- Mapping from a header looking like ``"NI9205_E4 @ upper threshold (according to NI9205_MP4l)"``, to the threshold color (usually, this is detecting instrument color). """ label_to_color = {} for threshold in self: for instrument in instruments: header = threshold_df_column_header(instrument, threshold) label_to_color[header] = threshold.color return label_to_color
[docs] def detecting_instruments(self) -> set[str | THRESHOLD_DETECTOR_T]: """Return instruments that detected at least one threshold.""" return {t.detecting_instrument for t in self}
[docs] def _thresholds_by_half_power_cycle( self, threshold_predicate: THRESHOLD_FILTER_T | None = None ) -> dict[str, list[Threshold]]: """Group thresholds by half power cycle, based on sample index range. Each group includes thresholds between two consecutive extrema: ``[extremum_i.sample_index, extremum_{i+1}.sample_index)`` The dictionary key is of the form: - "0 (increasing)" if power increases over the interval - "1 (decreasing)" if power decreases over the interval - "2 (undetermined)" if direction cannot be determined .. note:: Not ultra efficient. To update if necessary. Parameters ---------- threshold_predicate : Filter the :class:`.Threshold` instances. Returns ------- Dictionary mapping half-cycle index to thresholds within that range. Keys are sorted by increasing power cycle index values. """ thresholds_by_cycle: dict[str, list[Threshold]] = {} for i, (ext1, ext2) in enumerate( zip(self.extrema[:-1], self.extrema[1:]) ): if ext1.nature == "minimum" and ext2.nature == "maximum": direction = "increasing" elif ext1.nature == "maximum" and ext2.nature == "minimum": direction = "decreasing" else: direction = "undetermined" key = f"{i} ({direction})" thresholds = [ t for t in self if ext1.sample_index <= t.sample_index < ext2.sample_index and (threshold_predicate is None or threshold_predicate(t)) ] thresholds_by_cycle[key] = thresholds return thresholds_by_cycle
[docs] class AveragedThresholdSet(ThresholdSet): """Holds average of several thresholds. The main difference with a classic ``ThresholdSet`` is that its :meth:`.data_at_thresholds` is overriden to return data averaged from several :class:`.Threshold`. """
[docs] @classmethod def from_threshold_set( cls, threshold_set: ThresholdSet, threshold_predicate: THRESHOLD_FILTER_T | None = None, ) -> Self: """Create an object holding averaged thresholds. Parameters ---------- threshold_set : The thresholds to average. threshold_predicate : To filter thresholds to average. A typical example would be ``lambda t: t.sample_index > 200`` to keep only conditioned thresholds. Returns ------- Object containing "averaged" thresholds. It contains one lower and one upper threshold per detecting instrument (if already present in the original :class:`.ThresholdSet`). """ subset = [ t for t in threshold_set if threshold_predicate is None or threshold_predicate(t) ] return cls(subset, threshold_set.extrema)
[docs] def data_at_thresholds(self, *args, **kwargs) -> pd.DataFrame: """Return average of instrument values at threshold sample indices. Keep the xdata column as a representative index: for each y-column, compute the median of its xdata values. Parameters ---------- instruments : Instruments to which data must be plotted. Must have ``.position`` and ``.data`` attributes. tol : Tolerance for position matching. global_instruments : If instruments not position-specific (eg :class:`.ForwardPower`) should be returned. global_multipactor : If multipactor not position-specific (eg thresholds created by merging several other multipactor arrays) should be returned. xdata_instrument : Its data is returned at every threshold. It results in a unique ``xdata`` column, without ``nan``, that can be used as a common x-data for plotting. Returns ------- Columns are named by detecting instrument + threshold nature. Only index is average (median) of instruments values at the various thresholds. """ df = super().data_at_thresholds(*args, **kwargs) if df.index.name is None: return df.median().to_frame().T xname = df.index.name records = [] for col in df.columns: y = df[col].dropna() if y.empty: continue x_median = y.index.to_series().median() y_median = y.median() row = pd.Series({col: y_median}, name=x_median) records.append(row) df = pd.DataFrame(records).sort_index() df.index.name = xname df = df.reindex(sorted(df.columns), axis=1) return df