Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| #FIX | |
| import config as CFG | |
| from modules import TextEncoder, ProjectionHead, ImageEncoder | |
| class PoemTextModel(nn.Module): | |
| """ | |
| Model predicting poem and text embeddings, and their similarities. | |
| ... | |
| Attributes: | |
| ----------- | |
| poem_encoder : TextEncoder | |
| encoder used for extracting poem embeddings | |
| text_encoder : TextEncoder | |
| encoder used for extracting text embeddings | |
| poem_projection: ProjectionHead | |
| projection head used for poem embeddings (projects poem encoder output to shared embedding space) | |
| text_projection: ProjectionHead | |
| projection head used for text embeddings (projects text encoder output to shared embedding space) | |
| temperature: float | |
| used to scale the dot similarities | |
| Methods: | |
| -------- | |
| forward(batch): | |
| returns poem and text embeddings of batch | |
| similarity_scores(batch): | |
| computes dot similarities of a batch of text-poem pair | |
| predict(batch): | |
| predicts the most similar poem idx for each text (using previous methods) | |
| calculate_loss(batch): | |
| computes contrastive (cross entropy) loss for both poems and texts. | |
| save_current(): | |
| saves current model's encoders (if trainable) and projection heads. | |
| """ | |
| def __init__( | |
| self, | |
| poem_encoder_pretrained, | |
| text_encoder_pretrained, | |
| temperature=CFG.temperature, | |
| poem_embedding=CFG.poem_embedding, | |
| text_embedding=CFG.text_embedding, | |
| ): | |
| """ | |
| Initializes model's submodules | |
| Parameters: | |
| ----------- | |
| poem_encoder_pretrained: bool | |
| whether or not to load a pretrained poem encoder. | |
| text_encoder_pretrained: bool | |
| whether or not to load a pretrained text encoder. | |
| temperature: float, optional | |
| used to scale the dot similarities | |
| poem_embedding: int, optional | |
| dim of poem encoder's encoding output before projection | |
| text_embedding: int, optional | |
| dim of text encoder's encoding output before projection | |
| """ | |
| super().__init__() | |
| self.poem_encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=poem_encoder_pretrained, trainable= CFG.poem_encoder_trainable) | |
| self.text_encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable= CFG.text_encoder_trainable) | |
| self.poem_projection = ProjectionHead(embedding_dim=poem_embedding) | |
| if CFG.poem_projection_load_path: # if provided, load projection weights from this path | |
| self.poem_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device)) | |
| self.text_projection = ProjectionHead(embedding_dim=text_embedding) | |
| if CFG.text_projection_load_path: # if provided, load projection weights from this path | |
| self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device)) | |
| self.temperature = temperature | |
| def forward(self, batch): | |
| """ | |
| returns poem and text embeddings of batch | |
| Parameters: | |
| ----------- | |
| batch: list of dict | |
| input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text') | |
| Returns: | |
| -------- | |
| poem and text embeddings of batch (each of shape (batch_size, projection_dim)) | |
| """ | |
| beyts, texts = batch["beyt"], batch["text"] | |
| # Getting Beyt and Text Features | |
| poem_features = self.poem_encoder( | |
| input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"] | |
| ) | |
| text_features = self.text_encoder( | |
| input_ids=texts["input_ids"], attention_mask=texts["attention_mask"] | |
| ) | |
| # Getting Beyt and Text Embeddings (with same dimension) | |
| poem_embeddings = self.poem_projection(poem_features) | |
| text_embeddings = self.text_projection(text_features) | |
| return poem_embeddings, text_embeddings | |
| def similarity_scores(self, batch): | |
| """ | |
| computes dot similarities of a batch of text-poem pair | |
| Parameters: | |
| ----------- | |
| batch: list of dict | |
| input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text') | |
| Returns: | |
| -------- | |
| dot similarity of poem and text embeddings of batch (of shape (batch_size, batch_size)) | |
| """ | |
| # Getting Beyt and Text Embeddings (with same dimension) | |
| poem_embeddings, text_embeddings = self.forward(batch) | |
| # Normalizing embeddings | |
| poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1) | |
| text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) | |
| # Computing dot / cosine similarity of the normalized embeddings | |
| dot_similarity = text_embeddings_n @ poem_embeddings_n.T | |
| return dot_similarity # (batch_size, batch_size) first dim is texts, second dim is poems for each text | |
| def predict(self, batch): | |
| """ | |
| predicts the most similar poem (idx) for each text (using previous methods) | |
| Parameters: | |
| ----------- | |
| batch: list of dict | |
| input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text') | |
| Returns: | |
| -------- | |
| index of poem predicted for each text (of shape (batch_size)) | |
| """ | |
| dot_similarity = self.similarity_scores(batch) | |
| # Getting argmax in first dimension of the dot-similarities to predict index of the most similar poem for each text | |
| return torch.argmax(dot_similarity, dim=1) | |
| def calculate_loss(self, poem_embeddings, text_embeddings): | |
| """ | |
| computes contrastive (cross entropy) loss for both poems and texts. | |
| Parameters: | |
| ----------- | |
| poem_embeddings: of shape (batch_size, projection_dim) | |
| output embeddings of poem projection head | |
| text_embeddings: of shape (batch_size, projection_dim) | |
| output embeddings of text projection head | |
| Returns: | |
| -------- | |
| average of the loss computed from inputs | |
| """ | |
| # dot similarity of the embeddings scaled by temperature (logits) | |
| logits = (text_embeddings @ poem_embeddings.T) / self.temperature | |
| # computing targets for the cross entropy loss to compare with logits. | |
| # each embedding's similarity is computed with itself and then added, | |
| # scaled by the temperature parameter, and normalized into a probability distribution via a softmax | |
| poems_similarity = poem_embeddings @ poem_embeddings.T | |
| texts_similarity = text_embeddings @ text_embeddings.T | |
| targets = F.softmax( | |
| (poems_similarity + texts_similarity) / 2 * self.temperature, dim=-1 | |
| ) | |
| # taking cross entropy loss in both dimensions: once for texts and once for poems | |
| texts_loss = cross_entropy(logits, targets, reduction='none') | |
| poems_loss = cross_entropy(logits.T, targets.T, reduction='none') | |
| loss = (poems_loss + texts_loss) / 2.0 # average of losses. shape: (batch_size) | |
| return loss.mean() | |
| def save_current(self): | |
| """ | |
| saves current model's encoders (if trainable) and projection heads. | |
| """ | |
| if CFG.text_encoder_trainable: | |
| self.text_encoder.model.save_pretrained(CFG.text_encoder_save_path) | |
| if CFG.poem_encoder_trainable: | |
| self.poem_encoder.model.save_pretrained(CFG.poem_encoder_save_path) | |
| torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path) | |
| torch.save(self.poem_projection.state_dict(), CFG.poem_projection_save_path) | |
| class CLIPModel(nn.Module): | |
| """ | |
| Model predicting poem/text and image embeddings, and their similarities. | |
| ... | |
| Attributes: | |
| ----------- | |
| encoder : TextEncoder | |
| encoder used for extracting poem/text embeddings | |
| image_encoder : ImageEncoder | |
| encoder used for extracting image embeddings | |
| text_projection: ProjectionHead | |
| projection head used for poem/text embeddings (projects text encoder output to shared embedding space) | |
| image_projection: ProjectionHead | |
| projection head used for image embeddings (projects image encoder output to shared embedding space) | |
| temperature: float | |
| used to scale the dot similarities | |
| Methods: | |
| -------- | |
| forward(batch): | |
| returns poem/text and image embeddings of batch | |
| similarity_scores(batch): | |
| computes dot similarities of a batch of text-image pair | |
| predict(batch): | |
| predicts the most similar poem/text idx for each image (using previous methods) | |
| calculate_loss(batch): | |
| computes contrastive (cross entropy) loss for both poems/texts and images. | |
| save_current(): | |
| saves current model's encoders (if trainable) and projection heads. | |
| """ | |
| def __init__( | |
| self, | |
| image_encoder_pretrained, | |
| text_encoder_pretrained, | |
| text_projection_trainable, | |
| temperature=CFG.temperature, | |
| image_embedding=CFG.image_embedding, | |
| text_embedding=CFG.text_embedding, | |
| is_image_poem_pair=True | |
| ): | |
| """ | |
| Initializes model's submodules | |
| Parameters: | |
| ----------- | |
| image_encoder_pretrained: bool | |
| whether or not to load a pretrained image encoder. | |
| text_encoder_pretrained: bool | |
| whether or not to load a pretrained text encoder. | |
| text_projection_trainable: bool | |
| whether or not to train text projection | |
| (since the text projection is frozen in our trainings unlike other projections of models) | |
| temperature: float, optional | |
| used to scale the dot similarities | |
| image_embedding: int, optional | |
| dim of image encoder's encoding output before projection | |
| text_embedding: int, optional | |
| dim of text encoder's encoding output before projection | |
| is_image_poem_pair: bool, optional | |
| if True, the text inputs to this model is poems and needs one of the poem encoders to predict embeddings with. | |
| else it's a text that needs the encoders dedicated to text. | |
| """ | |
| super().__init__() | |
| # Loading the encoders and their projections using configs | |
| self.image_encoder = ImageEncoder(pretrained=image_encoder_pretrained, trainable=CFG.image_encoder_trainable) | |
| if is_image_poem_pair: | |
| self.encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.poem_encoder_trainable) | |
| self.text_projection = ProjectionHead(embedding_dim=text_embedding) | |
| if CFG.poem_projection_load_path: | |
| self.text_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device)) | |
| else: | |
| self.encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.text_encoder_trainable) | |
| self.text_projection = ProjectionHead(embedding_dim=text_embedding) | |
| if CFG.text_projection_load_path: | |
| self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device)) | |
| self.image_projection = ProjectionHead(embedding_dim=image_embedding) | |
| if CFG.image_projection_load_path: | |
| self.image_projection.load_state_dict(torch.load(CFG.image_projection_load_path, map_location=CFG.device)) | |
| if not text_projection_trainable: | |
| for p in self.text_projection.parameters(): | |
| p.requires_grad = False | |
| self.text_projection_trainable = text_projection_trainable | |
| self.is_image_poem_pair = is_image_poem_pair | |
| self.temperature = temperature | |
| def forward(self, batch): | |
| """ | |
| returns image and text/poem embeddings of batch | |
| Parameters: | |
| ----------- | |
| batch: list of dict | |
| input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) | |
| with keys 'image' and 'text') | |
| Returns: | |
| -------- | |
| poem/text and image embeddings of batch (each of shape (batch_size, projection_dim)) | |
| """ | |
| image, texts = batch["image"], batch["text"] | |
| # Getting Image and Text Features | |
| image_features = self.image_encoder(batch["image"]) | |
| text_features = self.encoder( | |
| input_ids=texts["input_ids"], attention_mask=texts["attention_mask"] | |
| ) | |
| # Getting Image and Text Embeddings (with same dimension) | |
| image_embeddings = self.image_projection(image_features) | |
| text_embeddings = self.text_projection(text_features) | |
| return image_embeddings, text_embeddings | |
| def similarity_scores(self, batch): | |
| """ | |
| computes dot similarities of a batch of text/poem-image pair | |
| Parameters: | |
| ----------- | |
| batch: list of dict | |
| input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) | |
| with keys 'image' and 'text') | |
| Returns: | |
| -------- | |
| dot similarity of poem/text and image embeddings of batch (of shape (batch_size, batch_size)) | |
| """ | |
| # Getting Image and Text Embeddings (with same dimension) | |
| image_embeddings, text_embeddings = self.forward(batch) | |
| # Normalizing embeddings | |
| image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) | |
| text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) | |
| # Computing dot / cosine similarity of the normalized embeddings | |
| dot_similarity = image_embeddings_n @ text_embeddings_n.T | |
| return dot_similarity # (batch_size, batch_size) first dim is images, second dim is poems/texts for each image | |
| def predict(self, batch): | |
| """ | |
| predicts the most similar poem/text (idx) for each image (using previous methods) | |
| Parameters: | |
| ----------- | |
| batch: list of dict | |
| input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) | |
| with keys 'image' and 'text') | |
| Returns: | |
| -------- | |
| index of poem/text predicted for each image (of shape (batch_size)) | |
| """ | |
| dot_similarity = self.similarity_scores(batch) | |
| # Getting argmax in first dimension of the dot-similarities | |
| # to predict index of the most similar poem/text for each image | |
| return torch.argmax(dot_similarity, dim=1) | |
| def calculate_loss(self, image_embeddings, text_embeddings): | |
| """ | |
| computes contrastive (cross entropy) loss for both poems/texts and images. | |
| Parameters: | |
| ----------- | |
| image_embeddings: of shape (batch_size, projection_dim) | |
| output embeddings of image projection head | |
| text_embeddings: of shape (batch_size, projection_dim) | |
| output embeddings of text projection head | |
| Returns: | |
| -------- | |
| average of the loss computed from inputs | |
| """ | |
| # dot similarity of the embeddings scaled by temperature (logits) | |
| logits = (text_embeddings @ image_embeddings.T) / self.temperature | |
| # computing targets for the cross entropy loss to compare with logits. | |
| # each embedding's similarity is computed with itself and then averaged, | |
| # scaled by the temperature parameter, and normalized into a probability distribution via a softmax | |
| images_similarity = image_embeddings @ image_embeddings.T | |
| texts_similarity = text_embeddings @ text_embeddings.T | |
| targets = F.softmax( | |
| (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 | |
| ) | |
| # taking cross entropy loss in both dimensions: once for texts and once for images | |
| texts_loss = cross_entropy(logits, targets, reduction='none') | |
| images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |
| loss = (images_loss + texts_loss) / 2.0 # average of losses. shape: (batch_size) | |
| return loss.mean() | |
| def save_current(self): | |
| """ | |
| saves current model's encoders and projection heads (if trainable). | |
| """ | |
| if self.is_image_poem_pair: | |
| if CFG.poem_encoder_trainable: | |
| self.encoder.model.save_pretrained(CFG.poem_encoder_save_path) | |
| else: | |
| if CFG.text_encoder_trainable: | |
| self.encoder.model.save_pretrained(CFG.text_encoder_save_path) | |
| if CFG.image_encoder_trainable: | |
| torch.save(self.image_encoder.model.state_dict(), CFG.image_encoder_weights_save_path) | |
| if self.text_projection_trainable: | |
| torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path) | |
| torch.save(self.image_projection.state_dict(), CFG.image_projection_save_path) | |
| def cross_entropy(preds, targets, reduction='none'): | |
| """ | |
| Computes cross_entropy of logits and targets using their last dimension | |
| Parameters: | |
| ----------- | |
| preds: tensor/numpy array | |
| logits | |
| targets: tensor/ numpy array | |
| reduction: str, optional | |
| if set to "mean", return loss mean across all dimensions. | |
| if set to "none", return loss computed using last dim. | |
| Returns: | |
| -------- | |
| loss or loss average | |
| """ | |
| log_softmax = nn.LogSoftmax(dim=-1) | |
| loss = (-targets * log_softmax(preds)).sum(1) # cross entropy loss | |
| if reduction == "none": | |
| return loss | |
| elif reduction == "mean": | |
| return loss.mean() |