File size: 4,056 Bytes
a602628
3bc6e37
 
 
 
a602628
 
 
6c32e21
 
3bc6e37
6c32e21
a602628
 
 
eb72117
3bc6e37
 
 
 
eb72117
6c32e21
3bc6e37
 
eb72117
3bc6e37
6c32e21
 
3bc6e37
6c32e21
3bc6e37
 
eb72117
3bc6e37
 
eb72117
3bc6e37
eb72117
6c32e21
3bc6e37
eb72117
 
 
 
 
 
 
 
 
 
 
 
6c32e21
eb72117
 
 
3bc6e37
eb72117
 
3bc6e37
 
eb72117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c32e21
 
 
 
 
 
 
 
 
 
 
 
 
 
eb72117
 
6c32e21
eb72117
3bc6e37
eb72117
a602628
3bc6e37
 
 
 
 
 
 
 
6c32e21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
HuggingFace Dataset Download Utility for LoRA Training Studio.

Provides a helper to download audio datasets from HuggingFace Hub.
The actual training pipeline lives in acestep/training/.
"""

import logging
import os
import shutil
from pathlib import Path
from typing import Tuple

logger = logging.getLogger(__name__)

AUDIO_SUFFIXES = {".wav", ".mp3", ".flac", ".ogg", ".opus"}


def download_hf_dataset(
    dataset_id: str,
    max_files: int = 50,
    offset: int = 0,
) -> Tuple[str, str]:
    """
    Download a subset of audio files from a HuggingFace dataset repo.

    Also pulls dataset.json from the repo if it exists (restoring labels
    and preprocessed flags from a previous session).

    Uses HF_TOKEN env var for authentication.

    Returns:
        Tuple of (output_dir, status_message)
    """
    try:
        from huggingface_hub import HfApi, hf_hub_download

        api = HfApi()
        token = os.environ.get("HF_TOKEN")

        logger.info(f"Listing files in '{dataset_id}'...")

        all_files = [
            f.rfilename
            for f in api.list_repo_tree(
                dataset_id, repo_type="dataset", token=token, recursive=True
            )
            if hasattr(f, "rfilename")
            and Path(f.rfilename).suffix.lower() in AUDIO_SUFFIXES
        ]

        total_available = len(all_files)
        selected = all_files[offset:offset + max_files]

        if not selected:
            return "", f"No audio files found in {dataset_id}"

        logger.info(
            f"Downloading {len(selected)}/{total_available} audio files..."
        )

        output_dir = Path("lora_training") / "hf" / dataset_id.replace("/", "_")
        output_dir.mkdir(parents=True, exist_ok=True)

        for i, filename in enumerate(selected):
            logger.info(f"  [{i + 1}/{len(selected)}] {filename}")
            cached_path = hf_hub_download(
                repo_id=dataset_id,
                filename=filename,
                repo_type="dataset",
                token=token,
            )
            # Symlink from cache into our working dir so scan_directory finds them
            dest = output_dir / Path(filename).name
            if not dest.exists():
                dest.symlink_to(cached_path)

        # Pull dataset.json from repo if it exists (restores previous session state)
        try:
            cached_json = hf_hub_download(
                repo_id=dataset_id,
                filename="dataset.json",
                repo_type="dataset",
                token=token,
            )
            dest_json = output_dir / "dataset.json"
            shutil.copy2(cached_json, str(dest_json))
            logger.info("Pulled dataset.json from HF repo")
        except Exception:
            logger.info("No dataset.json in HF repo (first session)")

        status = (
            f"Downloaded {len(selected)} of {total_available} "
            f"audio files from {dataset_id} (offset {offset})"
        )
        logger.info(status)
        return str(output_dir), status

    except ImportError:
        msg = "huggingface_hub is not installed. Run: pip install huggingface_hub"
        logger.error(msg)
        return "", msg
    except Exception as e:
        msg = f"Failed to download dataset: {e}"
        logger.error(msg)
        return "", msg


def upload_dataset_json_to_hf(dataset_id: str, json_path: str) -> str:
    """Push dataset.json to the HF dataset repo for persistence across sessions."""
    try:
        from huggingface_hub import HfApi

        token = os.environ.get("HF_TOKEN")
        if not token:
            return "HF_TOKEN not set — skipped HF sync"

        api = HfApi()
        api.upload_file(
            path_or_fileobj=json_path,
            path_in_repo="dataset.json",
            repo_id=dataset_id,
            repo_type="dataset",
            token=token,
        )
        return f"Synced dataset.json to {dataset_id}"

    except Exception as e:
        msg = f"HF sync failed: {e}"
        logger.error(msg)
        return msg