import spaces from huggingface_hub import hf_hub_download import subprocess import importlib import site import torch # Re-discover all .pth/.egg-link files for sitedir in site.getsitepackages(): site.addsitedir(sitedir) # Clear caches so importlib will pick up new modules importlib.invalidate_caches() def sh(cmd): subprocess.check_call(cmd, shell=True) flash_attention_installed = False try: print("Attempting to download and install FlashAttention wheel...") flash_attention_wheel = hf_hub_download( repo_id="alexnasa/flash-attn-3", repo_type="model", filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl", ) sh(f"pip install {flash_attention_wheel}") # tell Python to re-scan site-packages now that the egg-link exists import importlib, site site.addsitedir(site.getsitepackages()[0]) importlib.invalidate_caches() flash_attention_installed = True print("FlashAttention installed successfully.") except Exception as e: print(f"⚠️ Could not install FlashAttention: {e}") print("Continuing without FlashAttention...") attn_implementation = "flash_attention_2" if flash_attention_installed else "sdpa" dtype = torch.bfloat16 if flash_attention_installed else None