import numpy as np import matplotlib.pyplot as plt import pandas as pd from matplotlib.backends.backend_pdf import PdfPages from ppi.distributions import BinaryYFDistribution from ppi.compute_curve import compute_curve def find_crossing_point(x_data, y_data): """Find the x-coordinate where the curve crosses y=1, if it exists.""" for i in range(len(y_data) - 1): if (y_data[i] >= 1 and y_data[i + 1] < 1) or ( y_data[i] < 1 and y_data[i + 1] >= 1 ): # Linear interpolation to find exact crossing point x1, y1 = x_data[i], y_data[i] x2, y2 = x_data[i + 1], y_data[i + 1] x_cross = x1 + (1 - y1) * (x2 - x1) / (y2 - y1) return x_cross return None def create_sequential_plots(): eps_values = [0, 0.1, 0.2, 0.3] n_range = np.arange(4, 61, 2) # Generate data for all eps values all_data = {} corr = {} for eps in eps_values: binary_dist = BinaryYFDistribution( p_f=0.5, p_y_given_f1=0.5 + eps, p_y_given_f0=0.5 - eps ) data = compute_curve(binary_dist, n_range) all_data[eps] = pd.DataFrame(data) corr[eps] = binary_dist.correlation() # Create sequential plots for plot_num in range(1, 5): fig, ax = plt.subplots(1, 1, figsize=(8, 6)) ax.set_ylim(0.6, 2) # Plot lines for eps values up to current plot for i in range(plot_num): eps = eps_values[i] data = all_data[eps] # Plot the curve line = ax.plot( data["n"], data["relative_var_cf"], label=f"ε = {eps}, $\\rho = ${corr[eps]:.1f}", ) line_color = line[0].get_color() # Find and draw crossing point if it exists crossing_x = find_crossing_point( data["n"].values, data["relative_var_cf"].values ) if crossing_x is not None: ax.axvline(x=crossing_x, color=line_color, linestyle=":", alpha=0.7) # Format the plot ax.axhline(y=1, color="k", linestyle="--", alpha=0.7) ax.set_xlabel("Sample Size (n)") ax.set_ylabel("Relative Variance") ax.set_title("Comparison of Relative Variance Curves") ax.legend(loc="upper right") ax.grid(True, alpha=0.3) # Save to PDF filename = f"sequential_plot_{plot_num}.png" fig.savefig(filename, bbox_inches="tight") plt.close(fig) print(f"Saved {filename}") if __name__ == "__main__": create_sequential_plots()