import numpy as np import seaborn as sns import matplotlib.pyplot as plt import pandas as pd from ppi.distributions import BinaryYFDistribution from ppi.compute_curve import compute_curve def plot_curves(eps): fig, ax = plt.subplots(1, 1, figsize=(5, 5)) binary_dist = BinaryYFDistribution( p_f=0.5, p_y_given_f1=0.5 + eps, p_y_given_f0=0.5 - eps ) n_range = np.arange(4, 200, 2) data = compute_curve(binary_dist, n_range) data = pd.DataFrame(data) sns.lineplot( data, x="n", y="relative_var_cf", ax=ax, label="Relative Variance PPI++" ) sns.lineplot( data, x="n", y="relative_var_ppi", ax=ax, label="Relative Variance PPI" ) plt.legend() ax.axhline(y=1, color="k", linestyle="dotted") ax.set_title(f"Relative Variance Curve (ε = {eps})") return fig def plot_multiple_curves(eps_values, n_min, n_max, n_step): fig, ax = plt.subplots(1, 1, figsize=(8, 6)) # Parse epsilon values from string input try: eps_list = [float(x.strip()) for x in eps_values.split(",") if x.strip()] except ValueError: # Return error plot if parsing fails ax.text( 0.5, 0.5, "Invalid epsilon values. Please enter comma-separated numbers.", ha="center", va="center", transform=ax.transAxes, ) return fig if not eps_list: ax.text( 0.5, 0.5, "Please enter at least one epsilon value.", ha="center", va="center", transform=ax.transAxes, ) return fig n_range = np.arange(n_min, n_max + 1, n_step) # Plot curves for each epsilon value for eps in eps_list: if not (0 <= eps <= 0.5): continue # Skip invalid epsilon values binary_dist = BinaryYFDistribution( p_f=0.5, p_y_given_f1=0.5 + eps, p_y_given_f0=0.5 - eps ) data = compute_curve(binary_dist, n_range) data = pd.DataFrame(data) sns.lineplot(data, x="n", y="relative_var_cf", ax=ax, label=f"ε = {eps}") ax.axhline(y=1, color="k", linestyle="dotted", alpha=0.7) ax.set_xlabel("Sample Size (n)") ax.set_ylabel("Relative Variance") ax.set_title("Comparison of Relative Variance Curves") ax.legend() ax.grid(True, alpha=0.3) return fig