Source code for hermes_rheo.transforms.mutation_number

from piblin.transform.abc.measurement_set_transform import MeasurementSetTransform
import piblin.data.datasets.abc.split_datasets.one_dimensional_dataset as one_dimensional_dataset
from piblin.data import Measurement, MeasurementSet
import numpy as np
import copy
from collections import defaultdict
from matplotlib import pyplot as plt


class MutationNumberMeasurementSet(MeasurementSet):
    """
    Custom MeasurementSet class for handling mutation number data and providing custom visualization.

    Methods:
    --------
    visualize(self, show_all_frequencies=False, y_lim=None, **kwargs):
        Visualizes the mutation number data, either showing all frequencies or averaging over time with standard
        deviation.
    """

    def visualize(self, show_all_frequencies=False, y_lim=None, **kwargs):
        """
        Visualizes the mutation number data.

        Parameters:
        -----------
        show_all_frequencies : bool
            If True, plot Mutation number vs. Time for all frequencies. If False, only plot the average Mutation
            vs. Time with standard deviation.
        y_lim : tuple, optional
            A tuple specifying the y-axis limits (y_min, y_max). Default is None, which automatically sets the limits.

        Returns:
        --------
        fig : matplotlib.figure.Figure
            The figure object containing the plot(s).
        ax : matplotlib.axes._subplots.AxesSubplot or tuple
            The axes object(s) of the plot.
        """
        mutation_number_by_frequency = defaultdict(lambda: defaultdict(list))

        # Extract data from the dataset structure
        for measurement in self.measurements:
            freq = measurement.conditions['Angular Frequency']
            for data in measurement.datasets:
                time_values = data.x_values
                mutation_number_values = data.y_values
                for time, mutation_number in zip(time_values, mutation_number_values):
                    mutation_number_by_frequency[freq][time].append(mutation_number)

        # Calculate average and standard deviation
        average_times = sorted({time for times in mutation_number_by_frequency.values() for time in times})
        average_mutation_numbers = []
        mutation_number_stds = []
        for time in average_times:
            all_values_at_time = []
            for freq in mutation_number_by_frequency:
                if time in mutation_number_by_frequency[freq]:
                    values = mutation_number_by_frequency[freq][time]
                    all_values_at_time.extend(values)
            if all_values_at_time:
                average_mutation_numbers.append(np.mean(all_values_at_time))
                mutation_number_stds.append(np.std(all_values_at_time))

        average_times = np.array(average_times)
        average_mutation_numbers = np.array(average_mutation_numbers)
        mutation_number_stds = np.array(mutation_number_stds)

        # Generate colors for different time steps
        unique_times = sorted(set(average_times))
        colors = plt.cm.viridis(np.linspace(0, 1, len(unique_times)))

        if show_all_frequencies:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

            # Plot Mu vs. Time for all frequencies on the left plot
            for color, time in zip(colors, unique_times):
                for freq in mutation_number_by_frequency:
                    if time in mutation_number_by_frequency[freq]:
                        mutation_numbers = mutation_number_by_frequency[freq][time]
                        ax1.scatter([time] * len(mutation_numbers), mutation_numbers, color=color, s=40)

            ax1.set_xlabel('Time (s)', fontsize=20)
            ax1.set_ylabel('Mutation number', fontsize=20)
            ax1.tick_params(axis='x', labelsize=18)
            ax1.tick_params(axis='y', labelsize=18)
            if y_lim:
                ax1.set_ylim(y_lim)
        else:
            fig, ax2 = plt.subplots(figsize=(10, 6))

        # Plot Average Mu vs. Time with error bars on the right plot
        for i, (time, mean, std) in enumerate(zip(average_times, average_mutation_numbers, mutation_number_stds)):
            color = colors[i % len(colors)]
            ax2.errorbar(time, mean, yerr=std, fmt='o', color=color, ecolor='black', markersize=8, elinewidth=3,
                         capsize=4)

        ax2.set_xlabel('Time (s)', fontsize=20)
        ax2.set_ylabel('Mutation number (average)', fontsize=20)
        ax2.tick_params(axis='x', labelsize=18)
        ax2.tick_params(axis='y', labelsize=18)
        if y_lim:
            ax2.set_ylim(y_lim)

        plt.tight_layout()

        return fig, (ax1, ax2) if show_all_frequencies else (fig, ax2)


[docs] class MutationNumber(MeasurementSetTransform): """ A transform class to calculate and visualize the mutation number from measurement sets. Args: ----- state : str The state variable to be used (default is 'time'). state_sampling : str The method to determine the state value ('average', 'first point', 'last point'). dependent_variable : str The dependent variable to be used for mutation number calculation (default is 'complex modulus'). Methods: -------- _state_to_condition(self, target): Computes the specified state's value using the defined method and adds it as a condition in the dataset. _apply(self, target, **kwargs): Applies the mutation number calculation to the dataset and returns a MutationNumberMeasurementSet object. """ def __init__(self, state='time', state_sampling='first point', dependent_variable='complex modulus', *args, **kwargs): """ Initializes the MutationNumber transform. Parameters: ----------- state : str The state variable to be used (default is 'time'). state_mode : str The method to determine the state value ('average', 'first point', 'last point'). dependent_variable : str The dependent variable to be used for mutation number calculation (default is 'complex modulus'). """ super().__init__(*args, **kwargs) self.state = state self.state_sampling = state_sampling self.dependent_variable = dependent_variable def _state_to_condition(self, target): """ Applies the transformation to compute the specified state's value and adds it as a condition. This is used when a physical property is collected transiently and needs to be converted to a condition at a specific time. For example, the average temperature at which each measurement was taken. The state_sampling value allows to choose between 'average', 'first point', or 'last point' of the state variable to be used as the condition. Parameters: ----------- target : MeasurementSet The target dataset to which the transformation is applied. Raises: ------- ValueError: If the state name is not found in the dataset or if the state sampling method is invalid. """ for measurement in target.measurements: state_found = False for dataset in measurement.datasets: if self.state in dataset._data_array_names: dataset.switch_coordinates(independent_name='temperature', dependent_name=self.state) if self.state_sampling == 'average': value = np.average(dataset.y_values) elif self.state_sampling == 'first point': value = dataset.y_values[0] elif self.state_sampling == 'last point': value = dataset.y_values[-1] else: raise ValueError("Invalid method. Choose 'average', 'first point', or 'last point'") measurement.add_condition(f'{self.state}', value) state_found = True break if not state_found: raise ValueError(f"State name {self.state} not found in datasets for one of the measurements") def _apply(self, target, **kwargs): """ Applies the mutation number calculation to the target dataset and returns a MutationNumberMeasurementSet. Parameters: ----------- target : MeasurementSet The target dataset containing the measurements to be processed. Returns: -------- MutationNumberMeasurementSet : The dataset containing mutation numbers for different frequencies. Raises: ------- ValueError: If state conditions or mutation data are missing or invalid. """ self._state_to_condition(target) moduli = [] frequencies = [] time_conditions = [] # This contains the wave start times waiting_times = [] wave_durations = [] # Extract data and prepare for mutation number calculation for i, measurement in enumerate(target.measurements): dataset_frequency = measurement.datasets[0] dataset_frequency.switch_coordinates(independent_name='angular frequency', dependent_name=self.dependent_variable) frequency = copy.deepcopy(dataset_frequency.x_values) modulus = copy.deepcopy(dataset_frequency.y_values) time_condition = measurement.conditions['time'] # This represents the wave start time waiting_time = measurement.details['waiting_time'] dataset_time = measurement.datasets[1] dataset_time.switch_coordinates('step time', 'time') step_time = copy.deepcopy(dataset_time.x_values) wave_duration = max(step_time) - waiting_time moduli.append(modulus) frequencies.append(frequency) time_conditions.append(time_condition) waiting_times.append(waiting_time) wave_durations.append(wave_duration) moduli_by_frequency = defaultdict(lambda: defaultdict(list)) for i, freq_list in enumerate(frequencies): for j, freq in enumerate(freq_list): moduli_by_frequency[freq][time_conditions[i]].append(moduli[i][j]) mutation_number_by_frequency = defaultdict(lambda: defaultdict(list)) T_values = [wave_duration - waiting_time for wave_duration, waiting_time in zip(wave_durations, waiting_times)] measurements = [] for i, freq_list in enumerate(frequencies): if i > 0: # Use time_conditions to calculate time difference (delta_t) time_diff = time_conditions[i] - time_conditions[i - 1] T = T_values[i] for j, freq in enumerate(freq_list): if time_conditions[i] != time_conditions[i - 1]: ln_modulus_curr = np.log(moduli[i][j]) ln_modulus_prev = np.log(moduli[i - 1][j]) derivative_ln_modulus = (ln_modulus_curr - ln_modulus_prev) / time_diff # Updated time difference if derivative_ln_modulus != 0: mutation_number = T / (1 / derivative_ln_modulus) mutation_number_by_frequency[freq][time_conditions[i]].append(mutation_number) # Create Measurement objects with mutation number data for freq, time_dict in mutation_number_by_frequency.items(): all_times = [] all_mutation_numbers = [] for time, mutation_numbers in time_dict.items(): all_times.extend([time] * len(mutation_numbers)) all_mutation_numbers.extend(mutation_numbers) # Create the mutation number dataset mutation_number_dataset = one_dimensional_dataset.OneDimensionalDataset( dependent_variable_data=np.array(all_mutation_numbers), dependent_variable_names=['mutation number'], dependent_variable_units=['a.u.'], independent_variable_data=[np.array(all_times)], independent_variable_names=['time'], independent_variable_units=['s'], source='datasets in time and frequency domain') # Create a Measurement object measurements.append( Measurement(datasets=[mutation_number_dataset], conditions={'Angular Frequency': freq}, details={})) return MutationNumberMeasurementSet(measurements=measurements)