| from utils import CustomDataset, transform, Convert_ONNX |
| from torch.utils.data import Dataset, DataLoader |
| from utils import CustomDataset, TestingDataset, transform |
| from tqdm import tqdm |
| import torch |
| import numpy as np |
| from resnet_model_mask import ResidualBlock, ResNet |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from tqdm import tqdm |
| import torch.nn.functional as F |
| from torch.optim.lr_scheduler import ReduceLROnPlateau |
| import pickle |
| import matplotlib.pyplot as plt |
| import pandas as pd |
|
|
| torch.manual_seed(1) |
| |
|
|
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| num_gpus = torch.cuda.device_count() |
| print(num_gpus) |
|
|
| test_data_dir = '/mnt/buf1/pma/frbnn/test_ready' |
| test_dataset = TestingDataset(test_data_dir, transform=transform) |
|
|
| num_classes = 2 |
| testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32) |
|
|
| model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device) |
| model = nn.DataParallel(model) |
| model = model.to(device) |
| params = sum(p.numel() for p in model.parameters()) |
| print("num params ",params) |
|
|
| model_1 = 'models_mask/model-43-99.235_42.pt' |
| |
| model.load_state_dict(torch.load(model_1, weights_only=True)) |
| model = model.eval() |
|
|
| |
| val_loss = 0.0 |
| correct_valid = 0 |
| total = 0 |
| results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]} |
| model.eval() |
| with torch.no_grad(): |
| for images, labels in tqdm(testloader): |
| inputs, labels = images.to(device), labels |
| outputs = model(inputs, return_mask = True) |
| _, predicted = torch.max(outputs, 1) |
| results['output'].extend(outputs.cpu().numpy().tolist()) |
| results['pred'].extend(predicted.cpu().numpy().tolist()) |
| results['true'].extend(labels[0].cpu().numpy().tolist()) |
| results['freq'].extend(labels[2].cpu().numpy().tolist()) |
| results['dm'].extend(labels[1].cpu().numpy().tolist()) |
| results['snr'].extend(labels[3].cpu().numpy().tolist()) |
| results['boxcar'].extend(labels[4].cpu().numpy().tolist()) |
| total += labels[0].size(0) |
| correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item() |
| |
| |
| val_accuracy = correct_valid / total * 100.0 |
| print("===========================") |
| print('accuracy: ', val_accuracy) |
| print("===========================") |
|
|
| import pickle |
|
|
| |
| with open('models_mask/test_42.pkl', 'wb') as f: |
| pickle.dump(results, f) |
|
|
| from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix |
|
|
| |
| true = results['true'] |
| pred = results['pred'] |
|
|
| |
| precision = precision_score(true, pred) |
| recall = recall_score(true, pred) |
| f1 = f1_score(true, pred) |
| |
| tn, fp, fn, tp = confusion_matrix(true, pred).ravel() |
|
|
| |
| fpr = fp / (fp + tn) |
|
|
| print(f"False Positive Rate: {fpr:.3f}") |
|
|
| print(f"Precision: {precision:.3f}") |
| print(f"Recall: {recall:.3f}") |
| print(f"F1 Score: {f1:.3f}") |
|
|
| |
| |
| df = pd.DataFrame({ |
| 'dm': results['dm'], |
| 'true': results['true'], |
| 'pred': results['pred'], |
| 'snr': results['snr'], |
| 'freq': results['freq'], |
| 'boxcar': np.array(results['boxcar'])/2 |
| }) |
|
|
| |
| df = df[df['true'] == 1].copy() |
|
|
| print(f"Filtered to {len(df)} samples with true label = 1") |
|
|
| |
| dm_bins = np.linspace(df['dm'].min(), df['dm'].max(), 20) |
| df['dm_bin'] = pd.cut(df['dm'], bins=dm_bins, include_lowest=True) |
| print('min boxcar',df['boxcar'].min()) |
| |
| def calculate_accuracy_with_uncertainty(group): |
| correct = (group['true'] == group['pred']).sum() |
| total = len(group) |
| accuracy = correct / total * 100 |
| |
| p = correct / total |
| se = np.sqrt(p * (1 - p) / total) * 100 |
| return pd.Series({'accuracy': accuracy, 'std_error': se, 'n_samples': total}) |
|
|
| dm_accuracy = df.groupby('dm_bin').apply(calculate_accuracy_with_uncertainty).reset_index() |
|
|
| |
| dm_accuracy['dm_midpoint'] = dm_accuracy['dm_bin'].apply(lambda x: x.mid) |
|
|
| |
| plt.figure(figsize=(10, 6)) |
| ax1 = plt.gca() |
| ax1.errorbar(dm_accuracy['dm_midpoint'], dm_accuracy['accuracy'], |
| yerr=dm_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| capsize=5, capthick=2, elinewidth=1) |
| ax1.set_xlabel('Dispersion Measure (DM) [pc cm$^{-3}$]', fontsize=16) |
| ax1.set_ylabel('Accuracy (%)', fontsize=16) |
| ax1.set_title('Accuracy vs Dispersion Measure', fontsize=18) |
| ax1.grid(True, alpha=0.3) |
| ax1.set_ylim(97, 100) |
| ax1.tick_params(axis='both', which='major', labelsize=14) |
|
|
| |
| yticks = ax1.get_yticks() |
| ax1.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
| |
| ax1.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| label=f'Overall: {val_accuracy:.2f}%') |
| ax1.legend(fontsize=14) |
|
|
| |
| ax1.text(-0.1, -0.15, '(a)', transform=ax1.transAxes, fontsize=18, fontweight='bold') |
|
|
| plt.tight_layout() |
| plt.savefig('models_mask/accuracy_vs_dm.pdf', dpi=300, bbox_inches='tight') |
| plt.show() |
|
|
| |
| |
| df_snr_filtered = df[df['snr'] > 0].copy() |
|
|
| |
| snr_bins = np.linspace(df_snr_filtered['snr'].min(), df_snr_filtered['snr'].max(), 20) |
| df_snr_filtered['snr_bin'] = pd.cut(df_snr_filtered['snr'], bins=snr_bins, include_lowest=True) |
|
|
| |
| snr_accuracy = df_snr_filtered.groupby('snr_bin').apply(calculate_accuracy_with_uncertainty).reset_index() |
|
|
| |
| snr_accuracy['snr_midpoint'] = snr_accuracy['snr_bin'].apply(lambda x: x.mid) |
|
|
| |
| plt.figure(figsize=(10, 6)) |
| ax2 = plt.gca() |
| ax2.errorbar(snr_accuracy['snr_midpoint'], snr_accuracy['accuracy'], |
| yerr=snr_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| capsize=5, capthick=2, elinewidth=1) |
| ax2.set_xlabel('Signal-to-Noise Ratio (SNR)', fontsize=16) |
| ax2.set_ylabel('Accuracy (%)', fontsize=16) |
| ax2.set_title('Accuracy vs SNR', fontsize=18) |
| ax2.grid(True, alpha=0.3) |
| ax2.set_ylim(80, 100) |
| ax2.tick_params(axis='both', which='major', labelsize=14) |
|
|
| |
| yticks = ax2.get_yticks() |
| ax2.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
| |
| ax2.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| label=f'Overall: {val_accuracy:.2f}%') |
| ax2.legend(fontsize=14) |
|
|
| |
| ax2.text(-0.1, -0.15, '(b)', transform=ax2.transAxes, fontsize=18, fontweight='bold') |
|
|
| plt.tight_layout() |
| plt.savefig('models_mask/accuracy_vs_snr.pdf', dpi=300, bbox_inches='tight') |
| plt.show() |
|
|
| |
| |
| |
| |
| df_boxcar_filtered = df[df['boxcar'] > 0].copy() |
| df_boxcar_filtered['boxcar_bin'] = pd.qcut(df_boxcar_filtered['boxcar'], q=20, duplicates='drop') |
|
|
| |
| boxcar_accuracy = df_boxcar_filtered.groupby('boxcar_bin').apply(calculate_accuracy_with_uncertainty).reset_index() |
|
|
| |
| boxcar_accuracy['boxcar_midpoint'] = boxcar_accuracy['boxcar_bin'].apply(lambda x: x.mid) |
|
|
| |
| plt.figure(figsize=(10, 6)) |
| ax3 = plt.gca() |
| ax3.errorbar(boxcar_accuracy['boxcar_midpoint'], boxcar_accuracy['accuracy'], |
| yerr=boxcar_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| capsize=5, capthick=2, elinewidth=1) |
| ax3.set_xscale('log') |
| ax3.set_xlabel('Boxcar Width (log scale)', fontsize=16) |
| |
| ax3.grid(True, alpha=0.3) |
| ax3.set_ylim(0, 100) |
| ax3.tick_params(axis='both', which='major', labelsize=14) |
|
|
| |
| yticks = ax3.get_yticks() |
| ax3.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
| |
| ax3.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| label=f'Overall: {val_accuracy:.2f}%') |
| ax3.legend(fontsize=14) |
|
|
| |
| ax3.text(-0.1, -0.15, '(c)', transform=ax3.transAxes, fontsize=18, fontweight='bold') |
|
|
| plt.tight_layout() |
| plt.savefig('models_mask/accuracy_vs_boxcar.pdf', dpi=300, bbox_inches='tight') |
| plt.show() |
|
|
|
|
| print(f"Plots saved to models_mask/accuracy_vs_dm.pdf, models_mask/accuracy_vs_snr.pdf, and models_mask/accuracy_vs_boxcar.pdf") |
|
|
| |
| fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) |
|
|
| |
| ax1.errorbar(dm_accuracy['dm_midpoint'], dm_accuracy['accuracy'], |
| yerr=dm_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| capsize=5, capthick=2, elinewidth=1) |
| ax1.set_xlabel('Dispersion Measure (DM) [pc cm$^{-3}$]', fontsize=16) |
| ax1.set_ylabel('Accuracy (%)', fontsize=16) |
| |
| ax1.grid(True, alpha=0.3) |
| ax1.set_ylim(97, 100.5) |
| ax1.tick_params(axis='both', which='major', labelsize=14) |
|
|
| |
| yticks = ax1.get_yticks() |
| ax1.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
| ax1.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| label=f'Overall: {val_accuracy:.2f}%') |
| ax1.legend(fontsize=14) |
|
|
| |
| ax2.errorbar(snr_accuracy['snr_midpoint'], snr_accuracy['accuracy'], |
| yerr=snr_accuracy['std_error'], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| capsize=5, capthick=2, elinewidth=1) |
| ax2.set_xlabel('Signal-to-Noise Ratio (SNR)', fontsize=16) |
| |
| ax2.grid(True, alpha=0.3) |
| ax2.set_ylim(88, 100.5) |
| ax2.tick_params(axis='both', which='major', labelsize=14) |
|
|
| |
| yticks = ax2.get_yticks() |
| ax2.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
| ax2.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| label=f'Overall: {val_accuracy:.2f}%') |
| ax2.legend(fontsize=14) |
|
|
| |
| ax3.errorbar(boxcar_accuracy['boxcar_midpoint'][:-1], |
| boxcar_accuracy['accuracy'][:-1], |
| yerr=boxcar_accuracy['std_error'][:-1], fmt='o-', color='#b80707', linewidth=2, markersize=6, |
| capsize=5, capthick=2, elinewidth=1) |
| ax3.set_xscale('log') |
| ax3.set_xlabel('Boxcar Width (log scale) [s]', fontsize=16) |
| |
| ax3.grid(True, alpha=0.3) |
| ax3.set_ylim(96, 100.5) |
| ax3.tick_params(axis='both', which='major', labelsize=14) |
|
|
| |
| yticks = ax3.get_yticks() |
| ax3.set_yticks([tick for tick in yticks if tick <= 100]) |
|
|
| ax3.axhline(y=val_accuracy, color='r', linestyle='--', alpha=0.7, |
| label=f'Overall: {val_accuracy:.2f}%') |
| ax3.legend(fontsize=14) |
|
|
| |
| ax1.text(-0.1, -0.15, '(a)', transform=ax1.transAxes, fontsize=18, fontweight='bold') |
| ax2.text(-0.1, -0.15, '(b)', transform=ax2.transAxes, fontsize=18, fontweight='bold') |
| ax3.text(-0.1, -0.15, '(c)', transform=ax3.transAxes, fontsize=18, fontweight='bold') |
|
|
| plt.tight_layout() |
| plt.savefig('models_mask/accuracy_vs_all_parameters.pdf', |
| dpi=300, bbox_inches='tight', |
| pad_inches=0.1, format='pdf') |
| plt.show() |
|
|
| print(f"Combined plot saved to models_mask/accuracy_vs_all_parameters.pdf") |
|
|