Source code for multipac_testbench.util.helper

"""Define general usage functions."""

import logging
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from pathlib import Path
from typing import Any, TypeGuard, TypeVar

import numpy as np
import pandas as pd
from numpy.typing import NDArray

T = TypeVar("T")


[docs] def is_nested_list(obj: list[T] | list[list[T]]) -> bool: """Tell if ``obj`` is a nested list.""" return bool(obj) and isinstance(obj[0], list)
[docs] def flatten[T](nest: Iterable[T]) -> Iterator[T]: """Flatten nested list of lists of...""" for _in in nest: if isinstance(_in, Iterable) and not isinstance(_in, (str, bytes)): yield from flatten(_in) else: yield _in
[docs] def split_rows_by_masks( df: pd.Series | pd.DataFrame, masks: dict[str, NDArray[np.bool]], ) -> pd.DataFrame: """Split the rows of ``df`` into new columns based on a boolean mask. For each column in the original data, one new column per mask is created with the corresponding suffix. Rows not selected by a mask are filled with ``np.nan``. .. important:: Functions using the splitted ``df`` such as :func:`.styles_from_column_cycle` expect every key of ``masks`` to start with a double underscore (``__``). Examples -------- >>> mask = np.array([True, False, True]) >>> masks = {"__(grows)": mask, "__(decreases)": ~mask} >>> ser = pd.Series([1, 2, 3], name=data) >>> print(split_rows_by_masks(ser, masks)) data__(grows) data__(decreases) 0 1.0 NaN 1 NaN 2.0 2 3.0 NaN >>> df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]}) >>> print(split_rows_by_masks(df, masks)) col1__(grows) col1__(decreases) col2__(grows) col2__(decreases) 0 1.0 NaN 4.0 NaN 1 NaN 2.0 NaN 5.0 2 3.0 NaN 6.0 NaN Raises ------ ValueError If any row is matched by more than one mask or if mask lengths do not match the input. Parameters ---------- df : The input data to split row-wise. masks : A dictionary where each key is a suffix used to label the split columns, and each value is a boolean mask of the same length as the input data. Keys must start with two underscores `(`__``) to enable consistent column naming and compatibility with downstream styling logic (e.g., grouping lines by base column in plots). If multiple masks are ``True`` at the same row index, a ``ValueError`` is raised. Returns ------- A new DataFrame with columns split according to the masks. """ if not masks: raise ValueError("At least one mask must be provided.") length = len(df) for name, mask in masks.items(): if len(mask) != length: raise ValueError( f"Mask '{name}' has incorrect length ({len(mask)} != {length})" ) for key in masks.keys(): if key[:2] == "__": continue logging.warning( f"{key = } does not start with a double underscore. Splitted " "columns may not be recognized." ) # Ensure disjoint masks combined = np.zeros(length, dtype=int) for mask in masks.values(): combined += mask.astype(int) if (combined > 1).any(): raise ValueError( "Masks must be disjoint: multiple masks are True at the same " "position." ) df = df.to_frame() if isinstance(df, pd.Series) else df result = {} for col in df.columns: for suffix, mask in masks.items(): col_name = f"{col}{suffix}" result[col_name] = df[col].where(mask) return pd.DataFrame(result)
[docs] def output_filepath( filepath: Path, swr: float, freq_mhz: float, out_folder: str | Path, extension: str, ) -> Path: """Return a new path to save output files. Parameters ---------- filepath : Name of the data ``CSV`` file from LabViewer. swr : Theoretical :math:`SWR` to add to the output file name. freq_mhz : Theoretical rf frequency to add to the output file name. out_folder : Relative name of the folder where data will be saved; it is defined w.r.t. to the parent folder of ``filepath`` if it is a string. If it is a ``Path``, we consider it is absolute. extension : Extension of the output file, with the dot. Returns ------- A full filepath. """ if np.isinf(swr): swr_str = "SWR_infty" else: swr_str = f"SWR_{int(swr):05.0f}" freq_str = f"freq_{freq_mhz:03.0f}MHz" filename = ( filepath.with_stem(("_").join((swr_str, freq_str, filepath.stem))) .with_suffix(extension) .name ) folder = ( filepath.parent / out_folder if isinstance(out_folder, str) else out_folder ) if not folder.is_dir(): folder.mkdir(parents=True) return folder / filename
[docs] def save_by_position( items: dict[float, Any], base_path: Path, save_fn: Callable, kwargs: dict ): """Save keys of ``items`` according to their key (position). Parameters ---------- items : Objects to save, grouped by position. base_path : Common path of all objects to save. save_fn : Function to call for saving the objects. kwargs : Passed to ``save_fn``. """ for pos, item in items.items(): fname = base_path.with_name( f"{base_path.stem}_pos{pos:.3f}{base_path.suffix}" ) save_fn(item, fname, **kwargs)
[docs] def r_squared( residue: NDArray[np.float64], expected: NDArray[np.float64] ) -> float: """Compute the :math:`R^2` criterion to evaluate a fit. For Scipy ``curve_fit`` ``result`` output: ``residue`` is ``result[2]['fvec']`` and ``expected`` is the given ``data``. """ res_squared = residue**2 ss_err = np.sum(res_squared) ss_tot = np.sum((expected - expected.mean()) ** 2) r_squared = 1.0 - ss_err / ss_tot return r_squared
[docs] def types(my_list: Sequence) -> set[type]: """Get all different types in given list.""" return {type(x) for x in my_list}
[docs] def types_match(my_list: Sequence, to_match: type) -> bool: """Check if all elements of ``my_list`` have type ``type``. .. deprecated:: 1.9.0 Prefer :func:`is_collection_of`. """ logging.warning("Deprecated, prefer ``is_sequence_of``.") return types(my_list) == {to_match}
[docs] def is_collection_of( coll: Collection[object], typ: type[T] ) -> TypeGuard[Collection[T]]: """Return True if all elements of ``coll`` are instances of ``typ``. This is a clean replacement of :func:`types_match`. """ return all(isinstance(x, typ) for x in coll)
[docs] def drop_repeated_col( df: pd.DataFrame, col: pd.Index | str | None = None ) -> pd.DataFrame: """Remove consecutive rows with the same ``col`` value. If ``x_column`` is not provided, we take the first column in the dataframe. This function is used with :class:`.RPACurrent` and :class:`.RPAPotential` data. """ if col is None: col = df.columns[0] df = df.loc[df[col] != df[col].shift()] return df