Spaces:
Runtime error
Runtime error
| import io | |
| import timm | |
| import torch | |
| import streamlit as st | |
| from PIL import Image | |
| from timm.data import resolve_data_config | |
| from timm.data.transforms_factory import create_transform | |
| class ImageClassifier(object): | |
| def __init__(self, model, labels): | |
| self.model = model | |
| self.labels = labels | |
| def get_top_5_predictions(self, image): | |
| values, indices = torch.topk(self.get_output_probabilities(image), 5) | |
| return [ | |
| {'label': self.labels[i], 'score': v.item()} | |
| for i, v in zip(indices, values) | |
| ] | |
| def get_output_probabilities(self, image): | |
| output = self.classify_image(image) | |
| return torch.nn.functional.softmax(output[0], dim=0) | |
| def classify_image(self, image): | |
| self.model.eval() | |
| transform = self.create_image_transform() | |
| return self.model(transform(image).unsqueeze(0)) | |
| def create_image_transform(self): | |
| return create_transform(**resolve_data_config( | |
| self.model.pretrained_cfg, model=self.model)) | |
| class ImageClassificationApp(object): | |
| def __init__(self, title, classifier): | |
| self.title = title | |
| self.classifier = classifier | |
| def render(self): | |
| st.title(self.title) | |
| uploaded_image = self.get_uploaded_image() | |
| if uploaded_image is not None: | |
| self.show_image_and_results(uploaded_image) | |
| def get_uploaded_image(self): | |
| return st.file_uploader('Choose an image...', type=['jpg', 'png', 'jpeg']) | |
| def show_image_and_results(self, uploaded_image): | |
| self.show_uploaded_image(uploaded_image) | |
| self.show_classification_results(self.get_image(uploaded_image.read())) | |
| def show_uploaded_image(self, uploaded_image): | |
| st.image(uploaded_image, caption='Uploaded Image', use_column_width=True) | |
| def show_classification_results(self, image): | |
| st.subheader('Classification Results:') | |
| self.write_top_5_predictions(image) | |
| def write_top_5_predictions(self, image): | |
| for prediction in self.classifier.get_top_5_predictions(image): | |
| st.write(f"- {prediction['label']}: {prediction['score']:.4f}") | |
| def get_image(self, image_data): | |
| return Image.open(io.BytesIO(image_data)) | |
| if __name__ == '__main__': | |
| model = timm.create_model( | |
| 'hf-hub:nateraw/resnet50-oxford-iiit-pet', | |
| pretrained=True | |
| ) | |
| labels = model.pretrained_cfg['label_names'] | |
| classifier = ImageClassifier(model, labels) | |
| ImageClassificationApp( | |
| 'Pet Image Classification App', | |
| classifier | |
| ).render() | |