Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| AnalysisGNN Gradio App | |
| A Gradio interface for AnalysisGNN music analysis. | |
| Users can upload MusicXML scores, run the model, and view results. | |
| """ | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import logging | |
| import os | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| import time | |
| import torch | |
| import urllib.request | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from contextlib import contextmanager | |
| from pathlib import Path | |
| from typing import Tuple, Optional, Dict | |
| import traceback | |
| import warnings | |
| # Suppress warnings for cleaner output | |
| warnings.filterwarnings('ignore') | |
| # Import partitura and AnalysisGNN | |
| import partitura as pt | |
| from analysisgnn.models.analysis import ContinualAnalysisGNN | |
| from analysisgnn.utils.chord_representations import available_representations, NoteDegree49 | |
| # Ensure additional representations are available for decoding | |
| if "note_degree" not in available_representations and NoteDegree49 is not None: | |
| available_representations["note_degree"] = NoteDegree49 | |
| LOG_LEVEL = os.environ.get("ANALYSISGNN_LOG_LEVEL", "INFO").upper() | |
| logging.basicConfig( | |
| level=getattr(logging, LOG_LEVEL, logging.INFO), | |
| format="[%(asctime)s] %(levelname)s %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger("analysisgnn_app") | |
| PARALLEL_CONFIG = os.environ.get("ANALYSISGNN_PARALLEL", "auto").strip().lower() | |
| CPU_COUNT = os.cpu_count() or 1 | |
| MUSESCORE_APPIMAGE_URL = "https://www.modelscope.cn/studio/Genius-Society/piano_trans/resolve/master/MuseScore.AppImage" | |
| MUSESCORE_STORAGE_DIR = Path("artifacts") / "musescore" | |
| MUSESCORE_ENV_VAR = "MUSESCORE_BIN" | |
| MUSESCORE_RENDER_TIMEOUT = int(os.environ.get("MUSESCORE_RENDER_TIMEOUT", "180")) | |
| MUSESCORE_EXTRACT_TIMEOUT = int(os.environ.get("MUSESCORE_EXTRACT_TIMEOUT", "240")) | |
| _MUSESCORE_BINARY: Optional[str] = None | |
| _MUSESCORE_READY: bool = False | |
| MUSESCORE_V3_APPIMAGE_URL = "https://github.com/musescore/MuseScore/releases/download/v3.6.2/MuseScore-3.6.2.548021370-x86_64.AppImage" | |
| MUSESCORE_V3_STORAGE_DIR = Path("artifacts") / "musescore_v3" | |
| MUSESCORE_V3_ENV_VAR = "MUSESCORE_V3_BIN" | |
| _MUSESCORE_V3_BINARY: Optional[str] = None | |
| RENDER_OUTPUT_DIR = Path("artifacts") / "rendered_scores" | |
| XVFB_ENV_VAR = "XVFB_BIN" | |
| XVFB_STORAGE_DIR = Path("artifacts") / "xvfb" | |
| _XVFB_BINARY: Optional[str] = None | |
| # Global model variable | |
| MODEL = None | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info("Using device: %s", DEVICE) | |
| if torch.cuda.is_available(): | |
| logger.info("CUDA device: %s", torch.cuda.get_device_name(0)) | |
| def log_timing(label: str): | |
| """Log start/stop (with duration) for expensive operations.""" | |
| start = time.perf_counter() | |
| logger.info("βΆ %s", label) | |
| try: | |
| yield | |
| except Exception: | |
| elapsed = time.perf_counter() - start | |
| logger.exception("β %s failed after %.2fs", label, elapsed) | |
| raise | |
| else: | |
| elapsed = time.perf_counter() - start | |
| logger.info("β %s in %.2fs", label, elapsed) | |
| def should_parallelize() -> bool: | |
| """ | |
| Decide whether to run analysis/visualization in parallel. | |
| Controlled via ANALYSISGNN_PARALLEL env: | |
| - "0"/"false": force sequential | |
| - "1"/"true": force parallel | |
| - "auto" (default): enable if more than one CPU core is available | |
| """ | |
| if PARALLEL_CONFIG in {"0", "false", "no", "off"}: | |
| return False | |
| if PARALLEL_CONFIG in {"1", "true", "yes", "on"}: | |
| return True | |
| return CPU_COUNT > 1 | |
| def download_wandb_checkpoint(artifact_path: str = "melkisedeath/AnalysisGNN/model-uvj2ddun:v1") -> str: | |
| """Download checkpoint from Weights & Biases, or use cached version if available.""" | |
| # Create artifacts directory structure | |
| artifacts_dir = "checkpoint" | |
| os.makedirs(artifacts_dir, exist_ok=True) | |
| # Check if checkpoint already exists directly in artifacts/models | |
| checkpoint_path = os.path.join(artifacts_dir, "model.ckpt") | |
| if os.path.exists(checkpoint_path): | |
| logger.info("Using cached checkpoint: %s", checkpoint_path) | |
| return checkpoint_path | |
| # Check for any .ckpt file in the artifacts/models directory | |
| if os.path.exists(artifacts_dir): | |
| for fname in os.listdir(artifacts_dir): | |
| if fname.endswith('.ckpt'): | |
| checkpoint_path = os.path.join(artifacts_dir, fname) | |
| logger.info("Using cached checkpoint: %s", checkpoint_path) | |
| return checkpoint_path | |
| # Check artifact-specific subdirectory | |
| artifact_dir = os.path.join(artifacts_dir, os.path.basename(artifact_path)) | |
| checkpoint_path = os.path.join(artifact_dir, "model.ckpt") | |
| if os.path.exists(checkpoint_path): | |
| logger.info("Using cached checkpoint: %s", checkpoint_path) | |
| return checkpoint_path | |
| # Only import and use wandb if checkpoint is not cached | |
| import wandb | |
| logger.info("Downloading checkpoint from W&B: %s", artifact_path) | |
| # Initialize wandb in offline mode to avoid creating online runs | |
| run = wandb.init(mode="offline") | |
| try: | |
| artifact = run.use_artifact(artifact_path, type='model') | |
| with log_timing("Downloading W&B checkpoint"): | |
| artifact_dir = artifact.download(root=artifacts_dir) | |
| finally: | |
| wandb.finish() | |
| # Find the checkpoint file | |
| checkpoint_path = os.path.join(artifact_dir, "model.ckpt") | |
| if not os.path.exists(checkpoint_path): | |
| for fname in os.listdir(artifact_dir): | |
| if fname.endswith('.ckpt'): | |
| checkpoint_path = os.path.join(artifact_dir, fname) | |
| break | |
| return checkpoint_path | |
| def load_model() -> ContinualAnalysisGNN: | |
| """Load the AnalysisGNN model.""" | |
| global MODEL | |
| if MODEL is None: | |
| checkpoint_path = download_wandb_checkpoint() | |
| logger.info("Loading model from: %s", checkpoint_path) | |
| MODEL = ContinualAnalysisGNN.load_from_checkpoint( | |
| checkpoint_path, | |
| map_location=DEVICE | |
| ) | |
| MODEL.eval() | |
| MODEL.to(DEVICE) | |
| logger.info("Model loaded successfully!") | |
| return MODEL | |
| def _format_bytes(num_bytes: float) -> str: | |
| """Return human readable size string.""" | |
| units = ["B", "KB", "MB", "GB", "TB"] | |
| size = float(num_bytes) | |
| for unit in units: | |
| if size < 1024: | |
| return f"{size:.1f}{unit}" | |
| size /= 1024 | |
| return f"{size:.1f}PB" | |
| def _download_file(url: str, destination: Path) -> bool: | |
| """Download a file from url to destination.""" | |
| try: | |
| destination.parent.mkdir(parents=True, exist_ok=True) | |
| logger.info("Starting download: %s -> %s", url, destination) | |
| with urllib.request.urlopen(url) as response, open(destination, "wb") as out_file: | |
| total_size = int(response.headers.get("Content-Length", 0)) | |
| downloaded = 0 | |
| chunk_size = 1024 * 256 | |
| last_log = time.perf_counter() | |
| while True: | |
| chunk = response.read(chunk_size) | |
| if not chunk: | |
| break | |
| out_file.write(chunk) | |
| downloaded += len(chunk) | |
| now = time.perf_counter() | |
| if now - last_log > 5: | |
| pct = (downloaded / total_size * 100) if total_size else 0 | |
| logger.info( | |
| "Downloading... %s / %s (%.1f%%)", | |
| _format_bytes(downloaded), | |
| _format_bytes(total_size) if total_size else "unknown", | |
| pct, | |
| ) | |
| last_log = now | |
| logger.info( | |
| "Finished download: %s (%s)", | |
| destination, | |
| _format_bytes(destination.stat().st_size), | |
| ) | |
| return True | |
| except Exception as exc: | |
| logger.exception("Error downloading %s: %s", url, exc) | |
| return False | |
| def _cleanup_musescore_artifacts(remove_appimage: bool = False) -> None: | |
| """Remove partially extracted MuseScore artifacts to allow a clean retry.""" | |
| extract_dir = MUSESCORE_STORAGE_DIR / "squashfs-root" | |
| if extract_dir.exists(): | |
| logger.warning("Removing stale MuseScore extract at %s", extract_dir) | |
| shutil.rmtree(extract_dir, ignore_errors=True) | |
| if remove_appimage: | |
| appimage = MUSESCORE_STORAGE_DIR / "MuseScore.AppImage" | |
| if appimage.exists(): | |
| try: | |
| appimage.unlink() | |
| logger.warning("Removed corrupt MuseScore AppImage at %s", appimage) | |
| except Exception: | |
| logger.warning("Could not remove MuseScore AppImage at %s", appimage) | |
| def ensure_musescore_binary() -> Optional[str]: | |
| """Ensure a MuseScore binary is available for rendering.""" | |
| global _MUSESCORE_BINARY | |
| if _MUSESCORE_BINARY and os.path.exists(_MUSESCORE_BINARY): | |
| return _MUSESCORE_BINARY | |
| env_path = os.environ.get(MUSESCORE_ENV_VAR) | |
| if env_path and os.path.exists(env_path): | |
| logger.info("Using MuseScore binary from %s", MUSESCORE_ENV_VAR) | |
| _MUSESCORE_BINARY = env_path | |
| return _MUSESCORE_BINARY | |
| for candidate in ("mscore", "mscore3", "musescore3", "musescore", "MuseScore3"): | |
| found = shutil.which(candidate) | |
| if found: | |
| logger.info("Found MuseScore executable on PATH: %s", found) | |
| _MUSESCORE_BINARY = found | |
| return _MUSESCORE_BINARY | |
| MUSESCORE_STORAGE_DIR.mkdir(parents=True, exist_ok=True) | |
| appimage_path = (MUSESCORE_STORAGE_DIR / "MuseScore.AppImage").resolve(strict=False) | |
| apprun_path = (MUSESCORE_STORAGE_DIR / "squashfs-root" / "AppRun").resolve(strict=False) | |
| if apprun_path.exists(): | |
| logger.info("Using cached MuseScore AppRun at %s", apprun_path) | |
| os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") | |
| _MUSESCORE_BINARY = str(apprun_path) | |
| return _MUSESCORE_BINARY | |
| for attempt in (1, 2): | |
| if not appimage_path.exists() or appimage_path.stat().st_size == 0: | |
| logger.info("MuseScore AppImage missing. Downloading (attempt %s)...", attempt) | |
| if not _download_file(MUSESCORE_APPIMAGE_URL, appimage_path): | |
| return None | |
| try: | |
| os.chmod(appimage_path, 0o755) | |
| except Exception as exc: | |
| logger.warning("Could not chmod MuseScore AppImage: %s", exc) | |
| try: | |
| with log_timing("Extracting MuseScore AppImage"): | |
| subprocess.run( | |
| [str(appimage_path), "--appimage-extract"], | |
| cwd=MUSESCORE_STORAGE_DIR, | |
| check=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| timeout=MUSESCORE_EXTRACT_TIMEOUT, | |
| ) | |
| except subprocess.CalledProcessError as exc: | |
| stderr = exc.stderr.decode(errors='ignore') if exc.stderr else str(exc) | |
| logger.error("MuseScore extraction failed: %s", stderr) | |
| _cleanup_musescore_artifacts(remove_appimage=(attempt == 1)) | |
| continue | |
| except subprocess.TimeoutExpired: | |
| logger.error("MuseScore extraction timed out after %ss", MUSESCORE_EXTRACT_TIMEOUT) | |
| _cleanup_musescore_artifacts(remove_appimage=(attempt == 1)) | |
| continue | |
| if apprun_path.exists(): | |
| os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") | |
| _MUSESCORE_BINARY = str(apprun_path) | |
| try: | |
| os.chmod(apprun_path, 0o755) | |
| except Exception: | |
| logger.debug("Could not chmod MuseScore AppRun; continuing anyway.") | |
| logger.info("MuseScore AppRun ready at %s", _MUSESCORE_BINARY) | |
| return _MUSESCORE_BINARY | |
| logger.error("MuseScore extraction completed but AppRun was not found.") | |
| _cleanup_musescore_artifacts(remove_appimage=(attempt == 1)) | |
| logger.error("MuseScore binary unavailable after retries.") | |
| return None | |
| def ensure_musescore_v3_binary() -> Optional[str]: | |
| """Ensure a MuseScore 3.x binary is available for rendering.""" | |
| global _MUSESCORE_V3_BINARY | |
| if _MUSESCORE_V3_BINARY and os.path.exists(_MUSESCORE_V3_BINARY): | |
| return _MUSESCORE_V3_BINARY | |
| env_path = os.environ.get(MUSESCORE_V3_ENV_VAR) | |
| if env_path and os.path.exists(env_path): | |
| logger.info("Using MuseScore 3 binary from %s", MUSESCORE_V3_ENV_VAR) | |
| _MUSESCORE_V3_BINARY = env_path | |
| return _MUSESCORE_V3_BINARY | |
| storage = MUSESCORE_V3_STORAGE_DIR | |
| storage.mkdir(parents=True, exist_ok=True) | |
| appimage_path = (storage / "MuseScore-3.AppImage").resolve(strict=False) | |
| apprun_path = (storage / "squashfs-root" / "AppRun").resolve(strict=False) | |
| if apprun_path.exists(): | |
| logger.info("Using cached MuseScore 3 AppRun at %s", apprun_path) | |
| _MUSESCORE_V3_BINARY = str(apprun_path) | |
| return _MUSESCORE_V3_BINARY | |
| if not appimage_path.exists(): | |
| logger.info("MuseScore 3 AppImage missing. Downloading (first run only)...") | |
| if not _download_file(MUSESCORE_V3_APPIMAGE_URL, appimage_path): | |
| return None | |
| try: | |
| os.chmod(appimage_path, 0o755) | |
| except Exception as exc: | |
| logger.warning("Could not chmod MuseScore 3 AppImage: %s", exc) | |
| try: | |
| with log_timing("Extracting MuseScore 3 AppImage"): | |
| subprocess.run( | |
| [str(appimage_path), "--appimage-extract"], | |
| cwd=storage, | |
| check=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| timeout=MUSESCORE_EXTRACT_TIMEOUT, | |
| ) | |
| except subprocess.CalledProcessError as exc: | |
| stderr = exc.stderr.decode(errors='ignore') if exc.stderr else str(exc) | |
| logger.error("MuseScore 3 extraction failed: %s", stderr) | |
| return None | |
| except subprocess.TimeoutExpired: | |
| logger.error("MuseScore 3 extraction timed out after %ss", MUSESCORE_EXTRACT_TIMEOUT) | |
| return None | |
| if apprun_path.exists(): | |
| _MUSESCORE_V3_BINARY = str(apprun_path) | |
| try: | |
| os.chmod(apprun_path, 0o755) | |
| except Exception: | |
| pass | |
| logger.info("MuseScore 3 AppRun ready at %s", _MUSESCORE_V3_BINARY) | |
| return _MUSESCORE_V3_BINARY | |
| logger.error("MuseScore 3 extraction did not produce the expected AppRun binary.") | |
| return None | |
| def _download_xvfb_package(dest_dir: Path) -> Optional[Path]: | |
| """Download the Xvfb .deb package using apt.""" | |
| try: | |
| completed = subprocess.run( | |
| ["apt", "download", "xvfb"], | |
| cwd=str(dest_dir), | |
| check=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True, | |
| ) | |
| logger.debug("apt download xvfb stdout: %s", completed.stdout.strip()) | |
| if completed.stderr: | |
| logger.debug("apt download xvfb stderr: %s", completed.stderr.strip()) | |
| except FileNotFoundError: | |
| logger.error("'apt' command not available; cannot download Xvfb automatically.") | |
| return None | |
| except subprocess.CalledProcessError as exc: | |
| logger.error( | |
| "Failed to download Xvfb package (exit %s): %s", | |
| exc.returncode, | |
| exc.stderr.strip() if exc.stderr else exc, | |
| ) | |
| return None | |
| deb_candidates = sorted(dest_dir.glob("xvfb_*.deb"), key=lambda p: p.stat().st_mtime, reverse=True) | |
| if not deb_candidates: | |
| logger.error("apt download xvfb did not produce any .deb files under %s", dest_dir) | |
| return None | |
| return deb_candidates[0] | |
| def _extract_xvfb_binary(deb_path: Path, target_dir: Path) -> Optional[Path]: | |
| extract_dir = target_dir / "pkg" | |
| if extract_dir.exists(): | |
| shutil.rmtree(extract_dir, ignore_errors=True) | |
| try: | |
| subprocess.run( | |
| ["dpkg-deb", "-x", str(deb_path), str(extract_dir)], | |
| check=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| ) | |
| except FileNotFoundError: | |
| logger.error("'dpkg-deb' command not available; cannot extract Xvfb package.") | |
| return None | |
| except subprocess.CalledProcessError as exc: | |
| stderr = exc.stderr.decode(errors="ignore") if isinstance(exc.stderr, bytes) else exc.stderr | |
| logger.error("Failed to extract Xvfb package: %s", stderr or exc) | |
| return None | |
| xvfb_path = extract_dir / "usr/bin/Xvfb" | |
| if xvfb_path.exists(): | |
| logger.info("Xvfb binary extracted to %s", xvfb_path) | |
| try: | |
| os.chmod(xvfb_path, 0o755) | |
| except Exception: | |
| pass | |
| try: | |
| deb_path.unlink() | |
| except Exception: | |
| pass | |
| return xvfb_path | |
| logger.error("Extracted Xvfb package but could not find usr/bin/Xvfb inside %s", extract_dir) | |
| return None | |
| def ensure_xvfb_binary() -> Optional[str]: | |
| """Ensure we have an Xvfb binary available for headless rendering.""" | |
| global _XVFB_BINARY | |
| if _XVFB_BINARY and os.path.exists(_XVFB_BINARY): | |
| return _XVFB_BINARY | |
| env_path = os.environ.get(XVFB_ENV_VAR) | |
| if env_path and os.path.exists(env_path): | |
| _XVFB_BINARY = env_path | |
| return _XVFB_BINARY | |
| which = shutil.which("Xvfb") | |
| if which: | |
| _XVFB_BINARY = which | |
| return _XVFB_BINARY | |
| XVFB_STORAGE_DIR.mkdir(parents=True, exist_ok=True) | |
| extracted_bin = XVFB_STORAGE_DIR / "pkg" / "usr" / "bin" / "Xvfb" | |
| if extracted_bin.exists(): | |
| _XVFB_BINARY = str(extracted_bin) | |
| return _XVFB_BINARY | |
| deb_path = _download_xvfb_package(XVFB_STORAGE_DIR) | |
| if not deb_path: | |
| return None | |
| extracted = _extract_xvfb_binary(deb_path, XVFB_STORAGE_DIR) | |
| if extracted: | |
| _XVFB_BINARY = str(extracted) | |
| return _XVFB_BINARY | |
| return None | |
| def initialize_musescore_backend() -> bool: | |
| """Initialize MuseScore AppRun at startup to avoid on-demand downloads.""" | |
| global _MUSESCORE_READY | |
| if _MUSESCORE_READY: | |
| return True | |
| available = [] | |
| primary = ensure_musescore_binary() | |
| if primary: | |
| available.append(primary) | |
| logger.info("MuseScore 4 AppRun ready at startup: %s", primary) | |
| legacy = ensure_musescore_v3_binary() | |
| if legacy: | |
| available.append(legacy) | |
| logger.info("MuseScore 3 AppRun ready at startup: %s", legacy) | |
| if available: | |
| _MUSESCORE_READY = True | |
| return True | |
| logger.warning("No MuseScore AppRun binaries could be initialized; score visualization will fail.") | |
| return False | |
| def _coalesce_musescore_output(output_path: str) -> Optional[str]: | |
| """ | |
| Normalize MuseScore CLI output when it renders multiple PNG pages. | |
| MuseScore writes `basename-1.png`, `basename-2.png`, ... even if we request | |
| a single filename. We promote the first page to the requested output path | |
| so downstream code can always load one predictable image. | |
| """ | |
| target = Path(output_path) | |
| if target.exists(): | |
| return str(target) | |
| suffix = target.suffix | |
| pattern = f"{target.stem}-*{suffix}" if suffix else f"{target.name}-*" | |
| matches = sorted(target.parent.glob(pattern)) | |
| if not matches: | |
| return None | |
| first_page = matches[0] | |
| normalized_path: Optional[Path] = None | |
| try: | |
| shutil.move(str(first_page), str(target)) | |
| normalized_path = target | |
| except Exception: | |
| try: | |
| shutil.copy(str(first_page), str(target)) | |
| normalized_path = target | |
| except Exception: | |
| normalized_path = first_page | |
| if normalized_path == target: | |
| logger.debug("Normalized MuseScore output %s -> %s", first_page, target) | |
| else: | |
| logger.debug("Using MuseScore page %s as output", first_page) | |
| # Remove leftover pages to avoid clutter, keep best-effort | |
| for extra in matches: | |
| if extra == first_page: | |
| continue | |
| try: | |
| extra.unlink() | |
| except Exception: | |
| pass | |
| return str(normalized_path) | |
| def persist_rendered_image(src_path: str) -> Optional[str]: | |
| """Copy rendered PNG to a persistent artifacts directory for UI display.""" | |
| if not src_path or not os.path.exists(src_path): | |
| return None | |
| try: | |
| RENDER_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| dest = RENDER_OUTPUT_DIR / f"{int(time.time()*1000)}_{Path(src_path).name}" | |
| shutil.copy2(src_path, dest) | |
| return str(dest) | |
| except Exception as exc: | |
| logger.warning("Could not persist rendered image %s: %s", src_path, exc) | |
| return src_path | |
| def xvfb_session(): | |
| """Spin up a temporary Xvfb server if available.""" | |
| xvfb_bin = ensure_xvfb_binary() | |
| if not xvfb_bin: | |
| logger.warning("Xvfb binary unavailable; proceeding without virtual display.") | |
| yield None | |
| return | |
| display = None | |
| base_dir = Path("/tmp/.X11-unix") | |
| try: | |
| base_dir.mkdir(mode=0o1777, exist_ok=True) | |
| except Exception: | |
| pass | |
| used = {p.name for p in base_dir.glob("X*")} | |
| for candidate in range(99, 160): | |
| name = f"X{candidate}" | |
| if name not in used: | |
| display = f":{candidate}" | |
| break | |
| if display is None: | |
| logger.warning("No free DISPLAY slots for Xvfb.") | |
| yield None | |
| return | |
| cmd = [ | |
| xvfb_bin, | |
| display, | |
| "-screen", | |
| "0", | |
| "1920x1080x24", | |
| "-nolisten", | |
| "tcp", | |
| ] | |
| logger.debug("Starting Xvfb with command: %s", " ".join(cmd)) | |
| proc = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.DEVNULL, | |
| stderr=subprocess.DEVNULL, | |
| ) | |
| time.sleep(0.5) | |
| if proc.poll() is not None: | |
| logger.error("Xvfb failed to start (exit %s).", proc.returncode) | |
| yield None | |
| return | |
| try: | |
| yield display | |
| finally: | |
| proc.terminate() | |
| try: | |
| proc.wait(timeout=5) | |
| except subprocess.TimeoutExpired: | |
| proc.kill() | |
| def render_with_musescore(musicxml_path: Optional[str], output_path: str) -> Optional[str]: | |
| """Render using MuseScore command-line interface.""" | |
| if not musicxml_path or not os.path.exists(musicxml_path): | |
| return None | |
| candidates = [] | |
| legacy = ensure_musescore_v3_binary() | |
| if legacy: | |
| candidates.append(("MuseScore 3", legacy, True)) | |
| primary = ensure_musescore_binary() | |
| if primary: | |
| candidates.append(("MuseScore 4", primary, True)) | |
| if not candidates: | |
| logger.warning("No MuseScore binaries available for rendering.") | |
| return None | |
| last_error = None | |
| for label, musescore_bin, requires_display in candidates: | |
| env = os.environ.copy() | |
| env.setdefault("QTWEBENGINE_DISABLE_SANDBOX", "1") | |
| env.setdefault("MUSESCORE_NO_AUDIO", "1") | |
| cmd = [musescore_bin, "-o", output_path, musicxml_path] | |
| logger.info("Attempting rendering with %s (%s).", label, musescore_bin) | |
| try: | |
| with xvfb_session() as display: | |
| if display: | |
| env["DISPLAY"] = display | |
| env["QT_QPA_PLATFORM"] = "xcb" | |
| logger.debug("%s: using Xvfb display %s", label, display) | |
| else: | |
| if requires_display: | |
| logger.warning("%s requires an X11 display but Xvfb could not be started.", label) | |
| continue | |
| env["QT_QPA_PLATFORM"] = "offscreen" | |
| logger.debug("%s: using Qt offscreen platform.", label) | |
| with log_timing(f"{label} rendering"): | |
| subprocess.run( | |
| cmd, | |
| check=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| env=env, | |
| timeout=MUSESCORE_RENDER_TIMEOUT, | |
| ) | |
| except subprocess.CalledProcessError as exc: | |
| stderr = exc.stderr.decode(errors='ignore') if exc.stderr else str(exc) | |
| logger.error("%s rendering failed: %s", label, stderr) | |
| last_error = stderr | |
| continue | |
| except subprocess.TimeoutExpired: | |
| logger.error("%s rendering timed out after %ss", label, MUSESCORE_RENDER_TIMEOUT) | |
| last_error = f"{label} timed out" | |
| continue | |
| normalized_path = _coalesce_musescore_output(output_path) | |
| if normalized_path and os.path.exists(normalized_path): | |
| logger.info("%s rendered %s -> %s", label, musicxml_path, normalized_path) | |
| return normalized_path | |
| logger.error("%s rendered score but the expected output file was not found.", label) | |
| last_error = "output missing" | |
| logger.error("All MuseScore binaries failed to render the score. Last error: %s", last_error) | |
| return None | |
| def resolve_musicxml_path(musicxml_file) -> Optional[str]: | |
| """Return a filesystem path for the uploaded MusicXML file.""" | |
| if musicxml_file is None: | |
| return None | |
| if isinstance(musicxml_file, (str, os.PathLike)): | |
| return str(musicxml_file) | |
| if isinstance(musicxml_file, dict) and "name" in musicxml_file: | |
| return musicxml_file["name"] | |
| file_path = getattr(musicxml_file, "name", None) | |
| if file_path: | |
| return file_path | |
| return None | |
| def save_parsed_musicxml(score: pt.score.Score, original_path: Optional[str]) -> Optional[str]: | |
| """ | |
| Persist the parsed/normalized score to a temporary MusicXML file. | |
| Returns the path to the saved file or None if saving fails. | |
| """ | |
| try: | |
| suffix = ".musicxml" | |
| if original_path: | |
| original_suffix = Path(original_path).suffix.lower() | |
| if original_suffix in {".xml", ".musicxml"}: | |
| suffix = original_suffix | |
| fd, tmp_path = tempfile.mkstemp(suffix=suffix) | |
| os.close(fd) | |
| with log_timing("Saving parsed MusicXML"): | |
| pt.save_musicxml(score, tmp_path) | |
| return tmp_path | |
| except Exception as exc: | |
| logger.warning("Could not save parsed MusicXML: %s", exc) | |
| return None | |
| def render_score_to_image( | |
| score: pt.score.Score, | |
| output_path: str, | |
| source_musicxml_path: Optional[str] = None | |
| ) -> Optional[str]: | |
| """ | |
| Render score directly with the MuseScore AppRun (no other fallbacks). | |
| The `score` argument is unused but kept for backward compatibility with the | |
| earlier pipeline that rendered from a score object. | |
| """ | |
| del score # Render is driven solely by the MusicXML path | |
| if not source_musicxml_path or not os.path.exists(source_musicxml_path): | |
| logger.error("Cannot render score: MusicXML path '%s' not found.", source_musicxml_path) | |
| return None | |
| return render_with_musescore(source_musicxml_path, output_path) | |
| def predict_analysis( | |
| model: ContinualAnalysisGNN, | |
| score: pt.score.Score, | |
| tasks: list | |
| ) -> Dict[str, np.ndarray]: | |
| """ | |
| Perform music analysis prediction. | |
| Parameters | |
| ---------- | |
| model : ContinualAnalysisGNN | |
| The model to use for prediction | |
| score : pt.score.Score | |
| The score to analyze | |
| tasks : list | |
| List of analysis tasks to perform | |
| Returns | |
| ------- | |
| dict | |
| Dictionary mapping task names to predictions and confidence scores | |
| """ | |
| with torch.no_grad(): | |
| with log_timing("Model prediction"): | |
| predictions = model.predict(score) | |
| # Decode predictions | |
| decoded_predictions = {} | |
| for task in tasks: | |
| if task in predictions: | |
| pred_tensor = predictions[task] | |
| if len(pred_tensor.shape) > 1: | |
| # Get confidence scores (probabilities) | |
| pred_probs = torch.softmax(pred_tensor, dim=-1) | |
| pred_onehot = torch.argmax(pred_tensor, dim=-1) | |
| # Get confidence for the predicted class | |
| confidence = torch.max(pred_probs, dim=-1)[0] | |
| # Store confidence scores | |
| decoded_predictions[f"{task}_confidence"] = confidence.cpu().numpy() | |
| else: | |
| pred_onehot = pred_tensor | |
| # Decode using available representations | |
| if task in available_representations: | |
| try: | |
| decoded = available_representations[task].decode( | |
| pred_onehot.reshape(-1, 1) | |
| ) | |
| # Convert to numpy array if it's a list | |
| if isinstance(decoded, list): | |
| decoded_predictions[task] = np.array(decoded).flatten() | |
| else: | |
| decoded_predictions[task] = decoded.flatten() | |
| except (IndexError, ValueError) as e: | |
| logger.warning("Error decoding %s predictions: %s", task, e) | |
| # Fallback to raw indices | |
| decoded_predictions[task] = pred_onehot.cpu().numpy() | |
| else: | |
| decoded_predictions[task] = pred_onehot.cpu().numpy() | |
| # Add timing information | |
| try: | |
| if "onset" in predictions: | |
| decoded_predictions["onset_beat"] = predictions["onset"].cpu().numpy() | |
| else: | |
| decoded_predictions["onset_beat"] = score.note_array()["onset_beat"] | |
| except (AttributeError, KeyError, IndexError) as e: | |
| logger.warning("Could not add onset timing: %s", e) | |
| try: | |
| if "s_measure" in predictions: | |
| decoded_predictions["measure"] = predictions["s_measure"].cpu().numpy() | |
| else: | |
| decoded_predictions["measure"] = score[0].measure_number_map(score.note_array()["onset_div"]) | |
| except (AttributeError, KeyError, IndexError) as e: | |
| logger.warning("Could not add measure information: %s", e) | |
| return decoded_predictions | |
| def process_musicxml( | |
| musicxml_file, | |
| selected_tasks: list | |
| ) -> Tuple[Optional[str], Optional[pd.DataFrame], Optional[str], str]: | |
| """ | |
| Process a MusicXML file and return visualization and analysis results. | |
| Parameters | |
| ---------- | |
| musicxml_file : file | |
| Uploaded MusicXML file | |
| selected_tasks : list | |
| List of selected analysis tasks | |
| Returns | |
| ------- | |
| tuple | |
| (image_path, dataframe, parsed_musicxml_path, status_message) | |
| """ | |
| if musicxml_file is None: | |
| return None, None, None, "Please upload a MusicXML file." | |
| if not selected_tasks: | |
| return None, None, None, "Please select at least one analysis task." | |
| try: | |
| score_path = resolve_musicxml_path(musicxml_file) | |
| if score_path is None or not os.path.exists(score_path): | |
| return None, None, None, "Could not locate the uploaded MusicXML file." | |
| # Load the model | |
| status_msg = "Loading model..." | |
| logger.info(status_msg) | |
| model = load_model() | |
| # Load the score | |
| status_msg = "Loading score..." | |
| logger.info(status_msg) | |
| score = pt.load_musicxml(score_path) | |
| parsed_score_path = save_parsed_musicxml(score, score_path) | |
| # Render score to image | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_img: | |
| img_path = tmp_img.name | |
| rendered_path: Optional[str] = None | |
| predictions: Dict[str, np.ndarray] = {} | |
| source_path = parsed_score_path or score_path | |
| parallel_enabled = should_parallelize() | |
| logger.info("Rendering score (parallel analysis enabled=%s)...", parallel_enabled) | |
| if parallel_enabled: | |
| logger.info("Running analysis and visualization in parallel (threads=%s).", 2) | |
| render_success = False | |
| analysis_success = False | |
| with ThreadPoolExecutor(max_workers=2) as executor: | |
| future_map = { | |
| executor.submit( | |
| render_score_to_image, | |
| score, | |
| img_path, | |
| source_musicxml_path=source_path, | |
| ): "render", | |
| executor.submit( | |
| predict_analysis, | |
| model, | |
| score, | |
| selected_tasks, | |
| ): "analysis", | |
| } | |
| for future in as_completed(future_map): | |
| task_name = future_map[future] | |
| try: | |
| result = future.result() | |
| except Exception: | |
| logger.exception("%s task failed.", task_name.capitalize()) | |
| continue | |
| if task_name == "render": | |
| rendered_path = result | |
| render_success = True | |
| else: | |
| predictions = result or {} | |
| analysis_success = True | |
| if not render_success: | |
| logger.info("Retrying score rendering sequentially after parallel failure.") | |
| rendered_path = render_score_to_image( | |
| score, | |
| img_path, | |
| source_musicxml_path=source_path, | |
| ) | |
| if not analysis_success: | |
| logger.info("Retrying analysis sequentially after parallel failure.") | |
| predictions = predict_analysis(model, score, selected_tasks) | |
| else: | |
| logger.info("Running analysis and visualization sequentially (parallel disabled).") | |
| rendered_path = render_score_to_image( | |
| score, | |
| img_path, | |
| source_musicxml_path=source_path, | |
| ) | |
| predictions = predict_analysis(model, score, selected_tasks) | |
| persisted_path = persist_rendered_image(rendered_path) if rendered_path else None | |
| if rendered_path is None or persisted_path is None: | |
| logger.warning("MuseScore AppRun could not render the score or save the PNG; visualization will be unavailable.") | |
| # Create DataFrame | |
| if predictions: | |
| df = pd.DataFrame(predictions) | |
| # Add note/event IDs | |
| if 'note_id' not in df.columns: | |
| df.insert(0, 'note_id', range(len(df))) | |
| # Convert tpc_in_label logits into NCT-friendly labels | |
| if 'tpc_in_label' in df.columns: | |
| df['tpc_in_label'] = np.where( | |
| df['tpc_in_label'].astype(int) == 0, | |
| "NCT", | |
| "Chord Tone" | |
| ) | |
| # Reorder columns to have timing info first, then predictions, then confidence | |
| timing_cols = [col for col in ['note_id', 'onset_beat', 'measure'] if col in df.columns] | |
| confidence_cols = [col for col in df.columns if col.endswith('_confidence')] | |
| prediction_cols = [col for col in df.columns if col not in timing_cols and col not in confidence_cols] | |
| # Interleave predictions with their confidence scores | |
| ordered_cols = timing_cols.copy() | |
| for pred_col in prediction_cols: | |
| ordered_cols.append(pred_col) | |
| conf_col = f"{pred_col}_confidence" | |
| if conf_col in confidence_cols: | |
| ordered_cols.append(conf_col) | |
| df = df[ordered_cols] | |
| # Apply user-friendly column names | |
| rename_map = {} | |
| for key, label in DISPLAY_NAME_OVERRIDES.items(): | |
| if key in df.columns: | |
| rename_map[key] = label | |
| conf_key = f"{key}_confidence" | |
| if conf_key in df.columns: | |
| rename_map[conf_key] = f"{label} Confidence" | |
| if rename_map: | |
| df = df.rename(columns=rename_map) | |
| status_msg = f"β Analysis complete! Analyzed {len(df)} notes with {len(selected_tasks)} task(s)." | |
| if parsed_score_path: | |
| status_msg += " Parsed MusicXML ready for download." | |
| else: | |
| df = pd.DataFrame() | |
| status_msg = "β Analysis returned no predictions." | |
| if parsed_score_path: | |
| status_msg += " Parsed MusicXML ready for download." | |
| return persisted_path, df, parsed_score_path, status_msg | |
| except Exception as e: | |
| error_msg = f"Error processing file: {str(e)}\n\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| return None, None, None, error_msg | |
| # Define available tasks | |
| AVAILABLE_TASKS = { | |
| "cadence": "Cadence Detection", | |
| "localkey": "Local Key", | |
| "tonkey": "Tonalized Key", | |
| "quality": "Chord Quality", | |
| "root": "Chord Root", | |
| "bass": "Bass Note", | |
| "inversion": "Chord Inversion", | |
| "degree1": "Primary Degree", | |
| "degree2": "Secondary Degree", | |
| "romanNumeral": "Roman Numeral Analysis", | |
| "phrase": "Phrase Segmentation", | |
| "section": "Section Detection", | |
| "hrhythm": "Harmonic Rhythm", | |
| "pcset": "Pitch-Class Set", | |
| "tpc_in_label": "Non-Chord Tone (NCT)", | |
| "note_degree": "Note Degree", | |
| } | |
| DISPLAY_NAME_OVERRIDES = { | |
| "tpc_in_label": "NCT", | |
| "note_degree": "Note Degree", | |
| } | |
| # Ensure MuseScore AppRun is available before the UI is constructed | |
| initialize_musescore_backend() | |
| # Create Gradio interface | |
| with gr.Blocks(title="AnalysisGNN Music Analysis", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π΅ AnalysisGNN Music Analysis | |
| Upload a MusicXML score to perform automatic music analysis using Graph Neural Networks. | |
| **Supported Analysis Tasks:** | |
| - Cadence Detection | |
| - Key Analysis (Local & Tonalized) | |
| - Harmonic Analysis (Chords, Inversions, Roman Numerals) | |
| - Phrase & Section Segmentation | |
| - Non-Chord Tone Detection (TPC-in-label / NCT) | |
| - Note Degree Labeling | |
| **Model:** Pre-trained AnalysisGNN from [manoskary/analysisGNN](https://github.com/manoskary/analysisGNN) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input section | |
| gr.Markdown("### π Input") | |
| file_input = gr.File( | |
| label="Upload MusicXML Score", | |
| file_types=[".musicxml", ".xml", ".mxl"], | |
| type="filepath" | |
| ) | |
| task_selector = gr.CheckboxGroup( | |
| choices=list(AVAILABLE_TASKS.values()), | |
| value=["Cadence Detection", "Local Key", "Roman Numeral Analysis"], | |
| label="Select Analysis Tasks", | |
| info="Choose which tasks to perform" | |
| ) | |
| analyze_btn = gr.Button("πΌ Analyze Score", variant="primary", size="lg") | |
| gr.Markdown("---") | |
| example_btn = gr.Button("π΅ Try Example (Mozart K.158)", size="sm") | |
| with gr.Column(scale=2): | |
| # Output section | |
| gr.Markdown("### π Results") | |
| status_output = gr.Textbox( | |
| label="Status", | |
| lines=2, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Score visualization | |
| gr.Markdown("### πΌ Score Visualization") | |
| image_output = gr.Image( | |
| label="Rendered Score", | |
| type="filepath" | |
| ) | |
| parsed_score_output = gr.File( | |
| label="Parsed MusicXML (Download)", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Analysis results table | |
| gr.Markdown("### π Analysis Results") | |
| table_output = gr.Dataframe( | |
| label="Analysis Output", | |
| wrap=True, | |
| interactive=False | |
| ) | |
| download_btn = gr.Button("πΎ Download Results as CSV") | |
| csv_output = gr.File(label="Download CSV") | |
| # Example section | |
| gr.Markdown(""" | |
| ### π‘ Tips & Information | |
| **Getting Started:** | |
| - Click "Try Example" to load a Mozart quartet, or upload your own MusicXML file | |
| - Select the analysis tasks you're interested in | |
| - Click "Analyze Score" to run the model | |
| **Analysis Output:** | |
| The table shows note-level predictions for all selected tasks: | |
| - **Onset & Measure**: Timing information | |
| - **Keys**: Detected key areas (local and tonalized) | |
| - **Chords**: Harmonic analysis with Roman numerals | |
| - **Cadences**: Identified cadence points and types | |
| **Score Visualization:** | |
| Requires MuseScore or LilyPond for rendering. If unavailable, analysis will still work. | |
| """) | |
| # Event handlers | |
| def analyze_wrapper(file, tasks_selected): | |
| # Convert task names back to internal names | |
| task_mapping = {v: k for k, v in AVAILABLE_TASKS.items()} | |
| selected_task_keys = [task_mapping[t] for t in tasks_selected if t in task_mapping] | |
| return process_musicxml(file, selected_task_keys) | |
| def load_example(): | |
| """Load example Mozart score.""" | |
| import urllib.request | |
| url = "https://raw.githubusercontent.com/manoskary/humdrum-mozart-quartets/refs/heads/master/musicxml/k158-01.musicxml" | |
| # Create artifacts directory if it doesn't exist | |
| os.makedirs("./artifacts", exist_ok=True) | |
| example_path = "./artifacts/k158-01.musicxml" | |
| if not os.path.exists(example_path): | |
| try: | |
| logger.info("Downloading example score from: %s", url) | |
| urllib.request.urlretrieve(url, example_path) | |
| logger.info("Example score saved to: %s", example_path) | |
| except Exception as e: | |
| return None, f"Error downloading example: {e}" | |
| return example_path, "Example loaded! Click 'Analyze Score' to proceed." | |
| analyze_btn.click( | |
| fn=analyze_wrapper, | |
| inputs=[file_input, task_selector], | |
| outputs=[image_output, table_output, parsed_score_output, status_output] | |
| ) | |
| example_btn.click( | |
| fn=load_example, | |
| outputs=[file_input, status_output] | |
| ) | |
| def save_csv(df): | |
| if df is None or len(df) == 0: | |
| return None | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: | |
| df.to_csv(tmp.name, index=False) | |
| return tmp.name | |
| download_btn.click( | |
| fn=save_csv, | |
| inputs=[table_output], | |
| outputs=[csv_output] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| # Pre-load the model at startup for efficiency | |
| logger.info("=" * 50) | |
| logger.info("Initializing AnalysisGNN app...") | |
| logger.info("=" * 50) | |
| logger.info("Pre-loading model at startup...") | |
| load_model() | |
| logger.info("Model ready. Launching Gradio interface...") | |
| logger.info("=" * 50) | |
| demo.launch() | |