| from transformers import PreTrainedModel | |
| import torch | |
| import os | |
| class InceptionV3ModelForImageClassification(PreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| model_path = "google-safesearch-mini.bin" | |
| if self.config.model_name == "google-safesearch-mini": | |
| if not os.path.exists(model_path): | |
| import urllib.request | |
| urllib.request.urlretrieve("https://huggingface.co/FredZhang7/google-safesearch-mini/resolve/main/pytorch_model.bin", model_path) | |
| self.model = torch.jit.load(model_path) | |
| else: | |
| raise ValueError(f"Model {self.config.model_name} not found.") | |
| def forward(self, input_ids): | |
| return self.model(input_ids), None if self.config.model_name == "inception_v3" else self.model(input_ids) | |
| def freeze(self): | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| def unfreeze(self): | |
| for param in self.model.parameters(): | |
| param.requires_grad = True | |
| def train(self, mode=True): | |
| super().train(mode) | |
| self.model.train(mode) | |
| def eval(self): | |
| return self.train(False) | |
| def to(self, device): | |
| self.model.to(device) | |
| return self | |
| def cuda(self, device=None): | |
| return self.to("cuda") | |
| def cpu(self): | |
| return self.to("cpu") | |
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |
| return self.model.state_dict(destination, prefix, keep_vars) | |
| def load_state_dict(self, state_dict, strict=True): | |
| return self.model.load_state_dict(state_dict, strict) | |
| def parameters(self, recurse=True): | |
| return self.model.parameters(recurse) | |
| def named_parameters(self, prefix='', recurse=True): | |
| return self.model.named_parameters(prefix, recurse) | |
| def children(self): | |
| return self.model.children() | |
| def named_children(self): | |
| return self.model.named_children() | |
| def modules(self): | |
| return self.model.modules() | |
| def named_modules(self, memo=None, prefix=''): | |
| return self.model.named_modules(memo, prefix) | |
| def zero_grad(self, set_to_none=False): | |
| return self.model.zero_grad(set_to_none) | |
| def share_memory(self): | |
| return self.model.share_memory() | |
| def transform(self, image): | |
| from torchvision import transforms | |
| transform = transforms.Compose([ | |
| transforms.Resize(299), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=self.config.mean, std=self.config.std) | |
| ]) | |
| image = transform(image) | |
| return image | |
| def open_image(self, path): | |
| from PIL import Image | |
| path = 'https://images.unsplash.com/photo-1594568284297-7c64464062b1' | |
| if path.startswith('http://') or path.startswith('https://'): | |
| import requests | |
| from io import BytesIO | |
| response = requests.get(path) | |
| image = Image.open(BytesIO(response.content)).convert('RGB') | |
| else: | |
| image = Image.open(path).convert('RGB') | |
| return image | |
| def predict(self, path, device="cuda", print_tensor=True): | |
| image = self.open_image(path) | |
| image = self.transform(image) | |
| image = image.unsqueeze(0) | |
| if device == "cuda": | |
| image = image.cuda() | |
| self.cuda() | |
| else: | |
| image = image.cpu() | |
| self.cpu() | |
| with torch.no_grad(): | |
| out, aux = self(image) | |
| if print_tensor: | |
| print(out) | |
| _, predicted = torch.max(out.logits, 1) | |
| return self.config.classes[str(predicted.item())] | |