Spaces:
Runtime error
Runtime error
| from models import CLIPModel | |
| from inference import predict_poems_from_image | |
| from utils import get_poem_embeddings | |
| import config as CFG | |
| import json | |
| import torch | |
| import gradio as gr | |
| def greet_user(name): | |
| return "Hello " + name + " Welcome to Gradio!😎" | |
| if __name__ == "__main__": | |
| model = CLIPModel(image_encoder_pretrained=True, | |
| text_encoder_pretrained=True, | |
| text_projection_trainable=False, | |
| is_image_poem_pair=True | |
| ).to(CFG.device) | |
| model.eval() | |
| # Inference: Output some example predictions and write them in a file | |
| with open('poem_embeddings.json', encoding="utf-8") as f: | |
| pe = json.load(f) | |
| poem_embeddings = torch.Tensor([p['embeddings'] for p in pe]).to(CFG.device) | |
| print(poem_embeddings.shape) | |
| poems = [p['beyt'] for p in pe] | |
| def gradio_make_predictions(image): | |
| beyts = predict_poems_from_image(model, poem_embeddings, image, poems, n=10) | |
| return "\n".join(beyts) | |
| CFG.batch_size = 512 | |
| image_input = gr.Image(type="filepath") | |
| output = gr.Textbox() | |
| app = gr.Interface(fn = gradio_make_predictions, inputs=image_input, outputs=output) | |
| app.launch() |