moberst's picture
Upload folder using huggingface_hub
7845551 verified
import gradio as gr
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from ppi.distributions import BinaryYFDistribution
from ppi.compute_curve import compute_curve
def convert_y_params_to_f_params(p_y, p_f_given_y1, p_f_given_y0):
"""Convert Y-based parameters to F-based parameters using Bayes' theorem"""
# P(F = 1) = P(F = 1 | Y = 1) * P(Y = 1) + P(F = 1 | Y = 0) * P(Y = 0)
p_f = p_f_given_y1 * p_y + p_f_given_y0 * (1 - p_y)
# P(Y = 1 | F = 1) = P(F = 1 | Y = 1) * P(Y = 1) / P(F = 1)
p_y_given_f1 = (p_f_given_y1 * p_y) / p_f if p_f > 0 else 0
# P(Y = 1 | F = 0) = P(F = 0 | Y = 1) * P(Y = 1) / P(F = 0)
p_y_given_f0 = ((1 - p_f_given_y1) * p_y) / (1 - p_f) if p_f < 1 else 0
return p_f, p_y_given_f1, p_y_given_f0
def plot_custom_probabilities(p_f, p_y_given_f1, p_y_given_f0, show_ppi):
plt.close()
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
binary_dist = BinaryYFDistribution(
p_f=p_f, p_y_given_f1=p_y_given_f1, p_y_given_f0=p_y_given_f0
)
n_range = np.arange(4, 200, 2)
data = compute_curve(binary_dist, n_range)
data = pd.DataFrame(data)
corr = binary_dist.correlation()
ax.set_ylim(0.5, 2)
best_rel_perf = (
binary_dist.variance_y()
- binary_dist.covariance_f_y() ** 2 / binary_dist.variance_f()
) / binary_dist.variance_y()
# Always show PPI++
sns.lineplot(
data,
x="n",
y="relative_var_cf",
ax=ax,
label="Relative Variance PPI++",
linewidth=2,
)
# Conditionally show PPI
if show_ppi:
sns.lineplot(
data,
x="n",
y="relative_var_ppi",
ax=ax,
label="Relative Variance PPI",
linewidth=2,
linestyle="dotted",
)
# Add horizontal line at y=1
ax.axhline(
y=1, color="k", linestyle="dotted", alpha=0.7, label="Baseline (Sample Mean)"
)
# Add horizontal line at y=1
ax.axhline(
y=best_rel_perf,
color="b",
linestyle="dotted",
alpha=0.7,
label="Asymptotic Performance",
)
# Find and mark crossing points
cf_data = data["relative_var_cf"].values
n_data = data["n"].values
# Find where PPI++ crosses y=1
cf_crossings = np.where(np.diff(np.sign(cf_data - 1)))[0]
for crossing in cf_crossings:
if crossing < len(n_data) - 1:
ax.axvline(x=n_data[crossing], color="red", linestyle="--", alpha=0.7)
ax.text(
n_data[crossing],
0.55,
f"n={n_data[crossing]}",
rotation=0,
ha="center",
va="bottom",
color="red",
)
ax.set_xlabel("Sample Size (n)")
ax.set_ylabel("Relative Variance")
ax.set_title(
f"Relative Variance Analysis (Correlation of Psuedo-Label: {corr:.2f})"
)
ax.legend()
ax.grid(True, alpha=0.3)
return fig
def plot_y_based_probabilities(p_y, p_f_given_y1, p_f_given_y0, show_ppi):
"""Plot using Y-based parameters by converting to F-based parameters"""
# Convert Y-based parameters to F-based parameters
p_f, p_y_given_f1, p_y_given_f0 = convert_y_params_to_f_params(
p_y, p_f_given_y1, p_f_given_y0
)
# Use the existing plotting function with converted parameters
return plot_custom_probabilities(p_f, p_y_given_f1, p_y_given_f0, show_ppi)
# Create interface for F-based parameters (original)
f_based_interface = gr.Interface(
fn=plot_custom_probabilities,
inputs=[
gr.Slider(0.05, 0.95, value=0.50, step=0.05, label="P(F = 1)"),
gr.Slider(0.05, 0.95, value=0.60, step=0.05, label="P(Y = 1 | F = 1)"),
gr.Slider(0.05, 0.95, value=0.40, step=0.05, label="P(Y = 1 | F = 0)"),
gr.Checkbox(value=True, label="Show PPI curve"),
],
outputs=gr.Plot(label="PPI++ Analysis", format="png"),
title="Example: Specify in terms of psuedo-label prevalence and Y | F distribution",
description="""
Analyze relative variance curves of PPI and PPI++ (with Cross-Fitting), as compared to using the empirical mean of Y.
**Inputs:**
- P(F = 1): Prior probability of the binary pseudo-label
- P(Y = 1 | F = 1): Conditional probability of label given pseudo-label = 1
- P(Y = 1 | F = 0): Conditional probability of label given pseudo-label = 0
""",
live=True,
flagging_mode="never",
)
# Create interface for Y-based parameters
y_based_interface = gr.Interface(
fn=plot_y_based_probabilities,
inputs=[
gr.Slider(0.05, 0.95, value=0.1, step=0.05, label="P(Y = 1)"),
gr.Slider(0.05, 0.95, value=0.9, step=0.05, label="P(F = 1 | Y = 1)"),
gr.Slider(0.05, 0.95, value=0.1, step=0.05, label="P(F = 1 | Y = 0)"),
gr.Checkbox(value=True, label="Show PPI curve"),
],
outputs=gr.Plot(label="PPI++ Analysis", format="png"),
title="Example: Specify in terms of prevalance and F | Y distribution",
description="""
Analyze relative variance curves of PPI and PPI++ (with Cross-Fitting), as compared to using the empirical mean of Y.
**Inputs:**
- P(Y = 1): Prior probability of the binary label
- P(F = 1 | Y = 1): Conditional probability of pseudo-label given label = 1
- P(F = 1 | Y = 0): Conditional probability of pseudo-label given label = 0
""",
live=True,
flagging_mode="never",
)
# Create tabbed interface
demo = gr.TabbedInterface(
[f_based_interface, y_based_interface],
["F-based Parameters", "Y-based Parameters"],
title="Sample Size Analysis (Binary Case)",
)
if __name__ == "__main__":
demo.launch(share=True)