Skip to content

cafe.plot.plot_graph

cafe.plot.plot_graph

plot_graph(fadata, model_name=None, color=None, layout_by_row='color', nx_draw_kwargs={}, recompute_milestone_embedding=True, save=None, **sc_pl_embedding_kwargs)

Plot DAG base on milestone network amd show cell embedding

Parameters:

Name Type Description Default
fadata FateAnnData

FateAnnData object with trajectory.

required
model_name str | Sequence[str]

model name(s).

required
color str | Sequence[str]

Color(s), default extracted from prior information.

required
layout_by_row str

layout by row.

required
nx_draw_kwargs dict

additional keyword arguments for networkx draw.

required
sc_pl_embedding_kwargs dict

additional keyword arguments for scanpy embedding plot.

required
recompute_milestone_embedding bool

whether to recompute milestone embedding.

required
save bool | str

path to save the plot.

required
sc_pl_embedding_kwargs dict

additional keyword arguments for scanpy embedding plot.

required

Returns:

Name Type Description
axes

axes

Source code in cafe/plot/plot_graph.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def plot_graph(
    fadata: FateAnnData,
    model_name: str | Sequence[str] = None,
    color: str | Sequence[str] = None,
    layout_by_row: str = "color",
    nx_draw_kwargs: dict = {},
    recompute_milestone_embedding: bool = True,
    save: bool | str = None,
    **sc_pl_embedding_kwargs,
):
    """Plot DAG base on milestone network amd show cell embedding

    Args:
        fadata (FateAnnData): FateAnnData object with trajectory.
        model_name (str | Sequence[str], optional): model name(s).
        color (str | Sequence[str], optional): Color(s), default extracted from prior information.
        layout_by_row (str, optional): layout by row.
        nx_draw_kwargs (dict, optional): additional keyword arguments for networkx draw.
        sc_pl_embedding_kwargs (dict, optional): additional keyword arguments for scanpy embedding plot.
        recompute_milestone_embedding (bool, optional): whether to recompute milestone embedding.
        save (bool | str, optional): path to save the plot.
        sc_pl_embedding_kwargs (dict, optional): additional keyword arguments for scanpy embedding plot.

    Returns:
        axes: axes
    """
    if model_name is None:
        model_name = fadata.model_name
    if color is None:
        color = fadata.prior_information.get("cluster")

    model_name_list = [model_name] if isinstance(model_name, str) else model_name
    color_list = [color] if isinstance(color, str) else color

    if len(model_name_list) == 1:
        layout_by_row = "model"  # only one model as row
    if len(color_list) == 1:
        layout_by_row = "color"  # only one color as row

    # create subplots
    if layout_by_row == "model":
        row_list, col_list = model_name_list, color_list
    elif layout_by_row == "color":
        row_list, col_list = color_list, model_name_list
    n_rows = len(row_list)
    n_cols = len(col_list)
    figsize = sc_pl_embedding_kwargs.pop("figsize", (7 * n_cols, 5 * n_rows))  # replace sc plt figsize
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)

    # multiple model and color support
    for i, model_name in enumerate(model_name_list):
        milestone_wrapper = fadata.get_milestone_wrapper(model_name=model_name)  # extract milestone network
        milestone_id_list = milestone_wrapper.id_list
        milestone_network = milestone_wrapper.milestone_network
        milestone_percentages = milestone_wrapper.milestone_percentages
        divergence_regions = milestone_wrapper.divergence_regions
        is_directed = milestone_wrapper["directed"]
        milestone_embedding = None
        if recompute_milestone_embedding or milestone_embedding is None:
            logger.debug(f"calculate new milestone embedding for model_name:{model_name}.")
            G = nx.from_pandas_edgelist(
                milestone_network,
                source="from",
                target="to",
                edge_attr=True,
                create_using=nx.DiGraph if is_directed else nx.Graph,
            )
            for descrete_node in set(milestone_id_list) - set(G.nodes):
                # descrete node need external addition
                G.add_node(descrete_node)
            milestone_emb_dict = nx.nx_agraph.graphviz_layout(G, prog="dot")  # position
            # position fo cell
            milestone_emb_df = pd.DataFrame(milestone_emb_dict).T

            def mix_emb(mpg, emb_df=milestone_emb_df):
                # mix related milestone emb to get position for a cell
                mpg_emb = emb_df.loc[mpg["milestone_id"]]
                return mpg_emb.apply(lambda emb_dim: (emb_dim.array * mpg["percentage"].array)).sum()

            basis = "_milestone_network_emb"
            cell_emb_df = milestone_percentages.groupby("cell_id").apply(lambda mpg: mix_emb(mpg))
        else:
            # TODO: save in fadata
            # milestone_embedding = fadata.get_milestone_embedding(model_name=model_name)  # # TODO: save in fadata
            pass

        fadata_index_set = set(fadata.obs.index)
        emb_index_set = set(cell_emb_df.index)
        if fadata.shape[0] == cell_emb_df.shape[0] and fadata_index_set == emb_index_set:
            fadata.obsm[basis] = cell_emb_df.loc[fadata.obs.index].values
        else:
            # may lose some cells in cell_emb_df
            valid_cell_set = fadata_index_set & emb_index_set
            missing_cell_set = fadata_index_set - emb_index_set
            new_cell_emb_df = pd.DataFrame(index=fadata.obs.index, columns=cell_emb_df.columns)
            new_cell_emb_df.loc[valid_cell_set] = cell_emb_df.loc[valid_cell_set]
            new_cell_emb_df.loc[missing_cell_set] = 0.0  # set missing cell emb to zero
            fadata.obsm[basis] = new_cell_emb_df.values
            logger.warning(f"cell ids are mismatch between fadata.index and cell_emb_df '{model_name}', missing cells: {missing_cell_set}.")

        for j, color in enumerate(color_list):
            if layout_by_row == "model":
                ax = axes[i, j]  # row is model_name, col is color
            else:
                ax = axes[j, i]  # row is color, col is model_name

            if color == "milestone":
                # color of cells
                cell_color_key = "milestone"
                missing_cell_color = "#808080"
                cell_color_dict = milestone_wrapper["cell_color_dict"]
                if len(cell_color_dict) != fadata.n_obs:
                    logger.warning(f"milestone cell color length not equal to cell number! set missing color as '{missing_cell_color}'.")
                fadata.obs[cell_color_key] = pd.Categorical(fadata.obs.index, categories=fadata.obs.index.tolist())
                fadata.uns[f"{cell_color_key}_colors"] = [
                    cell_color_dict[i] if i in cell_color_dict else missing_cell_color for i in fadata.obs.index
                ]

            # base scanpy embedding scatter plot
            # plot single str for color parameter
            # zorder: 1: line, 2: cell(scanpy), 3: milestone
            sc_pl_embedding_kwargs["title"] = f"{fadata.get_parsed_model_name(model_name)}({color})"  # add title for subplot
            sc.pl.embedding(fadata, basis=basis, color=color, show=False, zorder=2, ax=ax, **sc_pl_embedding_kwargs)

            # legend remove
            if color == "milestone" or (layout_by_row == "color" and i < len(model_name) - 1):
                # remove legend for color with milestone, but it waste time for show and remove
                ax.legend().remove()

            # TODO: nx plot keep unchange in the color loop, but it should plot for every ax.
            milestone_color_dict = milestone_wrapper["milestone_color_dict"]
            nx.draw(
                G,
                milestone_emb_dict,
                with_labels=True,
                node_color=[milestone_color_dict[node] for node in G.nodes],
                width=5,
                edge_color="gray",
                arrowstyle="simple",
                arrowsize=30,
                ax=ax,
                **nx_draw_kwargs,
            )
            if divergence_regions.shape[0] > 0:
                plot_divergence_region(divergence_regions, milestone_emb_dict, ax=ax)  # divergence regoin

        del fadata.obsm[basis]

    if save is not None:
        if isinstance(save, bool) and save:
            save = f".cafe/{fadata.id}/img/graph_{basis}_{'_'.join(model_name_list)}.png"
        plt.savefig(save)
        logger.debug(f"save trajectory plot to '{save}'")
    return axes