Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from ppi.distributions import BinaryYFDistribution | |
| from ppi.compute_curve import compute_curve | |
| def convert_y_params_to_f_params(p_y, p_f_given_y1, p_f_given_y0): | |
| """Convert Y-based parameters to F-based parameters using Bayes' theorem""" | |
| # P(F = 1) = P(F = 1 | Y = 1) * P(Y = 1) + P(F = 1 | Y = 0) * P(Y = 0) | |
| p_f = p_f_given_y1 * p_y + p_f_given_y0 * (1 - p_y) | |
| # P(Y = 1 | F = 1) = P(F = 1 | Y = 1) * P(Y = 1) / P(F = 1) | |
| p_y_given_f1 = (p_f_given_y1 * p_y) / p_f if p_f > 0 else 0 | |
| # P(Y = 1 | F = 0) = P(F = 0 | Y = 1) * P(Y = 1) / P(F = 0) | |
| p_y_given_f0 = ((1 - p_f_given_y1) * p_y) / (1 - p_f) if p_f < 1 else 0 | |
| return p_f, p_y_given_f1, p_y_given_f0 | |
| def plot_custom_probabilities(p_f, p_y_given_f1, p_y_given_f0, show_ppi): | |
| plt.close() | |
| fig, ax = plt.subplots(1, 1, figsize=(8, 6)) | |
| binary_dist = BinaryYFDistribution( | |
| p_f=p_f, p_y_given_f1=p_y_given_f1, p_y_given_f0=p_y_given_f0 | |
| ) | |
| n_range = np.arange(4, 200, 2) | |
| data = compute_curve(binary_dist, n_range) | |
| data = pd.DataFrame(data) | |
| corr = binary_dist.correlation() | |
| ax.set_ylim(0.5, 2) | |
| best_rel_perf = ( | |
| binary_dist.variance_y() | |
| - binary_dist.covariance_f_y() ** 2 / binary_dist.variance_f() | |
| ) / binary_dist.variance_y() | |
| # Always show PPI++ | |
| sns.lineplot( | |
| data, | |
| x="n", | |
| y="relative_var_cf", | |
| ax=ax, | |
| label="Relative Variance PPI++", | |
| linewidth=2, | |
| ) | |
| # Conditionally show PPI | |
| if show_ppi: | |
| sns.lineplot( | |
| data, | |
| x="n", | |
| y="relative_var_ppi", | |
| ax=ax, | |
| label="Relative Variance PPI", | |
| linewidth=2, | |
| linestyle="dotted", | |
| ) | |
| # Add horizontal line at y=1 | |
| ax.axhline( | |
| y=1, color="k", linestyle="dotted", alpha=0.7, label="Baseline (Sample Mean)" | |
| ) | |
| # Add horizontal line at y=1 | |
| ax.axhline( | |
| y=best_rel_perf, | |
| color="b", | |
| linestyle="dotted", | |
| alpha=0.7, | |
| label="Asymptotic Performance", | |
| ) | |
| # Find and mark crossing points | |
| cf_data = data["relative_var_cf"].values | |
| n_data = data["n"].values | |
| # Find where PPI++ crosses y=1 | |
| cf_crossings = np.where(np.diff(np.sign(cf_data - 1)))[0] | |
| for crossing in cf_crossings: | |
| if crossing < len(n_data) - 1: | |
| ax.axvline(x=n_data[crossing], color="red", linestyle="--", alpha=0.7) | |
| ax.text( | |
| n_data[crossing], | |
| 0.55, | |
| f"n={n_data[crossing]}", | |
| rotation=0, | |
| ha="center", | |
| va="bottom", | |
| color="red", | |
| ) | |
| ax.set_xlabel("Sample Size (n)") | |
| ax.set_ylabel("Relative Variance") | |
| ax.set_title( | |
| f"Relative Variance Analysis (Correlation of Psuedo-Label: {corr:.2f})" | |
| ) | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| return fig | |
| def plot_y_based_probabilities(p_y, p_f_given_y1, p_f_given_y0, show_ppi): | |
| """Plot using Y-based parameters by converting to F-based parameters""" | |
| # Convert Y-based parameters to F-based parameters | |
| p_f, p_y_given_f1, p_y_given_f0 = convert_y_params_to_f_params( | |
| p_y, p_f_given_y1, p_f_given_y0 | |
| ) | |
| # Use the existing plotting function with converted parameters | |
| return plot_custom_probabilities(p_f, p_y_given_f1, p_y_given_f0, show_ppi) | |
| # Create interface for F-based parameters (original) | |
| f_based_interface = gr.Interface( | |
| fn=plot_custom_probabilities, | |
| inputs=[ | |
| gr.Slider(0.05, 0.95, value=0.50, step=0.05, label="P(F = 1)"), | |
| gr.Slider(0.05, 0.95, value=0.60, step=0.05, label="P(Y = 1 | F = 1)"), | |
| gr.Slider(0.05, 0.95, value=0.40, step=0.05, label="P(Y = 1 | F = 0)"), | |
| gr.Checkbox(value=True, label="Show PPI curve"), | |
| ], | |
| outputs=gr.Plot(label="PPI++ Analysis", format="png"), | |
| title="Example: Specify in terms of psuedo-label prevalence and Y | F distribution", | |
| description=""" | |
| Analyze relative variance curves of PPI and PPI++ (with Cross-Fitting), as compared to using the empirical mean of Y. | |
| **Inputs:** | |
| - P(F = 1): Prior probability of the binary pseudo-label | |
| - P(Y = 1 | F = 1): Conditional probability of label given pseudo-label = 1 | |
| - P(Y = 1 | F = 0): Conditional probability of label given pseudo-label = 0 | |
| """, | |
| live=True, | |
| flagging_mode="never", | |
| ) | |
| # Create interface for Y-based parameters | |
| y_based_interface = gr.Interface( | |
| fn=plot_y_based_probabilities, | |
| inputs=[ | |
| gr.Slider(0.05, 0.95, value=0.1, step=0.05, label="P(Y = 1)"), | |
| gr.Slider(0.05, 0.95, value=0.9, step=0.05, label="P(F = 1 | Y = 1)"), | |
| gr.Slider(0.05, 0.95, value=0.1, step=0.05, label="P(F = 1 | Y = 0)"), | |
| gr.Checkbox(value=True, label="Show PPI curve"), | |
| ], | |
| outputs=gr.Plot(label="PPI++ Analysis", format="png"), | |
| title="Example: Specify in terms of prevalance and F | Y distribution", | |
| description=""" | |
| Analyze relative variance curves of PPI and PPI++ (with Cross-Fitting), as compared to using the empirical mean of Y. | |
| **Inputs:** | |
| - P(Y = 1): Prior probability of the binary label | |
| - P(F = 1 | Y = 1): Conditional probability of pseudo-label given label = 1 | |
| - P(F = 1 | Y = 0): Conditional probability of pseudo-label given label = 0 | |
| """, | |
| live=True, | |
| flagging_mode="never", | |
| ) | |
| # Create tabbed interface | |
| demo = gr.TabbedInterface( | |
| [f_based_interface, y_based_interface], | |
| ["F-based Parameters", "Y-based Parameters"], | |
| title="Sample Size Analysis (Binary Case)", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |