| | """dcode - Text to Polargraph Gcode via Stable Diffusion""" |
| |
|
| | import re |
| | import os |
| | import json |
| | import gradio as gr |
| | import torch |
| | import torch.nn as nn |
| | from pathlib import Path |
| | import spaces |
| |
|
| | |
| | BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5} |
| |
|
| | |
| | _model = None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GcodeDecoderConfigV3: |
| | """Config for v3 decoder architecture.""" |
| | |
| | def __init__( |
| | self, |
| | latent_channels: int = 4, |
| | latent_size: int = 64, |
| | hidden_size: int = 1024, |
| | num_layers: int = 12, |
| | num_heads: int = 16, |
| | vocab_size: int = 8192, |
| | max_seq_len: int = 2048, |
| | dropout: float = 0.1, |
| | ffn_mult: int = 4, |
| | ): |
| | self.latent_channels = latent_channels |
| | self.latent_size = latent_size |
| | self.hidden_size = hidden_size |
| | self.num_layers = num_layers |
| | self.num_heads = num_heads |
| | self.vocab_size = vocab_size |
| | self.max_seq_len = max_seq_len |
| | self.dropout = dropout |
| | self.ffn_mult = ffn_mult |
| |
|
| |
|
| | class CNNLatentProjector(nn.Module): |
| | """CNN-based latent projector preserving spatial structure.""" |
| | |
| | def __init__(self, config: GcodeDecoderConfigV3): |
| | super().__init__() |
| | |
| | self.cnn = nn.Sequential( |
| | nn.Conv2d(config.latent_channels, 64, 3, stride=2, padding=1), |
| | nn.LayerNorm([64, 32, 32]), |
| | nn.GELU(), |
| | nn.Conv2d(64, 128, 3, stride=2, padding=1), |
| | nn.LayerNorm([128, 16, 16]), |
| | nn.GELU(), |
| | nn.Conv2d(128, 256, 3, stride=2, padding=1), |
| | nn.LayerNorm([256, 8, 8]), |
| | nn.GELU(), |
| | nn.Conv2d(256, config.hidden_size, 3, stride=2, padding=1), |
| | nn.LayerNorm([config.hidden_size, 4, 4]), |
| | nn.GELU(), |
| | ) |
| | |
| | self.num_memory_tokens = 16 |
| | self.memory_pos = nn.Parameter(torch.randn(1, self.num_memory_tokens, config.hidden_size) * 0.02) |
| | |
| | def forward(self, latent: torch.Tensor) -> torch.Tensor: |
| | B = latent.shape[0] |
| | x = self.cnn(latent) |
| | x = x.view(B, x.shape[1], -1).transpose(1, 2) |
| | x = x + self.memory_pos.expand(B, -1, -1) |
| | return x |
| |
|
| |
|
| | class GcodeDecoderV3(nn.Module): |
| | """Large transformer decoder for gcode generation (v3).""" |
| | |
| | def __init__(self, config: GcodeDecoderConfigV3): |
| | super().__init__() |
| | self.config = config |
| | |
| | self.latent_proj = CNNLatentProjector(config) |
| | self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size) |
| | self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size) |
| | self.embed_drop = nn.Dropout(config.dropout) |
| | |
| | self.layers = nn.ModuleList([ |
| | nn.TransformerDecoderLayer( |
| | d_model=config.hidden_size, |
| | nhead=config.num_heads, |
| | dim_feedforward=config.hidden_size * config.ffn_mult, |
| | dropout=config.dropout, |
| | activation='gelu', |
| | batch_first=True, |
| | norm_first=True, |
| | ) |
| | for _ in range(config.num_layers) |
| | ]) |
| | |
| | self.ln_f = nn.LayerNorm(config.hidden_size) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | |
| | def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: |
| | B, seq_len = input_ids.shape |
| | device = input_ids.device |
| | dtype = latent.dtype |
| | |
| | memory = self.latent_proj(latent) |
| | positions = torch.arange(seq_len, device=device) |
| | x = self.token_embed(input_ids) + self.pos_embed(positions) |
| | x = self.embed_drop(x) |
| | |
| | causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device, dtype=dtype) |
| | |
| | for layer in self.layers: |
| | x = layer(x, memory, tgt_mask=causal_mask) |
| | |
| | x = self.ln_f(x) |
| | return self.lm_head(x) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GcodeDecoderConfigV2: |
| | def __init__( |
| | self, |
| | latent_channels: int = 4, |
| | latent_size: int = 64, |
| | hidden_size: int = 768, |
| | num_layers: int = 6, |
| | num_heads: int = 12, |
| | vocab_size: int = 32128, |
| | max_seq_len: int = 1024, |
| | dropout: float = 0.1, |
| | ): |
| | self.latent_channels = latent_channels |
| | self.latent_size = latent_size |
| | self.latent_dim = latent_channels * latent_size * latent_size |
| | self.hidden_size = hidden_size |
| | self.num_layers = num_layers |
| | self.num_heads = num_heads |
| | self.vocab_size = vocab_size |
| | self.max_seq_len = max_seq_len |
| | self.dropout = dropout |
| |
|
| |
|
| | class GcodeDecoderV2(nn.Module): |
| | def __init__(self, config: GcodeDecoderConfigV2): |
| | super().__init__() |
| | self.config = config |
| | |
| | self.latent_proj = nn.Sequential( |
| | nn.Linear(config.latent_dim, config.hidden_size * 4), |
| | nn.GELU(), |
| | nn.Linear(config.hidden_size * 4, config.hidden_size * 16), |
| | nn.LayerNorm(config.hidden_size * 16), |
| | ) |
| | |
| | self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size) |
| | self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size) |
| | |
| | self.layers = nn.ModuleList([ |
| | nn.TransformerDecoderLayer( |
| | d_model=config.hidden_size, |
| | nhead=config.num_heads, |
| | dim_feedforward=config.hidden_size * 4, |
| | dropout=config.dropout, |
| | activation='gelu', |
| | batch_first=True, |
| | norm_first=True, |
| | ) |
| | for _ in range(config.num_layers) |
| | ]) |
| | |
| | self.ln_f = nn.LayerNorm(config.hidden_size) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | self.lm_head.weight = self.token_embed.weight |
| | |
| | def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: |
| | batch_size, seq_len = input_ids.shape |
| | device = input_ids.device |
| | dtype = latent.dtype |
| | |
| | latent_flat = latent.view(batch_size, -1) |
| | memory = self.latent_proj(latent_flat) |
| | memory = memory.view(batch_size, 16, self.config.hidden_size) |
| | |
| | positions = torch.arange(seq_len, device=device) |
| | x = self.token_embed(input_ids) + self.pos_embed(positions) |
| | |
| | causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device, dtype=dtype) |
| | |
| | for layer in self.layers: |
| | x = layer(x, memory, tgt_mask=causal_mask) |
| | |
| | x = self.ln_f(x) |
| | return self.lm_head(x) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def get_model(): |
| | """Load and cache the SD-Gcode model.""" |
| | global _model |
| | if _model is None: |
| | from diffusers import StableDiffusionPipeline |
| | from transformers import AutoTokenizer, PreTrainedTokenizerFast |
| | from huggingface_hub import hf_hub_download |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | dtype = torch.float16 if device == "cuda" else torch.float32 |
| | |
| | print("Loading SD-Gcode model...") |
| | |
| | |
| | config_path = hf_hub_download("twarner/dcode-sd-gcode-v3", "config.json") |
| | weights_path = hf_hub_download("twarner/dcode-sd-gcode-v3", "pytorch_model.bin") |
| | |
| | with open(config_path) as f: |
| | config = json.load(f) |
| | |
| | |
| | gcode_cfg = config.get("gcode_decoder", {}) |
| | is_v3 = gcode_cfg.get("ffn_mult") is not None or gcode_cfg.get("hidden_size", 768) >= 1024 |
| | |
| | print(f"Model version: {'v3' if is_v3 else 'v2'}") |
| | |
| | |
| | sd_model_id = config.get("sd_model_id", "runwayml/stable-diffusion-v1-5") |
| | print(f"Loading SD from {sd_model_id}...") |
| | pipe = StableDiffusionPipeline.from_pretrained( |
| | sd_model_id, |
| | torch_dtype=dtype, |
| | safety_checker=None, |
| | ).to(device) |
| | |
| | |
| | if is_v3: |
| | decoder_config = GcodeDecoderConfigV3( |
| | latent_channels=gcode_cfg.get("latent_channels", 4), |
| | latent_size=gcode_cfg.get("latent_size", 64), |
| | hidden_size=gcode_cfg.get("hidden_size", 1024), |
| | num_layers=gcode_cfg.get("num_layers", 12), |
| | num_heads=gcode_cfg.get("num_heads", 16), |
| | vocab_size=gcode_cfg.get("vocab_size", 8192), |
| | max_seq_len=gcode_cfg.get("max_seq_len", 2048), |
| | ffn_mult=gcode_cfg.get("ffn_mult", 4), |
| | ) |
| | gcode_decoder = GcodeDecoderV3(decoder_config).to(device, dtype) |
| | else: |
| | decoder_config = GcodeDecoderConfigV2( |
| | latent_channels=gcode_cfg.get("latent_channels", 4), |
| | latent_size=gcode_cfg.get("latent_size", 64), |
| | hidden_size=gcode_cfg.get("hidden_size", 768), |
| | num_layers=gcode_cfg.get("num_layers", 6), |
| | num_heads=gcode_cfg.get("num_heads", 12), |
| | vocab_size=gcode_cfg.get("vocab_size", 32128), |
| | max_seq_len=gcode_cfg.get("max_seq_len", 1024), |
| | ) |
| | gcode_decoder = GcodeDecoderV2(decoder_config).to(device, dtype) |
| | |
| | |
| | print("Loading finetuned weights...") |
| | state_dict = torch.load(weights_path, map_location=device, weights_only=False) |
| | |
| | |
| | text_encoder_state = {k.replace("text_encoder.", ""): v for k, v in state_dict.items() |
| | if k.startswith("text_encoder.")} |
| | if text_encoder_state: |
| | pipe.text_encoder.load_state_dict(text_encoder_state, strict=False) |
| | print(f"Loaded {len(text_encoder_state)} text encoder weights") |
| | |
| | unet_state = {k.replace("unet.", ""): v for k, v in state_dict.items() |
| | if k.startswith("unet.")} |
| | if unet_state: |
| | pipe.unet.load_state_dict(unet_state, strict=False) |
| | print(f"Loaded {len(unet_state)} UNet weights") |
| | |
| | |
| | decoder_state = {k.replace("gcode_decoder.", ""): v for k, v in state_dict.items() |
| | if k.startswith("gcode_decoder.")} |
| | if decoder_state: |
| | try: |
| | gcode_decoder.load_state_dict(decoder_state, strict=True) |
| | print(f"Loaded {len(decoder_state)} decoder weights (strict)") |
| | except Exception as e: |
| | print(f"Strict load failed: {e}") |
| | gcode_decoder.load_state_dict(decoder_state, strict=False) |
| | print(f"Loaded {len(decoder_state)} decoder weights (non-strict)") |
| | |
| | gcode_decoder.eval() |
| | |
| | |
| | try: |
| | |
| | tokenizer_path = hf_hub_download("twarner/dcode-sd-gcode-v3", "gcode_tokenizer/tokenizer.json") |
| | gcode_tokenizer = PreTrainedTokenizerFast( |
| | tokenizer_file=tokenizer_path, |
| | pad_token="<pad>", |
| | unk_token="<unk>", |
| | bos_token="<s>", |
| | eos_token="</s>", |
| | ) |
| | |
| | print(f"Loaded custom gcode tokenizer (vocab={gcode_tokenizer.vocab_size})") |
| | print(f" BOS='{gcode_tokenizer.bos_token}' (id={gcode_tokenizer.bos_token_id})") |
| | print(f" EOS='{gcode_tokenizer.eos_token}' (id={gcode_tokenizer.eos_token_id})") |
| | print(f" PAD='{gcode_tokenizer.pad_token}' (id={gcode_tokenizer.pad_token_id})") |
| | |
| | |
| | test = "G0 X100 Y200\nG1 X150 Y250" |
| | enc = gcode_tokenizer.encode(test) |
| | dec = gcode_tokenizer.decode(enc) |
| | print(f" Test encode: {len(enc)} tokens") |
| | print(f" Test decode: '{dec[:50]}...'") |
| | except Exception as e: |
| | print(f"Failed to load custom tokenizer: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | |
| | gcode_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") |
| | print("Using fallback T5 tokenizer") |
| | |
| | _model = { |
| | "pipe": pipe, |
| | "gcode_decoder": gcode_decoder, |
| | "gcode_tokenizer": gcode_tokenizer, |
| | "device": device, |
| | "dtype": dtype, |
| | "num_inference_steps": config.get("num_inference_steps", 20), |
| | "is_v3": is_v3, |
| | } |
| | print("Model loaded!") |
| | |
| | return _model |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def is_valid_coord(s: str) -> bool: |
| | """Check if a string is a valid coordinate number.""" |
| | try: |
| | v = float(s) |
| | return -1000 < v < 1000 |
| | except (ValueError, TypeError): |
| | return False |
| |
|
| |
|
| | def clean_gcode(gcode: str) -> str: |
| | """Clean up generated gcode - fix formatting, remove garbage.""" |
| | |
| | |
| | gcode = gcode.replace("<newline>", "\n") |
| | |
| | |
| | if gcode.count("\n") < 10: |
| | |
| | gcode = re.sub(r'([GM]\d+)', r'\n\1', gcode) |
| | |
| | |
| | gcode = re.sub(r'(G[01])([XYZ])', r'\1 \2', gcode) |
| | gcode = re.sub(r'(G[01])F', r'\1 F', gcode) |
| | |
| | |
| | cleaned_lines = [] |
| | seen_coords = set() |
| | |
| | for line in gcode.split("\n"): |
| | line = line.strip() |
| | if not line: |
| | continue |
| | |
| | |
| | if line.lower() in ["dcode", "gcode", "code", "output"]: |
| | continue |
| | if line.startswith("Source:") or line.startswith(";Generated"): |
| | continue |
| | if line.startswith("Workarea:") or line.startswith("Algorithm:"): |
| | continue |
| | |
| | |
| | if re.search(r'X-Y-|Y-X-|X-X-|Y-Y-', line): |
| | continue |
| | |
| | |
| | line = re.sub(r'X--(\d)', r'X-\1', line) |
| | line = re.sub(r'Y--(\d)', r'Y-\1', line) |
| | |
| | |
| | line = re.sub(r'(G[01])X', r'\1 X', line) |
| | line = re.sub(r'(G[01])Y', r'\1 Y', line) |
| | |
| | |
| | x_match = re.search(r'X([-\d.]+)', line) |
| | y_match = re.search(r'Y([-\d.]+)', line) |
| | |
| | |
| | if x_match: |
| | if not is_valid_coord(x_match.group(1)): |
| | continue |
| | if y_match: |
| | if not is_valid_coord(y_match.group(1)): |
| | continue |
| | |
| | |
| | if x_match and y_match: |
| | try: |
| | coord = (round(float(x_match.group(1)), 1), round(float(y_match.group(1)), 1)) |
| | if coord in seen_coords: |
| | |
| | if len(seen_coords) > 5: |
| | continue |
| | seen_coords.add(coord) |
| | |
| | if len(seen_coords) > 50: |
| | seen_coords = set(list(seen_coords)[-50:]) |
| | except ValueError: |
| | pass |
| | |
| | |
| | if line and line[0] in "GMgm;": |
| | cleaned_lines.append(line) |
| | |
| | result = "\n".join(cleaned_lines) |
| | print(f"Cleaned gcode: {len(cleaned_lines)} lines") |
| | return result |
| |
|
| |
|
| | def center_and_scale_gcode(gcode: str) -> str: |
| | """Center the drawing on the workplane and scale to fill 80% of it.""" |
| | lines = gcode.split("\n") |
| | |
| | |
| | coords = [] |
| | for line in lines: |
| | x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
| | y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
| | if x_match and y_match: |
| | try: |
| | x = float(x_match.group(1)) |
| | y = float(y_match.group(1)) |
| | |
| | if -1000 < x < 1000 and -1000 < y < 1000: |
| | coords.append((x, y)) |
| | except ValueError: |
| | pass |
| | |
| | if len(coords) < 2: |
| | return gcode |
| | |
| | |
| | xs = [c[0] for c in coords] |
| | ys = [c[1] for c in coords] |
| | min_x, max_x = min(xs), max(xs) |
| | min_y, max_y = min(ys), max(ys) |
| | |
| | |
| | width = max_x - min_x |
| | height = max_y - min_y |
| | |
| | if width < 1 or height < 1: |
| | return gcode |
| | |
| | |
| | target_width = (BOUNDS["right"] - BOUNDS["left"]) * 0.8 |
| | target_height = (BOUNDS["top"] - BOUNDS["bottom"]) * 0.8 |
| | |
| | |
| | scale = min(target_width / width, target_height / height) |
| | |
| | |
| | cx = (min_x + max_x) / 2 |
| | cy = (min_y + max_y) / 2 |
| | |
| | |
| | target_cx = (BOUNDS["left"] + BOUNDS["right"]) / 2 |
| | target_cy = (BOUNDS["bottom"] + BOUNDS["top"]) / 2 |
| | |
| | print(f"Centering: bbox=({min_x:.0f},{min_y:.0f})-({max_x:.0f},{max_y:.0f}), scale={scale:.2f}") |
| | |
| | |
| | result = [] |
| | for line in lines: |
| | new_line = line |
| | x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
| | y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
| | |
| | if x_match: |
| | try: |
| | x = float(x_match.group(1)) |
| | new_x = (x - cx) * scale + target_cx |
| | new_x = max(BOUNDS["left"], min(BOUNDS["right"], new_x)) |
| | new_line = re.sub(r"X[-\d.]+", f"X{new_x:.2f}", new_line, count=1, flags=re.IGNORECASE) |
| | except ValueError: |
| | pass |
| | |
| | if y_match: |
| | try: |
| | y = float(y_match.group(1)) |
| | new_y = (y - cy) * scale + target_cy |
| | new_y = max(BOUNDS["bottom"], min(BOUNDS["top"], new_y)) |
| | new_line = re.sub(r"Y[-\d.]+", f"Y{new_y:.2f}", new_line, count=1, flags=re.IGNORECASE) |
| | except ValueError: |
| | pass |
| | |
| | result.append(new_line) |
| | |
| | return "\n".join(result) |
| |
|
| |
|
| | def validate_gcode(gcode: str) -> str: |
| | """Clamp coordinates to machine bounds.""" |
| | lines = [] |
| | for line in gcode.split("\n"): |
| | corrected = line |
| | |
| | x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
| | if x_match: |
| | try: |
| | x = float(x_match.group(1)) |
| | x = max(BOUNDS["left"], min(BOUNDS["right"], x)) |
| | corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE) |
| | except ValueError: |
| | pass |
| |
|
| | y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
| | if y_match: |
| | try: |
| | y = float(y_match.group(1)) |
| | y = max(BOUNDS["bottom"], min(BOUNDS["top"], y)) |
| | corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE) |
| | except ValueError: |
| | pass |
| |
|
| | lines.append(corrected) |
| |
|
| | return "\n".join(lines) |
| |
|
| |
|
| | def gcode_to_svg(gcode: str) -> str: |
| | """Convert gcode to SVG for visual preview.""" |
| | paths = [] |
| | current_path = [] |
| | x, y = 0.0, 0.0 |
| | pen_down = False |
| |
|
| | |
| | gcode = gcode.replace("<newline>", "\n") |
| | |
| | |
| | |
| | lines = [] |
| | for raw_line in gcode.split("\n"): |
| | raw_line = raw_line.strip() |
| | if not raw_line: |
| | continue |
| | |
| | parts = re.split(r'(?=[GM]\d)', raw_line) |
| | for part in parts: |
| | part = part.strip() |
| | if part and not part.startswith(";") and part[0] in "GMgm": |
| | lines.append(part) |
| | |
| | for line in lines: |
| | if "M280" in line.upper(): |
| | match = re.search(r"S(\d+)", line, re.IGNORECASE) |
| | if match: |
| | angle = int(match.group(1)) |
| | was_down = pen_down |
| | pen_down = angle < 50 |
| | if was_down and not pen_down and len(current_path) > 1: |
| | paths.append(current_path[:]) |
| | current_path = [] |
| |
|
| | x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
| | y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
| | |
| | if x_match: |
| | try: |
| | x = float(x_match.group(1)) |
| | except ValueError: |
| | pass |
| | if y_match: |
| | try: |
| | y = float(y_match.group(1)) |
| | except ValueError: |
| | pass |
| |
|
| | if (x_match or y_match) and pen_down: |
| | current_path.append((x, y)) |
| |
|
| | if len(current_path) > 1: |
| | paths.append(current_path) |
| |
|
| | w = BOUNDS["right"] - BOUNDS["left"] |
| | h = BOUNDS["top"] - BOUNDS["bottom"] |
| | padding = 20 |
| | |
| | |
| | svg = f'''<svg xmlns="http://www.w3.org/2000/svg" |
| | viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}" |
| | class="gcode-preview" |
| | style="width: 100%; height: 480px; border-radius: 8px; border: 1px solid var(--block-border-color); background: var(--block-background-fill);"> |
| | <defs> |
| | <style> |
| | .gcode-preview .work-area {{ fill: var(--background-fill-primary); stroke: var(--block-border-color); }} |
| | .gcode-preview .draw-path {{ stroke: var(--body-text-color); }} |
| | .gcode-preview .info-text {{ fill: var(--body-text-color-subdued); }} |
| | </style> |
| | </defs> |
| | <rect class="work-area" x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}" stroke-width="1"/> |
| | ''' |
| |
|
| | for path in paths: |
| | if len(path) < 2: |
| | continue |
| | d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path)) |
| | svg += f'<path class="draw-path" d="{d}" fill="none" stroke-width="1" stroke-linecap="round" stroke-linejoin="round"/>' |
| |
|
| | total_points = sum(len(p) for p in paths) |
| | svg += f''' |
| | <text class="info-text" x="{BOUNDS["left"] + 8}" y="{-BOUNDS["top"] + 20}" font-family="monospace" font-size="12"> |
| | {len(paths)} paths / {total_points} points |
| | </text> |
| | ''' |
| | svg += "</svg>" |
| | return svg |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def enhance_prompt(prompt: str) -> str: |
| | """Enhance prompt to match BLIP caption style from training data. |
| | |
| | BLIP generates captions like: |
| | - "a drawing of a horse" |
| | - "a sketch of a cat" |
| | - "a black and white drawing" |
| | - "an illustration of a flower" |
| | """ |
| | prompt = prompt.strip().lower() |
| | |
| | |
| | if prompt.startswith(("a ", "an ", "the ")): |
| | enhanced = prompt |
| | |
| | elif any(x in prompt for x in ["drawing", "sketch", "illustration", "image"]): |
| | enhanced = f"a {prompt}" |
| | |
| | else: |
| | enhanced = f"a drawing of a {prompt}" |
| | |
| | |
| | enhanced += ", black and white, simple lines, sketch style" |
| | return enhanced |
| |
|
| |
|
| | @spaces.GPU |
| | def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float, seed: int = -1): |
| | """Generate gcode from text prompt.""" |
| | if not prompt or not prompt.strip(): |
| | return "Enter a prompt to generate gcode", gcode_to_svg("") |
| |
|
| | try: |
| | m = get_model() |
| | pipe = m["pipe"] |
| | gcode_decoder = m["gcode_decoder"] |
| | gcode_tokenizer = m["gcode_tokenizer"] |
| | device = m["device"] |
| | dtype = m["dtype"] |
| | is_v3 = m.get("is_v3", False) |
| | |
| | |
| | enhanced = enhance_prompt(prompt) |
| | print(f"Enhanced prompt: {enhanced}") |
| | |
| | |
| | generator = None |
| | if seed >= 0: |
| | generator = torch.Generator(device=device).manual_seed(int(seed)) |
| | print(f"Using seed: {seed}") |
| | |
| | |
| | with torch.no_grad(): |
| | |
| | result = pipe( |
| | enhanced, |
| | negative_prompt="color, shading, gradient, photorealistic, 3d, complex, detailed texture", |
| | num_inference_steps=num_steps, |
| | guidance_scale=guidance, |
| | output_type="latent", |
| | generator=generator, |
| | ) |
| | latent = result.images.to(dtype) |
| | print(f"Latent shape: {latent.shape}, dtype: {latent.dtype}") |
| | |
| | |
| | with torch.no_grad(): |
| | batch_size = latent.shape[0] |
| | |
| | |
| | bos_id = gcode_tokenizer.bos_token_id |
| | eos_id = gcode_tokenizer.eos_token_id |
| | pad_id = gcode_tokenizer.pad_token_id |
| | |
| | |
| | if is_v3: |
| | |
| | start_text = "G21\nG90\nM280 P0 S90\nG28\n" |
| | start_tokens = gcode_tokenizer.encode(start_text, add_special_tokens=False) |
| | if bos_id is not None: |
| | start_tokens = [bos_id] + start_tokens |
| | input_ids = torch.tensor([start_tokens], dtype=torch.long, device=device) |
| | else: |
| | start_tokens = gcode_tokenizer.encode(";", add_special_tokens=False) |
| | start_id = start_tokens[0] if start_tokens else 0 |
| | input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device) |
| | |
| | print(f"Starting with {input_ids.shape[1]} tokens, BOS={bos_id}, EOS={eos_id}") |
| | |
| | max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - input_ids.shape[1]) |
| | |
| | |
| | recent_tokens = [] |
| | |
| | for step in range(max_gen): |
| | logits = gcode_decoder(latent, input_ids) |
| | next_logits = logits[:, -1, :] / temperature |
| | |
| | |
| | if pad_id is not None: |
| | next_logits[:, pad_id] = float('-inf') |
| | next_logits[:, 1] = float('-inf') |
| | |
| | |
| | if recent_tokens: |
| | for token_id in set(recent_tokens[-50:]): |
| | next_logits[:, token_id] *= 0.5 |
| | |
| | |
| | top_k = 50 |
| | top_p = 0.92 |
| | |
| | |
| | top_k_logits, top_k_indices = torch.topk(next_logits, top_k, dim=-1) |
| | |
| | |
| | sorted_logits, sorted_idx = torch.sort(top_k_logits, descending=True, dim=-1) |
| | cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() |
| | sorted_indices_to_remove[:, 0] = False |
| | sorted_logits[sorted_indices_to_remove] = float('-inf') |
| | |
| | probs = torch.softmax(sorted_logits, dim=-1) |
| | sampled_idx = torch.multinomial(probs, num_samples=1) |
| | |
| | next_token = top_k_indices.gather(-1, sorted_idx.gather(-1, sampled_idx)) |
| | input_ids = torch.cat([input_ids, next_token], dim=1) |
| | recent_tokens.append(next_token.item()) |
| | |
| | |
| | if step < 5: |
| | tok_str = gcode_tokenizer.decode([next_token.item()]) |
| | print(f" Step {step}: token={next_token.item()}, str='{tok_str}'") |
| | |
| | |
| | if eos_id is not None and next_token.item() == eos_id: |
| | print(f"Hit EOS at step {step}") |
| | break |
| | |
| | |
| | if len(recent_tokens) > 30: |
| | if len(set(recent_tokens[-30:])) < 5: |
| | print(f"Stopping due to repetition at step {step}") |
| | break |
| | |
| | print(f"Generated {input_ids.shape[1]} total tokens") |
| | |
| | |
| | gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=False) |
| | |
| | |
| | gcode = gcode.replace("<pad>", "").replace("<s>", "").replace("</s>", "").replace("<unk>", "") |
| | |
| | |
| | gcode = gcode.replace("<newline>", "\n") |
| | |
| | print(f"Raw decoded (first 300 chars): {repr(gcode[:300])}") |
| | |
| | |
| | gcode = clean_gcode(gcode) |
| | |
| | |
| | gcode = center_and_scale_gcode(gcode) |
| | gcode = validate_gcode(gcode) |
| | line_count = len([l for l in gcode.split("\n") if l.strip()]) |
| | svg = gcode_to_svg(gcode) |
| | |
| | header = f"; dcode output\n; prompt: {prompt}\n; {line_count} commands\n\n" |
| | return header + gcode, svg |
| | |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | return f"; Error: {e}", gcode_to_svg("") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | css = """ |
| | @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;500&display=swap'); |
| | |
| | * { |
| | font-family: 'IBM Plex Mono', monospace !important; |
| | } |
| | |
| | .gradio-container { |
| | max-width: 900px !important; |
| | margin: auto; |
| | } |
| | |
| | footer { |
| | display: none !important; |
| | } |
| | """ |
| |
|
| | with gr.Blocks(css=css, theme=gr.themes.Default()) as demo: |
| | gr.Markdown("# dcode") |
| | gr.Markdown("text → polargraph gcode via stable diffusion") |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | prompt = gr.Textbox( |
| | label="prompt", |
| | placeholder="describe what to draw...", |
| | lines=2, |
| | show_label=True, |
| | ) |
| | |
| | with gr.Accordion("settings", open=False): |
| | temperature = gr.Slider(0.3, 1.2, value=0.7, label="temperature", step=0.1) |
| | max_tokens = gr.Slider(256, 2048, value=2048, step=256, label="max tokens") |
| | num_steps = gr.Slider(20, 75, value=50, step=5, label="diffusion steps") |
| | guidance = gr.Slider(5.0, 20.0, value=12.0, step=0.5, label="guidance") |
| | seed = gr.Number(value=-1, label="seed (-1 = random)", precision=0) |
| | |
| | generate_btn = gr.Button("generate", variant="secondary") |
| | |
| | gr.Examples( |
| | examples=[ |
| | ["a drawing of a horse"], |
| | ["a sketch of a cat"], |
| | ["a simple flower drawing"], |
| | ["a drawing of a tree"], |
| | ["abstract lines"], |
| | ["a portrait sketch"], |
| | ], |
| | inputs=prompt, |
| | label=None, |
| | examples_per_page=6, |
| | ) |
| | |
| | with gr.Column(scale=2): |
| | preview = gr.HTML(value=gcode_to_svg("")) |
| | |
| | with gr.Accordion("gcode", open=False): |
| | gcode_output = gr.Code(label=None, language=None, lines=12) |
| | |
| | gr.Markdown("---") |
| | gr.Markdown("machine: 841×1189mm / pen servo 40-90° / [github](https://github.com/Twarner491/dcode) / [model](https://huggingface.co/twarner/dcode-sd-gcode-v3) / mit") |
| | |
| | generate_btn.click(generate, [prompt, temperature, max_tokens, num_steps, guidance, seed], [gcode_output, preview]) |
| | prompt.submit(generate, [prompt, temperature, max_tokens, num_steps, guidance, seed], [gcode_output, preview]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|