Werli commited on
Commit
4025d33
·
verified ·
1 Parent(s): 5382ef8

Delete modules/beautify_model.py

Browse files
Files changed (1) hide show
  1. modules/beautify_model.py +0 -99
modules/beautify_model.py DELETED
@@ -1,99 +0,0 @@
1
- import os
2
- import io,copy,requests,spaces,gradio as gr,numpy as np
3
- from transformers import T5ForConditionalGeneration, T5Tokenizer
4
-
5
- LAMINI_PROMPT_LONG= "gokaygokay/Lamini-Prompt-Enchance-Long"
6
-
7
- class beautify_class:
8
- def __init__(self, repoId: str, device: str = None, loadModel: bool = False):
9
- self.modelPath = self.download_model(repoId)
10
- if device is None:
11
- import torch
12
- self.totalVram = 0
13
- if torch.cuda.is_available():
14
- try:
15
- deviceId = torch.cuda.current_device()
16
- self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory / (1024 * 1024 * 1024)
17
- except Exception as e:
18
- print(traceback.format_exc())
19
- print("Error detect vram: " + str(e))
20
- device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
21
- else:
22
- device = "cpu"
23
- self.device = device
24
- self.system_prompt = "Summarize, beautify and enhance the following English labels describing a single image into a readable English article:\n\n"
25
- if loadModel:
26
- self.load_model()
27
-
28
- def download_model(self, repoId):
29
- import huggingface_hub
30
- allowPatterns = [
31
- #"tf_model.h5",
32
- #"model.ckpt.index",
33
- #"flax_model.msgpack",
34
- #"pytorch_model.bin",
35
- "config.json",
36
- "generation_config.json",
37
- "model.safetensors",
38
- "tokenizer.json",
39
- "tokenizer_config.json",
40
- "special_tokens_map.json",
41
- "vocab.json",
42
- "added_tokens.json",
43
- "spiece.model"
44
- ]
45
- kwargs = {"allow_patterns": allowPatterns,}
46
- try:
47
- return huggingface_hub.snapshot_download(repoId, **kwargs)
48
- except (huggingface_hub.utils.HfHubHTTPError, requests.exceptions.ConnectionError) as exception:
49
- import warnings
50
- warnings.warn(
51
- "An error occurred while synchronizing the model %s from the Hugging Face Hub:\n%s",
52
- repoId,
53
- exception,
54
- )
55
- warnings.warn(
56
- "Trying to load the model directly from the local cache, if it exists."
57
- )
58
- kwargs["local_files_only"] = True
59
- return huggingface_hub.snapshot_download(repoId, **kwargs)
60
-
61
- def load_model(self):
62
- import transformers
63
- try:
64
- print('\n\nLoading model: %s\n\n' % self.modelPath)
65
- self.Tokenizer = T5Tokenizer.from_pretrained(self.modelPath)
66
- self.Model = T5ForConditionalGeneration.from_pretrained(self.modelPath).to(self.device)
67
- except Exception as e:
68
- self.release_vram()
69
- raise e
70
-
71
- def release_vram(self):
72
- try:
73
- import torch
74
- if torch.cuda.is_available():
75
- if getattr(self, "Model", None) is not None:
76
- self.Model.to('cpu')
77
- del self.Model
78
- if getattr(self, "Tokenizer", None) is not None:
79
- del self.Tokenizer
80
- import gc
81
- gc.collect()
82
- torch.cuda.empty_cache()
83
- print("release vram end.")
84
- except Exception as e:
85
- print(traceback.format_exc())
86
- print("Error release vram: " + str(e))
87
-
88
- def beautify(self, text: str, max_length: int = 400):
89
- try:
90
- input_ids = self.Tokenizer(self.system_prompt + text, return_tensors="pt").input_ids.to(self.device)
91
- output = self.Model.generate(input_ids, max_length=max_length, no_repeat_ngram_size=3, num_beams=2, early_stopping=True)
92
- result = self.Tokenizer.decode(output[0], skip_special_tokens=True)
93
- return result
94
- except Exception as e:
95
- print(traceback.format_exc())
96
- print("Error found: " + str(e))
97
- return None
98
-
99
- beautify_list=[LAMINI_PROMPT_LONG]