| |
| import os, math, numpy as np, sentencepiece as spm, requests, tqdm |
| from functools import partial |
| from typing import Any |
| import jax, jax.numpy as jnp |
| from jax import random |
| from flax import linen as nn |
| from flax.training import train_state, checkpoints |
| import optax |
| import requests |
|
|
| def download_file(url, save_path): |
| r = requests.get(url, stream=True) |
| r.raise_for_status() |
| with open(save_path, "wb") as f: |
| for chunk in r.iter_content(8192*2): |
| f.write(chunk) |
| print(f"โ
{save_path} ์ ์ฅ๋จ") |
| |
| |
| |
| SEQ_LEN = 512 |
| GLOBAL_BATCH = 256 |
| LIMIT = 200_000 |
| VOCAB_MODEL = "ko_unigram.model" |
| CORPUS_PATH = "corpus.txt" |
| SEED = 42 |
| LEARNING_RATE = 1e-4 |
| EPOCHS = 1 |
|
|
| if not os.path.exists(CORPUS_PATH): |
| download_file( |
| "https://huggingface.co/datasets/Yuchan5386/Prototype/resolve/main/corpus_ko.txt?download=true", |
| CORPUS_PATH |
| ) |
|
|
| if not os.path.exists(VOCAB_MODEL): |
| download_file( |
| "https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true", |
| VOCAB_MODEL |
| ) |
|
|
| DTYPE = jnp.bfloat16 if jax.local_devices()[0].platform == "tpu" else jnp.float32 |
| NUM_DEVICES = jax.device_count() |
| PER_DEVICE_BATCH = GLOBAL_BATCH // NUM_DEVICES |
| print("devices:", jax.devices(), "dtype:", DTYPE) |
|
|
| |
| |
| |
| sp = spm.SentencePieceProcessor() |
| sp.load(VOCAB_MODEL) |
| pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>")!=-1 else 0 |
| start_id = sp.piece_to_id("<start>") |
| end_id = sp.piece_to_id("<end>") |
| vocab_size = sp.get_piece_size() |
| print("vocab_size:", vocab_size, "pad_id:", pad_id, "start_id:", start_id, "end_id:", end_id) |
|
|
| |
| |
| |
| def line_to_ids(line, max_len=SEQ_LEN): |
| ids = sp.encode(line.strip(), out_type=int) |
| if len(ids) > max_len-1: ids = ids[:max_len-1] |
| ids += [end_id] + [pad_id]*(max_len-len(ids)-1) |
| return np.array(ids, dtype=np.int32) |
|
|
| def build_dataset(corpus_path, limit=LIMIT): |
| arr = [] |
| with open(corpus_path, "r", encoding="utf-8") as f: |
| for i, line in enumerate(f): |
| if i>=limit: break |
| line=line.strip() |
| if not line: continue |
| arr.append(line_to_ids(line)) |
| data = np.stack(arr, axis=0) |
| print("Loaded dataset:", data.shape) |
| return data |
|
|
| data_np = build_dataset(CORPUS_PATH, LIMIT) |
| inputs = data_np |
| targets = np.concatenate([data_np[:,1:], np.full((data_np.shape[0],1), pad_id, np.int32)], axis=1) |
|
|
| def create_batch_iter(inputs, targets, batch_size, rng): |
| idx = np.arange(inputs.shape[0]); rng.shuffle(idx) |
| for i in range(0,len(idx)-batch_size+1,batch_size): |
| batch_idx = idx[i:i+batch_size] |
| yield inputs[batch_idx], targets[batch_idx] |
|
|
| def shard(xs): return xs.reshape(NUM_DEVICES, -1, xs.shape[1]) |
|
|
| class SwiGLU(nn.Module): |
| d_model: int |
| @nn.compact |
| def __call__(self, x): |
| x_f32 = x.astype(jnp.float32) |
| proj = nn.Dense(self.d_model*2, dtype=jnp.float32)(x_f32) |
| x_val, x_gate = jnp.split(proj, 2, axis=-1) |
| out = x_val * nn.silu(x_gate) |
| out = nn.Dense(self.d_model, dtype=jnp.float32)(out) |
| return out.astype(x.dtype) |
|
|
| class LoU(nn.Module): |
| d_model: int |
| clip_value: float = 5.0 |
| eps: float = 1e-6 |
| @nn.compact |
| def __call__(self, x): |
| x_f32 = x.astype(jnp.float32) |
| residual = x_f32 |
| x_norm = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(x_f32) |
| Q = nn.Dense(self.d_model, dtype=jnp.float32) |
| K = nn.Dense(self.d_model, dtype=jnp.float32) |
| V = nn.Dense(self.d_model, dtype=jnp.float32) |
| q,k,v = Q(x_norm), K(x_norm), V(x_norm) |
| g_q = (jnp.tanh(q)+1)/2 |
| g_k = (jnp.tanh(k)+1)/2 |
| score = g_q * g_k |
| alpha_dynamic = nn.Dense(1, dtype=jnp.float32)(x_norm) |
| |
| score_t = jnp.transpose(score,(1,0,2)) |
| alpha_t = jnp.transpose(alpha_dynamic,(1,0,2)) |
| def step(prev, cur): |
| s, a = cur |
| new = a*s + (1-a)*prev |
| return new,new |
| init = score_t[0] |
| _, ema_seq = jax.lax.scan(step, init, (score_t[1:], alpha_t[1:])) |
| ema_full = jnp.concatenate([init[None,...], ema_seq], 0) |
| ema = jnp.transpose(ema_full,(1,0,2)) |
| out = v * ema + residual |
| out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(out) |
| return SwiGLU(self.d_model)(out).astype(x.dtype) |
|
|
|
|
| class Lo(nn.Module): |
| d_model:int |
| dtype:Any=DTYPE |
| @nn.compact |
| def __call__(self,x): |
| h=nn.Dense(64,dtype=self.dtype)(x); h=nn.silu(h) |
| h=nn.Dense(self.d_model,dtype=self.dtype)(h) |
| return nn.LayerNorm(epsilon=1e-5,dtype=self.dtype)(h)+x |
|
|
| class Block(nn.Module): |
| d_model:int |
| dtype:Any=DTYPE |
| @nn.compact |
| def __call__(self,x): |
| x=LoU(self.d_model,self.dtype)(x) |
| x=Lo(self.d_model,self.dtype)(x) |
| return x |
|
|
| class ReLM(nn.Module): |
| vocab_size:int; max_seq_len:int; d_model:int; n_layers:int; dtype:Any=DTYPE |
| def setup(self): |
| self.token_embed = nn.Embed(self.vocab_size,self.d_model,dtype=self.dtype) |
| self.pos_embed = nn.Embed(self.max_seq_len,self.d_model,dtype=self.dtype) |
| self.blocks=[Block(self.d_model,self.dtype) for _ in range(self.n_layers)] |
| self.ln_f=nn.LayerNorm(epsilon=1e-5,dtype=self.dtype) |
| def __call__(self,x,deterministic=True): |
| b,seq=x.shape |
| pos=jnp.arange(seq)[None,:] |
| x=self.token_embed(x)+self.pos_embed(pos) |
| for blk in self.blocks: x=blk(x) |
| x=self.ln_f(x) |
| logits=jnp.einsum("bld,vd->blv",x,self.token_embed.embedding) |
| return logits |
|
|
| def smoothed_ce(logits, targets, pad_id, eps=0.1): |
| logits = logits.astype(jnp.float32) |
| targets = targets.astype(jnp.int32) |
| vocab = logits.shape[-1] |
| mask = (targets != pad_id).astype(jnp.float32) |
| one_hot = jax.nn.one_hot(targets, vocab) |
| smooth = (1-eps)*one_hot + eps/vocab |
| log_probs = jax.nn.log_softmax(logits, axis=-1) |
| loss = -jnp.sum(smooth * log_probs, axis=-1) * mask |
| return jnp.sum(loss) / (jnp.sum(mask)+1e-8) |
|
|
| def masked_ppl(logits, targets, pad_id, eps=0.1): |
| logits = logits.astype(jnp.float32) |
| targets = targets.astype(jnp.int32) |
| vocab = logits.shape[-1] |
| mask = (targets != pad_id).astype(jnp.float32) |
| one_hot = jax.nn.one_hot(targets, vocab) |
| smooth = (1-eps)*one_hot + eps/vocab |
| log_probs = jax.nn.log_softmax(logits, axis=-1) |
| loss = -jnp.sum(smooth*log_probs, axis=-1) * mask |
| return jnp.exp(jnp.sum(loss)/(jnp.sum(mask)+1e-8)) |
|
|
| |
| |
| |
| class TrainState(train_state.TrainState): pass |
| def create_train_state(rng,model,lr): |
| params=model.init(rng,jnp.zeros((1,SEQ_LEN),dtype=jnp.int32))["params"] |
| tx=optax.chain(optax.clip_by_global_norm(1.0),optax.adamw(lr,b1=0.9,b2=0.95,eps=1e-8)) |
| return TrainState.create(apply_fn=model.apply,params=params,tx=tx) |
|
|
| |
| |
| |
| @partial(jax.pmap, axis_name="batch") |
| def train_step(state,bx,by,rngs): |
| def loss_fn(params): |
| logits=state.apply_fn({"params":params},bx,deterministic=False) |
| return smoothed_ce(logits,by,pad_id),logits |
| (loss,logits),grads=jax.value_and_grad(loss_fn,has_aux=True)(state.params) |
| grads=jax.lax.pmean(grads,"batch") |
| state=state.apply_gradients(grads=grads) |
| metrics={"loss":loss,"ppl":masked_ppl(logits,by,pad_id)} |
| metrics=jax.lax.pmean(metrics,"batch") |
| return state,metrics |
|
|
| |
| |
| |
| def top_p_sample(rng, logits, p=0.9, temperature=1.0): |
| probs=jax.nn.softmax(logits/temperature) |
| sorted_probs,sorted_idx=jax.lax.top_k(probs,logits.shape[-1]) |
| cum_probs=jnp.cumsum(sorted_probs) |
| mask=cum_probs<=p |
| top_probs=jnp.where(mask,sorted_probs,0.0) |
| top_probs=top_probs/jnp.sum(top_probs) |
| return int(sorted_idx[jax.random.categorical(rng,jnp.log(top_probs))]) |
|
|
| def generate_text(state,prompt,max_gen=256,p=0.9,temperature=0.8,min_len=20): |
| params=jax.tree_map(lambda x: np.array(x[0]),state.params) |
| tokens=sp.encode("<start> "+prompt,out_type=int) |
| generated=tokens.copy() |
| rng=random.PRNGKey(SEED) |
| for step in range(max_gen): |
| cur=generated[-SEQ_LEN:] |
| if len(cur)<SEQ_LEN: cur=cur+[pad_id]*(SEQ_LEN-len(cur)) |
| x=jnp.array([cur],dtype=jnp.int32) |
| logits=model.apply({"params":params},x,deterministic=True)[0,len(generated)-1] |
| logits=logits.at[end_id].add(-5.0).at[pad_id].add(-10.0) |
| next_id=top_p_sample(rng,logits,p,temperature) |
| generated.append(next_id) |
| if next_id==end_id and len(generated)>=min_len: break |
| return sp.decode(generated) |
|
|
| |
| |
| |
| rng=random.PRNGKey(SEED) |
| rng,init_rng=random.split(rng) |
| model=ReLM(vocab_size=vocab_size,max_seq_len=SEQ_LEN,d_model=512,n_layers=9,dtype=DTYPE) |
| state=create_train_state(init_rng,model,LEARNING_RATE) |
| state=jax.device_put_replicated(state,jax.local_devices()) |
|
|
| global_step=0 |
| for epoch in range(EPOCHS): |
| print(f"Epoch {epoch+1}/{EPOCHS}") |
| np_rng=np.random.default_rng(SEED+epoch) |
| batch_iter=create_batch_iter(inputs,targets,GLOBAL_BATCH,np_rng) |
| pbar=tqdm.tqdm(batch_iter,total=max(1,inputs.shape[0]//GLOBAL_BATCH)) |
| for bx,by in pbar: |
| bx_sh,by_sh=shard(bx),shard(by) |
| state,metrics=train_step(state,bx_sh,by_sh,jax.random.split(rng,NUM_DEVICES)) |
| m=jax.tree_util.tree_map(lambda x:x[0],metrics) |
| pbar.set_postfix(loss=float(m["loss"]),ppl=float(m["ppl"])) |
| global_step+=1 |
|
|
| |
| |
| |
| save_dir="./checkpoints" |
| os.makedirs(save_dir,exist_ok=True) |
| |
| |
|
|
| |
| import jax.tree_util |
| checkpoints.save_checkpoint(save_dir, jax.tree_util.tree_map(lambda x: np.array(x), state), step=global_step, keep=3) |
|
|
| print("Saved checkpoint to",save_dir) |
|
|
| |
| |
| |
| print("\n\n===== ์์ฑ ๊ฒฐ๊ณผ =====") |
| print(generate_text(state,"์ง๋ 2๋
๋์ ์ถ์ฐ์ฐ์ด ๊ตญ๊ฐ๊ฐ ํ์ํ ์ฐ๊ตฌ๋ฅผ",p=0.9)) |
|
|