Skip to content

Method List

cafe.method.function.cf_comp1.comp1(adata, repreprocess=True, basis='X_pca', recompute_basis=False, component=1)

Comp1: baseline for linear wrapper, extract an embedded component pseudotime method

Parameters:

Name Type Description Default
adata AnnData

The input AnnData object.

required
repreprocess bool

Whether to preprocess the data.

True
basis str

The embedding name in .obsm.

'X_pca'
recompute_basis bool

Whether to recompute the embedding.

False
component int

The component number.

1

Returns:

Name Type Description
dict dict

A trajectory dict of linear wrapper.

Source code in cafe/method/function/cf_comp1.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@method_info(
    name="comp1",
    version="0.0.1",
    description="Comp1: baseline for linear wrapper, extract an embedded component pseudotime method",
    wrapper_type="linear",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def comp1(
    adata: ad.AnnData,
    repreprocess: bool = True,
    basis: str = "X_pca",
    recompute_basis: bool = False,
    component: int = 1,
) -> dict:
    """Comp1: baseline for linear wrapper, extract an embedded component pseudotime method

    Args:
        adata (ad.AnnData): The input AnnData object.
        repreprocess (bool, optional): Whether to preprocess the data.
        basis (str, optional): The embedding name in .obsm.
        recompute_basis (bool, optional): Whether to recompute the embedding.
        component (int, optional): The component number.

    Returns:
        dict: A trajectory dict of linear wrapper.
    """
    # 1. preprocess
    embedding_method = basis[2:].lower() if basis.startswith("X_") else basis.lower()
    if repreprocess and recompute_basis:
        # stop at sc.pp.pca, or sc.pp.neighbors if other embedding method
        preprocess_pipeline(adata, style="scanpy", if_neighbors=False if embedding_method == "pca" else True)

    # 2. execute method
    if recompute_basis or (basis not in adata.obsm):
        # execute dimension reduction if basis not in adata.obsm
        available_embedding_methods = ["pca", "tsne", "umap"]  # TODO: phate, diffmap ...
        # recompute the embedding
        if embedding_method in available_embedding_methods:
            if embedding_method == "pca":
                pass  # already computed in preprocess_pipeline
            elif embedding_method == "tsne":
                sc.tl.tsne(adata)
            elif embedding_method == "umap":
                sc.tl.umap(adata)
        else:
            # default use pca
            print(f"embedding method '{embedding_method}' is not available, use 'PCA' instead")
            basis = "pca"

    # 3. extract results
    pseudotime = adata.obsm[basis][:, component - 1].tolist()

    # 4. save results
    trajectory_dict = {
        "wrapper_type": "linear",
        "pseudotime": pseudotime,
    }

    return trajectory_dict

cafe.method.function.cf_state_comp.state_comp(adata, repreprocess=True, n_comps=2, basis='X_pca', recompute_basis=False, pseudotime_index=1, wrapper_type='probability', cluster_key='clusters')

State_Comp: baseline for probability and lineage wrapper, state transition probability based on embedded components

Parameters:

Name Type Description Default
adata AnnData

The input AnnData object.

required
repreprocess (bool, optional): Whether to preprocess the data.
n_comps (int, optional): The number of components.
basis (str, optional): The embedding name in .obsm.
recompute_basis (bool, optional): Whether to recompute the embedding.
pseudotime_index (int, optional): The index of the component to use for pseudotime.
wrapper_type (str, optional): The type of wrapper to use.

Returns:

Name Type Description
dict

A trajectory dict of probability or lineage wrapper.

Source code in cafe/method/function/cf_state_comp.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@method_info(
    name="state_comp",
    version="0.0.1",
    description="State_Comp: baseline for probability and lineage wrapper, state transition probability based on embedded components",
    wrapper_type="probability",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def state_comp(
    adata: ad.AnnData,
    repreprocess: bool = True,
    n_comps: int = 2,
    basis: str = "X_pca",
    recompute_basis: bool = False,
    pseudotime_index: int = 1,
    wrapper_type: str = "probability",  # "probability" or "lineage"
    cluster_key: str = "clusters",
):
    """State_Comp: baseline for probability and lineage wrapper, state transition probability based on embedded components

    Args:
         adata (ad.AnnData): The input AnnData object.
        repreprocess (bool, optional): Whether to preprocess the data.
        n_comps (int, optional): The number of components.
        basis (str, optional): The embedding name in .obsm.
        recompute_basis (bool, optional): Whether to recompute the embedding.
        pseudotime_index (int, optional): The index of the component to use for pseudotime.
        wrapper_type (str, optional): The type of wrapper to use.

    Returns:
        dict: A trajectory dict of probability or lineage wrapper.

    """
    # 1. preprocess
    embedding_method = basis[2:].lower() if basis.startswith("X_") else basis.lower()
    if repreprocess and recompute_basis:
        preprocess_pipeline(adata, style="scanpy", if_neighbors=False if basis == "X_pca" else True)  # stop as sc.pp.pca
    cell_ids = adata.obs.index

    # 2. execute method
    # like comp1 method, but extract multiple components as multiple end states
    if (basis not in adata.obsm) or recompute_basis:
        # execute dimension reduction if basis not in adata.obsm
        available_embedding_methods = ["pca", "tsne", "umap"]  # TODO: phate, diffmap ...
        # recompute the embedding
        if embedding_method in available_embedding_methods:
            if embedding_method == "pca":
                pass  # already computed in preprocess_pipeline
            elif embedding_method == "tsne":
                sc.tl.tsne(adata)
            elif embedding_method == "umap":
                sc.tl.umap(adata)
        else:
            # default use pca
            print(f"embedding method '{embedding_method}' is not available, use 'PCA' instead")
            basis = "pca"
    # extract embedding results as state transition probabilities
    X_emb = adata.obsm[basis][:, :n_comps]
    X_emb_scaled = MinMaxScaler().fit_transform(X_emb)  # Normalization
    comp_column_list = [f"comp_{i}" for i in range(1, n_comps + 1)]  # the first ndim components correspond to n states
    # The normalized PCA result is used as the state transition probability, range of [0,1]
    end_state_probabilities = pd.DataFrame(
        columns=comp_column_list,
        data=normalize(X_emb_scaled, norm="l1"),  # l1 transform
        index=cell_ids,
    )
    end_state_probabilities["cell_id"] = cell_ids
    end_state_probabilities = end_state_probabilities[["cell_id"] + comp_column_list]

    # 3,4. extract and save results
    if wrapper_type == "lineage":
        # for lineage wrapper
        trajectory_dict = {
            "probability": end_state_probabilities[end_state_probabilities.columns[1:]],
            "cluster_key": cluster_key,
        }
    else:
        # for probability wrapper
        pseudotime = X_emb_scaled[:, pseudotime_index]  # specified component for pseudotime
        trajectory_dict = {
            "end_state_probabilities": end_state_probabilities,
            "pseudotime": pseudotime,
        }

    trajectory_dict["wrapper_type"] = wrapper_type
    return trajectory_dict

cafe.method.function.cf_cluster_mst.cluster_mst(adata, repreprocess=True, basis='X_pca', recluster=False, cluster='clusters', distance_metric='euclidean')

Cluster MST: baseline for cluster wrapper, creating a Minimum Spanning Tree (MST) on cluster centers.

This method first clusters the cells (or uses existing clusters), calculates the center of each cluster in a given embedding, and then constructs a Minimum Spanning Tree (MST) connecting these centers to represent the trajectory backbone.

Parameters:

Name Type Description Default
adata AnnData

The input AnnData object.

required
repreprocess bool

Whether to run the preprocessing pipeline.

True
basis str

The embedding in .obsm to use for calculating cluster centers.

'X_pca'
recluster bool

If True, re-computes cell clusters using the Leiden algorithm.

False
cluster str

The key in adata.obs where cluster information is stored/saved.

'clusters'
distance_metric str

The distance metric to use for calculating distances between cluster centers.

'euclidean'

Returns:

Name Type Description
dict

A trajectory dict compatible with the 'lineage' wrapper, containing the milestone network and cluster assignments.

Source code in cafe/method/function/cf_cluster_mst.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@method_info(
    name="cluster_mst",
    version="0.0.1",
    description="Cluster MST: baseline for cluster wrapper, creating a Minimum Spanning Tree (MST) on cluster centers.",
    wrapper_type="cluster",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def cluster_mst(
    adata: ad.AnnData,
    repreprocess: bool = True,
    basis: str = "X_pca",
    recluster: bool = False,
    cluster: str = "clusters",
    distance_metric: Optional[Literal["euclidean", "cosine", "manhattan", "cityblock", "l1", "l2"]] = "euclidean",
):
    """Cluster MST: baseline for cluster wrapper, creating a Minimum Spanning Tree (MST) on cluster centers.

    This method first clusters the cells (or uses existing clusters), calculates the
    center of each cluster in a given embedding, and then constructs a Minimum
    Spanning Tree (MST) connecting these centers to represent the trajectory backbone.

    Args:
        adata (ad.AnnData): The input AnnData object.
        repreprocess (bool, optional): Whether to run the preprocessing pipeline.
        basis (str, optional): The embedding in `.obsm` to use for calculating cluster centers.
        recluster (bool, optional): If True, re-computes cell clusters using the Leiden algorithm.
        cluster (str, optional): The key in `adata.obs` where cluster information is stored/saved.
        distance_metric (str, optional): The distance metric to use for calculating distances between cluster centers.

    Returns:
        dict: A trajectory dict compatible with the 'lineage' wrapper, containing the milestone network and cluster assignments.
    """
    # 1. preprocess
    if repreprocess:
        preprocess_pipeline(adata, style="scanpy", if_neighbors=True)  # ensure neighbors are computed
    adata.obs.reset_index(drop=True, inplace=True)  # for X_emb index consistency
    X_emb = adata.obsm[basis]

    # 2. execute method
    # (1) Cluster cells, with the center point as a milestone
    if recluster:
        # new cluster
        sc.pp.neighbors(adata)
        sc.tl.leiden(adata)
        cluster = "leiden"
    # (2) Calculate the low dimensional coordinates of the clustering centers
    centers = adata.obs.groupby(cluster).apply(lambda x: X_emb[list(x.index)].mean(axis=0))
    centers = pd.DataFrame(centers.tolist(), index=centers.index)
    milestone_ids = centers.index.tolist()
    cluster_milestones = adata.obs[cluster]
    # (3) Calculate the distance between cluster centers
    dis = pd.DataFrame(pairwise_distances(centers, metric=distance_metric), index=milestone_ids, columns=milestone_ids)
    dis_df = pd.DataFrame(data=dis.unstack().reset_index().values, columns=["from", "to", "weight"])  # width data to long data
    # (4) Calculate the distance between milestones and construct the minimum spanning tree as the milestone network
    G = nx.from_pandas_edgelist(dis_df, source="from", target="to", edge_attr="weight")
    mst = nx.minimum_spanning_tree(G, weight="weight")

    # 3. extract results
    milestone_network = nx.to_pandas_edgelist(mst)
    milestone_network.rename(columns={"source": "from", "target": "to", "weight": "length"}, inplace=True)
    milestone_network["directed"] = False

    # 4. save results
    trajectory_dict = {
        "wrapper_type": "cluster",
        "milestone_network": milestone_network,
        "cluster": cluster_milestones,
    }

    return trajectory_dict

cafe.method.function.cf_projection_mst.projection_mst(adata, repreprocess=True, basis='X_pca', recluster=True, cluster='clusters', distance_metric='euclidean')

Projection MST: projects cells onto a Minimum Spanning Tree constructed from cluster centers.

Parameters:

Name Type Description Default
adata AnnData

The input AnnData object.

required
repreprocess bool

Whether to run the preprocessing pipeline.

True
basis str

description. Defaults to "X_pca".

'X_pca'
recluster bool

If True, re-computes cell clusters using the Leiden algorithm.

True
cluster str

The key in adata.obs where cluster information is stored/saved.

'clusters'
distance_metric str

The distance metric to use for calculating distances between cluster centers.

'euclidean'

Returns: dict: A trajectory dict compatible with the 'projection' wrapper, containing the milestone network, embeddings, and cluster assignments.

Source code in cafe/method/function/cf_projection_mst.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@method_info(
    name="projection_mst",
    version="0.0.1",
    description="Projection MST: projects cells onto a Minimum Spanning Tree constructed from cluster centers.",
    wrapper_type="projection",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def projection_mst(
    adata: ad.AnnData,
    repreprocess: bool = True,
    basis: str = "X_pca",
    recluster: bool = True,
    cluster: str = "clusters",
    distance_metric: Optional[Literal["euclidean", "cosine", "manhattan", "cityblock", "l1", "l2"]] = "euclidean",
):
    """Projection MST: projects cells onto a Minimum Spanning Tree constructed from cluster centers.

    Args:
        adata (ad.AnnData): The input AnnData object.
        repreprocess (bool, optional): Whether to run the preprocessing pipeline.
        basis (str, optional): _description_. Defaults to "X_pca".
        recluster (bool, optional): If True, re-computes cell clusters using the Leiden algorithm.
        cluster (str, optional): The key in `adata.obs` where cluster information is stored/saved.
        distance_metric (str, optional): The distance metric to use for calculating distances between cluster centers.
    Returns:
        dict: A trajectory dict compatible with the 'projection' wrapper, containing the milestone network, embeddings, and cluster assignments.
    """
    # 1. preprocess
    if repreprocess:
        preprocess_pipeline(
            adata, style="scanpy", if_neighbors=True
        )  # ensure neighbors are computed    adata.obs.reset_index(drop=True, inplace=True)
    adata.obs.reset_index(drop=True, inplace=True)  # for X_emb index consistency
    X_emb = adata.obsm[basis]

    # 2. execute method
    # (1) if recluster cells, with the center point as a milestone
    if recluster:
        # new cluster
        sc.pp.neighbors(adata)
        sc.tl.leiden(adata)
        cluster = "leiden"
    # (2) Calculate the low dimensional coordinates of the clustering centers
    # centers = np.array(list(adata.obs.groupby(cluster_key).apply(lambda x: X_emb[list(x.index)].mean(axis=0))))
    # milestone_ids = [f"M{i}" for i in range(centers.shape[0])]
    # centers = pd.DataFrame(centers, index=milestone_ids)
    centers = adata.obs.groupby(cluster).apply(lambda x: X_emb[list(x.index)].mean(axis=0))
    centers = pd.DataFrame(centers.tolist(), index=centers.index)
    milestone_ids = centers.index.tolist()
    cluster_milestones = adata.obs[cluster]
    # (3) Calculate the distance between cluster centers
    dis = pd.DataFrame(pairwise_distances(centers, metric=distance_metric), index=milestone_ids, columns=milestone_ids)
    disdf = pd.DataFrame(data=dis.unstack().reset_index().values, columns=["from", "to", "weight"])  # width data to long data
    # (4) Calculate the distance between milestones and construct the minimum spanning tree as the milestone network
    G = nx.from_pandas_edgelist(disdf, source="from", target="to", edge_attr="weight")
    mst = nx.minimum_spanning_tree(G, weight="weight")

    # 3. extract results
    milestone_network = nx.to_pandas_edgelist(mst)
    milestone_network.rename(columns={"source": "from", "target": "to", "weight": "length"}, inplace=True)
    milestone_network["directed"] = False
    comp_ids = [f"comp_{i+1}" for i in range(centers.shape[1])]
    X_emb = pd.DataFrame(X_emb, index=adata.obs.index, columns=comp_ids)
    milestone_emb = centers
    milestone_emb.columns = comp_ids

    # 4. save results
    trajectory_dict = {
        "wrapper_type": "projection",
        "milestone_network": milestone_network,
        "X_emb": X_emb,
        "milestone_emb": milestone_emb,
        "cluster": cluster_milestones,  # projection cell to releate cluster edge
    }
    return trajectory_dict

cafe.method.function.cf_graph_mst.graph_mst(adata, repreprocess=True)

Source code in cafe/method/function/cf_graph_mst.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@method_info(
    name="graph_mst",
    version="0.0.1",
    description="Graph MST: baseline for graph wrapper, creating a Minimum Spanning Tree (MST) on cluster centers.",
    wrapper_type="graph",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def graph_mst(
    adata: ad.AnnData,
    repreprocess: bool = True,
):
    # 1. preprocess
    if repreprocess:
        preprocess_pipeline(adata, style="scanpy", if_neighbors=True)  # ensure neighbors are computed

    # 2.execute method
    cell_id_list = adata.obs.index.tolist()
    G = nx.from_scipy_sparse_array(adata.obsp["distances"])  # construct graph from a sparse matrix
    cell_mst = nx.minimum_spanning_tree(G, weight="weight")  # construct the minimum spanning

    # 3. extract results
    cell_graph = nx.to_pandas_edgelist(cell_mst, source="from", target="to").rename(columns={"weight": "length"})
    cell_graph["from"] = cell_graph["from"].apply(lambda x: cell_id_list[x])
    cell_graph["to"] = cell_graph["to"].apply(lambda x: cell_id_list[x])
    # to_keep = pd.Series(data=True, index=cell_ids)

    # 4. save results
    trajectory_dict = {
        "wrapper_type": "graph",
        "cell_graph": cell_graph,
        "to_keep": None,  # keep all
    }
    return trajectory_dict

cafe.method.function.cf_paga.paga(adata, start_cell, repreprocess=True, repreprocess_kwargs={}, cluster='clusters', n_dcs=15, connectivity_cutoff=0.5)

PAGA: partition-based graph abstraction.

Parameters:

Name Type Description Default
adata AnnData

AnnData object

required
start_cell str

Starting cell ID for pseudotime calculation.

required
repreprocess bool

whether to repreprocess the anndata object.

True
repreprocess_kwargs dict

Parameters for repreprocess pipeline.

{}
cluster str

Cluster column name in adata.obs.

'clusters'
n_dcs int

Number of diffusion components.

15
connectivity_cutoff float

Cutoff for the connectivity matrix.

0.5

Returns:

Name Type Description
dict

Trajectory results including branch network, branches, and progressions.

Source code in cafe/method/function/cf_paga.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@method_info(
    name="paga",
    version="0.0.1",
    description="PAGA: partition-based graph abstraction",
    wrapper_type="branch",
    doi="10.1186/s13059-019-1663-x",
    github_url="https://github.com/theislab/paga",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def paga(
    adata: ad.AnnData,
    start_cell: str,
    repreprocess: bool = True,
    repreprocess_kwargs: dict = {},
    cluster: str = "clusters",
    n_dcs: int = 15,
    connectivity_cutoff=0.5,
):
    """PAGA: partition-based graph abstraction.

    Args:
        adata (ad.AnnData): AnnData object
        start_cell (str): Starting cell ID for pseudotime calculation.
        repreprocess (bool, optional): whether to repreprocess the anndata object.
        repreprocess_kwargs (dict, optional):  Parameters for repreprocess pipeline.
        cluster (str, optional): Cluster column name in adata.obs.
        n_dcs (int, optional): Number of diffusion components.
        connectivity_cutoff (float, optional): Cutoff for the connectivity matrix.

    Returns:
        dict: Trajectory results including branch network, branches, and progressions.
    """
    # 1. preprocess
    if repreprocess:
        preprocess_pipeline(adata, **repreprocess_kwargs)
    sc.tl.diffmap(adata)

    # 2. execute method
    sc.tl.paga(adata, groups=cluster)
    # set start porint for dpt
    adata.uns["iroot"] = np.where(adata.obs.index == start_cell)[0][0]
    sc.tl.dpt(adata, n_dcs=n_dcs)

    # 3. extract results
    # (1) parameters for results extracting
    epsilon = 1e-3  # a very small scaling values
    branch_ids = adata.obs[cluster].unique().to_list()
    # (2) branches
    branches = pd.DataFrame(
        {
            "branch_id": branch_ids,
            "directed": True,
        }
    )
    branches["length"] = (
        adata.obs[[cluster, "dpt_pseudotime"]]
        .groupby(cluster)
        .apply(lambda x: x["dpt_pseudotime"].max() - x["dpt_pseudotime"].min() + epsilon)
        .reset_index()[0]
    )
    # (3) branch_network
    branch_network = (
        pd.DataFrame(
            np.triu(adata.uns["paga"]["connectivities"].todense(), k=0),  # keep the upper triangular matrix
            index=adata.obs[cluster].cat.categories,
            columns=adata.obs[cluster].cat.categories,
        )
        .stack()
        .reset_index()
    )
    branch_network.columns = ["from", "to", "length"]
    branch_network = branch_network[branch_network["length"] >= connectivity_cutoff]  # set threshold to filter insignificant edges
    average_pseudotime_dict = adata.obs.groupby(cluster)["dpt_pseudotime"].mean()

    def modify_milestone_network_direction(x):
        if average_pseudotime_dict[x["from"]] <= average_pseudotime_dict[x["to"]]:
            return x
        else:
            x["from"], x["to"] = x["to"], x["from"]
            return x

    branch_network.apply(modify_milestone_network_direction, axis=1)  # Adjust the direction of the edge
    # sort edges by "from" and "to" columns to facilitate subsequent milestone numbering
    branch_network["from_pseudotime"] = branch_network["from"].apply(lambda x: average_pseudotime_dict[x])
    branch_network["to_pseudotime"] = branch_network["to"].apply(lambda x: average_pseudotime_dict[x])
    branch_network = branch_network.sort_values(["from_pseudotime", "to_pseudotime"])
    branch_network = branch_network[["from", "to"]].reset_index(drop=True)
    # (4) branch_progressions
    branch_progressions = pd.DataFrame({"cell_id": adata.obs.index, "branch_id": adata.obs[cluster], "percentage": adata.obs["dpt_pseudotime"]})
    # sort cells by pseudo time within the branch
    branch_progressions["percentage"] = (
        branch_progressions.groupby("branch_id")["percentage"].apply(lambda x: (x - x.min()) / (x.max() - x.min() + epsilon)).values
    )
    branch_progressions

    # 4. save results
    trajectory_dict = {
        "wrapper_type": "branch",
        "branch_network": branch_network,
        "branches": branches,
        "branch_progressions": branch_progressions,
    }
    return trajectory_dict

cafe.method.function.cf_cytotrace2.cytotrace2(adata, repreprocess=True, cluster=None, cytotrace2_kwargs={})

Cytotrace2: cellular potency categories and absolute developmental potential.

Parameters:

Name Type Description Default
adata AnnData

The input AnnData object.

required
repreprocess bool

Whether to preprocess the data.

True
cluster str

Cluster column in '.obs' columns.

None
cytotrace2_kwargs dict

cytotraces2 core parameter dict, refer to github source code.

{}

Returns:

Name Type Description
dict dict

A trajectory dict with keys: "wrapper_type" and "pseudotime".

Source code in cafe/method/function/cf_cytotrace2.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@method_info(
    name="cytotrace2",
    version="0.0.1",
    description="Cytotrace2: cellular potency categories and absolute developmental potential",
    wrapper_type="linear",
    doi="10.1101/2024.03.19.585637",
    github_url="https://github.com/digitalcytometry/cytotrace2",
    use_gpu=True,
    cpu_parallelization=True,
    available=True,
)
def cytotrace2(adata: ad.AnnData, repreprocess: bool = True, cluster: str = None, cytotrace2_kwargs: dict = {}) -> dict:
    """Cytotrace2: cellular potency categories and absolute developmental potential.

    Args:
        adata (ad.AnnData):  The input AnnData object.
        repreprocess (bool, optional): Whether to preprocess the data.
        cluster (str, optional): Cluster column in '.obs' columns.
        cytotrace2_kwargs (dict, optional): cytotraces2 core parameter dict, refer to
            [github source code](https://github.com/digitalcytometry/cytotrace2/blob/main/cytotrace2_python/cytotrace2_py/cytotrace2_py.py).

    Returns:
        dict: A trajectory dict with keys: "wrapper_type" and "pseudotime".
    """
    from cytotrace2_py.cytotrace2_py import cytotrace2

    with tempfile.TemporaryDirectory() as tmp_wd:
        # 1. preprocess:
        if repreprocess:
            # cytotrace2 don't recommend use log-transformed expression matrix and HVGs, only normalized here.
            sc.pp.normalize_per_cell(adata)
        else:
            if not (np.issubdtype(adata.X.dtype, np.integer) or np.isclose(adata.X.data, np.round(adata.X.data)).all()):
                print("warnning: raw expression matrix is transformed, count matrix is not available,")
            else:
                print("use count matrix")
        X = adata.X.toarray()
        # write tmp file: expression matrix and annotation(if required)
        expression_file = f"{tmp_wd}/cytotrace2_expression.csv"
        annotation_path = ""
        df = pd.DataFrame(X, index=adata.obs_names, columns=adata.var_names).T
        print(f"write expression matrix({df.shape}) to  {expression_file}")
        df.to_csv(expression_file, sep="\t")
        if cluster is not None:
            annotation_path = f"{tmp_wd}/cytotrace2_annotations.csv"
            adata.obs[cluster].to_csv(annotation_path, sep="\t")
            print(f"write annotation to {annotation_path}")
        else:
            annotation_path = ""

        # 2. execute method
        result = cytotrace2(expression_file, annotation_path, disable_plotting=True, **cytotrace2_kwargs)

        # 3. extract results
        pseudotime = (1 - result["CytoTRACE2_Score"]).tolist()

        # 4. save results
        trajectory_dict = {
            "wrapper_type": "linear",
            "pseudotime": pseudotime,
        }

    return trajectory_dict

cafe.method.function.cf_sctc.sctc(adata, repreprocess=True)

SCTC: Single-Cell Transcriptional Complexity.

Parameters:

Name Type Description Default
adata AnnData

The input AnnData object.

required
repreprocess bool

Whether to preprocess the data. Defaults to True.

True

Returns:

Name Type Description
dict dict

A trajectory dict with keys: "wrapper_type" and "pseudotime".

Source code in cafe/method/function/cf_sctc.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@method_info(
    name="sctc",
    version="0.0.1",
    description="SCTC: single-Cell Transcriptional Complexity",
    wrapper_type="linear",
    doi="10.1093/nar/gkae340",
    github_url="https://github.com/hailinphysics/sctc",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def sctc(
    adata: ad.AnnData,
    repreprocess: bool = True,
) -> dict:
    """SCTC: Single-Cell Transcriptional Complexity.

    Args:
        adata (ad.AnnData): The input AnnData object.
        repreprocess (bool, optional): Whether to preprocess the data. Defaults to True.

    Returns:
        dict: A trajectory dict with keys: "wrapper_type" and "pseudotime".
    """

    import sctc

    # 1. preprocess
    if repreprocess:
        sc.pp.normalize_per_cell(adata)
        sc.pp.log1p(adata)
        sc.pp.highly_variable_genes(adata)
        adata = adata[:, adata.var["highly_variable"]]

    # 2. execute method
    cci, gci = sctc.complexity_index(adata.X.toarray())  # only use cci for cell

    # 3. extract results
    pseudotime = (1 - cci).tolist()  # pseudotime and cci are negatively correlated

    # 4. save results
    trajectory_dict = {
        "wrapper_type": "linear",
        "pseudotime": pseudotime,
    }
    return trajectory_dict

cafe.method.function.cf_palantir.palantir(adata, start_cell, repreprocess=True, palantir_kwargs={}, palantir_results_kwargs={}, wrapper_type='linear', linear_type='pseudotime', cluster='clusters')

Palantir: characterization of cell fate probabilities

Parameters:

Name Type Description Default
adata AnnData

The input AnnData object

required
start_cell str

The starting cell ID for palantir.

required
repreprocess bool

Whether to preprocess the data.

True
palantir_kwargs dict

Palantir core parameter dict, refer to scanpy.external.tl.palantir.

{}
palantir_results_kwargs dict

Palantir result output parameter dict, refer to scanpy.external.tl.palantir_results.

{}
wrapper_type Literal['linear', 'probability', 'lineage']

Wrapper type for the output.

'linear'
linear_type Literal['pseudotime', 'entropy']

Linear type for linear wrapper.

'pseudotime'
cluster str

Cluster column in '.obs' columns for lineage wrapper.

'clusters'

Returns: dict: A trajectory dict with keys: "wrapper_type" and "pseudotime".

Source code in cafe/method/function/cf_palantir.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@method_info(
    name="palantir",
    version="0.0.1",
    description="Palantir: characterization of cell fate probabilities",
    wrapper_type=["linear", "probability", "lineage"],
    doi="10.1038/s41587-019-0068-4",
    github_url="https://github.com/dpeerlab/Palantir",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def palantir(
    adata: ad.AnnData,
    start_cell: str,
    repreprocess: bool = True,
    palantir_kwargs: dict = {},
    palantir_results_kwargs: dict = {},
    wrapper_type: Literal["linear", "probability", "lineage"] = "linear",
    linear_type: Literal["pseudotime", "entropy"] = "pseudotime",
    cluster: str = "clusters",
):
    """Palantir: characterization of cell fate probabilities

    Args:
        adata (ad.AnnData): The input AnnData object
        start_cell (str): The starting cell ID for palantir.
        repreprocess (bool, optional):  Whether to preprocess the data.
        palantir_kwargs (dict, optional): Palantir core parameter dict, refer to
            [scanpy.external.tl.palantir](https://scanpy.readthedocs.io/en/stable/external/generated/scanpy.external.tl.palantir.html).
        palantir_results_kwargs (dict, optional): Palantir result output parameter dict, refer to
            [scanpy.external.tl.palantir_results](https://scanpy.readthedocs.io/en/stable/external/generated/scanpy.external.tl.palantir_results.html).
        wrapper_type (Literal["linear", "probability", "lineage"], optional): Wrapper type for the output.
        linear_type (Literal["pseudotime", "entropy"], optional): Linear type for linear wrapper.
        cluster (str, optional): Cluster column in '.obs' columns for lineage wrapper.
    Returns:
        dict: A trajectory dict with keys: "wrapper_type" and "pseudotime".
    """

    # ref: https://palantir.readthedocs.io/en/latest/notebooks/Palantir_sample_notebook.html
    # ref: https://scanpy.readthedocs.io/en/stable/external/generated/scanpy.external.tl.palantir.html
    # 1. preprocess
    if repreprocess:
        sc.pp.normalize_per_cell(adata)
        sc.pp.log1p(adata)
        sc.pp.pca(adata)
        sc.pp.neighbors(adata)
        print("repreprocess finish")

    # 2. execute method
    # TODO: check early_cell in cell_ids
    sce.tl.palantir(adata, **palantir_kwargs)  # DiffusionMap and MAGIC
    pr_res = sce.tl.palantir_results(adata, early_cell=start_cell, **palantir_results_kwargs)  # Pseudotime and branch probabilities
    print("palantir execute finish")

    # 3,4. extract and save results for different wrapper type
    # multiple output data which adapt to multiple wrapper
    # TODO: multiple output wrapper parallelization
    cell_ids = adata.obs.index
    if linear_type == "pseudotime":
        # pseudotime
        pseudotime = pr_res.pseudotime
    else:
        # entropy
        pseudotime = pr_res.entropy

    trajectory_dict = {}
    if wrapper_type == "linear":
        # for linear wrapper
        trajectory_dict["pseudotime"] = pseudotime
    elif wrapper_type == "probability":
        # for probability wrapper
        end_state_probabilities = pr_res.branch_probs
        end_state_probabilities["cell_id"] = cell_ids
        trajectory_dict["end_state_probabilities"] = end_state_probabilities
    else:
        # TODO: for lineage wrapper
        terminal_states = palantir_results_kwargs.get("terminal_states", [])
        probability = pr_res.branch_probs
        probability.columns = adata.obs[cluster][cell_ids.get_indexer(terminal_states)]

        trajectory_dict["probability"] = probability
        trajectory_dict["cluster"] = cluster

    trajectory_dict["wrapper_type"] = wrapper_type

    return trajectory_dict

cafe.method.function.cf_scvelo.scvelo(adata, repreprocess=True, repreprocess_kwargs={}, velocity_kwargs={}, velocity_graph_kwargs={})

scVelo: RNA velocity generalized through dynamical modeling

Parameters:

Name Type Description Default
adata AnnData

AnnData object

required
repreprocess bool

Whether to repreprocess the anndata object.

True
repreprocess_kwargs dict

Parameter dict for repreprocess pipeline.

{}
velocity_kwargs dict

Parameter dict for velocity calculation, refer to scvelo.tl.velocity.

{}
velocity_graph_kwargs dict

Parameter dict for velocity graph calculation, refer to scvelo.tl.velocity_embedding.

{}

Returns:

Name Type Description
dict

trajectory dict with keys about velocity

Source code in cafe/method/function/cf_scvelo.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@method_info(
    name="scvelo",
    version="0.0.1",
    description="scVelo: RNA velocity generalized through dynamical modeling",
    wrapper_type="velocity",
    doi="10.1038/s41587-020-0591-3",
    github_url="https://github.com/theislab/scvelo",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def scvelo(
    adata: ad.AnnData,
    repreprocess: bool = True,
    repreprocess_kwargs: dict = {},
    velocity_kwargs: dict = {},
    velocity_graph_kwargs: dict = {},
):
    """scVelo: RNA velocity generalized through dynamical modeling

    Args:
        adata (ad.AnnData): AnnData object
        repreprocess (bool, optional): Whether to repreprocess the anndata object.
        repreprocess_kwargs (dict, optional): Parameter dict for repreprocess pipeline.
        velocity_kwargs (dict, optional): Parameter dict for velocity calculation, refer to [scvelo.tl.velocity](https://scvelo.readthedocs.io/en/stable/scvelo.tl.velocity.html).
        velocity_graph_kwargs (dict, optional): Parameter dict for velocity graph calculation, refer to [scvelo.tl.velocity_embedding](https://scvelo.readthedocs.io/en/stable/scvelo.tl.velocity_embedding.html).

    Returns:
        dict: trajectory dict with keys about velocity
    """
    # ref: https://scvelo.readthedocs.io/en/stable/VelocityBasics.html
    # 1. preprocess
    if repreprocess:
        preprocess_pipeline(adata, style="scvelo", **repreprocess_kwargs)

    # 2. execute method
    scv.tl.velocity(adata, **velocity_kwargs)  # compute high dimensional velocity
    scv.tl.velocity_graph(adata, **velocity_graph_kwargs)  # compute transition probability
    # scv.pl.velocity_embedding_stream(adata, basis="umap", show=False)  # don't plot here
    adata.uns["method_name"] = scvelo  # to find correspodding function "extract_trajectory_dict" easily

    # 3,4. extract and save results
    trajectory_dict = extract_trajectory_dict(adata)
    return trajectory_dict

cafe.method.function.cf_dynamo.dynamo(adata, basis, repreprocess=True, repreprocess_kwargs={}, moment=True, n_neighbors=30, dynamics_kwargs={}, cell_velocities_kwargs={})

Dynamo: Mapping Transcriptomic Vector Fields of Single Cells

Parameters:

Name Type Description Default
adata AnnData

AnnData object.

required
repreprocess bool

Whether to repreprocess the anndata object.

True
repreprocess_kwargs dict

Parameter dict for repreprocess pipeline with dynamo style.

{}
dynamics_kwargs dict

Parameter dict for cell dynamics high dimensional velocity calculation, refer to dyn.tl.dynamics.

{}
cell_velocities_kwargs dict

Parameter dict for cell low dimensional velocity calculation, refer to dynamo.tl.cell_velocities.

{}

Returns:

Name Type Description
dict

trajectory dict with keys about velocity

Source code in cafe/method/function/cf_dynamo.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@method_info(
    name="dynamo",
    version="0.0.1",
    description="Dynamo: Mapping Transcriptomic Vector Fields of Single Cells",
    wrapper_type="velocity",
    doi="10.1016/j.cell.2021.12.045",
    github_url="https://github.com/aristoteleo/dynamo-release",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def dynamo(
    adata: ad.AnnData,
    basis: str,
    repreprocess: bool = True,
    repreprocess_kwargs: dict = {},
    moment: bool = True,
    n_neighbors: int = 30,
    dynamics_kwargs: dict = {},
    cell_velocities_kwargs: dict = {},
):
    """Dynamo: Mapping Transcriptomic Vector Fields of Single Cells

    Args:
        adata (ad.AnnData): AnnData object.
        repreprocess (bool, optional): Whether to repreprocess the anndata object.
        repreprocess_kwargs (dict, optional): Parameter dict for repreprocess pipeline with dynamo style.
        dynamics_kwargs (dict, optional): Parameter dict for cell dynamics high dimensional velocity calculation, refer to [dyn.tl.dynamics](https://dynamo-release.readthedocs.io/en/latest/api/reference/dynamo.tl.dynamics.html#dynamo.tl.dynamics).
        cell_velocities_kwargs (dict, optional): Parameter dict for cell low dimensional velocity calculation, refer to [dynamo.tl.cell_velocities](https://dynamo-release.readthedocs.io/en/latest/api/reference/dynamo.tl.cell_velocities.html#dynamo.tl.cell_velocities).

    Returns:
        dict: trajectory dict with keys about velocity

    """
    # 1. preprocess
    if repreprocess:
        preprocess_pipeline(adata, style="dynamo", **repreprocess_kwargs)
        dyn.tl.moments(adata)  # make sure moments is computed
        adata.layers["Ms"] = adata.layers["M_s"]  # extract moment matrix for transition matrix calculation
        adata.layers["Mu"] = adata.layers["M_u"]
        dyn.tl.neighbors(adata, n_neighbors=n_neighbors)  # recompute neighbors
    basis = basis[2:]  # remove "X_"
    # 2. execute method
    # dynamo core function
    dyn.tl.dynamics(adata, **dynamics_kwargs)
    dyn.tl.cell_velocities(adata, basis=basis, **cell_velocities_kwargs)  # scv.tl.velocity_graph(adata)
    # velocity_key = "velocity"
    # adata.layers[velocity_key] = adata.layers["velocity_S"].toarray()  # extract velocity matrix
    # adata.var[f"{velocity_key}_genes"] = adata.var["use_for_transition"]  # extract velocity gene
    adata = adata[:, adata.var["use_for_transition"]]  # extract velocity gene
    velocity_embedding = adata.obsm[f"velocity_{basis}"]

    # 3,4. extract and save results
    trajectory_dict = {
        "wrapper_type": "velocity",
        "velocity": None,
        "velocity_graph": None,
        "velocity_graph_neg": None,
        "velocity_embedding": velocity_embedding,
        "neighbors": {"distances": adata.obsp["distances"], "connectivities": adata.obsp["connectivities"]},
        "obs_index": adata.obs.index,
        "var_index": adata.var.index,
    }

    return trajectory_dict

cafe.method.function.cf_velovi.velovi(adata, repreprocess=True, repreprocess_kwargs={}, velovi_model_kwargs={}, velovi_train_kwargs={}, n_sample=25)

Deep generative modeling of transcriptional dynamics for RNA velocity analysis in single cells Args: adata (ad.AnnData): AnnData object repreprocess (bool, optional): Whether to repreprocess the anndata object. repreprocess_kwargs (dict, optional): Parameter dict for repreprocess pipeline. velovi_model_kwargs (dict, optional): Parameter dict for VeloVI model initialization, refer to scvi.external.VELOVI. velovi_train_kwargs (dict, optional): Parameter dict for VeloVI model training , refer to scvi.external.VELOVI.train. n_sample (int, optional): Sample number from latent space, refer to scvi.external.VELOVI.get_latent_time.

Returns:

Name Type Description
dict

trajectory dict with keys about velocity

Source code in cafe/method/function/cf_velovi.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@method_info(
    name="velovi",
    version="0.0.1",
    description="VeloVI: Deep generative modeling of transcriptional dynamics for RNA velocity analysis in single cells",
    wrapper_type="velocity",
    doi="10.1038/s41592-023-01994-w",
    github_url="https://github.com/yoseflab/velovi",
    use_gpu=True,
    cpu_parallelization=True,
    available=True,
)
def velovi(
    adata: ad.AnnData,
    repreprocess: bool = True,
    repreprocess_kwargs: dict = {},
    velovi_model_kwargs: dict = {},
    velovi_train_kwargs: dict = {},
    n_sample: int = 25,
):
    """Deep generative modeling of transcriptional dynamics for RNA velocity analysis in single cells
    Args:
        adata (ad.AnnData): AnnData object
        repreprocess (bool, optional):  Whether to repreprocess the anndata object.
        repreprocess_kwargs (dict, optional):  Parameter dict for repreprocess pipeline.
        velovi_model_kwargs (dict, optional): Parameter dict for VeloVI model initialization, refer to [scvi.external.VELOVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.external.VELOVI.html).
        velovi_train_kwargs (dict, optional): Parameter dict for VeloVI model training , refer to [scvi.external.VELOVI.train](https://docs.scvi-tools.org/en/stable/api/reference/scvi.external.VELOVI.html#scvi.external.VELOVI.train).
        n_sample (int, optional): Sample number from latent space, refer to [scvi.external.VELOVI.get_latent_time](https://docs.scvi-tools.org/en/stable/api/reference/scvi.external.VELOVI.html#scvi.external.VELOVI.get_latent_time).

    Returns:
        dict: trajectory dict with keys about velocity

    """
    # ref: https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/velovi.html
    # the package is not available in cafe environment, so we import it here
    from scvi.external import VELOVI

    # 1. preprocess
    if repreprocess:
        preprocess_pipeline(adata, style="scvelo", **repreprocess_kwargs)

    # 2. execute method
    VELOVI.setup_anndata(adata, spliced_layer="Ms", unspliced_layer="Mu")
    vae = VELOVI(adata, **velovi_model_kwargs)
    vae.train(**velovi_train_kwargs)  # TODO: very slow, need GPU
    # extract velocity to adata.layers["velocity"]
    latent_time = vae.get_latent_time(n_samples=n_sample)
    velocities = vae.get_velocity(n_samples=n_sample, velo_statistic="mean")
    t = latent_time
    scaling = 20 / t.max(0)
    adata.layers["velocity"] = velocities / scaling
    scv.tl.velocity_graph(adata)

    # 3,4. extract and save results
    trajectory_dict = {
        "wrapper_type": "velocity",
        "velocity": adata.layers["velocity"],
        "velocity_graph": adata.uns["velocity_graph"],
        "velocity_graph_neg": adata.uns["velocity_graph_neg"],
        "neighbors": {"distances": adata.obsp["distances"], "connectivities": adata.obsp["connectivities"]},
        "obs_index": adata.obs.index,
        "var_index": adata.var.index,
    }

    return trajectory_dict

cafe.method.function.cf_veloae.veloae(adata, repreprocess=True, repreprocess_kwargs={}, veloae_args={})

VeloAE: Representation learning of RNA velocity reveals robust cell transitions

Source code in cafe/method/function/cf_veloae.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@method_info(
    name="veloae",
    version="0.0.1",
    description="VeloAE: Representation learning of RNA velocity reveals robust cell transitions",
    wrapper_type="velocity",
    doi="10.1073/pnas.2105859118",
    github_url="https://github.com/qiaochen/VeloAE",
    use_gpu=True,
    cpu_parallelization=True,
    available=True,
)
def veloae(
    adata: ad.AnnData,
    repreprocess: bool = True,
    repreprocess_kwargs: dict = {},
    veloae_args: dict = {},
):
    """VeloAE: Representation learning of RNA velocity reveals robust cell transitions"""
    # ref: https://github.com/qiaochen/VeloAE/blob/main/notebooks/pancreas/model-pancreas-gat.ipynb
    import torch
    from veloproj import (
        estimate_ld_velocity,
        fit_model,
        get_parser,
        init_model,
        new_adata,
    )

    # 1. preprocess
    if repreprocess:
        preprocess_pipeline(adata, style="scvelo", **repreprocess_kwargs)

    args = [
        "--lr",
        "1e-6",
        "--n-epochs",
        "10",
        "--g-rep-dim",
        "100",
        "--k-dim",
        "100",
        "--model-name",
        "pancreas_scv_model.cpt",
        "--exp-name",
        "CohAE_pancreas_scv",
        "--device",
        "cuda:0",
        "--gumbsoft_tau",
        "1",
        "--nb_g_src",
        "X",
        "--ld_nb_g_src",
        "X",
        "--n_raw_gene",
        "2000",
        "--n_conn_nb",
        "30",
        "--n_nb_newadata",
        "30",
        "--aux_weight",
        "1",
        "--fit_offset_train",
        "false",
        "--fit_offset_pred",
        "true",
        "--use_offset_pred",
        "false",
        "--gnn_layer",
        "GAT",
        "--vis-key",
        "X_umap",
        "--vis_type_col",
        "clusters",
        "--scv_n_jobs",
        "10",
    ]
    parser = get_parser()
    args = parser.parse_args(args=args)
    for k, v in veloae_args.items():
        if args.hasattr(k):
            setattr(args, k, v)
        else:
            print(f"Warning: args has no attribute {k}, please check the parameter name")
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    device = torch.device(args.device if args.device.startswith("cuda") and torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 2. execute scvelo stochastic model result
    # use scvelo stochastic model result
    scv.tl.velocity(adata, vkey="stc_velocity", mode="stochastic")
    # extract tensors
    spliced = adata.layers["Ms"]
    unspliced = adata.layers["Mu"]
    tensor_s = torch.FloatTensor(spliced).to(device)
    tensor_u = torch.FloatTensor(unspliced).to(device)
    tensor_x = torch.FloatTensor(adata.X.toarray()).to(device)
    tensor_v = torch.FloatTensor(adata.layers["stc_velocity"]).to(device)
    inputs = [tensor_s, tensor_u]
    xyids = [0, 1]
    if args.use_x:
        inputs.append(tensor_x)
    # model initialization
    model = init_model(adata, args, device)
    # model training
    model = fit_model(args, adata, model, inputs, tensor_v, xyids, device)
    # model inference to get high dimensional veocity
    model.eval()
    with torch.no_grad():
        x = model.encoder(tensor_x)
        s = model.encoder(tensor_s)
        u = model.encoder(tensor_u)

        v = (
            estimate_ld_velocity(
                s, u, device=device, perc=[5, 95], norm=args.use_norm, fit_offset=args.fit_offset_pred, use_offset=not args.use_offset_pred
            )
            .cpu()
            .numpy()
        )
        x = x.cpu().numpy()
        s = s.cpu().numpy()
        u = u.cpu().numpy()
        # project velocity to low-dim space
    adata = new_adata(adata, x, s, u, v, g_basis=args.ld_nb_g_src, n_nb_newadata=args.n_nb_newadata)
    scv.tl.velocity_graph(adata, vkey="new_velocity", n_jobs=args.scv_n_jobs)

    # 3,4. extract and save results
    trajectory_dict = {
        "wrapper_type": "velocity",
        "X": adata.X,
        "velocity": adata.layers["new_velocity"],
        "velocity_graph": adata.uns["new_velocity_graph"],
        "velocity_graph_neg": adata.uns["new_velocity_graph_neg"],
        "neighbors": {"distances": adata.obsp["distances"], "connectivities": adata.obsp["connectivities"]},
        "obs_index": adata.obs.index,
        "var_index": adata.var.index,  # here the var_index is the latent dimension, can't apply to now
    }

    return trajectory_dict

cafe.method.function.cf_unitvelo.unitvelo(adata, cluster, configuration_kwargs={})

UniTVelo: temporally unified RNA velocity reinforces single-cell trajectory inference

Parameters:

Name Type Description Default
adata AnnData

AnnData object.

required
cluster str

Cluster column name in adata.obs.

required
configuration_kwargs dict

Parameter dict for unitvelo pipeline, refer to config.py.

{}

Returns:

Name Type Description
dict

trajectory dict with keys about velocity

Source code in cafe/method/function/cf_unitvelo.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@method_info(
    name="unitvelo",
    version="0.0.1",
    description="UniTVelo: temporally unified RNA velocity reinforces single-cell trajectory inference",
    wrapper_type="velocity",
    doi="10.1038/s41467-022-34188-7",
    github_url="https://github.com/StatBiomed/UniTVelo",
    use_gpu=True,
    cpu_parallelization=True,
    available=True,
)
def unitvelo(
    adata: ad.AnnData,
    cluster: str,
    configuration_kwargs: dict = {},
):
    """UniTVelo: temporally unified RNA velocity reinforces single-cell trajectory inference

    Args:
        adata (ad.AnnData): AnnData object.
        cluster (str): Cluster column name in adata.obs.
        configuration_kwargs (dict, optional):  Parameter dict for unitvelo pipeline, refer to [config.py](https://github.com/StatBiomed/UniTVelo/blob/main/unitvelo/config.py).

    Returns:
        dict: trajectory dict with keys about velocity
    """

    # ref: https://github.com/StatBiomed/UniTVelo/blob/main/notebooks/Figure3_BoneMarrow.ipynb
    import os
    import shutil

    import unitvelo as utv

    # 1,2 preprocess and execute method
    if "filename" in adata.uns:
        adata_filename = adata.uns["filename"]
    else:
        # for docker test: save adata for latter unitvelo pipeline
        adata_filename = "adata.h5ad"
        adata.write(adata_filename)
        print("save adata for unitvelo pipeline:", adata_filename)

    # configuration
    velo = utv.config.Configuration()
    velo.MAX_ITER = 1000
    for k, v in configuration_kwargs.items():
        setattr(velo, k, v)
    # run model
    adata = utv.run_model(adata_filename, label=cluster, config_file=velo)

    # remove tmp dir
    tmp_dir = adata_filename.replace(".h5ad", "")
    if os.path.exists(tmp_dir):
        print("remove unitvelo tmp dir:", tmp_dir)
        shutil.rmtree(tmp_dir)

    # 3,4. extract and save results
    trajectory_dict = {
        "wrapper_type": "velocity",
        "velocity": adata.layers["velocity"],
        "velocity_graph": adata.uns["velocity_graph"],
        "velocity_graph_neg": adata.uns["velocity_graph_neg"],
        "neighbors": {"distances": adata.obsp["distances"], "connectivities": adata.obsp["connectivities"]},
        "obs_index": adata.obs.index,
        "var_index": adata.var.index,
    }

    return trajectory_dict

cafe.method.function.cf_celldancer.celldancer(adata, cluster, basis, repreprocess=True, repreprocess_kwargs={}, velocity_kwargs={}, compute_cell_velocity_kwargs={})

CellDancer: Estimating Cell-dependent RNA Velocity

Parameters:

Name Type Description Default
adata AnnData

AnnData object

required
repreprocess bool

Whether to repreprocess the anndata object.

True
repreprocess_kwargs dict

Parameter dict for repreprocess pipeline.

{}

Returns:

Name Type Description
dict

trajectory dict with keys about velocity, only velocity_embedding is available

Source code in cafe/method/function/cf_celldancer.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@method_info(
    name="celldancer",
    version="0.0.1",
    description="CellDancer: Estimating Cell-dependent RNA Velocity",
    wrapper_type="velocity",
    doi="10.1038/s41587-023-01728-5",
    github_url="https://github.com/GuangyuWangLab2021/cellDancer",
    use_gpu=True,
    cpu_parallelization=True,
    available=True,
)
def celldancer(
    adata: ad.AnnData,
    cluster: str,
    basis: str,
    repreprocess: bool = True,
    repreprocess_kwargs: dict = {},
    velocity_kwargs: dict = {},
    compute_cell_velocity_kwargs: dict = {},
):
    """CellDancer: Estimating Cell-dependent RNA Velocity

    Args:
        adata (ad.AnnData): AnnData object
        repreprocess (bool, optional): Whether to repreprocess the anndata object.
        repreprocess_kwargs (dict, optional): Parameter dict for repreprocess pipeline.

    Returns:
        dict: trajectory dict with keys about velocity, only velocity_embedding is available
    """

    import celldancer as cd

    # 1. preprocess
    if repreprocess:
        preprocess_pipeline(adata, style="scvelo", **repreprocess_kwargs)

    # 2. execute method
    # transfer adata to cellDancer format
    cellDancer_df = cd.utilities.adata_to_df_with_embed(
        adata,
        us_para=["Mu", "Ms"],
        cell_type_para=cluster,
        embed_para=basis,
    )
    loss_df, cellDancer_df = cd.velocity(cellDancer_df, **velocity_kwargs)
    cellDancer_df = cd.compute_cell_velocity(cellDancer_df=cellDancer_df, **compute_cell_velocity_kwargs)
    # loss_df, cellDancer_df = cd.velocity(cellDancer_df)
    # cellDancer_df = cd.compute_cell_velocity(cellDancer_df=cellDancer_df, projection_neighbor_size=100) # compute by transciption parameter

    # 3. extract results
    velocity_df = cellDancer_df.groupby("cellID").first()[["velocity1", "velocity2"]].fillna(0)
    adata.obsm[f"velocity_{basis[2:]}"] = velocity_df.loc[adata.obs.index].values  # align index
    # celldancer generate many zero velocity cells, only extracted valid velocity cell to construct trajectory.
    adata = adata[~((adata.obsm[f"velocity_{basis[2:]}"] == 0).all(axis=1))].copy()
    trajectory_dict = extract_trajectory_dict(adata, basis=basis)

    # 4. save results
    return trajectory_dict

cafe.method.function.cf_pyrovelocity.pyrovelocity(adata, configuration_kwargs={})

PyroVelocity: probabilistic modeling of RNA velocity

Parameters:

Name Type Description Default
adata AnnData

AnnData object.

required
configuration_kwargs dict

Configuration dict for pyrovelocity pipeline, refer to pancrease template.

{}

Returns:

Name Type Description
dict

trajectory dict with keys about velocity

Source code in cafe/method/function/cf_pyrovelocity.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
@method_info(
    name="pyrovelocity",
    version="0.0.1",
    description="PyroVelocity: probabilistic modeling of RNA velocity",
    wrapper_type="velocity",
    doi="10.1101/2022.09.12.507691",
    github_url="https://github.com/pinellolab/pyrovelocity",
    use_gpu=True,
    cpu_parallelization=True,
    available=True,
)
def pyrovelocity(
    adata: ad.AnnData,
    configuration_kwargs: dict = {},
):
    """PyroVelocity: probabilistic modeling of RNA velocity

    Args:
        adata (ad.AnnData): AnnData object.
        configuration_kwargs (dict, optional): Configuration dict for pyrovelocity pipeline, refer to [pancrease template](https://github.com/pinellolab/pyrovelocity/blob/v0.4.5/src/pyrovelocity/workflows/main_configuration.py).

    Returns:
        dict: trajectory dict with keys about velocity
    """

    #  ref: https://docs.pyrovelocity.net/templates/user_example/user_example
    #  pyrovelocity workflow add.
    # https://github.com/pyrovelocity/pyrovelocity/blob/v0.4.5/src/pyrovelocity/workflows/main_configuration.py
    import tempfile

    import mlflow
    import scanpy as sc
    from pyrovelocity.workflows.main_configuration import (
        pancreas_configuration as templete_configuration,
    )
    from pyrovelocity.workflows.main_workflow import (
        download_data,
        postprocess_data,
        preprocess_data,
        train_model,
    )

    # working dir is setting to avoid use previous middle file.

    with tempfile.TemporaryDirectory() as tmp_wd:
        #
        data_set_name = "tmp"
        # adata_filename = f"{tmp_wd}/{data_set_name}.h5ad"
        # sc.write(adata_filename, adata)
        if "filename" in adata.uns:
            adata_filename = adata.uns["filename"]
        else:
            adata_filename = "adata.h5ad"
            adata.write(adata_filename)
            print("save adata for pyrovelocity pipeline:", adata_filename)
        data_set_name = adata_filename.split("/")[-1]
        data_external_path = adata_filename.replace(data_set_name, "")
        data_set_name = data_set_name.replace(".h5ad", "")

        # configuration object construction based on pancrease template
        # input, preprocessed and output filename
        templete_configuration.download_dataset.data_set_name = data_set_name
        templete_configuration.download_dataset.data_external_path = data_external_path
        templete_configuration.download_dataset.source = ""
        templete_configuration.preprocess_data.data_set_name = data_set_name
        templete_configuration.preprocess_data.adata = adata_filename
        templete_configuration.training_configuration_1.data_set_name = data_set_name
        templete_configuration.training_configuration_1.adata = f"{tmp_wd}/{data_set_name}_processed.h5ad"
        # other configuration
        templete_configuration.training_configuration_1.max_epochs = 200
        for category, category_configuration_dict in configuration_kwargs.items():
            category_configuration_object = getattr(templete_configuration, category)
            if category_configuration_object is None:
                print(f"Warning: no category '{category}' in template configuration")
                continue
            else:
                for k, v in category_configuration_dict.items():
                    if hasattr(category_configuration_object, k):
                        setattr(category_configuration_object, k, v)
                    else:
                        print(f"Warning: no parameter '{k}' in configuration category '{category}'")

        # data
        data = download_data(download_dataset_args=templete_configuration.download_dataset)

        # preprocess
        processed_data = preprocess_data(
            data=data,
            preprocess_data_args=templete_configuration.preprocess_data,
        )

        # train
        mlflow.set_experiment("0")
        model_output = train_model(
            processed_data,
            train_model_configuration=templete_configuration.training_configuration_1,
        )

        # postprocess
        postprocessing_outputs = postprocess_data(
            preprocess_data_args=templete_configuration.preprocess_data,
            training_outputs=model_output,
            postprocess_configuration=templete_configuration.postprocess_configuration,
        )

        # read result adata
        adata = sc.read(postprocessing_outputs.postprocessed_data)
        print("adata result", adata.shape)

    trajectory_dict = {
        "wrapper_type": "velocity",
        "velocity": adata.layers["velocity_pyro"],
        "velocity_graph": adata.uns["velocity_pyro_graph"],
        "velocity_graph_neg": adata.uns["velocity_pyro_graph_neg"],
        "neighbors": {"distances": adata.obsp["distances"], "connectivities": adata.obsp["connectivities"]},
        "obs_index": adata.obs.index,
        "var_index": adata.var.index,
        "save_h5ad": postprocessing_outputs.postprocessed_data,
    }

    return trajectory_dict

cafe.method.function.cf_stavia.stavia(adata, cluster, start_cell, repreprocess=True, data_basis='X_pca', ncomps=30, via_kwargs={}, prune_milestone=True)

StaVia: spatially and temporally aware cartography with higher-order random walks for cell atlases Returns: dict: trajectory dict with keys about cluster wrapper

Source code in cafe/method/function/cf_stavia.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@method_info(
    name="stavia",
    version="0.0.1",
    description="StaVia: spatially and temporally aware cartography with higher-order random walks for cell atlases",
    wrapper_type="cluster",
    doi="10.1186/s13059-024-03347-y",
    github_url="https://github.com/ShobiStassen/VIA",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def stavia(
    adata: ad.AnnData,
    cluster: str,
    start_cell: str,
    repreprocess: bool = True,
    data_basis: str = "X_pca",
    ncomps: int = 30,
    via_kwargs: dict = {},
    prune_milestone: bool = True,  # whether to prune the milestone network. Default to True to get a clean backbone graph similar to native VIA plots.
):
    """StaVia: spatially and temporally aware cartography with higher-order random walks for cell atlases
    Returns:
        dict: trajectory dict with keys about cluster wrapper
    """
    # ref: https://pyvia.readthedocs.io/en/latest/notebooks/ViaJupyter_scRNA_Hematopoiesis.html
    import igraph as ig
    import pyVIA.core as via

    # 1. preprocess
    # 2. execute method
    via_object = via.VIA(
        data=adata.obsm[data_basis][:, :ncomps],
        true_label=adata.obs[cluster],
        root_user=start_cell,
        **via_kwargs,
    )
    via_object.run_VIA()

    # 3. extract results
    if prune_milestone:
        # ref: plot_trajectory_curves (https://github.com/ShobiStassen/VIA/blob/master/VIA/plotting_via.py#L3100)
        super_edgelist = via_object.edgelist_maxout
        super_cluster_labels = via_object.labels  # 细胞归属里程碑标签
        final_super_terminal = via_object.terminal_clusters
        super_root = via_object.root[0]
        G_orange = ig.Graph(n=len(set(super_cluster_labels)), edges=super_edgelist)
        # 保存从根节点到各终端状态的最短路径的边
        ll_ = []
        for fst_i in final_super_terminal:
            path_orange = G_orange.get_shortest_paths(super_root, to=fst_i)[0]
            len_path_orange = len(path_orange)
            for enum_edge, edge_fst in enumerate(path_orange):
                if enum_edge < (len_path_orange - 1):
                    ll_.append((edge_fst, path_orange[enum_edge + 1]))
        edgelist = list(set(ll_))
    else:
        edgelist = via_object.edgelist_maxout
    edgelist = [[str(item[0]), str(item[1])] for item in edgelist]
    milestone_network = pd.DataFrame(
        data=edgelist,
        columns=["from", "to"],
    )
    milestone_network["length"] = 1
    milestone_network["directed"] = True

    adata.obs["stavia_cluster"] = [str(i) for i in via_object.labels]
    cluster_milestones = adata.obs["stavia_cluster"]

    # 4. save results
    trajectory_dict = {
        "wrapper_type": "cluster",
        "milestone_network": milestone_network,
        "cluster": cluster_milestones,
    }
    return trajectory_dict

cafe.method.function.cf_cellrank.cellrank(adata, cluster, repreprocess=True, wrapper_type='probability', kernel='connectivity', kernel_params={}, initial_states=None, terminal_states=None, fit_kwargs={}, predict_terminal_states_kwargs={}, using_macrostate=True)

CellRank 2: unified fate mapping in multiview single-cell data

Parameters:

Name Type Description Default
adata AnnData

AnnData object.

required
repreprocess bool

Whether to repreprocess the anndata object.

True

Returns:

Name Type Description
dict

trajectory dict with keys about velocity

Source code in cafe/method/function/cf_cellrank.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@method_info(
    name="cellrank",
    version="0.0.1",
    description="CellRank 2: unified fate mapping in multiview single-cell data",
    wrapper_type="velocity",
    doi="10.1038/s41592-024-02303-9",
    github_url="https://github.com/theislab/cellrank",
    use_gpu=False,
    cpu_parallelization=True,
    available=True,
)
def cellrank(
    adata: ad.AnnData,
    cluster: str,
    repreprocess: bool = True,
    wrapper_type: str = "probability",
    kernel: str = "connectivity",
    kernel_params: dict = {},
    initial_states=None,
    terminal_states=None,
    fit_kwargs: dict = {},
    predict_terminal_states_kwargs: dict = {},
    using_macrostate: bool = True,
):
    """CellRank 2: unified fate mapping in multiview single-cell data

    Args:
        adata (ad.AnnData): AnnData object.
        repreprocess (bool, optional): Whether to repreprocess the anndata object.

    Returns:
        dict: trajectory dict with keys about velocity
    """
    # 1.  preprocess
    if repreprocess:
        preprocess_pipeline(adata, style="scanpy")

    # 2. execute method
    # kernel
    if kernel == "connectivity":
        kernel_obj = cr.kernels.ConnectivityKernel(adata, **kernel_params)
    elif kernel == "velocity":
        if "velocity" not in adata.layers:
            # TODO: check and calculate velocity adata.layer["velocity"] first
            raise ValueError("adata.layers['velocity'] not found, please calculate velocity first.")
        kernel_obj = cr.kernels.VelocityKernel(adata, **kernel_params).compute_transition_matrix()
    else:
        # TODO: Other kernel in parameters
        kernel_obj = None
    # TODO: complex kernel with multiple views
    kernel_obj.compute_transition_matrix()

    # estimator
    g = cr.estimators.GPCCA(kernel_obj)
    # identify macrostates, related parameters are in fit_kwargs
    if fit_kwargs.get("cluster_key") is None:
        fit_kwargs["cluster_key"] = cluster
        print(f"use default cluster key: {cluster}")
    if fit_kwargs.get("n_states") is None:
        cluster = fit_kwargs["cluster_key"]
        n_states = len(adata.obs[cluster].cat.categories)
        fit_kwargs["n_states"] = n_states
        print(f"set n_states={n_states} according to cluster key({cluster})")
    g.fit(**fit_kwargs)
    macrostates = g.macrostates.cat.categories.tolist()  # valid macro states
    # set initial and terminal states
    if initial_states is not None:
        initial_states = [macrostate for macrostate in macrostates if re.sub(r"_\d+$", "", macrostate) in initial_states]
        g.set_initial_states(states=initial_states)
        print(f"set initial states({initial_states}) mannually")
    if terminal_states is not None:
        # set terminal states mannually
        terminal_states = [macrostate for macrostate in macrostates if re.sub(r"_\d+$", "", macrostate) in terminal_states]
        g.set_terminal_states(states=terminal_states)
        print(f"set terminal states({terminal_states}) mannually")
    else:
        g.predict_terminal_states(**predict_terminal_states_kwargs)
        print("predict terminal states automatically")
    # compute fate probabilities
    g.compute_fate_probabilities()
    # extract lineage object
    lineage = g._fate_probabilities
    end_state_probabilities = pd.DataFrame(lineage.__array__(), columns=lineage.names, index=adata.obs.index)
    adata.obsm["lineages_fwd"] = lineage
    adata.obs[end_state_probabilities.columns] = end_state_probabilities
    # macrostate
    macrostate_df = pd.DataFrame(g.macrostates_memberships.__array__(), columns=g.macrostates.cat.categories.tolist())
    macrostate_list = macrostate_df.idxmax(axis=1).tolist()

    # 3. extract results
    trajectory_dict = {"wrapper_type": wrapper_type}
    if wrapper_type == "lineage":
        if using_macrostate:
            # macrostate as cluster with suffix number
            macrostate_df = pd.DataFrame(g.macrostates_memberships.__array__(), columns=g.macrostates.cat.categories.tolist())
            macrostate_list = macrostate_df.idxmax(axis=1).tolist()
            trajectory_dict["new_cluster_list"] = macrostate_list
        else:
            # raw cluster, remove suffix number
            end_state_probabilities.columns = end_state_probabilities.columns.str.replace(r"_\d+", "")  # remove suffix number
            trajectory_dict["cluster_key"] = cluster
        trajectory_dict["probability"] = end_state_probabilities

    else:
        # for probability wrapper
        trajectory_dict["end_state_probabilities"] = end_state_probabilities

    # 4. save results
    return trajectory_dict