DUDE1.0Beta / finetune.py
DSDUDEd's picture
Rename fine-tune.py to finetune.py
4900adc verified
# finetune.py
import os
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
)
import torch
MODEL_ID = os.environ.get("MODEL_ID", "openai-community/gpt2")
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "gpt2-deepwriting-finetuned")
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 4))
EPOCHS = int(os.environ.get("EPOCHS", 3))
HF_TOKEN = os.environ.get("HF_TOKEN", None)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
model.resize_token_embeddings(len(tokenizer))
def tokenize_function(examples):
# pick the first text-like column automatically
text_col = None
for c in examples.keys():
if c.lower() in ("text", "content", "sentence", "story"):
text_col = c
break
if text_col is None:
text_col = list(examples.keys())[0]
return tokenizer(examples[text_col], truncation=True, padding=False, max_length=512)
def main():
# load dataset; use token=... instead of deprecated use_auth_token
ds = load_dataset("m-a-p/DeepWriting-20K", token=HF_TOKEN)
# pick train split if exists
split = "train" if "train" in ds else list(ds.keys())[0]
train_ds = ds[split]
tokenized = train_ds.map(tokenize_function, batched=True, remove_columns=train_ds.column_names)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=BATCH_SIZE,
num_train_epochs=EPOCHS,
save_strategy="epoch",
logging_steps=100,
fp16=torch.cuda.is_available(),
push_to_hub=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized,
data_collator=data_collator,
)
trainer.train()
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Saved fine-tuned model to", OUTPUT_DIR)
if __name__ == "__main__":
main()