Spaces:
Runtime error
Runtime error
| from datetime import timedelta | |
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| import torchvision | |
| import torch | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import numpy as np | |
| from inference import Inference | |
| import utils | |
| encoder_model_name = 'google/vit-large-patch32-224-in21k' | |
| decoder_model_name = 'gpt2-large' | |
| frame_step = 300 | |
| inference = Inference( | |
| decoder_model_name=decoder_model_name, | |
| ) | |
| model = SentenceTransformer('all-mpnet-base-v2') | |
| def search_in_video(video, query): | |
| result = torchvision.io.read_video(video) | |
| video = result[0] | |
| video_fps = result[2]['video_fps'] | |
| video_segments = [ | |
| video[idx:idx + frame_step, :, :, :] for idx in range(0, video.shape[0], frame_step) | |
| ] | |
| pixel_values = [utils.video2image(video_seg, encoder_model_name) for video_seg in video_segments] | |
| pixel_values = torch.stack(pixel_values) | |
| generated_texts = inference.generate_texts(pixel_values) | |
| sentences = [query] + generated_texts | |
| sentence_embeddings = model.encode(sentences) | |
| similarities = cosine_similarity( | |
| [sentence_embeddings[0]], | |
| sentence_embeddings[1:] | |
| ) | |
| arg_sorted_similarities = np.argsort(similarities) | |
| ordered_similarity_scores = similarities[0][arg_sorted_similarities] | |
| top1 = video_segments[arg_sorted_similarities[0, -1]] | |
| top2 = video_segments[arg_sorted_similarities[0, -2]] | |
| top3 = video_segments[arg_sorted_similarities[0, -3]] | |
| torchvision.io.write_video('top1.mp4', top1, video_fps) | |
| torchvision.io.write_video('top2.mp4', top2, video_fps) | |
| torchvision.io.write_video('top3.mp4', top3, video_fps) | |
| total_frames = video.shape[0] | |
| video_frame_segs = [ | |
| [idx, min(idx + frame_step, total_frames)] for idx in range(0, total_frames, frame_step) | |
| ] | |
| ordered_start_ends = [] | |
| for [start, end] in video_frame_segs: | |
| s = timedelta(seconds=(start / video_fps)) | |
| e = timedelta(seconds=(end / video_fps)) | |
| ordered_start_ends.append(f'{s}:{e}') | |
| ordered_start_ends = np.array(ordered_start_ends)[arg_sorted_similarities] | |
| labels_to_scores = dict( | |
| zip(ordered_start_ends[0].tolist(), ordered_similarity_scores[0].tolist()) | |
| ) | |
| return 'top1.mp4', 'top2.mp4', 'top3.mp4', labels_to_scores | |
| app = gr.Interface( | |
| fn=search_in_video, | |
| inputs=['video', 'text'], | |
| outputs=[ | |
| gr.Video(format='mp4', label='Top1'), | |
| gr.Video(format='mp4', label='Top2'), | |
| gr.Video(format='mp4', label='Top3'), | |
| gr.outputs.Label(num_top_classes=5, type='auto', label='Scores'), | |
| ], | |
| ) | |
| app.launch() | |