| | |
| | """BirdNET Audio Classification Script |
| | |
| | This script loads a WAV file and uses the BirdNET ONNX model to predict bird species. |
| | The model expects audio input of shape [batch_size, 144000] (3 seconds at 48kHz). |
| | |
| | Created using Copilot. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import numpy as np |
| | import librosa |
| | import onnxruntime as ort |
| | import argparse |
| | import os |
| | from collections import defaultdict |
| |
|
| |
|
| | def load_audio( |
| | file_path: str, target_sr: int = 48000, duration: float = 3.0 |
| | ) -> np.ndarray: |
| | """ |
| | Load and preprocess audio file for BirdNET model. |
| | |
| | Args: |
| | file_path (str): Path to the audio file |
| | target_sr (int): Target sample rate (48kHz for BirdNET) |
| | duration (float): Duration in seconds (3.0 for BirdNET) |
| | |
| | Returns: |
| | np.ndarray: Preprocessed audio array of shape [144000] |
| | """ |
| | try: |
| | |
| | audio, sr = librosa.load(file_path, sr=target_sr, duration=duration) |
| |
|
| | |
| | target_length = int(target_sr * duration) |
| |
|
| | if len(audio) < target_length: |
| | |
| | audio = np.pad(audio, (0, target_length - len(audio))) |
| | elif len(audio) > target_length: |
| | |
| | audio = audio[:target_length] |
| |
|
| | return audio.astype(np.float32) |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Error loading audio file {file_path}: {str(e)}") |
| |
|
| |
|
| | def load_labels(labels_path: str) -> list[str]: |
| | """ |
| | Load BirdNET species labels from the labels file. |
| | |
| | Args: |
| | labels_path (str): Path to the labels file |
| | |
| | Returns: |
| | list[str]: List of species names |
| | """ |
| | try: |
| | labels = [] |
| | with open(labels_path, "r", encoding="utf-8") as f: |
| | for line in f: |
| | line = line.strip() |
| | if line: |
| | |
| | |
| | if "_" in line: |
| | common_name = line.split("_", 1)[1] |
| | labels.append(common_name) |
| | else: |
| | labels.append(line) |
| | return labels |
| | except Exception as e: |
| | raise RuntimeError(f"Error loading labels file {labels_path}: {str(e)}") |
| |
|
| |
|
| | def load_audio_full(file_path: str, target_sr: int = 48000) -> np.ndarray: |
| | """ |
| | Load full audio file for moving window analysis. |
| | |
| | Args: |
| | file_path (str): Path to the audio file |
| | target_sr (int): Target sample rate (48kHz for BirdNET) |
| | |
| | Returns: |
| | np.ndarray: Full audio array |
| | """ |
| | try: |
| | |
| | audio, sr = librosa.load(file_path, sr=target_sr) |
| | return audio.astype(np.float32) |
| | except Exception as e: |
| | raise RuntimeError(f"Error loading audio file {file_path}: {str(e)}") |
| |
|
| |
|
| | def create_audio_windows( |
| | audio: np.ndarray, window_size: int = 144000, overlap: float = 0.5 |
| | ) -> tuple[np.ndarray, list[float]]: |
| | """ |
| | Create overlapping windows from audio for analysis. |
| | |
| | Args: |
| | audio (np.ndarray): Full audio array |
| | window_size (int): Size of each window (144000 for 3 seconds at 48kHz) |
| | overlap (float): Overlap ratio (0.5 = 50% overlap) |
| | |
| | Returns: |
| | tuple[np.ndarray, list[float]]: (windows array, timestamps) |
| | """ |
| | step_size = int(window_size * (1 - overlap)) |
| | windows = [] |
| | timestamps = [] |
| |
|
| | for start in range(0, len(audio) - window_size + 1, step_size): |
| | end = start + window_size |
| | window = audio[start:end] |
| |
|
| | |
| | if len(window) == window_size: |
| | windows.append(window) |
| | |
| | timestamps.append(start / 48000.0) |
| |
|
| | return np.array(windows), timestamps |
| |
|
| |
|
| | def load_onnx_model(model_path: str) -> ort.InferenceSession: |
| | """ |
| | Load ONNX model for inference. |
| | |
| | Args: |
| | model_path (str): Path to the ONNX model file |
| | |
| | Returns: |
| | ort.InferenceSession: Loaded ONNX model session |
| | """ |
| | try: |
| | |
| | session = ort.InferenceSession(model_path) |
| | return session |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Error loading ONNX model {model_path}: {str(e)}") |
| |
|
| |
|
| | def predict_audio(session: ort.InferenceSession, audio_data: np.ndarray) -> np.ndarray: |
| | """ |
| | Run inference on audio data using the ONNX model. |
| | |
| | Args: |
| | session (ort.InferenceSession): ONNX model session |
| | audio_data (np.ndarray): Audio data of shape [144000] or [batch, 144000] |
| | |
| | Returns: |
| | np.ndarray: Model predictions |
| | """ |
| | try: |
| | |
| | if len(audio_data.shape) == 1: |
| | input_data = np.expand_dims(audio_data, axis=0) |
| | else: |
| | input_data = audio_data |
| |
|
| | |
| | input_name = session.get_inputs()[0].name |
| |
|
| | |
| | outputs = session.run(None, {input_name: input_data}) |
| |
|
| | return outputs[0] |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Error during model inference: {str(e)}") |
| |
|
| |
|
| | def predict_audio_batch( |
| | session: ort.InferenceSession, |
| | windows_batch: np.ndarray, |
| | batch_size: int = 128, |
| | show_progress: bool = True, |
| | ) -> np.ndarray: |
| | """ |
| | Run inference on batches of audio windows for better performance. |
| | |
| | Args: |
| | session (ort.InferenceSession): ONNX model session |
| | windows_batch (np.ndarray): Array of windows, shape [num_windows, 144000] |
| | batch_size (int): Number of windows to process in each batch |
| | show_progress (bool): Whether to show progress updates |
| | |
| | Returns: |
| | np.ndarray: All predictions concatenated, shape [num_windows, num_classes] |
| | """ |
| | try: |
| | all_predictions = [] |
| | num_windows = len(windows_batch) |
| |
|
| | |
| | input_name = session.get_inputs()[0].name |
| |
|
| | |
| | batch_num = 0 |
| | for start_idx in range(0, num_windows, batch_size): |
| | end_idx = min(start_idx + batch_size, num_windows) |
| | current_batch = windows_batch[start_idx:end_idx] |
| | batch_num += 1 |
| |
|
| | if show_progress and (batch_num % 5 == 0 or batch_num == 1): |
| | progress = (end_idx / num_windows) * 100 |
| | print( |
| | f" Batch {batch_num}: processing windows {start_idx + 1}-{end_idx} ({progress:.1f}%)" |
| | ) |
| |
|
| | |
| | outputs = session.run(None, {input_name: current_batch}) |
| | batch_predictions = outputs[0] |
| |
|
| | all_predictions.append(batch_predictions) |
| |
|
| | |
| | return np.concatenate(all_predictions, axis=0) |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Error during batch model inference: {str(e)}") |
| |
|
| |
|
| | def analyze_detections( |
| | all_predictions: np.ndarray, |
| | timestamps: list[float], |
| | labels: list[str], |
| | confidence_threshold: float = 0.1, |
| | ) -> dict[str, list[dict[str, float | int]]]: |
| | """ |
| | Analyze predictions across all windows and summarize detections. |
| | |
| | Args: |
| | all_predictions (np.ndarray): Predictions from all windows, shape [num_windows, num_classes] |
| | timestamps (list[float]): Timestamps for each window |
| | labels (list[str]): Species labels |
| | confidence_threshold (float): Minimum confidence for detection |
| | |
| | Returns: |
| | dict[str, list[dict[str, float | int]]]: Summary of detections with timestamps |
| | """ |
| | detections = defaultdict(list) |
| |
|
| | |
| | for i, (predictions, timestamp) in enumerate(zip(all_predictions, timestamps)): |
| | |
| | scores = predictions |
| |
|
| | |
| | above_threshold = np.where(scores > confidence_threshold)[0] |
| |
|
| | for idx in above_threshold: |
| | confidence = float(scores[idx]) |
| | species_name = labels[idx] if idx < len(labels) else f"Class {idx}" |
| |
|
| | detections[species_name].append( |
| | {"timestamp": timestamp, "confidence": confidence, "window": i} |
| | ) |
| |
|
| | return dict(detections) |
| |
|
| |
|
| | def main() -> int: |
| | parser = argparse.ArgumentParser( |
| | description="BirdNET Audio Classification with Moving Window" |
| | ) |
| | parser.add_argument("audio_file", help="Path to the WAV audio file") |
| | parser.add_argument( |
| | "--model", default="model.onnx", help="Path to the ONNX model file" |
| | ) |
| | parser.add_argument( |
| | "--labels", |
| | default="BirdNET_GLOBAL_6K_V2.4_Labels.txt", |
| | help="Path to the labels file", |
| | ) |
| | parser.add_argument( |
| | "--top-k", |
| | type=int, |
| | default=5, |
| | help="Number of top predictions to show per window", |
| | ) |
| | parser.add_argument( |
| | "--overlap", type=float, default=0.5, help="Window overlap ratio (0.0-1.0)" |
| | ) |
| | parser.add_argument( |
| | "--confidence", |
| | type=float, |
| | default=0.1, |
| | help="Minimum confidence threshold for detections", |
| | ) |
| | parser.add_argument( |
| | "--batch-size", |
| | type=int, |
| | default=128, |
| | help="Batch size for inference (default: 128)", |
| | ) |
| | parser.add_argument( |
| | "--single-window", |
| | action="store_true", |
| | help="Analyze only first 3 seconds (single window)", |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | if not os.path.exists(args.audio_file): |
| | print(f"Error: Audio file '{args.audio_file}' not found.") |
| | return 1 |
| |
|
| | if not os.path.exists(args.model): |
| | print(f"Error: Model file '{args.model}' not found.") |
| | return 1 |
| |
|
| | if not os.path.exists(args.labels): |
| | print(f"Error: Labels file '{args.labels}' not found.") |
| | return 1 |
| |
|
| | try: |
| | |
| | print(f"Loading labels from: {args.labels}") |
| | labels = load_labels(args.labels) |
| | print(f"Loaded {len(labels)} species labels") |
| |
|
| | |
| | print(f"Loading ONNX model: {args.model}") |
| | session = load_onnx_model(args.model) |
| |
|
| | |
| | input_info = session.get_inputs()[0] |
| | output_info = session.get_outputs()[0] |
| | print(f"Model input: {input_info.name}, shape: {input_info.shape}") |
| | print(f"Model output: {output_info.name}, shape: {output_info.shape}") |
| |
|
| | if args.single_window: |
| | |
| | print(f"Loading first 3 seconds of audio file: {args.audio_file}") |
| | audio_data = load_audio(args.audio_file) |
| | print(f"Audio loaded successfully. Shape: {audio_data.shape}") |
| |
|
| | print("Running inference on single window...") |
| | predictions = predict_audio(session, audio_data) |
| |
|
| | |
| | predictions = np.array(predictions) |
| | if len(predictions.shape) > 1: |
| | scores = predictions[0] |
| | else: |
| | scores = predictions |
| |
|
| | |
| | top_indices = np.argsort(scores)[-args.top_k :][::-1] |
| |
|
| | print(f"\nTop {args.top_k} predictions for first 3 seconds:") |
| | for i, idx in enumerate(top_indices): |
| | confidence = float(scores[idx]) |
| | species_name = labels[idx] if idx < len(labels) else f"Class {idx}" |
| | print(f"{i + 1:2d}. {species_name}: {confidence:.6f}") |
| |
|
| | else: |
| | |
| | print(f"Loading full audio file: {args.audio_file}") |
| | full_audio = load_audio_full(args.audio_file) |
| | audio_duration = len(full_audio) / 48000.0 |
| | print(f"Audio loaded successfully. Duration: {audio_duration:.2f} seconds") |
| |
|
| | |
| | print(f"Creating windows with {args.overlap * 100:.0f}% overlap...") |
| | windows, timestamps = create_audio_windows(full_audio, overlap=args.overlap) |
| | print(f"Created {len(windows)} windows of 3 seconds each") |
| |
|
| | |
| | print( |
| | f"Running batch inference on {len(windows)} windows (batch size: {args.batch_size})..." |
| | ) |
| | num_batches = (len(windows) + args.batch_size - 1) // args.batch_size |
| | print(f"Processing {num_batches} batches...") |
| |
|
| | |
| | all_predictions = predict_audio_batch(session, windows, args.batch_size) |
| | print(f"Completed batch inference on {len(windows)} windows") |
| |
|
| | |
| | print( |
| | f"Analyzing detections with confidence threshold {args.confidence}..." |
| | ) |
| | detections = analyze_detections( |
| | all_predictions, timestamps, labels, args.confidence |
| | ) |
| |
|
| | |
| | sorted_species = sorted( |
| | detections.items(), |
| | key=lambda x: max(det["confidence"] for det in x[1]), |
| | reverse=True, |
| | ) |
| |
|
| | print("\n=== DETECTION SUMMARY ===") |
| | print(f"Audio duration: {audio_duration:.2f} seconds") |
| | print(f"Windows analyzed: {len(windows)}") |
| | print( |
| | f"Species detected (>{args.confidence:.2f} confidence): {len(sorted_species)}" |
| | ) |
| |
|
| | if sorted_species: |
| | print("\nTop detections:") |
| | for species, detections_list in sorted_species[: args.top_k]: |
| | max_conf = max(det["confidence"] for det in detections_list) |
| | num_detections = len(detections_list) |
| | first_detection = min(det["timestamp"] for det in detections_list) |
| | last_detection = max(det["timestamp"] for det in detections_list) |
| |
|
| | print(f"\n{species}") |
| | print(f" Max confidence: {max_conf:.6f}") |
| | print(f" Detections: {num_detections}") |
| | print( |
| | f" Time range: {first_detection:.1f}s - {last_detection:.1f}s" |
| | ) |
| |
|
| | |
| | strong_detections = sorted( |
| | detections_list, key=lambda x: x["confidence"], reverse=True |
| | )[:3] |
| | for det in strong_detections: |
| | print(f" {det['timestamp']:6.1f}s: {det['confidence']:.6f}") |
| | else: |
| | print( |
| | f"No detections found above confidence threshold {args.confidence}" |
| | ) |
| |
|
| | return 0 |
| |
|
| | except Exception as e: |
| | print(f"Error: {str(e)}") |
| | return 1 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | exit(main()) |
| |
|