Skip to content

cafe.data.WaypointWrapper

cafe.data.WaypointWrapper

Bases: FateWrapper

Wrapper for trajectory waypoint

Source code in cafe/data/fate_waypoint_wrapper.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
class WaypointWrapper(FateWrapper):
    """Wrapper for trajectory waypoint"""

    def __init__(
        self,
        milestone_wrapper: MilestoneWrapper,
        name: str = "WaypointWrapper",
        n_waypoints: int = 200,
        # edge length transform function
        transform: Callable[[float], float] = lambda x: x,
        resolution: float = None,
    ):
        """Initialize the WaypointWrapper class.

        Args:
            milestone_wrapper (MilestoneWrapper): MlestoneWrapper object for trajectory
            name (str, optional): name of the wrapper.
            n_waypoints (int, optional): num of waypoint.
            transform (_type_, optional): transform function for milestone network edge length.
            resolution (float, optional): resolution.
        """
        self.id = random_time_string(name)
        # need to be deleted after __init__ function
        self.milestone_wrapper = milestone_wrapper
        self._select_waypoints(n_waypoints, transform, resolution)
        # del self.milestone_wrapper  # delete the attribute to save memory

    def _select_waypoints(
        self,
        n_waypoints: int = 200,
        # edge length transform function
        transform: Callable[[float], float] = lambda x: x,
        resolution: float = None,
    ) -> None:
        """select waypoints base milestone network edge length and resolution parameter

        ref: pydynverse/wrap/wrap_add_waypoints.select_waypoints

        Args:
            n_waypoints (int, optional): num of waypoint.
            transform (_type_, optional): transform function for milestone network edge length.
            resolution (float, optional): resolution.
        """
        mr = self.milestone_wrapper

        if resolution is None:
            # compute resolution automaticall based on the sum of milestone network length after transformation
            resolution = mr.milestone_network["length"].apply(lambda x: transform(x)).sum() / n_waypoints

        # percentage list construction and explode
        def waypoint_id_from_progressions_row(row):
            # get waypoint_id by considering a row comprehensively
            match row["percentage"]:
                case 0:
                    return f"MILESTONE_BEGIN_W{row['from']}_{row['to']}"
                case 1:
                    return f"MILESTONE_END_W{row['from']}_{row['to']}"
                case _:
                    return f"W{row.name+1}"  # waypoint id start from 1

        waypoint_progressions = mr.milestone_network.copy()
        # waypoint_progressions = waypoint_progressions[~(waypoint_progressions["from"] == waypoint_progressions["to"])] # remove self-loop edge
        waypoint_progressions["percentage"] = waypoint_progressions["length"].apply(lambda x: [i / x for i in np.arange(0, x, resolution)] + [1])
        waypoint_progressions = waypoint_progressions[["from", "to", "percentage"]]
        waypoint_progressions = waypoint_progressions.explode("percentage").reset_index(drop=True)
        waypoint_progressions["percentage"] = waypoint_progressions["percentage"].astype("float")
        waypoint_progressions["waypoint_id"] = waypoint_progressions.apply(waypoint_id_from_progressions_row, axis=1)
        self.waypoint_progressions = waypoint_progressions

        self.id_list = waypoint_progressions["waypoint_id"].unique().tolist()

        # progressions -> percentages
        waypoint_progressions_tmp = waypoint_progressions.copy()
        waypoint_progressions_tmp = waypoint_progressions_tmp.rename(columns={"waypoint_id": "cell_id"})  # reuse pre column name
        # tmp "cell_id" column name for MilestoneWrapper.reuse convert_progressions_to_milestone_percentages
        waypoint_milestone_percentages = MilestoneWrapper.convert_progressions_to_milestone_percentages(
            milestone_network=mr.milestone_network, progressions=waypoint_progressions_tmp
        ).rename(columns={"cell_id": "waypoint_id"})
        self.waypoint_milestone_percentages = waypoint_milestone_percentages

        self.waypoint_geodesic_distances = self._calculate_geodesic_distances().loc[waypoint_progressions["waypoint_id"]]

        waypoint_network = (
            waypoint_progressions.sort_values(by=["from", "to", "percentage"])
            .groupby(["from", "to"])
            .apply(
                lambda group: group.assign(
                    from_waypoint=group["waypoint_id"],
                    to_waypoint=group["waypoint_id"].shift(-1),
                )
            )
            .dropna()
            .reset_index(drop=True)
        )  # Sort in ascending percentage within the group. "lead" function get the next row, get None if is the last row in group
        waypoint_network = waypoint_network[["from_waypoint", "to_waypoint", "from", "to"]]
        waypoint_network.columns = ["from", "to", "from_milestone_id", "to_milestone_id"]
        self.waypoint_network = waypoint_network

        waypoints = waypoint_milestone_percentages.iloc[waypoint_milestone_percentages.groupby("waypoint_id")["percentage"].idxmax()].reset_index(
            drop=True
        )
        waypoints["milestone_id"] = waypoints.apply(
            lambda x: x["milestone_id"] if x["percentage"] == 1 else None, axis=1
        )  # if waypoint is not on milestone, the milestone_id=None
        waypoints = waypoints[["waypoint_id", "milestone_id"]]
        self.waypoints = waypoints

    def _calculate_geodesic_distances(self) -> pd.DataFrame:
        """Calculate geodesic distances between cells and waypoints/milestones

        overall idea:
            1. calculate the full path of the target point within each divergent region separately
            2. merge and calculate the distance on the overall graph

        ref: pydynverse/wrap/calculate_geodesic_distances.py

        Returns:
            pd.DataFrame: distances dataframe
        """
        # attribute in the MilestoneWrapper
        # don't affect the original milestone_network
        cell_id_list = self.milestone_wrapper.cell_id_list
        milestone_id_list = self.milestone_wrapper.id_list
        milestone_network = self.milestone_wrapper.milestone_network.copy()
        milestone_percentages = self.milestone_wrapper.milestone_percentages.copy()
        divergence_regions = self.milestone_wrapper.divergence_regions.copy()
        directed = self.milestone_wrapper.directed

        waypoint_id_list = self.id_list
        waypoint_milestone_percentages = self.waypoint_milestone_percentages

        milestone_percentages = pd.concat([milestone_percentages, waypoint_milestone_percentages.rename(columns={"waypoint_id": "cell_id"})])

        # rename all milestone ids to MILESTONE_ID
        def milestone_trafo_fun(x):
            x = str(x)
            if x.startswith("MILESTONE_"):
                return x
            return f"MILESTONE_{x}"

        milestone_network = milestone_network.copy()
        milestone_network["from"] = milestone_network["from"].apply(milestone_trafo_fun)
        milestone_network["to"] = milestone_network["to"].apply(milestone_trafo_fun)
        milestone_id_list = list(map(milestone_trafo_fun, milestone_id_list))
        milestone_percentages["milestone_id"] = milestone_percentages["milestone_id"].apply(milestone_trafo_fun)
        divergence_regions["milestone_id"] = divergence_regions["milestone_id"].apply(milestone_trafo_fun)

        # add an extra divergence area, where normal edges are also treated as divergence areas
        extra_divergences = milestone_network.copy()
        # remove self-loop edge in _select_waypoints function
        # extra_divergences = extra_divergences[~(extra_divergences["from"] == extra_divergences["to"])] # remove self-loop edge
        # extra_divergences = extra_divergences.query("from != to") # query is more elegant, but from is a key for python
        # in_divergence determines whether the current edge is within the existing divergence region
        divergence_regions_set_list = divergence_regions.groupby("divergence_id")["milestone_id"].apply(set).tolist()

        def is_milestone_in_divergence(milestone_set, divergence_regions_set_list):
            for divergence_regions_set in divergence_regions_set_list:
                if milestone_set.issubset(divergence_regions_set):
                    return True
            return False

        extra_divergences["in_divergence"] = extra_divergences.apply(
            lambda x: is_milestone_in_divergence({x["from"], x["to"]}, divergence_regions_set_list), axis=1
        )
        extra_divergences = extra_divergences[~extra_divergences["in_divergence"]]  # only reserve the new divergence area
        extra_divergences["divergence_id"] = extra_divergences.apply(lambda x: f"{x['from']}__{x['to']}", axis=1)
        extra_divergences = pd.concat(
            [
                # add new columns: milestone_id, is_start
                extra_divergences.assign(milestone_id=extra_divergences["from"], is_start=True),
                extra_divergences.assign(milestone_id=extra_divergences["to"], is_start=False),
            ]
        )[["divergence_id", "milestone_id", "is_start"]]

        # merge divergence regions
        divergence_regions = pd.concat([divergence_regions, extra_divergences]).reset_index(drop=True)
        divergence_ids = divergence_regions["divergence_id"].unique()

        # NetworkX for related data from edge DataFrame
        milestone_graph = nx.from_pandas_edgelist(milestone_network, source="from", target="to", edge_attr="length")
        divergence_regions["is_start"] = divergence_regions["is_start"].astype(bool)  # ensure "is_start" column is bool

        # 1. calculate inner-divergence distances separately
        # calculate the distance between cells within the divergent
        def calc_divergence_inner_distance_df(did):
            dir = divergence_regions[divergence_regions["divergence_id"] == did]
            # starting point of the region is milestone_id
            mid = dir[dir["is_start"]]["milestone_id"].tolist()
            # milestone_id of all milestones in the divergence
            tent = dir["milestone_id"].tolist()
            tent_distances = pd.DataFrame(
                index=mid, columns=tent, data=np.zeros((len(mid), len(tent)))
            )  # The distance from the starting point within the region to all milestones
            # extract corresponding edges from the graph
            for i in mid:
                for j in tent:
                    if i == j:
                        tent_distances.loc[i, j] = 0
                    else:
                        tent_distances.loc[i, j] = milestone_graph.edges[(i, j)]["length"]
            # find cell_id of relevant points by reusing is_milestone_in_divergence
            relevant_pct_cell_id_list = milestone_percentages.groupby("cell_id")["milestone_id"].apply(
                lambda x: is_milestone_in_divergence(set(x), [set(tent)])
            )
            relevant_pct_cell_id_list = relevant_pct_cell_id_list[relevant_pct_cell_id_list].index.to_list()
            relevant_pct = milestone_percentages[milestone_percentages["cell_id"].apply(lambda x: x in relevant_pct_cell_id_list)]
            if relevant_pct.shape[0] <= 1:
                return None

            scaled_dists = relevant_pct.copy()
            # scaled_dists["dist"] = scaled_dists.apply(lambda x: x["percentage"] * tent_distances.loc[mid, x["milestone_id"]], axis=1)
            scaled_dists["dist"] = scaled_dists.apply(
                lambda x: x["percentage"] * tent_distances.loc[mid, x["milestone_id"]].values.flatten()[0], axis=1
            )  # fix for self-loop waypoint

            tent_distances_long = tent_distances.melt(var_name="from", value_name="length")  # wide data to long data
            tent_distances_long["to"] = tent_distances_long["from"]

            pct_mat = (
                pd.concat(
                    [
                        scaled_dists[["cell_id", "milestone_id", "dist"]].rename(columns={"cell_id": "from", "milestone_id": "to", "dist": "length"}),
                        tent_distances_long,
                    ]
                )
                .drop_duplicates()
                .pivot(index="from", columns="to", values="length")
                .fillna(0)
            )  # (n_cell+n_milestone+n_waypoint)*n_milestone, long data to wide data, "from" is index

            wp_cells = list(set(pct_mat.index) & set(waypoint_id_list))

            if directed:
                # TODO: directed graph
                pass

            distances = pairwise_distances(pct_mat, pct_mat.loc[wp_cells + tent], metric="manhattan")
            distances = pd.DataFrame(index=pct_mat.index, columns=wp_cells + tent, data=distances)
            distances = distances.reset_index().melt(id_vars="from", var_name="to", value_name="length")  # wide data to long data
            distances = distances[~(distances["from"] == distances["to"])]
            return distances

        cell_in_tent_distances = pd.concat([calc_divergence_inner_distance_df(did) for did in divergence_ids])

        if directed:
            # TODO: directed graph
            pass

        # NOTE: 2. merge calculation(use igraph to accelerate compared to networkx)
        # select the shortest distance mode for subsequent directed graphs
        if directed or directed == "forward":
            mode = "out"
        elif directed == "reverse":
            mode = "in"
        else:
            mode = "all"
        # extract the shortest edge after merging, currently has little effect, may be useful for the ring graph
        edgelist_df = pd.concat([milestone_network, cell_in_tent_distances]).groupby(["from", "to"]).agg({"length": "min"}).reset_index()
        # merge two graph to one graph
        gr = ig.Graph.TupleList(edgelist_df.values, edge_attrs=["length"])
        # TODO: unconnected graph may corrupt
        # print("=========================")
        # print("edgelist_df:\n", edgelist_df)
        # edgelist_df.to_csv("test.csv")
        # # print("gr:\n", gr)
        # print("waypoint_id_list:\n", waypoint_id_list)
        # print("cell_id_list:\n", cell_id_list)

        # for isolated milestones and cells which not shown in did, set shortest path = None
        out = pd.DataFrame(index=waypoint_id_list, columns=cell_id_list)
        valid_nodes = set(edgelist_df["from"].unique()) | set(edgelist_df["from"].unique())
        valid_waypoint_id_list = list(set(waypoint_id_list) & valid_nodes)
        valid_cell_id_list = list(set(cell_id_list) & valid_nodes)

        shortest_paths = gr.shortest_paths(source=valid_waypoint_id_list, target=valid_cell_id_list, weights="length", mode=mode)
        out.loc[valid_waypoint_id_list, valid_cell_id_list] = shortest_paths

        # # TODO: filter cells
        # cell_ids_filtered_list = []
        # if len(cell_ids_filtered_list) > 0:
        #     pass

        return out.loc[waypoint_id_list, cell_id_list].astype(float)

__init__(milestone_wrapper, name='WaypointWrapper', n_waypoints=200, transform=lambda x: x, resolution=None)

Initialize the WaypointWrapper class.

Parameters:

Name Type Description Default
milestone_wrapper MilestoneWrapper

MlestoneWrapper object for trajectory

required
name str

name of the wrapper.

'WaypointWrapper'
n_waypoints int

num of waypoint.

200
transform _type_

transform function for milestone network edge length.

lambda x: x
resolution float

resolution.

None
Source code in cafe/data/fate_waypoint_wrapper.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    milestone_wrapper: MilestoneWrapper,
    name: str = "WaypointWrapper",
    n_waypoints: int = 200,
    # edge length transform function
    transform: Callable[[float], float] = lambda x: x,
    resolution: float = None,
):
    """Initialize the WaypointWrapper class.

    Args:
        milestone_wrapper (MilestoneWrapper): MlestoneWrapper object for trajectory
        name (str, optional): name of the wrapper.
        n_waypoints (int, optional): num of waypoint.
        transform (_type_, optional): transform function for milestone network edge length.
        resolution (float, optional): resolution.
    """
    self.id = random_time_string(name)
    # need to be deleted after __init__ function
    self.milestone_wrapper = milestone_wrapper
    self._select_waypoints(n_waypoints, transform, resolution)