Source code for cytopus.tl.hierarchy

#import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout
[docs] def build_nested_dict(graph, node): ''' build nested dictionary from reverse view of cytopus cell type hierarchy graph: networkx.DiGraph.view, reverse view of Cytopus cell type hierarchy root: str, name of root node in the reversed view ''' nested_dict = {node: {}} for neighbor in graph.successors(node): nested_dict[node].update(build_nested_dict(graph, neighbor)) return nested_dict
[docs] def get_hierarchy_dict(G): ''' reverse Cytopus cell type hierarchy and build nested hierarchy from it G: Cytopus.KnowledgeBase, containing cell type hierarchy ''' import networkx as nx #get view of cell type hierarchy node_list_plot = G.filter_nodes(attribute_name='class', attributes = ['cell_type']) def filter_node(n1): return n1 in node_list_plot view = nx.subgraph_view(G.graph, filter_node=filter_node) #reverse graph view (going from least granular to most granular cell type) reversed_view = view.reverse(copy=True) root_nodes = [n for n in reversed_view.nodes if reversed_view.in_degree(n) == 0] #build the nested dictionary hierarchy_dict = {} for root in root_nodes: hierarchy_dict.update(build_nested_dict(reversed_view, root)) return hierarchy_dict
[docs] def create_hierarchical_graph(data, type_label): import networkx as nx G = nx.DiGraph() for parent, children in data.items(): G.add_node(parent) if isinstance(children, dict): child_node = create_hierarchical_graph(children,type_label) G.add_nodes_from(child_node.nodes(data=True)) G.add_edges_from([(u,v) for u,v in child_node.edges()]) for child in children: G.add_edge(child, parent) else: for child in children: G.add_node(child) G.add_edge(child, parent) nx.set_node_attributes(G, type_label,'type') return G
[docs] def get_all_keys(d): keys = set() for k, v in d.items(): keys.add(k) if isinstance(v, dict): keys |= get_all_keys(v) return keys
[docs] def get_nodes_of_type(graph, node_type): nodes = [node for node in graph.nodes() if graph.nodes[node]['type'] == node_type] nodes.sort(key=lambda x: x.split('.')) return nodes
[docs] def get_indices(df, value): return df.index[df.astype(str).apply(lambda x: x == value).any(axis=1)].tolist()
[docs] def get_node_labels(graph, node_type): import networkx as nx nodes = [node for node in nx.dfs_postorder_nodes(graph) if graph.nodes[node]['type'] == node_type] return nodes[::-1]
[docs] class Hierarchy: import networkx as nx
[docs] def __init__(self, hierarchy_dict): ''' load hierarchy class hierarchy_dict: dict, nested dict containing the cell type hierarchy ''' self.graph = create_hierarchical_graph(hierarchy_dict,type_label = 'cell_type') print(self.__str__())
def __str__(self): all_celltypes = get_nodes_of_type(self.graph, 'cell_type') return f"Hierarchy class containing {len(all_celltypes)} cell types:{all_celltypes}"
[docs] def identities(self): ''' print cell types contained in hierarchy ''' print(get_nodes_of_type(self.graph, node_type='cell_type'))
[docs] def plot_celltypes(self, node_color='#8decf5', node_size = 1000,edge_width= 1,arrow_size=20 ,edge_color= 'k',label_size = 10, figsize=[30,30]): ''' plot all cell types contained in hierarchy object ''' #plt.rcParams["figure.figsize"] = figure_size #plt.rcParams["figure.autolayout"] = True import networkx as nx import matplotlib.pyplot as plt node_list_plot = get_nodes_of_type(self.graph, 'cell_type') def filter_node(n1): return n1 in node_list_plot view = nx.subgraph_view(self.graph, filter_node=filter_node) pos=graphviz_layout(view) plt.rcParams["figure.figsize"] = figsize nodes = nx.draw_networkx_nodes(view, pos=pos,node_color=node_color,nodelist=None,node_size=node_size,label=True) edges = nx.draw_networkx_edges(view, pos=pos, edgelist=None, width=edge_width, edge_color=edge_color, style='solid', alpha=None, arrowstyle=None, arrowsize=arrow_size, edge_cmap=None, edge_vmin=None, edge_vmax=None, ax=None, arrows=None, label=None, node_size=node_size, nodelist=None, node_shape='o', connectionstyle='arc3', min_source_margin=0, min_target_margin=0) labels = nx.draw_networkx_labels(view,pos=pos,font_size=label_size)
[docs] def add_cells(self, adata, obs_columns=None): ''' Add cells to their most granular annotation in the hierarchy object. adata: anndata.AnnData, containing the cell type annotations under adata.obs. obs_columns: list, list of columns in adata.obs where the cell type annotations are stored (recommended). ''' import warnings import networkx as nx if obs_columns is None: adata_sub = adata.obs else: adata_sub = adata.obs[obs_columns] # Get cell type annotations from adata (covering both obs_columns and full-obs cases) adata_celltypes = set() for column in adata_sub.columns: adata_celltypes.update(adata_sub[column].dropna().unique()) # Retrieve cell type nodes from the hierarchy celltype_nodes = get_node_labels(self.graph, 'cell_type') missing_celltypes = adata_celltypes - set(celltype_nodes) # Warn if there are missing cell types if missing_celltypes: warnings.warn( f"Cell types {list(missing_celltypes)} are not contained in the hierarchy. Skipping..." ) # Loop over cell types and assign cells for cell_type in celltype_nodes: barcodes = get_indices(adata_sub, cell_type) for barcode in barcodes: # Check if the cell is already in the hierarchy if barcode in self.graph: # Find all current cell type assignments for this cell current_annotations = [ edge[0] for edge in self.graph.in_edges(barcode) if self.graph.nodes[edge[0]]['type'] == 'cell_type' ] skip = False for current_annotation in current_annotations: if nx.has_path(self.graph, current_annotation, cell_type): # current_annotation is a descendant (more granular), keep it; skip new one skip = True break elif nx.has_path(self.graph, cell_type, current_annotation): # new cell_type is a descendant (more granular), replace current self.graph.remove_edge(current_annotation, barcode) if skip: continue # If no path exists between annotations, assume unrelated; add the new annotation # Add the cell to the hierarchy self.graph.add_node(barcode, type='cell') self.graph.add_edge(cell_type, barcode)
[docs] def query_ancestors(self, query_node, adata=None, obs_key='hierarchical_query'): ''' retrieves all cell barcodes belonging to the cell type and all of its subsets query_node: str, cell type name fir which to retrieve barcodes node_type: str, node type of cell type node (here: 'cell_type') adata: anndata.AnnData, adata to store the cell type annotations under adata.obs[obs_key] obs_key: str, column label to store cell tyoe annotations under adata.obs[obs_key] returns: dict, containing the barcodes belonging to each annotation in self.annotations, if adata is provided they will also be stored in adata.obs[obs_key] ''' import networkx as nx import anndata node_type='cell_type' if node_type == self.graph.nodes[query_node]['type']: nodes_of_specific_type = [node for node in nx.ancestors(self.graph, query_node) if self.graph.nodes[node]['type'] == node_type] nodes_of_specific_type.append(query_node) cell_nodes = {} for node in set(nodes_of_specific_type): cell_edges = [edge for edge in self.graph.edges(node) if self.graph.nodes[edge[1]]['type'] == 'cell'] cell_nodes[node] = [edge[1] for edge in cell_edges] cell_nodes_inv = {} for k,v in cell_nodes.items(): for i in v: cell_nodes_inv[i] = k if isinstance(adata,anndata._core.anndata.AnnData): adata.obs[obs_key]= adata.obs_names.map(cell_nodes_inv) self.annotations = cell_nodes else: print('query_node:',query_node,'should be of type',node_type,'stopping...')
[docs] def trim_annotations(self, adata, coarse_labels, obs_key='trimmed_annotation'): """ Trim the hierarchy to revert all labels to their coarse parent labels from a defined list of labels. coarse_labels: list, list of labels to which the hierarchy should be trimmed. adata: anndata.AnnData, adata to store the trimmed annotations under adata.obs[obs_key] obs_key: str, column label to store trimmed annotations under adata.obs[obs_key] returns: dict, containing the barcodes belonging to each coarse label. """ import warnings import networkx as nx import anndata # Check if all coarse labels are in the graph for label in coarse_labels: if label not in self.graph.nodes: raise ValueError(f"Label '{label}' does not exist in the hierarchy.") if self.graph.nodes[label]['type'] != 'cell_type': warnings.warn( f"Label '{label}' exists in the hierarchy but is not of type 'cell_type'. Skipping..." ) # Create a dictionary to store the trimmed annotations trimmed_annotations = {} # Iterate over each coarse label and collect its upstream cell nodes for label in coarse_labels: if self.graph.nodes[label]['type'] == 'cell_type': # Get all granular (upstream) nodes granular_nodes = nx.ancestors(self.graph, label) granular_nodes.add(label) # Include the label itself # Collect all cell nodes under the current coarse label cell_nodes = [] for node in granular_nodes: cell_edges = [edge for edge in self.graph.edges(node) if self.graph.nodes[edge[1]]['type'] == 'cell'] cell_nodes.extend([edge[1] for edge in cell_edges]) # Store the cell nodes under the current coarse label trimmed_annotations[label] = cell_nodes # If adata is provided, add the trimmed annotations to adata.obs if isinstance(adata, anndata._core.anndata.AnnData): cell_nodes_inv = {} for k, v in trimmed_annotations.items(): for i in v: cell_nodes_inv[i] = k adata.obs[obs_key] = adata.obs_names.map(cell_nodes_inv) return trimmed_annotations
[docs] def get_cells_for_cell_type(self, cell_type): """ Retrieve all cells assigned to a specific cell type in the hierarchy. cell_type: str, name of the cell type node to query. returns: ls, of cell barcodes assigned to the given cell type. """ # Check if the provided node is a valid cell type if cell_type not in self.graph.nodes: raise ValueError(f"Cell type '{cell_type}' does not exist in the hierarchy.") if self.graph.nodes[cell_type]['type'] != 'cell_type': raise ValueError(f"Node '{cell_type}' is not of type 'cell_type'.") # Retrieve all 'cell' nodes connected to the cell type node cell_edges = [ edge for edge in self.graph.edges(cell_type) if self.graph.nodes[edge[1]]['type'] == 'cell' ] cells = [edge[1] for edge in cell_edges] return cells