Skip to content

cafe.metric.metric_cluster

cafe.metric.metric_cluster

calculate_mapping(fadata, grouping='milestones', simplify=False, ref_model='ref', pred_model='default')

计算轨迹映射指标——只需一个 FateAnnData,通过 ref_model / pred_model 从 fadata.uns["cafe"]['trajectory_history_dict'] 中取出两条不同轨迹进行对比。

Parameters:

Name Type Description Default
fadata FateAnnData

包含多条轨迹的 FateAnnData

required
grouping str

'milestones' 或 'branches'

'milestones'
simplify bool

是否先简化轨迹骨架

False
ref_model str

参考轨迹在 trajectory_history_dict 中的 key

'ref'
pred_model str

预测轨迹在 trajectory_history_dict 中的 key

'default'

Returns:

Type Description
dict

{'recovery': ..., 'relevance': ..., 'F1': ...}

Source code in cafe/metric/metric_cluster.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def calculate_mapping(
    fadata: FateAnnData,
    grouping: str = "milestones",
    simplify: bool = False,
    ref_model: str = "ref",
    pred_model: str = "default",
) -> dict:
    """
    计算轨迹映射指标——只需一个 FateAnnData,通过 ref_model / pred_model 从
    fadata.uns["cafe"]['trajectory_history_dict'] 中取出两条不同轨迹进行对比。

    Args:
        fadata: 包含多条轨迹的 FateAnnData
        grouping: 'milestones' 或 'branches'
        simplify: 是否先简化轨迹骨架
        ref_model: 参考轨迹在 trajectory_history_dict 中的 key
        pred_model: 预测轨迹在 trajectory_history_dict 中的 key

    Returns:
        {'recovery': ..., 'relevance': ..., 'F1': ...}
    """
    # 参数校验
    if grouping not in ["branches", "milestones"]:
        raise ValueError("grouping must be either 'branches' or 'milestones'")

    # 1. 取出所有历史轨迹字典
    hist = fadata.uns.get("cafe", {}).get("trajectory_history_dict", {})
    # 如果任一模型不存在,直接返回 0
    if ref_model not in hist or pred_model not in hist:
        return {"recovery": 0.0, "relevance": 0.0, "F1": 0.0}

    # 2. (可选)简化骨架
    if simplify:
        fadata.simplify_trajectory(ref_model)
        fadata.simplify_trajectory(pred_model)

    # 3. 分组用到的列名
    if grouping == "milestones":
        group_key = "_cafe_nm_group"
    elif grouping == "branches":
        group_key = "_cafe_te_group"
    else:
        raise ValueError("grouping must be either 'milestones' or 'branches'")

    # 为了不丢失原 model_name,先保存后还原
    orig_model = fadata.model_name

    # 4. 对“参考”轨迹做分组
    fadata.model_name = ref_model
    if grouping == "milestones":
        fadata.group_onto_nearest_milestones(cluster_key=group_key)
    else:
        fadata.group_onto_trajectory_edges(cluster_key=group_key)
    groups_ref = fadata.obs.groupby(group_key).apply(lambda df: set(df.index))

    # 5. 对“预测”轨迹做分组
    fadata.model_name = pred_model
    if grouping == "milestones":
        fadata.group_onto_nearest_milestones(cluster_key=group_key)
    else:
        fadata.group_onto_trajectory_edges(cluster_key=group_key)
    groups_pred = fadata.obs.groupby(group_key).apply(lambda df: set(df.index))

    # 恢复原来的 model_name
    fadata.model_name = orig_model

    # 6. 计算 Jaccard 矩阵
    jaccard = pd.DataFrame(index=groups_ref.index, columns=groups_pred.index, dtype=float)
    for rname, rcells in groups_ref.items():
        for pname, pcells in groups_pred.items():
            inter = len(rcells & pcells)
            uni = len(rcells | pcells)
            jaccard.loc[rname, pname] = (inter / uni) if uni > 0 else 0.0

    # 7. recovery, relevance, F1
    recovery = jaccard.max(axis=1).mean() if not jaccard.empty else 0.0
    relevance = jaccard.max(axis=0).mean() if not jaccard.empty else 0.0
    if (recovery + relevance) > 0:
        f1 = 2 * recovery * relevance / (recovery + relevance)
    else:
        f1 = 0.0

    return {"recovery": recovery, "relevance": relevance, "F1": f1}

calculate_mapping_branches(fadata, return_type='score', **kwargs)

计算分支分组映射指标

Source code in cafe/metric/metric_cluster.py
102
103
104
105
106
107
108
def calculate_mapping_branches(fadata: FateAnnData, return_type: str = "score", **kwargs) -> dict:
    """计算分支分组映射指标"""
    m = calculate_mapping(fadata, grouping="branches", **kwargs)
    if return_type == "score":
        return m["F1"]
    else:
        return {f"{k}_branches": v for k, v in m.items()}

calculate_mapping_milestones(fadata, return_type='score', **kwargs)

计算里程碑分组映射指标

Source code in cafe/metric/metric_cluster.py
93
94
95
96
97
98
99
def calculate_mapping_milestones(fadata: FateAnnData, return_type: str = "score", **kwargs) -> dict:
    """计算里程碑分组映射指标"""
    m = calculate_mapping(fadata, grouping="milestones", **kwargs)
    if return_type == "score":
        return m["F1"]
    else:
        return {f"{k}_milestones": v for k, v in m.items()}