File size: 1,168 Bytes
6229e10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch

def standard_loss(self, model, inputs, return_outputs=False):
	outputs = model(**inputs)
	if self.args.past_index >= 0:
		self._past = outputs[self.args.past_index]
	loss = outputs[0].mean()
	print("loss : ", loss)
	return (loss,outputs) if return_outputs else loss

def hinge_cost(m, a, b):
	dist = m - torch.sqrt(torch.sum((a - b)**2, axis=1))
	return torch.mean(torch.clamp(dist,0,float('inf'))**2)

def sim_metric_loss(self, model, inputs, return_outputs=False):

	# single pass version
	batch_size = len(inputs["input_ids"])//4
	outputs = model(**inputs)
	x_p = outputs[:batch_size]
	x_n = outputs[batch_size:2*batch_size]
	y_p = outputs[2*batch_size:3*batch_size]
	y_n = outputs[3*batch_size:]

	model_attr = model
	if isinstance(model, torch.nn.DataParallel):
		model_attr = model.module

	cost_p = torch.mean(torch.sum((x_p - y_p)**2, axis=1))
	cost_n = model_attr.negative_importance*hinge_cost(
			model_attr.negative_threshold, x_n, y_n)
	cost_e = model_attr.entropy_importance*torch.mean(
			torch.sum(x_p**2, axis=1) + torch.sum(y_p**2, axis=1))
	loss = cost_p + cost_n + cost_e

	print(loss)
	return (loss,None) if return_outputs else loss