File size: 5,003 Bytes
f45a0cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1b92b9
f45a0cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1b92b9
 
f45a0cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1857618
 
 
 
 
 
 
f45a0cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python
"""HyperView Hugging Face Space demo: CLIP + HyCoCLIP on Imagenette.

Usage:
  python demo.py --precompute   # run during Docker build
  python demo.py                # run as app entrypoint
"""

from __future__ import annotations

import os
import sys

import hyperview as hv

HOST = os.environ.get("HOST", "0.0.0.0")
PORT = int(os.environ.get("PORT", "7860"))

DATASET_NAME = os.environ.get("DEMO_DATASET", "imagenette_clip_hycoclip")
HF_DATASET = os.environ.get("DEMO_HF_DATASET", "Multimodal-Fatima/Imagenette_validation")
HF_SPLIT = os.environ.get("DEMO_HF_SPLIT", "validation")
HF_IMAGE_KEY = os.environ.get("DEMO_HF_IMAGE_KEY", "image")
HF_LABEL_KEY = os.environ.get("DEMO_HF_LABEL_KEY", "label")
NUM_SAMPLES = int(os.environ.get("DEMO_SAMPLES", "300"))
SAMPLE_SEED = int(os.environ.get("DEMO_SEED", "42"))

CLIP_MODEL_ID = os.environ.get("DEMO_CLIP_MODEL", "openai/clip-vit-base-patch32")
HYPER_MODEL_ID = os.environ.get("DEMO_HYPER_MODEL", "hycoclip-vit-s")


def _truthy_env(name: str, default: bool = True) -> bool:
    value = os.environ.get(name)
    if value is None:
        return default
    return value.strip().lower() not in {"0", "false", "no", "off", ""}


def _ensure_demo_ready(dataset: hv.Dataset) -> None:
    if len(dataset) == 0:
        print(f"Loading samples from {HF_DATASET} ({HF_SPLIT})...")
        dataset.add_from_huggingface(
            HF_DATASET,
            split=HF_SPLIT,
            image_key=HF_IMAGE_KEY,
            label_key=HF_LABEL_KEY,
            max_samples=NUM_SAMPLES,
            shuffle=True,
            seed=SAMPLE_SEED,
        )

    spaces = dataset.list_spaces()

    clip_space = next(
        (
            space
            for space in spaces
            if getattr(space, "provider", None) == "embed-anything"
            and getattr(space, "model_id", None) == CLIP_MODEL_ID
        ),
        None,
    )

    if clip_space is None:
        print(f"Computing CLIP embeddings ({CLIP_MODEL_ID})...")
        dataset.compute_embeddings(model=CLIP_MODEL_ID, provider="embed-anything", show_progress=True)
        spaces = dataset.list_spaces()
        clip_space = next(
            (
                space
                for space in spaces
                if getattr(space, "provider", None) == "embed-anything"
                and getattr(space, "model_id", None) == CLIP_MODEL_ID
            ),
            None,
        )

    if clip_space is None:
        raise RuntimeError("Failed to create CLIP embedding space")

    compute_hyperbolic = _truthy_env("DEMO_COMPUTE_HYPERBOLIC", default=True)
    hyper_space = next(
        (
            space
            for space in spaces
            if getattr(space, "provider", None) == "hyper-models"
            and getattr(space, "model_id", None) == HYPER_MODEL_ID
        ),
        None,
    )

    if compute_hyperbolic and hyper_space is None:
        try:
            print(f"Computing hyperbolic embeddings ({HYPER_MODEL_ID})...")
            dataset.compute_embeddings(model=HYPER_MODEL_ID, provider="hyper-models", show_progress=True)
            spaces = dataset.list_spaces()
            hyper_space = next(
                (
                    space
                    for space in spaces
                    if getattr(space, "provider", None) == "hyper-models"
                    and getattr(space, "model_id", None) == HYPER_MODEL_ID
                ),
                None,
            )
        except Exception as exc:
            print(f"WARNING: hyperbolic embeddings failed ({type(exc).__name__}: {exc})")

    layouts = dataset.list_layouts()
    geometries = {getattr(layout, "geometry", None) for layout in layouts}

    if "euclidean" not in geometries:
        print("Computing euclidean layout...")
        dataset.compute_visualization(space_key=clip_space.space_key, geometry="euclidean")

    if "poincare" not in geometries:
        print("Computing poincaré layout...")
        poincare_space_key = hyper_space.space_key if hyper_space is not None else clip_space.space_key
        dataset.compute_visualization(space_key=poincare_space_key, geometry="poincare")


def main() -> None:
    dataset = hv.Dataset(DATASET_NAME)

    if len(dataset) == 0 or not dataset.list_layouts():
        print("Preparing demo dataset...")
        try:
            _ensure_demo_ready(dataset)
        except Exception as exc:
            import traceback
            traceback.print_exc()
            print(f"\nFATAL: demo setup failed: {type(exc).__name__}: {exc}", file=sys.stderr)
            sys.exit(1)
    else:
        print(
            f"Loaded cached dataset '{DATASET_NAME}' with "
            f"{len(dataset.list_spaces())} spaces and {len(dataset.list_layouts())} layouts"
        )

    if "--precompute" in sys.argv:
        print("Precompute complete")
        return

    print(f"Starting HyperView on {HOST}:{PORT}")
    hv.launch(dataset, host=HOST, port=PORT, open_browser=False)


if __name__ == "__main__":
    main()