Source code for sarcgraph.sg

import numpy as np
import pandas as pd
import trackpy as tp
import networkx as nx

import skvideo.io
import skimage.io
import skvideo.utils

from skimage.filters import laplace, gaussian, threshold_otsu
from skimage import measure
from sklearn.cluster import OPTICS
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import distance_matrix
from pathlib import Path
from typing import List, Union, Tuple

from sarcgraph.config import Config


[docs] class SarcGraph: def __init__(self, config: Config = None, **kwargs): """Zdiscs and sarcomeres segmentation and tracking. Parameters ---------- config : Config, optional Configuration settings for the application. If not provided, default settings are used, and any additional settings can be passed via keyword arguments. **kwargs Additional settings to override the defaults or values in the provided Config object. """ if config is None: config = Config() self.config = config self._update_config(**kwargs) self._create_output_dir() self.print_config()
[docs] def print_config(self): self.config.print()
def _update_config(self, **kwargs): """ Update configurations with given keyword arguments. Parameters ---------- **kwargs : dict Keyword arguments corresponding to configuration attributes to be updated. """ for key, value in kwargs.items(): if hasattr(self.config, key): setattr(self.config, key, value) if key == "output_dir": self._create_output_dir() else: raise AttributeError( f"{key} is not a valid configuration parameter. Use " "print_config() to see all available parameters." ) def _create_output_dir(self): Path(self.config.output_dir).mkdir(parents=True, exist_ok=True) def _pop_kwargs(self, *args, **kwargs): """Pop keys from kwargs and return the value and the updated kwargs. Parameters ---------- *args Keys to pop from kwargs. **kwargs Keyword arguments to pop from. Returns ------- tuple (popped_dict, updated_kwargs) """ popped_dict = {key: kwargs.pop(key, None) for key in args} return popped_dict, kwargs ############################################################# # Data Processing # #############################################################
[docs] def load_data(self, file_path: str = None) -> np.ndarray: """Loads a video/image file. Parameters ---------- file_path : str, optional A direct path to a video/image file to load. Returns ------- numpy.ndarray All frames in the raw format in gray scale. """ if file_path is None: raise ValueError("The input file_path is not specified.") if self.config.input_type == "video": try: data = skvideo.io.vread(file_path) data = np.squeeze(skvideo.utils.rgb2gray(data)) self._check_validity(data, self.config.input_type) except ValueError: data = skimage.io.imread(file_path, plugin="tifffile") data = np.squeeze(skvideo.utils.rgb2gray(data)) self._check_validity(data, self.config.input_type) else: print("Trying to load image") data = skimage.io.imread(file_path) print("got error!") data = np.squeeze(skvideo.utils.rgb2gray(data)) self._check_validity(data, self.config.input_type) return data
def _check_validity(self, data, input_type): if input_type == "image": if data.ndim != 2: raise ValueError( "Loaded image data is not valid, expected a 2D image." ) else: if not (data.ndim == 3 and data.shape[0] > 1): raise ValueError( "Loaded video data is not valid, expected a 3D data." ) return True
[docs] def save_data( self, data: Union[np.ndarray, List, pd.DataFrame], file_name: str ) -> None: """Saves a numpy array or a pandas dataframe based on the type of data. Parameters ---------- data : Union[np.ndarray, List, pd.DataFrame] file_name: str """ if data is None: return if not isinstance(file_name, str): raise TypeError("file_name must be a string.") full_path = f"{self.config.output_dir}/{file_name}" if isinstance(data, (np.ndarray, list)): np.save(f"{full_path}.npy", data, allow_pickle=True) elif isinstance(data, pd.DataFrame): data.to_csv(f"{full_path}.csv") else: raise TypeError( "data must be either a numpy.ndarray, a list, or a" " pandas.DataFrame." )
[docs] def filter_frames(self, frames: np.ndarray) -> np.ndarray: """Convolves each image with Laplacian and Gaussian filters. Parameters ---------- frames : np.ndarray A 2D image or a stack of 2D images. Returns ------- np.ndarray A stack of filtered images. """ sigma = self.config.sigma if frames.ndim == 2: frames = frames[np.newaxis, ...] elif frames.ndim != 3: raise ValueError( "Input must be a 2D image or a stack of 2D images." ) # Initialize an array to hold the filtered images filtered_frames = np.zeros_like(frames, dtype=np.float64) # Apply filters to each frame for i in range(frames.shape[0]): frame = frames[i] # Apply the Laplacian filter laplacian_filtered = laplace(frame) # Apply the Gaussian filter gaussian_filtered = gaussian(laplacian_filtered, sigma=sigma) filtered_frames[i] = gaussian_filtered return filtered_frames
############################################################# # Z-Disc Segmentation # #############################################################
[docs] def zdisc_segmentation(self, **kwargs): """ Perform z-disc segmentation using various inputs and optional processing functions. - Pre-filtered frames can be provided by specifying 'filtered_frames'. - Raw frames can be provided by specifying 'raw_frames'; overrides 'filtered_frames'. - An input file path can be provided by specifying 'input_file'; overrides 'raw_frames' and 'filtered_frames'. - Processing functions can be provided as a list of callable functions by specifying 'processing_functions'. Any configuration parameters can also be updated to customize the behavior of the segmentation and processing. Parameters ---------- **kwargs : dict Arbitrary keyword arguments including: - 'input_file': Path to the input file to load raw frames. - 'raw_frames': Array-like structure with raw frames for processing. - 'filtered_frames': Array-like structure with pre-filtered frames for processing. - 'processing_functions': List of function references to apply additional processing. - Any configuration parameters to update. Returns ------- pd.DataFrame The processed z-disc information after segmentation. """ # Extract specific kwargs for input data inputs_dict, kwargs = self._pop_kwargs( "input_file", "raw_frames", "filtered_frames", "processing_functions", **kwargs, ) input_file = inputs_dict["input_file"] raw_frames = inputs_dict["raw_frames"] filtered_frames = inputs_dict["filtered_frames"] processing_functions = [] if inputs_dict["processing_functions"] is not None: processing_functions = inputs_dict["processing_functions"] self._update_config(**kwargs) # Validate provided processing functions if not all(callable(fn) for fn in processing_functions): raise ValueError( "All items in 'processing_functions' must be ", "callable." ) mock_contour_input = np.array( [[0, 1], [1, 0], [2, 1], [2, 2], [1, 3], [0, 2], [0, 1]] ) for fn in processing_functions: test_output = fn(mock_contour_input) if not isinstance(test_output, dict): raise ValueError( f"The function {fn.__name__} did not return a dictionary." ) # Load data if 'input_file' is provided if input_file is not None: raw_frames = self.load_data(input_file) # Generate filtered frames if 'raw_frames' are provided if raw_frames is not None: filtered_frames = self.filter_frames(raw_frames) # If no 'filtered_frames' are provided by now, raise an error if filtered_frames is None: raise ValueError( "No valid input data provided. Please specify " "'input_file', 'raw_frames', or 'filtered_frames'" "." ) # Detect contours from filtered frames contours = self._detect_contours(filtered_frames) # Default processing functions default_processing_methods = [ self._zdisc_center, self._zdisc_endpoints, ] # Apply additional processing functions provided by the user zdiscs = self._process_contours( contours, default_processing_methods + processing_functions ) # Save output if configured to do so if self.config.save_output: self.save_data(raw_frames, "raw_frames") self.save_data(filtered_frames, "filtered_frames") self.save_data(contours, "contours") self.save_data(zdiscs, "segmented_zdiscs") return zdiscs
def _detect_contours( self, filtered_frames: np.ndarray ) -> List[np.ndarray]: """Returns contours of detected zdiscs in all frames filtered by the length threshold specified in the configuration. Parameters ---------- filtered_frames : np.ndarray Returns ------- np.ndarray of shape (num_frames, num_contours, contour_length, 2) """ if filtered_frames.ndim != 3: raise ValueError( "The input must be a 3D numpy array: (frames, " "dim_1, dim_2)" ) valid_contours = [] for frame in filtered_frames: contours = self._find_frame_contours(frame) valid_contours_for_frame = self._validate_contours(contours) valid_contours.append(valid_contours_for_frame) return np.array(valid_contours, dtype=object) def _validate_contours( self, contours: List[np.ndarray] ) -> List[np.ndarray]: """Validates contours based on their length and ensure they are closed. Parameters ---------- contours : List[np.ndarray] List of numpy arrays of detected contours. Returns ------- List[np.ndarray] List of numpy arrays of valid contours. """ valid_contours = [] for contour in contours: if not np.allclose(contour[0], contour[-1]): continue if ( self.config.zdisc_min_length <= len(contour) and len(contour) <= self.config.zdisc_max_length ): valid_contours.append(contour) return valid_contours def _find_frame_contours(self, frame: np.ndarray) -> List[np.ndarray]: """Detects contours within a single frame. Parameters ---------- frame : np.ndarray Single frame of the filtered image stack. Returns ------- List[np.array] List of numpy arrays representing detected contours. """ contour_thresh = threshold_otsu(frame) contours = measure.find_contours(frame, contour_thresh) return contours def _process_contours( self, contours_all: Union[List, np.ndarray], processing_functions: List[callable], ) -> pd.DataFrame: """ Processes a list of contours across all frames using a list of default and user-defined functions and saves the results as a dataframe Parameters ---------- contours_all : List or np.ndarray A list of z-discs contours as 2D numpy arrays. processing_functions : List[callable] A list of functions that each take a contour as input and return a dictionary where keys are attribute names and values are the corresponding data (e.g. {}'x': 125.64}). Returns ------- pd.DataFrame A dataframe where each row represents a zdisc and includes data from all processing functions. """ data_frame_list = [] for frame_index, contours in enumerate(contours_all): for contour in contours: zdisc_info = {"frame": frame_index} for func in processing_functions: zdisc_info.update(func(contour)) data_frame_list.append(zdisc_info) return pd.DataFrame(data_frame_list) def _zdisc_center(self, contour: np.ndarray) -> dict: """ Calculate the centroid of a zdisc given its contour. Parameters ---------- contour : np.ndarray NumPy array of shape (contour_length, 2) representing z-disc Returns ------- dict A dictionary with keys 'x' and 'y' representing the z-disc centroid """ center_coords = np.mean(np.unique(contour, axis=0), axis=0) return {"x": center_coords[0], "y": center_coords[1]} def _zdisc_endpoints(self, contour: Union[List, np.ndarray]) -> dict: """ Identify the main axis of a zdisc by finding the two points in the contour that are farthest apart from each other. Parameters ---------- contour : np.ndarray NumPy array of shape (contour_length, 2) representing z-disc Returns ------- dict A dictionary with keys 'p1_x', 'p1_y', 'p2_x', 'p2_y' representing the coordinates of the two endpoints. """ dist_mat = distance_matrix(contour, contour) indices = np.unravel_index(dist_mat.argmax(), dist_mat.shape) p1, p2 = contour[indices[0]], contour[indices[1]] return {"p1_x": p1[0], "p1_y": p1[1], "p2_x": p2[0], "p2_y": p2[1]} ######################################################### # Z-disc Tracking # ######################################################### def _merge_tracked_zdiscs( self, tracked_zdiscs: pd.DataFrame, ) -> pd.DataFrame: """A post processing step to group related partially tracked zdiscs using the OPTICS algorithm. Increases the robustness of zdisc tracking as well as the number of fully tracked zdiscs. Parameters ---------- tracked_zdiscs : pd.DataFrame tracked zdiscs information for all frames Returns ------- pd.DataFrame Notes ----- For a detailed description of the OPTICS algorithm check: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.OPTICS.html """ full_track_ratio = self.config.full_track_ratio num_frames = tracked_zdiscs.frame.max() + 1 tracked_zdiscs_grouped = tracked_zdiscs.groupby("particle")["particle"] tracked_zdiscs["freq"] = tracked_zdiscs_grouped.transform("count") fully_tracked_zdiscs = tracked_zdiscs.loc[ tracked_zdiscs.freq == num_frames ] partially_tracked_zdiscs = tracked_zdiscs.loc[ tracked_zdiscs.freq < num_frames ] if partially_tracked_zdiscs.empty: return fully_tracked_zdiscs partially_tracked_clusters = ( partially_tracked_zdiscs[["x", "y", "particle"]] .groupby(by=["particle"]) .mean() ) # merge related clusters (neighbors) all_clusters_xy = ( tracked_zdiscs.groupby("particle").mean()[["x", "y"]].to_numpy() ) clusters_min_dist = np.min( distance_matrix(all_clusters_xy, all_clusters_xy) + 1e6 * np.eye(len(all_clusters_xy)), axis=1, ) optics_max_eps = np.mean(clusters_min_dist) data = np.array(partially_tracked_clusters) optics_model = OPTICS(max_eps=optics_max_eps, min_samples=2) optics_result = optics_model.fit_predict(data) optics_clusters = np.unique(optics_result) all_merged_zdiscs = [] for i, optics_cluster in enumerate(optics_clusters): index = np.where(optics_result == optics_cluster)[0] particles_in_cluster = partially_tracked_clusters.iloc[ index ].index.to_numpy() if optics_cluster >= 0: merged_zdiscs = ( partially_tracked_zdiscs.loc[ partially_tracked_zdiscs["particle"].isin( particles_in_cluster ) ] .groupby("frame") .mean() ) if len(merged_zdiscs) > num_frames * full_track_ratio: merged_zdiscs.particle = -(i + 1) merged_zdiscs.freq = len(merged_zdiscs) all_merged_zdiscs.append(merged_zdiscs.reset_index()) else: for p in particles_in_cluster: no_merge_zdiscs = partially_tracked_zdiscs.loc[ partially_tracked_zdiscs["particle"] == p ] if len(no_merge_zdiscs) > num_frames * full_track_ratio: all_merged_zdiscs.append(no_merge_zdiscs) if all_merged_zdiscs: all_merged_zdiscs = pd.concat(all_merged_zdiscs) return pd.concat((fully_tracked_zdiscs, all_merged_zdiscs)) else: return fully_tracked_zdiscs
[docs] def zdisc_tracking(self, **kwargs) -> pd.DataFrame: """Track detected Z-Discs in video data, similar to zdisc_segmentation, but with an additional tracking step. If 'segmented_zdiscs' is provided, it directly proceeds to tracking; otherwise, it first runs the z-discs segmentation process. This function shares similar inputs to `zdisc_segmentation` with the addition of an optional 'segmented_zdiscs' DataFrame parameter. Parameters ---------- **kwargs : dict, optional Shared keyword arguments with `zdisc_segmentation`, plus: - segmented_zdiscs: pd.DataFrame, optional Pre-segmented zdiscs information. If provided, it should contain at least the following columns: 'frame', 'x', 'y', 'p1_x', 'p1_y', 'p2_x', 'p2_y'. Returns ------- pd.DataFrame Tracked zdiscs information; adds 'particle' column for each tracked zdisc. Notes ----- - For details on shared parameters and segmentation process, refer to :func:`sarcgraph.sg.SarcGraph.zdisc_segmentation`. See Also -------- :func:`sarcgraph.sg.SarcGraph.zdisc_segmentation` """ # Extract specific kwargs for input data inputs_dict, kwargs = self._pop_kwargs("segmented_zdiscs", **kwargs) segmented_zdiscs = inputs_dict["segmented_zdiscs"] if segmented_zdiscs is None: detected_zdiscs = self.zdisc_segmentation(**kwargs) return self.zdisc_tracking( segmented_zdiscs=detected_zdiscs, **kwargs ) # Validate segmented_zdiscs if provided required_columns = { "frame", "x", "y", "p1_x", "p1_y", "p2_x", "p2_y", } if ( not isinstance(segmented_zdiscs, pd.DataFrame) or segmented_zdiscs.empty or not required_columns.issubset(segmented_zdiscs.columns) ): raise ValueError( "Provided 'segmented_zdiscs' DataFrame is not in the " "correct format. The DataFrame must be non-empty and " "include the following columns: 'frame', 'x', 'y', 'p1_x'," " 'p1_y', 'p2_x', 'p2_y'." ) _, kwargs = self._pop_kwargs( "input_file", "raw_frames", "filtered_frames", **kwargs ) self._update_config(**kwargs) if self.config.input_type == "image": segmented_zdiscs["particle"] = np.arange(len(segmented_zdiscs)) tracked_zdiscs = segmented_zdiscs else: num_frames = len(segmented_zdiscs["frame"].unique()) t = tp.link_df( segmented_zdiscs, search_range=self.config.tp_depth, memory=num_frames, ) tracked_zdiscs = tp.filter_stubs(t, 2).reset_index(drop=True) if not self.config.skip_merge: tracked_zdiscs = self._merge_tracked_zdiscs(tracked_zdiscs) if self.config.save_output: self.save_data(tracked_zdiscs, "tracked_zdiscs") return tracked_zdiscs
############################################################# # Sarcomere Detection # #############################################################
[docs] def sarcomere_detection( self, **kwargs ) -> Tuple[pd.DataFrame, List[nx.Graph]]: """Detect sarcomeres in a video/image using dynamic keyword arguments. Sarcomere detection can be initiated with tracked zdisc information, or by first running zdisc tracking. Parameters ---------- **kwargs : dict Keyword arguments dynamically passed. Relevant keys are: 'input_file' : str The address of an image or a video file to be loaded. 'raw_frames' : np.ndarray Raw input frames as a numpy array. 'filtered_frames' : np.ndarray Pre-filtered frames for processing. 'segmented_zdiscs' : pd.DataFrame Information of all detected zdiscs in every frame. 'tracked_zdiscs' : pd.DataFrame Information of tracked zdiscs. Must be non-empty and include the following columns: 'frame', 'x', 'y', 'p1_x', 'p1_y', 'p2_x', 'p2_y', 'particle'. If provided and valid, sarcomere detection is initiated immediately. Additional configuration parameters relevant for sarcomere detection and tracking can also be provided via kwargs. Returns ------- Tuple[pd.DataFrame, List[nx.Graph]] A tuple containing: - A DataFrame with detected sarcomeres information including 'frame', 'sarc_id', 'zdiscs', 'x', 'y', 'length', 'width', and 'angle'. - A list of Graph objects representing connected sarcomeres (myofibrils). Notes ----- - If 'tracked_zdiscs' is not provided, this function will internally call `zdisc_tracking` to generate tracked zdiscs from the provided keyword arguments. It is assumed that `zdisc_tracking` can handle error checking and appropriate defaults for its parameters. - For a detailed description of the Trackpy package check: http://soft-matter.github.io/trackpy/v0.5.0/tutorial.html - For a detailed description of the OPTICS algorithm check: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.OPTICS.html See Also -------- :func:`sarcgraph.sg.SarcGraph.zdisc_segmentation` :func:`sarcgraph.sg.SarcGraph.zdisc_tracking` """ # Extract specific kwargs for input data inputs_dict, kwargs = self._pop_kwargs("tracked_zdiscs", **kwargs) tracked_zdiscs = inputs_dict["tracked_zdiscs"] if tracked_zdiscs is None: tracked_zdiscs = self.zdisc_tracking(**kwargs) return self.sarcomere_detection( tracked_zdiscs=tracked_zdiscs, **kwargs ) # Validate segmented_zdiscs if provided required_columns = { "frame", "x", "y", "p1_x", "p1_y", "p2_x", "p2_y", "particle", } if ( not isinstance(tracked_zdiscs, pd.DataFrame) or tracked_zdiscs.empty or not required_columns.issubset(tracked_zdiscs.columns) ): raise ValueError( "Provided 'tracked_zdiscs' DataFrame is not in the correct " "format. The DataFrame must be non-empty and include at least " "the following columns: 'frame', 'x', 'y', 'p1_x', 'p1_y', " "'p2_x', 'p2_y', 'particle'." ) _, kwargs = self._pop_kwargs( "input_file", "raw_frames", "filtered_frames", "segmented_zdiscs", **kwargs, ) self._update_config(**kwargs) zdiscs_clusters = ( tracked_zdiscs.groupby("particle") .mean() .reset_index()[["x", "y", "particle"]] .to_numpy() ) G = self._zdisc_to_graph(zdiscs_clusters) G = self._score_graph(G) G = self._prune_graph(G) myofibrils = [G.subgraph(c).copy() for c in nx.connected_components(G)] sarcs = self._process_sarcomeres(G, tracked_zdiscs) if self.config.save_output: self.save_data(sarcs, "sarcomeres") return sarcs, myofibrils
def _zdisc_to_graph(self, zdiscs: np.array) -> nx.Graph: """Creates a graph with zdiscs as nodes. Each zdisc is connected to its ``K`` nearest neighbors. Parameters ---------- zdiscs : np.array, shape=(N, 3) zdiscs information as an array. The first two columns are the x and y location of zdisc centers and the last is the particle id. Returns ------- nx.Graph """ G = self._graph_initialization(zdiscs) return self._add_edges(G, self._find_nearest_neighbors(zdiscs[:, 0:2])) def _graph_initialization(self, zdiscs: np.array) -> nx.Graph: """Initializes a graph of z-discs. Parameters ---------- zdiscs : np.array An array of z-disc data. The first two columns are expected to be x and y coordinates, and the last column is expected to be the particle ID. Returns ------- nx.Graph A networkx Graph object where each node represents a z-disc and has 'pos' (position) and 'particle_id' attributes. """ num_nodes = len(zdiscs) G = nx.Graph() G.add_nodes_from(range(num_nodes)) nodes_pos_dict = {i: pos for i, pos in enumerate(zdiscs[:, 0:2])} nodes_particle_dict = {j: id for j, id in enumerate(zdiscs[:, -1])} nx.set_node_attributes(G, values=nodes_pos_dict, name="pos") nx.set_node_attributes( G, values=nodes_particle_dict, name="particle_id" ) return G def _find_nearest_neighbors(self, zdiscs: np.array) -> np.array: """Finds the K nearest neighbors of each z-disc. K can be specified as a config parameter num_neighbors. Parameters ---------- zdiscs : np.array Array of z-disc positions (x and y). Returns ------- np.array Array of indices of nearest neighbors for each z-disc. """ K = self.config.num_neighbors neigh = NearestNeighbors(n_neighbors=2) neigh.fit(zdiscs) nearestNeighbors = neigh.kneighbors( zdiscs, K + 1, return_distance=False ) return nearestNeighbors def _add_edges(self, G: nx.Graph, nearestNeighbors: np.array) -> nx.Graph: """Adds edges to the graph by connecting on K nearest neighbors. Parameters ---------- G : nx.Graph The graph to which edges will be added. nearestNeighbors : np.array Array of indices of nearest neighbors. Returns ------- nx.Graph The graph with added edges. """ edges = [] for node, neighbors in enumerate(nearestNeighbors[:, 1:]): for neighbor in neighbors: edges.append((node, neighbor)) G.add_edges_from(edges) return G def _score_graph(self, G: nx.Graph) -> nx.Graph: """Assigns a score to each connection of the input graph. Higher score indicates the two corresponding zdiscs are likely to be two ends of a sarcomere. Parameters ---------- G : nx.Graph Returns ------- nx.Graph a graph of zdiscs with all connections scored """ c_avg_length = self.config.coeff_avg_length l_avg = self.config.avg_sarc_length l_max = self.config.max_sarc_length l_min = self.config.min_sarc_length edges_attr_dict = {} for node in range(G.number_of_nodes()): for neighbor in G.neighbors(node): score = 0 v1, l1 = self._sarc_vector(G, node, neighbor) if l1 <= l_max and l1 >= l_min: avg_length_score = np.exp(-np.pi * (1 - l1 / l_avg) ** 2) for far_neighbor in G.neighbors(neighbor): if far_neighbor in [node, neighbor]: pass else: v2, l2 = self._sarc_vector( G, far_neighbor, neighbor ) sum_scores = self._sarc_score(v1, v2, l1, l2) score = np.max((score, sum_scores)) score += c_avg_length * avg_length_score edges_attr_dict[(node, neighbor)] = score edges_attr_dict_keep_max = {} for key in edges_attr_dict.keys(): node_1 = key[0] node_2 = key[1] max_score = max( edges_attr_dict[(node_1, node_2)], edges_attr_dict[(node_2, node_1)], ) edges_attr_dict_keep_max[(min(key), max(key))] = max_score nx.set_edge_attributes( G, values=edges_attr_dict_keep_max, name="score" ) return G def _sarc_vector(self, G, node1, node2): """Calculates the vector and length between two nodes in the graph. Parameters ---------- G : nx.Graph node1, node2 The node ids for which the vector is calculated. Returns ------- tuple The vector connecting node1 to node2 and its length. """ sarc = G.nodes[node2]["pos"] - G.nodes[node1]["pos"] length = np.linalg.norm(sarc) return sarc, length def _length_score(self, l1, l2): """Calculates the length score between two potential sarcomeres. Parameters ---------- l1, l2 The lengths of the two sarcomeres to compare. Returns ------- float The calculated length score. """ d_l = np.abs(l2 - l1) / l1 return 1 / (1 + d_l) def _sarcs_angle(self, v1, v2, l1, l2): """Calculates the angle between two vectors. Parameters ---------- v1, v2 : np.array l1, l2 : float Returns ------- float The angle between the two vectors. """ return np.arccos(np.dot(v1, v2) / (l1 * l2)) / (np.pi / 2) def _angle_score(self, v1, v2, l1, l2): """ Calculates the angle score between two potential sarcomeres. Parameters ---------- v1, v2 : np.array l1, l2 : float Returns ------- float The calculated angle score. """ theta = self._sarcs_angle(v1, v2, l1, l2) return np.power(theta - 1, 2) if theta >= 1 else 0 def _sarc_score(self, v1, v2, l1, l2): """ Calculates sarcomere score based on length and angle scores. Parameters ---------- v1, v2 : np.array The vectors representing two potential connected sarcomeres. l1, l2 : float The lengths of the two sarcomeres. Returns ------- float The sarcomere score. """ c_len = self.config.coeff_neighbor_length c_ang = self.config.coeff_neighbor_angle len_score = self._length_score(l1, l2) ang_score = self._angle_score(v1, v2, l1, l2) return c_len * len_score + c_ang * ang_score def _prune_graph(self, G: nx.Graph) -> nx.Graph: """Prunes the input graph to get rid of invalid or less probable connections. Parameters ---------- G : nx.Graph A scored graph of zdiscs clusters Returns ------- nx.Graph """ score_threshold = self.config.score_threshold angle_threshold = self.config.angle_threshold nx.set_edge_attributes(G, values=0, name="validity") for node in range(G.number_of_nodes()): vectors = [] scores = [] neighbors = list(G.neighbors(node)) for neighbor in neighbors: vectors.append(G.nodes[neighbor]["pos"] - G.nodes[node]["pos"]) scores.append(G[node][neighbor]["score"]) # sort by scores if np.max(scores) > score_threshold: sort_indices = np.argsort(scores)[::-1] best_vector = vectors[sort_indices[0]] G[node][neighbors[sort_indices[0]]]["validity"] += 1 for idx in sort_indices[1:]: s = scores[idx] n = neighbors[idx] v = vectors[idx] l1 = np.linalg.norm(best_vector) l2 = np.linalg.norm(v) theta = self._sarcs_angle(v, best_vector, l1, l2) if theta > angle_threshold and s > score_threshold: G[node][n]["validity"] += 1 break edges2remove = [] for edge in G.edges(): if G.edges[edge]["validity"] < 2: edges2remove.append(edge) G.remove_edges_from(edges2remove) return G def _get_connected_zdiscs( self, G: nx.Graph, tracked_zdiscs: pd.DataFrame, edge: Tuple[int, int], ) -> [pd.DataFrame, pd.DataFrame]: """ Retrieves the tracking data for the z-discs connected by a given edge in the graph. Parameters ---------- G : nx.Graph The graph representing z-disc connections. tracked_zdiscs : pd.DataFrame The dataframe with tracking data for all z-discs. edge : tuple The tuple representing the edge connecting two z-discs in the graph. Returns ------- Tuple[pd.DataFrame, pd.DataFrame] A tuple of dataframes, each corresponding to the tracking data of one of the two z-discs connected by the edge. """ p1 = int(G.nodes[edge[0]]["particle_id"]) p2 = int(G.nodes[edge[1]]["particle_id"]) z1 = tracked_zdiscs[tracked_zdiscs.particle == p1] z2 = tracked_zdiscs[tracked_zdiscs.particle == p2] z1.columns = np.insert(z1.columns.values[1:] + "_p1", 0, "frame") z2.columns = np.insert(z2.columns.values[1:] + "_p2", 0, "frame") return z1, z2 def _initialize_sarc( self, z0: pd.DataFrame, z1: pd.DataFrame, z2: pd.DataFrame ) -> pd.DataFrame: """ This function creates a base DataFrame for a single sarcomere by merging base frame information from `z0` with z-disc data from `z1` and `z2`. Parameters ---------- z0 : pd.DataFrame A DataFrame with a single column 'frame' representing all frames in the dataset. z1 : pd.DataFrame A DataFrame containing the details of the first z-disc of the sarcomere in each frame where it is present. z2 : pd.DataFrame A DataFrame containing the details of the second z-disc of the sarcomere in each frame where it is present. Returns ------- pd.DataFrame A merged DataFrame representing the initial state of a single sarcomere across all frames. """ sarc = pd.merge(z0, z1, how="outer", on="frame") sarc = pd.merge(sarc, z2, how="outer", on="frame") return sarc def _process_sarc(self, sarc: pd.DataFrame) -> pd.DataFrame: """ Processes a single sarcomere's data to calculate its properties. Takes a dataframe initialized with sarcomere information and computes various properties such as center position, length, width, and angle based on the positions of the connected z-discs. Parameters ---------- sarc : pd.DataFrame The initialized dataframe for a single sarcomere containing the positions of the connected z-discs and other relevant information. Returns ------- pd.DataFrame The dataframe for a single sarcomere with additional computed properties. """ sarc["x"] = (sarc.x_p1 + sarc.x_p2) / 2 sarc["y"] = (sarc.y_p1 + sarc.y_p2) / 2 length = np.sqrt( (sarc.x_p1 - sarc.x_p2) ** 2 + (sarc.y_p1 - sarc.y_p2) ** 2 ) sarc["length"] = length width1 = np.sqrt( (sarc.p1_x_p1 - sarc.p2_x_p1) ** 2 + (sarc.p1_y_p1 - sarc.p2_y_p1) ** 2 ) width2 = np.sqrt( (sarc.p1_x_p2 - sarc.p2_x_p2) ** 2 + (sarc.p1_y_p2 - sarc.p2_y_p2) ** 2 ) sarc["width"] = (width1 + width2) / 2 angle = np.arctan2(sarc.x_p2 - sarc.x_p1, sarc.y_p2 - sarc.y_p1) angle[angle < 0] += np.pi angle = np.pi - angle sarc["angle"] = angle sarc = sarc[ [ "frame", "sarc_id", "x", "y", "length", "width", "angle", "zdiscs", ] ] return sarc def _process_sarcomeres( self, G: nx.Graph, tracked_zdiscs: pd.DataFrame ) -> pd.DataFrame: """ Processes sarcomeres from a graph of connected z-discs. This function takes a graph where nodes represent z-discs and edges represent connections between them, alongside a dataframe of tracked z-discs across frames. It combines information from connected z-discs, and compiles processed sarcomere properties into a dataframe. Parameters ---------- G : nx.Graph A networkx Graph object where nodes represent z-discs and edges represent sarcomeres connecting these z-discs. tracked_zdiscs : pd.DataFrame A pandas dataframe containing tracked positions of z-discs across different frames. Returns ------- pd.DataFrame A dataframe containing processed sarcomere data with assigned identifiers and their properties such as length, width, and angle for each frame. """ sarcs = [] num_frames = tracked_zdiscs.frame.max() + 1 z0 = pd.DataFrame(np.arange(0, num_frames, 1), columns=["frame"]) for i, edge in enumerate(G.edges): z1, z2 = self._get_connected_zdiscs(G, tracked_zdiscs, edge) sarc = self._initialize_sarc(z0, z1, z2) sarc["sarc_id"] = i p1 = z1.particle_p1.values[0] p2 = z2.particle_p2.values[0] sarc["zdiscs"] = ",".join(map(str, sorted((p1, p2)))) sarcs.append(self._process_sarc(sarc)) sarcs = pd.concat(sarcs).reset_index().drop("index", axis=1) return sarcs