Spaces:
Runtime error
Runtime error
File size: 5,748 Bytes
7845551 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gradio as gr
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from ppi.distributions import BinaryYFDistribution
from ppi.compute_curve import compute_curve
def convert_y_params_to_f_params(p_y, p_f_given_y1, p_f_given_y0):
"""Convert Y-based parameters to F-based parameters using Bayes' theorem"""
# P(F = 1) = P(F = 1 | Y = 1) * P(Y = 1) + P(F = 1 | Y = 0) * P(Y = 0)
p_f = p_f_given_y1 * p_y + p_f_given_y0 * (1 - p_y)
# P(Y = 1 | F = 1) = P(F = 1 | Y = 1) * P(Y = 1) / P(F = 1)
p_y_given_f1 = (p_f_given_y1 * p_y) / p_f if p_f > 0 else 0
# P(Y = 1 | F = 0) = P(F = 0 | Y = 1) * P(Y = 1) / P(F = 0)
p_y_given_f0 = ((1 - p_f_given_y1) * p_y) / (1 - p_f) if p_f < 1 else 0
return p_f, p_y_given_f1, p_y_given_f0
def plot_custom_probabilities(p_f, p_y_given_f1, p_y_given_f0, show_ppi):
plt.close()
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
binary_dist = BinaryYFDistribution(
p_f=p_f, p_y_given_f1=p_y_given_f1, p_y_given_f0=p_y_given_f0
)
n_range = np.arange(4, 200, 2)
data = compute_curve(binary_dist, n_range)
data = pd.DataFrame(data)
corr = binary_dist.correlation()
ax.set_ylim(0.5, 2)
best_rel_perf = (
binary_dist.variance_y()
- binary_dist.covariance_f_y() ** 2 / binary_dist.variance_f()
) / binary_dist.variance_y()
# Always show PPI++
sns.lineplot(
data,
x="n",
y="relative_var_cf",
ax=ax,
label="Relative Variance PPI++",
linewidth=2,
)
# Conditionally show PPI
if show_ppi:
sns.lineplot(
data,
x="n",
y="relative_var_ppi",
ax=ax,
label="Relative Variance PPI",
linewidth=2,
linestyle="dotted",
)
# Add horizontal line at y=1
ax.axhline(
y=1, color="k", linestyle="dotted", alpha=0.7, label="Baseline (Sample Mean)"
)
# Add horizontal line at y=1
ax.axhline(
y=best_rel_perf,
color="b",
linestyle="dotted",
alpha=0.7,
label="Asymptotic Performance",
)
# Find and mark crossing points
cf_data = data["relative_var_cf"].values
n_data = data["n"].values
# Find where PPI++ crosses y=1
cf_crossings = np.where(np.diff(np.sign(cf_data - 1)))[0]
for crossing in cf_crossings:
if crossing < len(n_data) - 1:
ax.axvline(x=n_data[crossing], color="red", linestyle="--", alpha=0.7)
ax.text(
n_data[crossing],
0.55,
f"n={n_data[crossing]}",
rotation=0,
ha="center",
va="bottom",
color="red",
)
ax.set_xlabel("Sample Size (n)")
ax.set_ylabel("Relative Variance")
ax.set_title(
f"Relative Variance Analysis (Correlation of Psuedo-Label: {corr:.2f})"
)
ax.legend()
ax.grid(True, alpha=0.3)
return fig
def plot_y_based_probabilities(p_y, p_f_given_y1, p_f_given_y0, show_ppi):
"""Plot using Y-based parameters by converting to F-based parameters"""
# Convert Y-based parameters to F-based parameters
p_f, p_y_given_f1, p_y_given_f0 = convert_y_params_to_f_params(
p_y, p_f_given_y1, p_f_given_y0
)
# Use the existing plotting function with converted parameters
return plot_custom_probabilities(p_f, p_y_given_f1, p_y_given_f0, show_ppi)
# Create interface for F-based parameters (original)
f_based_interface = gr.Interface(
fn=plot_custom_probabilities,
inputs=[
gr.Slider(0.05, 0.95, value=0.50, step=0.05, label="P(F = 1)"),
gr.Slider(0.05, 0.95, value=0.60, step=0.05, label="P(Y = 1 | F = 1)"),
gr.Slider(0.05, 0.95, value=0.40, step=0.05, label="P(Y = 1 | F = 0)"),
gr.Checkbox(value=True, label="Show PPI curve"),
],
outputs=gr.Plot(label="PPI++ Analysis", format="png"),
title="Example: Specify in terms of psuedo-label prevalence and Y | F distribution",
description="""
Analyze relative variance curves of PPI and PPI++ (with Cross-Fitting), as compared to using the empirical mean of Y.
**Inputs:**
- P(F = 1): Prior probability of the binary pseudo-label
- P(Y = 1 | F = 1): Conditional probability of label given pseudo-label = 1
- P(Y = 1 | F = 0): Conditional probability of label given pseudo-label = 0
""",
live=True,
flagging_mode="never",
)
# Create interface for Y-based parameters
y_based_interface = gr.Interface(
fn=plot_y_based_probabilities,
inputs=[
gr.Slider(0.05, 0.95, value=0.1, step=0.05, label="P(Y = 1)"),
gr.Slider(0.05, 0.95, value=0.9, step=0.05, label="P(F = 1 | Y = 1)"),
gr.Slider(0.05, 0.95, value=0.1, step=0.05, label="P(F = 1 | Y = 0)"),
gr.Checkbox(value=True, label="Show PPI curve"),
],
outputs=gr.Plot(label="PPI++ Analysis", format="png"),
title="Example: Specify in terms of prevalance and F | Y distribution",
description="""
Analyze relative variance curves of PPI and PPI++ (with Cross-Fitting), as compared to using the empirical mean of Y.
**Inputs:**
- P(Y = 1): Prior probability of the binary label
- P(F = 1 | Y = 1): Conditional probability of pseudo-label given label = 1
- P(F = 1 | Y = 0): Conditional probability of pseudo-label given label = 0
""",
live=True,
flagging_mode="never",
)
# Create tabbed interface
demo = gr.TabbedInterface(
[f_based_interface, y_based_interface],
["F-based Parameters", "Y-based Parameters"],
title="Sample Size Analysis (Binary Case)",
)
if __name__ == "__main__":
demo.launch(share=True)
|