18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171 | def plot_graph(
fadata: FateAnnData,
model_name: str | Sequence[str] = None,
color: str | Sequence[str] = None,
layout_by_row: str = "color",
nx_draw_kwargs: dict = {},
recompute_milestone_embedding: bool = True,
save: bool | str = None,
**sc_pl_embedding_kwargs,
):
"""Plot DAG base on milestone network amd show cell embedding
Args:
fadata (FateAnnData): FateAnnData object with trajectory.
model_name (str | Sequence[str], optional): model name(s).
color (str | Sequence[str], optional): Color(s), default extracted from prior information.
layout_by_row (str, optional): layout by row.
nx_draw_kwargs (dict, optional): additional keyword arguments for networkx draw.
sc_pl_embedding_kwargs (dict, optional): additional keyword arguments for scanpy embedding plot.
recompute_milestone_embedding (bool, optional): whether to recompute milestone embedding.
save (bool | str, optional): path to save the plot.
sc_pl_embedding_kwargs (dict, optional): additional keyword arguments for scanpy embedding plot.
Returns:
axes: axes
"""
if model_name is None:
model_name = fadata.model_name
if color is None:
color = fadata.prior_information.get("cluster")
model_name_list = [model_name] if isinstance(model_name, str) else model_name
color_list = [color] if isinstance(color, str) else color
if len(model_name_list) == 1:
layout_by_row = "model" # only one model as row
if len(color_list) == 1:
layout_by_row = "color" # only one color as row
# create subplots
if layout_by_row == "model":
row_list, col_list = model_name_list, color_list
elif layout_by_row == "color":
row_list, col_list = color_list, model_name_list
n_rows = len(row_list)
n_cols = len(col_list)
figsize = sc_pl_embedding_kwargs.pop("figsize", (7 * n_cols, 5 * n_rows)) # replace sc plt figsize
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)
# multiple model and color support
for i, model_name in enumerate(model_name_list):
milestone_wrapper = fadata.get_milestone_wrapper(model_name=model_name) # extract milestone network
milestone_id_list = milestone_wrapper.id_list
milestone_network = milestone_wrapper.milestone_network
milestone_percentages = milestone_wrapper.milestone_percentages
divergence_regions = milestone_wrapper.divergence_regions
is_directed = milestone_wrapper["directed"]
milestone_embedding = None
if recompute_milestone_embedding or milestone_embedding is None:
logger.debug(f"calculate new milestone embedding for model_name:{model_name}.")
G = nx.from_pandas_edgelist(
milestone_network,
source="from",
target="to",
edge_attr=True,
create_using=nx.DiGraph if is_directed else nx.Graph,
)
for descrete_node in set(milestone_id_list) - set(G.nodes):
# descrete node need external addition
G.add_node(descrete_node)
milestone_emb_dict = nx.nx_agraph.graphviz_layout(G, prog="dot") # position
# position fo cell
milestone_emb_df = pd.DataFrame(milestone_emb_dict).T
def mix_emb(mpg, emb_df=milestone_emb_df):
# mix related milestone emb to get position for a cell
mpg_emb = emb_df.loc[mpg["milestone_id"]]
return mpg_emb.apply(lambda emb_dim: (emb_dim.array * mpg["percentage"].array)).sum()
basis = "_milestone_network_emb"
cell_emb_df = milestone_percentages.groupby("cell_id").apply(lambda mpg: mix_emb(mpg))
else:
# TODO: save in fadata
# milestone_embedding = fadata.get_milestone_embedding(model_name=model_name) # # TODO: save in fadata
pass
fadata_index_set = set(fadata.obs.index)
emb_index_set = set(cell_emb_df.index)
if fadata.shape[0] == cell_emb_df.shape[0] and fadata_index_set == emb_index_set:
fadata.obsm[basis] = cell_emb_df.loc[fadata.obs.index].values
else:
# may lose some cells in cell_emb_df
valid_cell_set = fadata_index_set & emb_index_set
missing_cell_set = fadata_index_set - emb_index_set
new_cell_emb_df = pd.DataFrame(index=fadata.obs.index, columns=cell_emb_df.columns)
new_cell_emb_df.loc[valid_cell_set] = cell_emb_df.loc[valid_cell_set]
new_cell_emb_df.loc[missing_cell_set] = 0.0 # set missing cell emb to zero
fadata.obsm[basis] = new_cell_emb_df.values
logger.warning(f"cell ids are mismatch between fadata.index and cell_emb_df '{model_name}', missing cells: {missing_cell_set}.")
for j, color in enumerate(color_list):
if layout_by_row == "model":
ax = axes[i, j] # row is model_name, col is color
else:
ax = axes[j, i] # row is color, col is model_name
if color == "milestone":
# color of cells
cell_color_key = "milestone"
missing_cell_color = "#808080"
cell_color_dict = milestone_wrapper["cell_color_dict"]
if len(cell_color_dict) != fadata.n_obs:
logger.warning(f"milestone cell color length not equal to cell number! set missing color as '{missing_cell_color}'.")
fadata.obs[cell_color_key] = pd.Categorical(fadata.obs.index, categories=fadata.obs.index.tolist())
fadata.uns[f"{cell_color_key}_colors"] = [
cell_color_dict[i] if i in cell_color_dict else missing_cell_color for i in fadata.obs.index
]
# base scanpy embedding scatter plot
# plot single str for color parameter
# zorder: 1: line, 2: cell(scanpy), 3: milestone
sc_pl_embedding_kwargs["title"] = f"{fadata.get_parsed_model_name(model_name)}({color})" # add title for subplot
sc.pl.embedding(fadata, basis=basis, color=color, show=False, zorder=2, ax=ax, **sc_pl_embedding_kwargs)
# legend remove
if color == "milestone" or (layout_by_row == "color" and i < len(model_name) - 1):
# remove legend for color with milestone, but it waste time for show and remove
ax.legend().remove()
# TODO: nx plot keep unchange in the color loop, but it should plot for every ax.
milestone_color_dict = milestone_wrapper["milestone_color_dict"]
nx.draw(
G,
milestone_emb_dict,
with_labels=True,
node_color=[milestone_color_dict[node] for node in G.nodes],
width=5,
edge_color="gray",
arrowstyle="simple",
arrowsize=30,
ax=ax,
**nx_draw_kwargs,
)
if divergence_regions.shape[0] > 0:
plot_divergence_region(divergence_regions, milestone_emb_dict, ax=ax) # divergence regoin
del fadata.obsm[basis]
if save is not None:
if isinstance(save, bool) and save:
save = f".cafe/{fadata.id}/img/graph_{basis}_{'_'.join(model_name_list)}.png"
plt.savefig(save)
logger.debug(f"save trajectory plot to '{save}'")
return axes
|