Source code for ncem.pl.type_couplings

import re

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd


def _get_edge_weights(
    adata,
    alpha,
    pvals_key,
    params_key,
):
    """
    Plot cluster frequencies.

    Args:
        adata: AnnData instance with data and annotation.
        alpha:
        pvals_key:
        params_key:

    """
    de_genes = adata.varm[pvals_key].T < alpha
    de_genes = pd.DataFrame(de_genes.sum(axis=1), columns=["de_genes"])

    de_genes["index"] = [re.search("index_(.*):", de_genes.iloc[i].name).group(1) for i in range(de_genes.shape[0])]
    de_genes["neighbor"] = [
        re.search(":neighbor_(.*)", de_genes.iloc[i].name).group(1) for i in range(de_genes.shape[0])
    ]
    if params_key:
        magnitude = pd.DataFrame(np.log10((adata.varm[params_key].T ** 2).sum(axis=1)), columns=["magnitude"])
        edge_weights = de_genes.join(magnitude)
    else:
        edge_weights = de_genes
    return edge_weights


def _get_graph(x, edge_width_scale, edge_type, clip_edges):
    """
    Plot cluster frequencies.

    Args:
        x: AnnData instance with data and annotation.
        edge_width_scale:
        edge_type:
        clip_edges:

    """
    G = nx.from_pandas_edgelist(
        x, target="index", source="neighbor", edge_attr=["magnitude", "de_genes"], create_using=nx.DiGraph()
    )
    G.remove_edges_from(nx.selfloop_edges(G))
    nodes = np.unique(x["index"])
    pos = nx.circular_layout(G)

    # adjust edge scaling
    if edge_type == "magnitude":
        width = [e["magnitude"] * edge_width_scale for u, v, e in G.edges(data=True) if e["de_genes"] > clip_edges]
    elif edge_type == "de_genes":
        width = [e["de_genes"] * edge_width_scale for u, v, e in G.edges(data=True) if e["de_genes"] > clip_edges]

    return G, pos, nodes, width


def _draw_graph(G, pos, nodes, width, ax, palette, clip_edges):
    """
    Plot cluster frequencies.

    Args:
        G: AnnData instance with data and annotation.
        pos:
        nodes:
        width:
        ax:
        palette:
        clip_edges:

    """
    nx.set_node_attributes(G, dict([(x, palette[i]) for i, x in enumerate(nodes)]), "color")
    node_color = nx.get_node_attributes(G, "color")

    selected_edges = [(u, v) for u, v, e in G.edges(data=True) if e["de_genes"] > clip_edges]

    nx.draw_networkx(
        G,
        pos,
        with_labels=False,
        edgelist=selected_edges,
        width=width,
        arrowstyle="-|>",
        node_color=list(node_color.values()),
        ax=ax,
        connectionstyle="arc3, rad = 0.1",
    )


[docs]def circular( adata, alpha, scale_edge, pvals_key: str = "ncem_fdr_pvals", params_key: str = "ncem_params", figsize=(10, 5), edge_type: str = "magnitude", clip_edges: int = 0, ): """ Plot cluster frequencies. Args: adata: AnnData instance with data and annotation. alpha: scale_edge: params_key: pvals_key: figsize: edge_type: clip_edges: """ from matplotlib import rcParams from scanpy.plotting import palettes edge_weights = _get_edge_weights(adata, alpha, pvals_key, params_key) G, pos, nodes, width = _get_graph(edge_weights, scale_edge, edge_type, clip_edges) length = len(nodes) # check if default matplotlib palette has enough colors if len(rcParams["axes.prop_cycle"].by_key()["color"]) >= length: cc = rcParams["axes.prop_cycle"]() palette = [next(cc)["color"] for _ in range(length)] else: if length <= 20: palette = palettes.default_20 elif length <= 28: palette = palettes.default_28 elif length <= len(palettes.default_102): # 103 colors palette = palettes.default_102 else: palette = ["grey" for _ in range(length)] fig, ax = plt.subplots(nrows=1, ncols=2, figsize=figsize) ax[0].axis("off") ax[1].axis("off") _draw_graph(G, pos, nodes, width, ax=ax[0], palette=palette, clip_edges=clip_edges) node_color = nx.get_node_attributes(G, "color") values = sorted(list(set(node_color.values()))) for k, v in node_color.items(): # make dummy scatterplot to generate labels ax[1].scatter([], [], color=v, label=k) ax[1].legend(loc="center left", frameon=False) plt.tight_layout() plt.show() return edge_weights
[docs]def circular_rotated_labels( adata, alpha, scale_edge, pvals_key: str = "ncem_fdr_pvals", params_key: str = "ncem_params", figsize=(10, 5), edge_type: str = "magnitude", clip_edges: int = 0, text_space: float = 1.15, ): """ Plot cluster frequencies. Args: adata: AnnData instance with data and annotation. alpha: scale_edge: params_key: pvals_key: figsize: edge_type: clip_edges: text_space: """ from matplotlib import axes, gridspec, rcParams, ticker from scanpy.plotting import palettes edge_weights = _get_edge_weights(adata, alpha, pvals_key, params_key) G, pos, nodes, width = _get_graph(edge_weights, scale_edge, edge_type, clip_edges) length = len(nodes) # check if default matplotlib palette has enough colors if len(rcParams["axes.prop_cycle"].by_key()["color"]) >= length: cc = rcParams["axes.prop_cycle"]() palette = [next(cc)["color"] for _ in range(length)] else: if length <= 20: palette = palettes.default_20 elif length <= 28: palette = palettes.default_28 elif length <= len(palettes.default_102): # 103 colors palette = palettes.default_102 else: palette = ["grey" for _ in range(length)] logg.info( f"the obs value {value_to_plot!r} has more than 103 categories. Uniform " "'grey' color will be used for all categories." ) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) ax.axis("off") nx.set_node_attributes(G, dict([(x, palette[i]) for i, x in enumerate(nodes)]), "color") node_color = nx.get_node_attributes(G, "color") selected_edges = [(u, v) for u, v, e in G.edges(data=True) if e["de_genes"] > clip_edges] description = nx.draw_networkx_labels(G, pos, font_size=17) n = len(nodes) node_list = sorted(G.nodes()) angle = [] angle_dict = {} for i, node in zip(range(n), node_list): theta = 2.0 * np.pi * i / n angle.append((np.cos(theta), np.sin(theta))) angle_dict[node] = theta pos = {} for node_i, node in enumerate(node_list): pos[node] = angle[node_i] r = fig.canvas.get_renderer() trans = plt.gca().transData.inverted() for node, t in description.items(): bb = t.get_window_extent(renderer=r) bbdata = bb.transformed(trans) radius = text_space + bbdata.width / 2.0 position = (radius * np.cos(angle_dict[node]), radius * np.sin(angle_dict[node])) t.set_position(position) t.set_rotation(angle_dict[node] * 360.0 / (2.0 * np.pi)) t.set_clip_on(False) nx.draw_networkx( G, pos, with_labels=False, edgelist=selected_edges, width=width, arrowstyle="-|>", node_color=list(node_color.values()), ax=ax, connectionstyle="arc3, rad = 0.1", ) node_color = nx.get_node_attributes(G, "color") plt.tight_layout() plt.show()