Safetensors
gLM2
custom_code
andrecornman commited on
Commit
08754cb
·
verified ·
1 Parent(s): 57d7cf6

fix init_weights

Browse files
Files changed (1) hide show
  1. modeling_glm2.py +18 -2
modeling_glm2.py CHANGED
@@ -352,7 +352,7 @@ class gLM2PreTrainedModel(PreTrainedModel):
352
  supports_gradient_checkpointing = False
353
 
354
  # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
355
- def _init_weights(module, initializer_range=0.02):
356
  if isinstance(module, nn.Linear):
357
  nn.init.normal_(module.weight, std=initializer_range)
358
  if module.bias is not None:
@@ -361,6 +361,22 @@ class gLM2PreTrainedModel(PreTrainedModel):
361
  nn.init.normal_(module.weight, std=initializer_range)
362
  if module.padding_idx is not None:
363
  nn.init.zeros_(module.weight[module.padding_idx])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
 
366
  class gLM2Model(gLM2PreTrainedModel):
@@ -412,7 +428,7 @@ class gLM2ForMaskedLM(gLM2PreTrainedModel):
412
 
413
  self.glm2 = gLM2Model(config)
414
  self.lm_head = gLM2LMHead(config)
415
- self.init_weights()
416
 
417
  def forward(
418
  self,
 
352
  supports_gradient_checkpointing = False
353
 
354
  # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
355
+ def _init_weights(self, module, initializer_range=0.02):
356
  if isinstance(module, nn.Linear):
357
  nn.init.normal_(module.weight, std=initializer_range)
358
  if module.bias is not None:
 
361
  nn.init.normal_(module.weight, std=initializer_range)
362
  if module.padding_idx is not None:
363
  nn.init.zeros_(module.weight[module.padding_idx])
364
+ elif isinstance(module, RotaryEmbedding):
365
+ # Re-calculate the frequencies using the module's stored attributes
366
+ inv_freq = 1.0 / (
367
+ module.base
368
+ ** (
369
+ torch.arange(0, module.dim, 2, device=module.inv_freq.device, dtype=torch.float32)
370
+ / module.dim
371
+ )
372
+ )
373
+ # Force the buffer to update
374
+ with torch.no_grad():
375
+ module.inv_freq.copy_(inv_freq)
376
+ elif isinstance(module, RMSNorm):
377
+ if hasattr(module, "variance_epsilon"):
378
+ with torch.no_grad():
379
+ module.variance_epsilon.fill_(self.config.norm_eps)
380
 
381
 
382
  class gLM2Model(gLM2PreTrainedModel):
 
428
 
429
  self.glm2 = gLM2Model(config)
430
  self.lm_head = gLM2LMHead(config)
431
+ self.post_init()
432
 
433
  def forward(
434
  self,