Spaces:
Configuration error
Configuration error
| from __future__ import annotations | |
| import aiohttp | |
| import os | |
| import traceback | |
| import logging | |
| from folder_paths import models_dir | |
| import re | |
| from typing import Callable, Any, Optional, Awaitable, Dict | |
| from enum import Enum | |
| import time | |
| from dataclasses import dataclass | |
| class DownloadStatusType(Enum): | |
| PENDING = "pending" | |
| IN_PROGRESS = "in_progress" | |
| COMPLETED = "completed" | |
| ERROR = "error" | |
| class DownloadModelStatus(): | |
| status: str | |
| progress_percentage: float | |
| message: str | |
| already_existed: bool = False | |
| def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool): | |
| self.status = status.value # Store the string value of the Enum | |
| self.progress_percentage = progress_percentage | |
| self.message = message | |
| self.already_existed = already_existed | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "status": self.status, | |
| "progress_percentage": self.progress_percentage, | |
| "message": self.message, | |
| "already_existed": self.already_existed | |
| } | |
| async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], | |
| model_name: str, | |
| model_url: str, | |
| model_sub_directory: str, | |
| progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], | |
| progress_interval: float = 1.0) -> DownloadModelStatus: | |
| """ | |
| Download a model file from a given URL into the models directory. | |
| Args: | |
| model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): | |
| A function that makes an HTTP request. This makes it easier to mock in unit tests. | |
| model_name (str): | |
| The name of the model file to be downloaded. This will be the filename on disk. | |
| model_url (str): | |
| The URL from which to download the model. | |
| model_sub_directory (str): | |
| The subdirectory within the main models directory where the model | |
| should be saved (e.g., 'checkpoints', 'loras', etc.). | |
| progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): | |
| An asynchronous function to call with progress updates. | |
| Returns: | |
| DownloadModelStatus: The result of the download operation. | |
| """ | |
| if not validate_model_subdirectory(model_sub_directory): | |
| return DownloadModelStatus( | |
| DownloadStatusType.ERROR, | |
| 0, | |
| "Invalid model subdirectory", | |
| False | |
| ) | |
| if not validate_filename(model_name): | |
| return DownloadModelStatus( | |
| DownloadStatusType.ERROR, | |
| 0, | |
| "Invalid model name", | |
| False | |
| ) | |
| file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) | |
| existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) | |
| if existing_file: | |
| return existing_file | |
| try: | |
| status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) | |
| await progress_callback(relative_path, status) | |
| response = await model_download_request(model_url) | |
| if response.status != 200: | |
| error_message = f"Failed to download {model_name}. Status code: {response.status}" | |
| logging.error(error_message) | |
| status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) | |
| await progress_callback(relative_path, status) | |
| return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) | |
| return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) | |
| except Exception as e: | |
| logging.error(f"Error in downloading model: {e}") | |
| return await handle_download_error(e, model_name, progress_callback, relative_path) | |
| def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]: | |
| full_model_dir = os.path.join(models_base_dir, model_directory) | |
| os.makedirs(full_model_dir, exist_ok=True) | |
| file_path = os.path.join(full_model_dir, model_name) | |
| # Ensure the resulting path is still within the base directory | |
| abs_file_path = os.path.abspath(file_path) | |
| abs_base_dir = os.path.abspath(str(models_base_dir)) | |
| if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: | |
| raise Exception(f"Invalid model directory: {model_directory}/{model_name}") | |
| relative_path = '/'.join([model_directory, model_name]) | |
| return file_path, relative_path | |
| async def check_file_exists(file_path: str, | |
| model_name: str, | |
| progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], | |
| relative_path: str) -> Optional[DownloadModelStatus]: | |
| if os.path.exists(file_path): | |
| status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) | |
| await progress_callback(relative_path, status) | |
| return status | |
| return None | |
| async def track_download_progress(response: aiohttp.ClientResponse, | |
| file_path: str, | |
| model_name: str, | |
| progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], | |
| relative_path: str, | |
| interval: float = 1.0) -> DownloadModelStatus: | |
| try: | |
| total_size = int(response.headers.get('Content-Length', 0)) | |
| downloaded = 0 | |
| last_update_time = time.time() | |
| async def update_progress(): | |
| nonlocal last_update_time | |
| progress = (downloaded / total_size) * 100 if total_size > 0 else 0 | |
| status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) | |
| await progress_callback(relative_path, status) | |
| last_update_time = time.time() | |
| with open(file_path, 'wb') as f: | |
| chunk_iterator = response.content.iter_chunked(8192) | |
| while True: | |
| try: | |
| chunk = await chunk_iterator.__anext__() | |
| except StopAsyncIteration: | |
| break | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| if time.time() - last_update_time >= interval: | |
| await update_progress() | |
| await update_progress() | |
| logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") | |
| status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) | |
| await progress_callback(relative_path, status) | |
| return status | |
| except Exception as e: | |
| logging.error(f"Error in track_download_progress: {e}") | |
| logging.error(traceback.format_exc()) | |
| return await handle_download_error(e, model_name, progress_callback, relative_path) | |
| async def handle_download_error(e: Exception, | |
| model_name: str, | |
| progress_callback: Callable[[str, DownloadModelStatus], Any], | |
| relative_path: str) -> DownloadModelStatus: | |
| error_message = f"Error downloading {model_name}: {str(e)}" | |
| status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) | |
| await progress_callback(relative_path, status) | |
| return status | |
| def validate_model_subdirectory(model_subdirectory: str) -> bool: | |
| """ | |
| Validate that the model subdirectory is safe to install into. | |
| Must not contain relative paths, nested paths or special characters | |
| other than underscores and hyphens. | |
| Args: | |
| model_subdirectory (str): The subdirectory for the specific model type. | |
| Returns: | |
| bool: True if the subdirectory is safe, False otherwise. | |
| """ | |
| if len(model_subdirectory) > 50: | |
| return False | |
| if '..' in model_subdirectory or '/' in model_subdirectory: | |
| return False | |
| if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory): | |
| return False | |
| return True | |
| def validate_filename(filename: str)-> bool: | |
| """ | |
| Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. | |
| Args: | |
| filename (str): The filename to validate | |
| Returns: | |
| bool: True if the filename is valid, False otherwise | |
| """ | |
| if not filename.lower().endswith(('.sft', '.safetensors')): | |
| return False | |
| # Check if the filename is empty, None, or just whitespace | |
| if not filename or not filename.strip(): | |
| return False | |
| # Check for any directory traversal attempts or invalid characters | |
| if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']): | |
| return False | |
| # Check if the filename starts with a dot (hidden file) | |
| if filename.startswith('.'): | |
| return False | |
| # Use a whitelist of allowed characters | |
| if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename): | |
| return False | |
| # Ensure the filename isn't too long | |
| if len(filename) > 255: | |
| return False | |
| return True | |