| | """ |
| | Unified Evaluation Suite for 8-bit Threshold Computer |
| | ====================================================== |
| | GPU-batched evaluation with per-circuit reporting. |
| | Includes CPU runtime for threshold-weight execution. |
| | |
| | Usage: |
| | python eval.py # Run circuit evaluation |
| | python eval.py --device cpu # CPU mode |
| | python eval.py --pop_size 1000 # Population mode for evolution |
| | python eval.py --cpu-test # Run CPU smoke test |
| | |
| | API (for prune_weights.py): |
| | from eval import load_model, create_population, BatchedFitnessEvaluator |
| | from eval import ThresholdCPU, ThresholdALU, CPUState |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import os |
| | import time |
| | from collections import defaultdict |
| | from dataclasses import dataclass, field |
| | from typing import Callable, Dict, List, Optional, Tuple |
| |
|
| | import torch |
| | from safetensors import safe_open |
| |
|
| |
|
| | MODEL_PATH = os.path.join(os.path.dirname(__file__), "neural_computer.safetensors") |
| |
|
| |
|
| | @dataclass |
| | class CircuitResult: |
| | """Result for a single circuit test.""" |
| | name: str |
| | passed: int |
| | total: int |
| | failures: List[Tuple] = field(default_factory=list) |
| |
|
| | @property |
| | def success(self) -> bool: |
| | return self.passed == self.total |
| |
|
| | @property |
| | def rate(self) -> float: |
| | return self.passed / self.total if self.total > 0 else 0.0 |
| |
|
| |
|
| | def heaviside(x: torch.Tensor) -> torch.Tensor: |
| | """Threshold activation: 1 if x >= 0, else 0.""" |
| | return (x >= 0).float() |
| |
|
| |
|
| | def load_model(path: str = MODEL_PATH) -> Dict[str, torch.Tensor]: |
| | """Load model tensors from safetensors.""" |
| | with safe_open(path, framework='pt') as f: |
| | return {name: f.get_tensor(name).float() for name in f.keys()} |
| |
|
| |
|
| | def load_metadata(path: str = MODEL_PATH) -> Dict: |
| | """Load metadata from safetensors (includes signal_registry).""" |
| | with safe_open(path, framework='pt') as f: |
| | meta = f.metadata() |
| | if meta and 'signal_registry' in meta: |
| | return {'signal_registry': json.loads(meta['signal_registry'])} |
| | return {'signal_registry': {}} |
| |
|
| |
|
| | def get_manifest(tensors: Dict[str, torch.Tensor]) -> Dict[str, int]: |
| | """Extract manifest values from tensors. |
| | |
| | Returns dict with data_bits, addr_bits, memory_bytes, version. |
| | Defaults to 8-bit data, 16-bit addr for legacy models. |
| | """ |
| | return { |
| | 'data_bits': int(tensors.get('manifest.data_bits', torch.tensor([8.0])).item()), |
| | 'addr_bits': int(tensors.get('manifest.addr_bits', |
| | tensors.get('manifest.pc_width', torch.tensor([16.0]))).item()), |
| | 'memory_bytes': int(tensors.get('manifest.memory_bytes', torch.tensor([65536.0])).item()), |
| | 'version': float(tensors.get('manifest.version', torch.tensor([1.0])).item()), |
| | } |
| |
|
| |
|
| | def create_population( |
| | base_tensors: Dict[str, torch.Tensor], |
| | pop_size: int, |
| | device: str = 'cuda' |
| | ) -> Dict[str, torch.Tensor]: |
| | """Replicate base tensors for batched population evaluation.""" |
| | return { |
| | name: tensor.unsqueeze(0).expand(pop_size, *tensor.shape).clone().to(device) |
| | for name, tensor in base_tensors.items() |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | FLAG_NAMES = ["Z", "N", "C", "V"] |
| | CTRL_NAMES = ["HALT", "MEM_WE", "MEM_RE", "RESERVED"] |
| |
|
| | PC_BITS = 16 |
| | IR_BITS = 16 |
| | REG_BITS = 8 |
| | REG_COUNT = 4 |
| | FLAG_BITS = 4 |
| | SP_BITS = 16 |
| | CTRL_BITS = 4 |
| | MEM_BYTES = 65536 |
| | MEM_BITS = MEM_BYTES * 8 |
| |
|
| | STATE_BITS = PC_BITS + IR_BITS + (REG_BITS * REG_COUNT) + FLAG_BITS + SP_BITS + CTRL_BITS + MEM_BITS |
| |
|
| |
|
| | def int_to_bits(value: int, width: int) -> List[int]: |
| | return [(value >> (width - 1 - i)) & 1 for i in range(width)] |
| |
|
| |
|
| | def bits_to_int(bits: List[int]) -> int: |
| | value = 0 |
| | for bit in bits: |
| | value = (value << 1) | int(bit) |
| | return value |
| |
|
| |
|
| | def bits_msb_to_lsb(bits: List[int]) -> List[int]: |
| | return list(reversed(bits)) |
| |
|
| |
|
| | @dataclass |
| | class CPUState: |
| | pc: int |
| | ir: int |
| | regs: List[int] |
| | flags: List[int] |
| | sp: int |
| | ctrl: List[int] |
| | mem: List[int] |
| |
|
| | def copy(self) -> 'CPUState': |
| | return CPUState( |
| | pc=int(self.pc), |
| | ir=int(self.ir), |
| | regs=[int(r) for r in self.regs], |
| | flags=[int(f) for f in self.flags], |
| | sp=int(self.sp), |
| | ctrl=[int(c) for c in self.ctrl], |
| | mem=[int(m) for m in self.mem], |
| | ) |
| |
|
| |
|
| | def pack_state(state: CPUState) -> List[int]: |
| | bits: List[int] = [] |
| | bits.extend(int_to_bits(state.pc, PC_BITS)) |
| | bits.extend(int_to_bits(state.ir, IR_BITS)) |
| | for reg in state.regs: |
| | bits.extend(int_to_bits(reg, REG_BITS)) |
| | bits.extend([int(f) for f in state.flags]) |
| | bits.extend(int_to_bits(state.sp, SP_BITS)) |
| | bits.extend([int(c) for c in state.ctrl]) |
| | for byte in state.mem: |
| | bits.extend(int_to_bits(byte, REG_BITS)) |
| | return bits |
| |
|
| |
|
| | def unpack_state(bits: List[int]) -> CPUState: |
| | if len(bits) != STATE_BITS: |
| | raise ValueError(f"Expected {STATE_BITS} bits, got {len(bits)}") |
| |
|
| | idx = 0 |
| | pc = bits_to_int(bits[idx:idx + PC_BITS]) |
| | idx += PC_BITS |
| | ir = bits_to_int(bits[idx:idx + IR_BITS]) |
| | idx += IR_BITS |
| |
|
| | regs = [] |
| | for _ in range(REG_COUNT): |
| | regs.append(bits_to_int(bits[idx:idx + REG_BITS])) |
| | idx += REG_BITS |
| |
|
| | flags = [int(b) for b in bits[idx:idx + FLAG_BITS]] |
| | idx += FLAG_BITS |
| |
|
| | sp = bits_to_int(bits[idx:idx + SP_BITS]) |
| | idx += SP_BITS |
| |
|
| | ctrl = [int(b) for b in bits[idx:idx + CTRL_BITS]] |
| | idx += CTRL_BITS |
| |
|
| | mem = [] |
| | for _ in range(MEM_BYTES): |
| | mem.append(bits_to_int(bits[idx:idx + REG_BITS])) |
| | idx += REG_BITS |
| |
|
| | return CPUState(pc=pc, ir=ir, regs=regs, flags=flags, sp=sp, ctrl=ctrl, mem=mem) |
| |
|
| |
|
| | def decode_ir(ir: int) -> Tuple[int, int, int, int]: |
| | opcode = (ir >> 12) & 0xF |
| | rd = (ir >> 10) & 0x3 |
| | rs = (ir >> 8) & 0x3 |
| | imm8 = ir & 0xFF |
| | return opcode, rd, rs, imm8 |
| |
|
| |
|
| | def flags_from_result(result: int, carry: int, overflow: int) -> Tuple[int, int, int, int]: |
| | z = 1 if result == 0 else 0 |
| | n = 1 if (result & 0x80) else 0 |
| | c = 1 if carry else 0 |
| | v = 1 if overflow else 0 |
| | return z, n, c, v |
| |
|
| |
|
| | def alu_add(a: int, b: int) -> Tuple[int, int, int]: |
| | full = a + b |
| | result = full & 0xFF |
| | carry = 1 if full > 0xFF else 0 |
| | overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 |
| | return result, carry, overflow |
| |
|
| |
|
| | def alu_sub(a: int, b: int) -> Tuple[int, int, int]: |
| | full = (a - b) & 0x1FF |
| | result = full & 0xFF |
| | carry = 1 if a >= b else 0 |
| | overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 |
| | return result, carry, overflow |
| |
|
| |
|
| | def ref_step(state: CPUState) -> CPUState: |
| | """Reference CPU cycle (pure Python arithmetic).""" |
| | if state.ctrl[0] == 1: |
| | return state.copy() |
| |
|
| | s = state.copy() |
| |
|
| | hi = s.mem[s.pc] |
| | lo = s.mem[(s.pc + 1) & 0xFFFF] |
| | s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) |
| | next_pc = (s.pc + 2) & 0xFFFF |
| |
|
| | opcode, rd, rs, imm8 = decode_ir(s.ir) |
| | a = s.regs[rd] |
| | b = s.regs[rs] |
| |
|
| | addr16 = None |
| | next_pc_ext = next_pc |
| | if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): |
| | addr_hi = s.mem[next_pc] |
| | addr_lo = s.mem[(next_pc + 1) & 0xFFFF] |
| | addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) |
| | next_pc_ext = (next_pc + 2) & 0xFFFF |
| |
|
| | write_result = True |
| | result = a |
| | carry = 0 |
| | overflow = 0 |
| |
|
| | if opcode == 0x0: |
| | result, carry, overflow = alu_add(a, b) |
| | elif opcode == 0x1: |
| | result, carry, overflow = alu_sub(a, b) |
| | elif opcode == 0x2: |
| | result = a & b |
| | elif opcode == 0x3: |
| | result = a | b |
| | elif opcode == 0x4: |
| | result = a ^ b |
| | elif opcode == 0x5: |
| | result = (a << 1) & 0xFF |
| | elif opcode == 0x6: |
| | result = (a >> 1) & 0xFF |
| | elif opcode == 0x7: |
| | result = (a * b) & 0xFF |
| | elif opcode == 0x8: |
| | if b == 0: |
| | result = 0xFF |
| | else: |
| | result = a // b |
| | elif opcode == 0x9: |
| | result, carry, overflow = alu_sub(a, b) |
| | write_result = False |
| | elif opcode == 0xA: |
| | result = s.mem[addr16] |
| | elif opcode == 0xB: |
| | s.mem[addr16] = b & 0xFF |
| | write_result = False |
| | elif opcode == 0xC: |
| | s.pc = addr16 & 0xFFFF |
| | write_result = False |
| | elif opcode == 0xD: |
| | cond_type = imm8 & 0x7 |
| | if cond_type == 0: |
| | take_branch = s.flags[0] == 1 |
| | elif cond_type == 1: |
| | take_branch = s.flags[0] == 0 |
| | elif cond_type == 2: |
| | take_branch = s.flags[2] == 1 |
| | elif cond_type == 3: |
| | take_branch = s.flags[2] == 0 |
| | elif cond_type == 4: |
| | take_branch = s.flags[1] == 1 |
| | elif cond_type == 5: |
| | take_branch = s.flags[1] == 0 |
| | elif cond_type == 6: |
| | take_branch = s.flags[3] == 1 |
| | else: |
| | take_branch = s.flags[3] == 0 |
| | if take_branch: |
| | s.pc = addr16 & 0xFFFF |
| | else: |
| | s.pc = next_pc_ext |
| | write_result = False |
| | elif opcode == 0xE: |
| | ret_addr = next_pc_ext & 0xFFFF |
| | s.sp = (s.sp - 1) & 0xFFFF |
| | s.mem[s.sp] = (ret_addr >> 8) & 0xFF |
| | s.sp = (s.sp - 1) & 0xFFFF |
| | s.mem[s.sp] = ret_addr & 0xFF |
| | s.pc = addr16 & 0xFFFF |
| | write_result = False |
| | elif opcode == 0xF: |
| | s.ctrl[0] = 1 |
| | write_result = False |
| |
|
| | if opcode <= 0x9 or opcode in (0xA, 0x7, 0x8): |
| | s.flags = list(flags_from_result(result, carry, overflow)) |
| |
|
| | if write_result: |
| | s.regs[rd] = result & 0xFF |
| |
|
| | if opcode not in (0xC, 0xD, 0xE): |
| | s.pc = next_pc_ext |
| |
|
| | return s |
| |
|
| |
|
| | def ref_run_until_halt(state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: |
| | """Reference execution loop.""" |
| | s = state.copy() |
| | for i in range(max_cycles): |
| | if s.ctrl[0] == 1: |
| | return s, i |
| | s = ref_step(s) |
| | return s, max_cycles |
| |
|
| |
|
| | class ThresholdALU: |
| | def __init__(self, model_path: str = MODEL_PATH, device: str = "cpu") -> None: |
| | self.device = device |
| | self.tensors = {k: v.float().to(device) for k, v in load_model(model_path).items()} |
| |
|
| | def _get(self, name: str) -> torch.Tensor: |
| | return self.tensors[name] |
| |
|
| | def _eval_gate(self, weight_key: str, bias_key: str, inputs: List[float]) -> float: |
| | w = self._get(weight_key) |
| | b = self._get(bias_key) |
| | inp = torch.tensor(inputs, device=self.device) |
| | return heaviside((inp * w).sum() + b).item() |
| |
|
| | def _eval_xor(self, prefix: str, inputs: List[float]) -> float: |
| | inp = torch.tensor(inputs, device=self.device) |
| | w_or = self._get(f"{prefix}.layer1.or.weight") |
| | b_or = self._get(f"{prefix}.layer1.or.bias") |
| | w_nand = self._get(f"{prefix}.layer1.nand.weight") |
| | b_nand = self._get(f"{prefix}.layer1.nand.bias") |
| | w2 = self._get(f"{prefix}.layer2.weight") |
| | b2 = self._get(f"{prefix}.layer2.bias") |
| |
|
| | h_or = heaviside((inp * w_or).sum() + b_or).item() |
| | h_nand = heaviside((inp * w_nand).sum() + b_nand).item() |
| | hidden = torch.tensor([h_or, h_nand], device=self.device) |
| | return heaviside((hidden * w2).sum() + b2).item() |
| |
|
| | def _eval_full_adder(self, prefix: str, a: float, b: float, cin: float) -> Tuple[float, float]: |
| | ha1_sum = self._eval_xor(f"{prefix}.ha1.sum", [a, b]) |
| | ha1_carry = self._eval_gate(f"{prefix}.ha1.carry.weight", f"{prefix}.ha1.carry.bias", [a, b]) |
| |
|
| | ha2_sum = self._eval_xor(f"{prefix}.ha2.sum", [ha1_sum, cin]) |
| | ha2_carry = self._eval_gate( |
| | f"{prefix}.ha2.carry.weight", f"{prefix}.ha2.carry.bias", [ha1_sum, cin] |
| | ) |
| |
|
| | cout = self._eval_gate(f"{prefix}.carry_or.weight", f"{prefix}.carry_or.bias", [ha1_carry, ha2_carry]) |
| | return ha2_sum, cout |
| |
|
| | def add(self, a: int, b: int) -> Tuple[int, int, int]: |
| | a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS)) |
| | b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS)) |
| |
|
| | carry = 0.0 |
| | sum_bits: List[int] = [] |
| | for bit in range(REG_BITS): |
| | sum_bit, carry = self._eval_full_adder( |
| | f"arithmetic.ripplecarry8bit.fa{bit}", float(a_bits[bit]), float(b_bits[bit]), carry |
| | ) |
| | sum_bits.append(int(sum_bit)) |
| |
|
| | result = bits_to_int(list(reversed(sum_bits))) |
| | carry_out = int(carry) |
| | overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 |
| | return result, carry_out, overflow |
| |
|
| | def sub(self, a: int, b: int) -> Tuple[int, int, int]: |
| | a_bits = bits_msb_to_lsb(int_to_bits(a, REG_BITS)) |
| | b_bits = bits_msb_to_lsb(int_to_bits(b, REG_BITS)) |
| |
|
| | carry = 1.0 |
| | sum_bits: List[int] = [] |
| | for bit in range(REG_BITS): |
| | notb = self._eval_gate( |
| | f"arithmetic.sub8bit.notb{bit}.weight", |
| | f"arithmetic.sub8bit.notb{bit}.bias", |
| | [float(b_bits[bit])], |
| | ) |
| |
|
| | xor1 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor1", [float(a_bits[bit]), notb]) |
| | xor2 = self._eval_xor(f"arithmetic.sub8bit.fa{bit}.xor2", [xor1, carry]) |
| |
|
| | and1 = self._eval_gate( |
| | f"arithmetic.sub8bit.fa{bit}.and1.weight", |
| | f"arithmetic.sub8bit.fa{bit}.and1.bias", |
| | [float(a_bits[bit]), notb], |
| | ) |
| | and2 = self._eval_gate( |
| | f"arithmetic.sub8bit.fa{bit}.and2.weight", |
| | f"arithmetic.sub8bit.fa{bit}.and2.bias", |
| | [xor1, carry], |
| | ) |
| | carry = self._eval_gate( |
| | f"arithmetic.sub8bit.fa{bit}.or_carry.weight", |
| | f"arithmetic.sub8bit.fa{bit}.or_carry.bias", |
| | [and1, and2], |
| | ) |
| |
|
| | sum_bits.append(int(xor2)) |
| |
|
| | result = bits_to_int(list(reversed(sum_bits))) |
| | carry_out = int(carry) |
| | overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 |
| | return result, carry_out, overflow |
| |
|
| | def bitwise_and(self, a: int, b: int) -> int: |
| | a_bits = int_to_bits(a, REG_BITS) |
| | b_bits = int_to_bits(b, REG_BITS) |
| | w = self._get("alu.alu8bit.and.weight") |
| | bias = self._get("alu.alu8bit.and.bias") |
| |
|
| | out_bits = [] |
| | for bit in range(REG_BITS): |
| | inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
| | out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() |
| | out_bits.append(int(out)) |
| |
|
| | return bits_to_int(out_bits) |
| |
|
| | def bitwise_or(self, a: int, b: int) -> int: |
| | a_bits = int_to_bits(a, REG_BITS) |
| | b_bits = int_to_bits(b, REG_BITS) |
| | w = self._get("alu.alu8bit.or.weight") |
| | bias = self._get("alu.alu8bit.or.bias") |
| |
|
| | out_bits = [] |
| | for bit in range(REG_BITS): |
| | inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
| | out = heaviside((inp * w[bit * 2:bit * 2 + 2]).sum() + bias[bit]).item() |
| | out_bits.append(int(out)) |
| |
|
| | return bits_to_int(out_bits) |
| |
|
| | def bitwise_not(self, a: int) -> int: |
| | a_bits = int_to_bits(a, REG_BITS) |
| | w = self._get("alu.alu8bit.not.weight") |
| | bias = self._get("alu.alu8bit.not.bias") |
| |
|
| | out_bits = [] |
| | for bit in range(REG_BITS): |
| | inp = torch.tensor([float(a_bits[bit])], device=self.device) |
| | out = heaviside((inp * w[bit]).sum() + bias[bit]).item() |
| | out_bits.append(int(out)) |
| |
|
| | return bits_to_int(out_bits) |
| |
|
| | def bitwise_xor(self, a: int, b: int) -> int: |
| | a_bits = int_to_bits(a, REG_BITS) |
| | b_bits = int_to_bits(b, REG_BITS) |
| |
|
| | w_or = self._get("alu.alu8bit.xor.layer1.or.weight") |
| | b_or = self._get("alu.alu8bit.xor.layer1.or.bias") |
| | w_nand = self._get("alu.alu8bit.xor.layer1.nand.weight") |
| | b_nand = self._get("alu.alu8bit.xor.layer1.nand.bias") |
| | w2 = self._get("alu.alu8bit.xor.layer2.weight") |
| | b2 = self._get("alu.alu8bit.xor.layer2.bias") |
| |
|
| | out_bits = [] |
| | for bit in range(REG_BITS): |
| | inp = torch.tensor([float(a_bits[bit]), float(b_bits[bit])], device=self.device) |
| | h_or = heaviside((inp * w_or[bit * 2:bit * 2 + 2]).sum() + b_or[bit]) |
| | h_nand = heaviside((inp * w_nand[bit * 2:bit * 2 + 2]).sum() + b_nand[bit]) |
| | hidden = torch.stack([h_or, h_nand]) |
| | out = heaviside((hidden * w2[bit * 2:bit * 2 + 2]).sum() + b2[bit]).item() |
| | out_bits.append(int(out)) |
| |
|
| | return bits_to_int(out_bits) |
| |
|
| | def shift_left(self, a: int) -> int: |
| | a_bits = int_to_bits(a, REG_BITS) |
| | out_bits = [] |
| | for bit in range(REG_BITS): |
| | w = self._get(f"alu.alu8bit.shl.bit{bit}.weight") |
| | bias = self._get(f"alu.alu8bit.shl.bit{bit}.bias") |
| | if bit < 7: |
| | inp = torch.tensor([float(a_bits[bit + 1])], device=self.device) |
| | else: |
| | inp = torch.tensor([0.0], device=self.device) |
| | out = heaviside((inp * w).sum() + bias).item() |
| | out_bits.append(int(out)) |
| | return bits_to_int(out_bits) |
| |
|
| | def shift_right(self, a: int) -> int: |
| | a_bits = int_to_bits(a, REG_BITS) |
| | out_bits = [] |
| | for bit in range(REG_BITS): |
| | w = self._get(f"alu.alu8bit.shr.bit{bit}.weight") |
| | bias = self._get(f"alu.alu8bit.shr.bit{bit}.bias") |
| | if bit > 0: |
| | inp = torch.tensor([float(a_bits[bit - 1])], device=self.device) |
| | else: |
| | inp = torch.tensor([0.0], device=self.device) |
| | out = heaviside((inp * w).sum() + bias).item() |
| | out_bits.append(int(out)) |
| | return bits_to_int(out_bits) |
| |
|
| | def multiply(self, a: int, b: int) -> int: |
| | """8-bit multiply using partial product AND gates + shift-add.""" |
| | a_bits = int_to_bits(a, REG_BITS) |
| | b_bits = int_to_bits(b, REG_BITS) |
| |
|
| | pp = [[0] * 8 for _ in range(8)] |
| | for i in range(8): |
| | for j in range(8): |
| | w = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.weight") |
| | bias = self._get(f"alu.alu8bit.mul.pp.a{i}b{j}.bias") |
| | inp = torch.tensor([float(a_bits[i]), float(b_bits[j])], device=self.device) |
| | pp[i][j] = int(heaviside((inp * w).sum() + bias).item()) |
| |
|
| | result = 0 |
| | for j in range(8): |
| | if b_bits[j] == 0: |
| | continue |
| | row = 0 |
| | for i in range(8): |
| | row |= (pp[i][j] << (7 - i)) |
| | shifted = row << (7 - j) |
| | result, _, _ = self.add(result & 0xFF, shifted & 0xFF) |
| | if shifted > 255 or result > 255: |
| | result = (result + (shifted >> 8)) & 0xFF |
| |
|
| | return result & 0xFF |
| |
|
| | def divide(self, a: int, b: int) -> Tuple[int, int]: |
| | """8-bit divide using restoring division with threshold gates.""" |
| | if b == 0: |
| | return 0xFF, a |
| |
|
| | a_bits = int_to_bits(a, REG_BITS) |
| |
|
| | quotient = 0 |
| | remainder = 0 |
| |
|
| | for stage in range(8): |
| | remainder = ((remainder << 1) | a_bits[stage]) & 0xFF |
| |
|
| | rem_bits = int_to_bits(remainder, REG_BITS) |
| | div_bits = int_to_bits(b, REG_BITS) |
| |
|
| | w = self._get(f"alu.alu8bit.div.stage{stage}.cmp.weight") |
| | bias = self._get(f"alu.alu8bit.div.stage{stage}.cmp.bias") |
| | inp = torch.tensor([float(rem_bits[i]) for i in range(8)] + |
| | [float(div_bits[i]) for i in range(8)], device=self.device) |
| | cmp_result = int(heaviside((inp * w).sum() + bias).item()) |
| |
|
| | if cmp_result: |
| | remainder, _, _ = self.sub(remainder, b) |
| | quotient = (quotient << 1) | 1 |
| | else: |
| | quotient = quotient << 1 |
| |
|
| | return quotient & 0xFF, remainder & 0xFF |
| |
|
| |
|
| | class ThresholdCPU: |
| | def __init__(self, model_path: str = MODEL_PATH, device: str = "cpu") -> None: |
| | self.device = device |
| | self.alu = ThresholdALU(model_path, device=device) |
| |
|
| | def _addr_decode(self, addr: int) -> torch.Tensor: |
| | bits = torch.tensor(int_to_bits(addr, PC_BITS), device=self.device, dtype=torch.float32) |
| | w = self.alu._get("memory.addr_decode.weight") |
| | b = self.alu._get("memory.addr_decode.bias") |
| | return heaviside((w * bits).sum(dim=1) + b) |
| |
|
| | def _memory_read(self, mem: List[int], addr: int) -> int: |
| | sel = self._addr_decode(addr) |
| | mem_bits = torch.tensor( |
| | [int_to_bits(byte, REG_BITS) for byte in mem], |
| | device=self.device, |
| | dtype=torch.float32, |
| | ) |
| | and_w = self.alu._get("memory.read.and.weight") |
| | and_b = self.alu._get("memory.read.and.bias") |
| | or_w = self.alu._get("memory.read.or.weight") |
| | or_b = self.alu._get("memory.read.or.bias") |
| |
|
| | out_bits: List[int] = [] |
| | for bit in range(REG_BITS): |
| | inp = torch.stack([mem_bits[:, bit], sel], dim=1) |
| | and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit]) |
| | out_bit = heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item() |
| | out_bits.append(int(out_bit)) |
| |
|
| | return bits_to_int(out_bits) |
| |
|
| | def _memory_write(self, mem: List[int], addr: int, value: int) -> List[int]: |
| | sel = self._addr_decode(addr) |
| | data_bits = torch.tensor(int_to_bits(value, REG_BITS), device=self.device, dtype=torch.float32) |
| | mem_bits = torch.tensor( |
| | [int_to_bits(byte, REG_BITS) for byte in mem], |
| | device=self.device, |
| | dtype=torch.float32, |
| | ) |
| |
|
| | sel_w = self.alu._get("memory.write.sel.weight") |
| | sel_b = self.alu._get("memory.write.sel.bias") |
| | nsel_w = self.alu._get("memory.write.nsel.weight").squeeze(1) |
| | nsel_b = self.alu._get("memory.write.nsel.bias") |
| | and_old_w = self.alu._get("memory.write.and_old.weight") |
| | and_old_b = self.alu._get("memory.write.and_old.bias") |
| | and_new_w = self.alu._get("memory.write.and_new.weight") |
| | and_new_b = self.alu._get("memory.write.and_new.bias") |
| | or_w = self.alu._get("memory.write.or.weight") |
| | or_b = self.alu._get("memory.write.or.bias") |
| |
|
| | we = torch.ones_like(sel) |
| | sel_inp = torch.stack([sel, we], dim=1) |
| | write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b) |
| | nsel = heaviside((write_sel * nsel_w) + nsel_b) |
| |
|
| | new_mem_bits = torch.zeros((MEM_BYTES, REG_BITS), device=self.device) |
| | for bit in range(REG_BITS): |
| | old_bit = mem_bits[:, bit] |
| | data_bit = data_bits[bit].expand(MEM_BYTES) |
| | inp_old = torch.stack([old_bit, nsel], dim=1) |
| | inp_new = torch.stack([data_bit, write_sel], dim=1) |
| |
|
| | and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit]) |
| | and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit]) |
| | or_inp = torch.stack([and_old, and_new], dim=1) |
| | out_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit]) |
| | new_mem_bits[:, bit] = out_bit |
| |
|
| | return [bits_to_int([int(b) for b in new_mem_bits[i].tolist()]) for i in range(MEM_BYTES)] |
| |
|
| | def _conditional_jump_byte(self, prefix: str, pc_byte: int, target_byte: int, flag: int) -> int: |
| | pc_bits = int_to_bits(pc_byte, REG_BITS) |
| | target_bits = int_to_bits(target_byte, REG_BITS) |
| |
|
| | out_bits: List[int] = [] |
| | for bit in range(REG_BITS): |
| | not_sel = self.alu._eval_gate( |
| | f"{prefix}.bit{bit}.not_sel.weight", |
| | f"{prefix}.bit{bit}.not_sel.bias", |
| | [float(flag)], |
| | ) |
| | and_a = self.alu._eval_gate( |
| | f"{prefix}.bit{bit}.and_a.weight", |
| | f"{prefix}.bit{bit}.and_a.bias", |
| | [float(pc_bits[bit]), not_sel], |
| | ) |
| | and_b = self.alu._eval_gate( |
| | f"{prefix}.bit{bit}.and_b.weight", |
| | f"{prefix}.bit{bit}.and_b.bias", |
| | [float(target_bits[bit]), float(flag)], |
| | ) |
| | out_bit = self.alu._eval_gate( |
| | f"{prefix}.bit{bit}.or.weight", |
| | f"{prefix}.bit{bit}.or.bias", |
| | [and_a, and_b], |
| | ) |
| | out_bits.append(int(out_bit)) |
| |
|
| | return bits_to_int(out_bits) |
| |
|
| | def step(self, state: CPUState) -> CPUState: |
| | """Single CPU cycle using threshold neurons.""" |
| | if state.ctrl[0] == 1: |
| | return state.copy() |
| |
|
| | s = state.copy() |
| |
|
| | hi = self._memory_read(s.mem, s.pc) |
| | lo = self._memory_read(s.mem, (s.pc + 1) & 0xFFFF) |
| | s.ir = ((hi & 0xFF) << 8) | (lo & 0xFF) |
| | next_pc = (s.pc + 2) & 0xFFFF |
| |
|
| | opcode, rd, rs, imm8 = decode_ir(s.ir) |
| | a = s.regs[rd] |
| | b = s.regs[rs] |
| |
|
| | addr16 = None |
| | next_pc_ext = next_pc |
| | if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): |
| | addr_hi = self._memory_read(s.mem, next_pc) |
| | addr_lo = self._memory_read(s.mem, (next_pc + 1) & 0xFFFF) |
| | addr16 = ((addr_hi & 0xFF) << 8) | (addr_lo & 0xFF) |
| | next_pc_ext = (next_pc + 2) & 0xFFFF |
| |
|
| | write_result = True |
| | result = a |
| | carry = 0 |
| | overflow = 0 |
| |
|
| | if opcode == 0x0: |
| | result, carry, overflow = self.alu.add(a, b) |
| | elif opcode == 0x1: |
| | result, carry, overflow = self.alu.sub(a, b) |
| | elif opcode == 0x2: |
| | result = self.alu.bitwise_and(a, b) |
| | elif opcode == 0x3: |
| | result = self.alu.bitwise_or(a, b) |
| | elif opcode == 0x4: |
| | result = self.alu.bitwise_xor(a, b) |
| | elif opcode == 0x5: |
| | result = self.alu.shift_left(a) |
| | elif opcode == 0x6: |
| | result = self.alu.shift_right(a) |
| | elif opcode == 0x7: |
| | result = self.alu.multiply(a, b) |
| | elif opcode == 0x8: |
| | result, _ = self.alu.divide(a, b) |
| | elif opcode == 0x9: |
| | result, carry, overflow = self.alu.sub(a, b) |
| | write_result = False |
| | elif opcode == 0xA: |
| | result = self._memory_read(s.mem, addr16) |
| | elif opcode == 0xB: |
| | s.mem = self._memory_write(s.mem, addr16, b & 0xFF) |
| | write_result = False |
| | elif opcode == 0xC: |
| | s.pc = addr16 & 0xFFFF |
| | write_result = False |
| | elif opcode == 0xD: |
| | cond_type = imm8 & 0x7 |
| | cond_circuits = [ |
| | ("control.jz", 0), |
| | ("control.jnz", 0), |
| | ("control.jc", 2), |
| | ("control.jnc", 2), |
| | ("control.jn", 1), |
| | ("control.jp", 1), |
| | ("control.jv", 3), |
| | ("control.jnv", 3), |
| | ] |
| | circuit_prefix, flag_idx = cond_circuits[cond_type] |
| | hi_pc = self._conditional_jump_byte( |
| | circuit_prefix, |
| | (next_pc_ext >> 8) & 0xFF, |
| | (addr16 >> 8) & 0xFF, |
| | s.flags[flag_idx], |
| | ) |
| | lo_pc = self._conditional_jump_byte( |
| | circuit_prefix, |
| | next_pc_ext & 0xFF, |
| | addr16 & 0xFF, |
| | s.flags[flag_idx], |
| | ) |
| | s.pc = ((hi_pc & 0xFF) << 8) | (lo_pc & 0xFF) |
| | write_result = False |
| | elif opcode == 0xE: |
| | ret_addr = next_pc_ext & 0xFFFF |
| | s.sp = (s.sp - 1) & 0xFFFF |
| | s.mem = self._memory_write(s.mem, s.sp, (ret_addr >> 8) & 0xFF) |
| | s.sp = (s.sp - 1) & 0xFFFF |
| | s.mem = self._memory_write(s.mem, s.sp, ret_addr & 0xFF) |
| | s.pc = addr16 & 0xFFFF |
| | write_result = False |
| | elif opcode == 0xF: |
| | s.ctrl[0] = 1 |
| | write_result = False |
| |
|
| | if opcode <= 0x9 or opcode == 0xA: |
| | s.flags = list(flags_from_result(result, carry, overflow)) |
| |
|
| | if write_result: |
| | s.regs[rd] = result & 0xFF |
| |
|
| | if opcode not in (0xC, 0xD, 0xE): |
| | s.pc = next_pc_ext |
| |
|
| | return s |
| |
|
| | def run_until_halt(self, state: CPUState, max_cycles: int = 256) -> Tuple[CPUState, int]: |
| | """Execute until HALT or max_cycles reached.""" |
| | s = state.copy() |
| | for i in range(max_cycles): |
| | if s.ctrl[0] == 1: |
| | return s, i |
| | s = self.step(s) |
| | return s, max_cycles |
| |
|
| | def forward(self, state_bits: torch.Tensor, max_cycles: int = 256) -> torch.Tensor: |
| | """Tensor-in, tensor-out interface for neural integration.""" |
| | bits_list = [int(b) for b in state_bits.detach().cpu().flatten().tolist()] |
| | state = unpack_state(bits_list) |
| | final, _ = self.run_until_halt(state, max_cycles=max_cycles) |
| | return torch.tensor(pack_state(final), dtype=torch.float32) |
| |
|
| |
|
| | def encode_instr(opcode: int, rd: int, rs: int, imm8: int) -> int: |
| | return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm8 & 0xFF) |
| |
|
| |
|
| | def write_instr(mem: List[int], addr: int, instr: int) -> None: |
| | mem[addr & 0xFFFF] = (instr >> 8) & 0xFF |
| | mem[(addr + 1) & 0xFFFF] = instr & 0xFF |
| |
|
| |
|
| | def write_addr(mem: List[int], addr: int, value: int) -> None: |
| | mem[addr & 0xFFFF] = (value >> 8) & 0xFF |
| | mem[(addr + 1) & 0xFFFF] = value & 0xFF |
| |
|
| |
|
| | def run_smoke_test() -> int: |
| | """Smoke test: LOAD 5, LOAD 7, ADD, STORE, HALT. Expect result = 12.""" |
| | mem = [0] * 65536 |
| |
|
| | write_instr(mem, 0x0000, encode_instr(0xA, 0, 0, 0x00)) |
| | write_addr(mem, 0x0002, 0x0100) |
| | write_instr(mem, 0x0004, encode_instr(0xA, 1, 0, 0x00)) |
| | write_addr(mem, 0x0006, 0x0101) |
| | write_instr(mem, 0x0008, encode_instr(0x0, 0, 1, 0x00)) |
| | write_instr(mem, 0x000A, encode_instr(0xB, 0, 0, 0x00)) |
| | write_addr(mem, 0x000C, 0x0102) |
| | write_instr(mem, 0x000E, encode_instr(0xF, 0, 0, 0x00)) |
| |
|
| | mem[0x0100] = 5 |
| | mem[0x0101] = 7 |
| |
|
| | state = CPUState( |
| | pc=0, |
| | ir=0, |
| | regs=[0, 0, 0, 0], |
| | flags=[0, 0, 0, 0], |
| | sp=0xFFFE, |
| | ctrl=[0, 0, 0, 0], |
| | mem=mem, |
| | ) |
| |
|
| | print("Running reference implementation...") |
| | final, cycles = ref_run_until_halt(state, max_cycles=20) |
| |
|
| | assert final.ctrl[0] == 1, "HALT flag not set" |
| | assert final.regs[0] == 12, f"R0 expected 12, got {final.regs[0]}" |
| | assert final.mem[0x0102] == 12, f"MEM[0x0102] expected 12, got {final.mem[0x0102]}" |
| | assert cycles <= 10, f"Unexpected cycle count: {cycles}" |
| | print(f" Reference: R0={final.regs[0]}, MEM[0x0102]={final.mem[0x0102]}, cycles={cycles}") |
| |
|
| | print("Running threshold-weight implementation...") |
| | threshold_cpu = ThresholdCPU() |
| | t_final, t_cycles = threshold_cpu.run_until_halt(state, max_cycles=20) |
| |
|
| | assert t_final.ctrl[0] == 1, "Threshold HALT flag not set" |
| | assert t_final.regs[0] == final.regs[0], f"Threshold R0 mismatch: {t_final.regs[0]} != {final.regs[0]}" |
| | assert t_final.mem[0x0102] == final.mem[0x0102], ( |
| | f"Threshold MEM[0x0102] mismatch: {t_final.mem[0x0102]} != {final.mem[0x0102]}" |
| | ) |
| | assert t_cycles == cycles, f"Threshold cycle count mismatch: {t_cycles} != {cycles}" |
| | print(f" Threshold: R0={t_final.regs[0]}, MEM[0x0102]={t_final.mem[0x0102]}, cycles={t_cycles}") |
| |
|
| | print("Validating forward() tensor I/O...") |
| | bits = torch.tensor(pack_state(state), dtype=torch.float32) |
| | out_bits = threshold_cpu.forward(bits, max_cycles=20) |
| | out_state = unpack_state([int(b) for b in out_bits.tolist()]) |
| | assert out_state.regs[0] == final.regs[0], f"Forward R0 mismatch: {out_state.regs[0]} != {final.regs[0]}" |
| | assert out_state.mem[0x0102] == final.mem[0x0102], ( |
| | f"Forward MEM[0x0102] mismatch: {out_state.mem[0x0102]} != {final.mem[0x0102]}" |
| | ) |
| | print(f" Forward: R0={out_state.regs[0]}, MEM[0x0102]={out_state.mem[0x0102]}") |
| |
|
| | print("\nSmoke test: PASSED") |
| | return 0 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class BatchedFitnessEvaluator: |
| | """ |
| | GPU-batched fitness evaluator with per-circuit reporting. |
| | Tests all circuits comprehensively. |
| | """ |
| |
|
| | def __init__(self, device: str = 'cuda', model_path: str = MODEL_PATH, tensors: Dict[str, torch.Tensor] = None): |
| | self.device = device |
| | self.model_path = model_path |
| | self.metadata = load_metadata(model_path) |
| | self.signal_registry = self.metadata.get('signal_registry', {}) |
| | self.results: List[CircuitResult] = [] |
| | self.category_scores: Dict[str, Tuple[float, int]] = {} |
| | self.total_tests = 0 |
| |
|
| | |
| | if tensors is not None: |
| | self.manifest = get_manifest(tensors) |
| | else: |
| | base_tensors = load_model(model_path) |
| | self.manifest = get_manifest(base_tensors) |
| | self.data_bits = self.manifest['data_bits'] |
| | self.addr_bits = self.manifest['addr_bits'] |
| |
|
| | self._setup_tests() |
| |
|
| | def _setup_tests(self): |
| | """Pre-compute test vectors on device.""" |
| | d = self.device |
| |
|
| | |
| | self.tt2 = torch.tensor( |
| | [[0, 0], [0, 1], [1, 0], [1, 1]], |
| | device=d, dtype=torch.float32 |
| | ) |
| |
|
| | |
| | self.tt3 = torch.tensor([ |
| | [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], |
| | [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1] |
| | ], device=d, dtype=torch.float32) |
| |
|
| | |
| | self.expected = { |
| | 'and': torch.tensor([0, 0, 0, 1], device=d, dtype=torch.float32), |
| | 'or': torch.tensor([0, 1, 1, 1], device=d, dtype=torch.float32), |
| | 'nand': torch.tensor([1, 1, 1, 0], device=d, dtype=torch.float32), |
| | 'nor': torch.tensor([1, 0, 0, 0], device=d, dtype=torch.float32), |
| | 'xor': torch.tensor([0, 1, 1, 0], device=d, dtype=torch.float32), |
| | 'xnor': torch.tensor([1, 0, 0, 1], device=d, dtype=torch.float32), |
| | 'implies': torch.tensor([1, 1, 0, 1], device=d, dtype=torch.float32), |
| | 'biimplies': torch.tensor([1, 0, 0, 1], device=d, dtype=torch.float32), |
| | 'not': torch.tensor([1, 0], device=d, dtype=torch.float32), |
| | 'ha_sum': torch.tensor([0, 1, 1, 0], device=d, dtype=torch.float32), |
| | 'ha_carry': torch.tensor([0, 0, 0, 1], device=d, dtype=torch.float32), |
| | 'fa_sum': torch.tensor([0, 1, 1, 0, 1, 0, 0, 1], device=d, dtype=torch.float32), |
| | 'fa_cout': torch.tensor([0, 0, 0, 1, 0, 1, 1, 1], device=d, dtype=torch.float32), |
| | } |
| |
|
| | |
| | self.not_inputs = torch.tensor([[0], [1]], device=d, dtype=torch.float32) |
| |
|
| | |
| | self.test_8bit = torch.tensor([ |
| | 0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, |
| | 0b10101010, 0b01010101, 0b11110000, 0b00001111, |
| | 0b11001100, 0b00110011, 0b10000001, 0b01111110 |
| | ], device=d, dtype=torch.long) |
| |
|
| | |
| | self.test_8bit_bits = torch.stack([ |
| | ((self.test_8bit >> (7 - i)) & 1).float() for i in range(8) |
| | ], dim=1) |
| |
|
| | |
| | comp_tests = [ |
| | (0, 0), (1, 0), (0, 1), (5, 3), (3, 5), (5, 5), |
| | (255, 0), (0, 255), (128, 127), (127, 128), |
| | (100, 99), (99, 100), (64, 32), (32, 64), |
| | (1, 1), (254, 255), (255, 254), (128, 128), |
| | (0, 128), (128, 0), (64, 64), (192, 192), |
| | (15, 16), (16, 15), (240, 239), (239, 240), |
| | (85, 170), (170, 85), (0xAA, 0x55), (0x55, 0xAA), |
| | (0x0F, 0xF0), (0xF0, 0x0F), (0x33, 0xCC), (0xCC, 0x33), |
| | (2, 3), (3, 2), (126, 127), (127, 126), |
| | (129, 128), (128, 129), (200, 199), (199, 200), |
| | (50, 51), (51, 50), (10, 20), (20, 10), |
| | (100, 100), (200, 200), (77, 77), (0, 0) |
| | ] |
| | self.comp_a = torch.tensor([c[0] for c in comp_tests], device=d, dtype=torch.long) |
| | self.comp_b = torch.tensor([c[1] for c in comp_tests], device=d, dtype=torch.long) |
| |
|
| | |
| | self.mod_test = torch.arange(256, device=d, dtype=torch.long) |
| |
|
| | |
| | self.test_32bit = torch.tensor([ |
| | 0, 1, 2, 255, 256, 65535, 65536, |
| | 0x7FFFFFFF, 0x80000000, 0xFFFFFFFF, |
| | 0x12345678, 0xDEADBEEF, 0xCAFEBABE, |
| | 1000000, 1000000000, 2147483647, |
| | 0x55555555, 0xAAAAAAAA, 0x0F0F0F0F, 0xF0F0F0F0 |
| | ], device=d, dtype=torch.long) |
| |
|
| | |
| | comp32_tests = [ |
| | (0, 0), (1, 0), (0, 1), (1000, 999), (999, 1000), |
| | (0xFFFFFFFF, 0), (0, 0xFFFFFFFF), |
| | (0x80000000, 0x7FFFFFFF), (0x7FFFFFFF, 0x80000000), |
| | (1000000, 1000000), (0x12345678, 0x12345678), |
| | (0xDEADBEEF, 0xCAFEBABE), (0xCAFEBABE, 0xDEADBEEF), |
| | (256, 255), (255, 256), (65536, 65535), (65535, 65536), |
| | ] |
| | self.comp32_a = torch.tensor([c[0] for c in comp32_tests], device=d, dtype=torch.long) |
| | self.comp32_b = torch.tensor([c[1] for c in comp32_tests], device=d, dtype=torch.long) |
| |
|
| | def _record(self, name: str, passed: int, total: int, failures: List[Tuple] = None): |
| | """Record a circuit test result.""" |
| | self.results.append(CircuitResult( |
| | name=name, |
| | passed=passed, |
| | total=total, |
| | failures=failures or [] |
| | )) |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_single_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor, |
| | expected: torch.Tensor) -> torch.Tensor: |
| | """Test single-layer gate (AND, OR, NAND, NOR, IMPLIES).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | w = pop[f'{prefix}.weight'] |
| | b = pop[f'{prefix}.bias'] |
| |
|
| | |
| | out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])): |
| | if exp.item() != got.item(): |
| | failures.append((inp.tolist(), exp.item(), got.item())) |
| |
|
| | self._record(prefix, int(correct[0].item()), len(expected), failures) |
| | return correct |
| |
|
| | def _test_twolayer_gate(self, pop: Dict, prefix: str, inputs: torch.Tensor, |
| | expected: torch.Tensor) -> torch.Tensor: |
| | """Test two-layer gate (XOR, XNOR, BIIMPLIES).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| |
|
| | |
| | w1_n1 = pop[f'{prefix}.layer1.neuron1.weight'] |
| | b1_n1 = pop[f'{prefix}.layer1.neuron1.bias'] |
| | w1_n2 = pop[f'{prefix}.layer1.neuron2.weight'] |
| | b1_n2 = pop[f'{prefix}.layer1.neuron2.bias'] |
| |
|
| | h1 = heaviside(inputs @ w1_n1.view(pop_size, -1).T + b1_n1.view(pop_size)) |
| | h2 = heaviside(inputs @ w1_n2.view(pop_size, -1).T + b1_n2.view(pop_size)) |
| | hidden = torch.stack([h1, h2], dim=-1) |
| |
|
| | |
| | w2 = pop[f'{prefix}.layer2.weight'] |
| | b2 = pop[f'{prefix}.layer2.bias'] |
| | out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) |
| |
|
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])): |
| | if exp.item() != got.item(): |
| | failures.append((inp.tolist(), exp.item(), got.item())) |
| |
|
| | self._record(prefix, int(correct[0].item()), len(expected), failures) |
| | return correct |
| |
|
| | def _test_xor_ornand(self, pop: Dict, prefix: str, inputs: torch.Tensor, |
| | expected: torch.Tensor) -> torch.Tensor: |
| | """Test XOR with or/nand layer naming.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| |
|
| | w_or = pop[f'{prefix}.layer1.or.weight'] |
| | b_or = pop[f'{prefix}.layer1.or.bias'] |
| | w_nand = pop[f'{prefix}.layer1.nand.weight'] |
| | b_nand = pop[f'{prefix}.layer1.nand.bias'] |
| |
|
| | h_or = heaviside(inputs @ w_or.view(pop_size, -1).T + b_or.view(pop_size)) |
| | h_nand = heaviside(inputs @ w_nand.view(pop_size, -1).T + b_nand.view(pop_size)) |
| | hidden = torch.stack([h_or, h_nand], dim=-1) |
| |
|
| | w2 = pop[f'{prefix}.layer2.weight'] |
| | b2 = pop[f'{prefix}.layer2.bias'] |
| | out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) |
| |
|
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i, (inp, exp, got) in enumerate(zip(inputs, expected, out[:, 0])): |
| | if exp.item() != got.item(): |
| | failures.append((inp.tolist(), exp.item(), got.item())) |
| |
|
| | self._record(prefix, int(correct[0].item()), len(expected), failures) |
| | return correct |
| |
|
| | def _test_boolean_gates(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test all boolean gates.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== BOOLEAN GATES ===") |
| |
|
| | |
| | for gate in ['and', 'or', 'nand', 'nor', 'implies']: |
| | scores += self._test_single_gate(pop, f'boolean.{gate}', self.tt2, self.expected[gate]) |
| | total += 4 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | |
| | w = pop['boolean.not.weight'] |
| | b = pop['boolean.not.bias'] |
| | out = heaviside(self.not_inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
| | correct = (out == self.expected['not'].unsqueeze(1)).float().sum(0) |
| | scores += correct |
| | total += 2 |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for inp, exp, got in zip(self.not_inputs, self.expected['not'], out[:, 0]): |
| | if exp.item() != got.item(): |
| | failures.append((inp.tolist(), exp.item(), got.item())) |
| | self._record('boolean.not', int(correct[0].item()), 2, failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | |
| | for gate in ['xnor', 'biimplies']: |
| | scores += self._test_twolayer_gate(pop, f'boolean.{gate}', self.tt2, self.expected.get(gate, self.expected['xnor'])) |
| | total += 4 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | |
| | scores += self._test_twolayer_gate(pop, 'boolean.xor', self.tt2, self.expected['xor']) |
| | total += 4 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _eval_xor(self, pop: Dict, prefix: str, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
| | """Evaluate XOR gate with or/nand decomposition. |
| | |
| | Args: |
| | a, b: Tensors of shape [num_tests] or [num_tests, pop_size] |
| | |
| | Returns: |
| | Tensor of shape [num_tests, pop_size] |
| | """ |
| | pop_size = next(iter(pop.values())).shape[0] |
| |
|
| | |
| | if a.dim() == 1: |
| | a = a.unsqueeze(1).expand(-1, pop_size) |
| | if b.dim() == 1: |
| | b = b.unsqueeze(1).expand(-1, pop_size) |
| |
|
| | |
| | inputs = torch.stack([a, b], dim=-1) |
| |
|
| | w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, 2) |
| | b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size) |
| | w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, 2) |
| | b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size) |
| |
|
| | |
| | h_or = heaviside((inputs * w_or).sum(-1) + b_or) |
| | h_nand = heaviside((inputs * w_nand).sum(-1) + b_nand) |
| |
|
| | |
| | hidden = torch.stack([h_or, h_nand], dim=-1) |
| |
|
| | w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, 2) |
| | b2 = pop[f'{prefix}.layer2.bias'].view(pop_size) |
| | return heaviside((hidden * w2).sum(-1) + b2) |
| |
|
| | def _eval_single_fa(self, pop: Dict, prefix: str, |
| | a: torch.Tensor, b: torch.Tensor, cin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Evaluate single full adder. |
| | |
| | Args: |
| | a, b, cin: Tensors of shape [num_tests] or [num_tests, pop_size] |
| | |
| | Returns: |
| | sum_out, cout: Both of shape [num_tests, pop_size] |
| | """ |
| | pop_size = next(iter(pop.values())).shape[0] |
| |
|
| | |
| | if a.dim() == 1: |
| | a = a.unsqueeze(1).expand(-1, pop_size) |
| | if b.dim() == 1: |
| | b = b.unsqueeze(1).expand(-1, pop_size) |
| | if cin.dim() == 1: |
| | cin = cin.unsqueeze(1).expand(-1, pop_size) |
| |
|
| | |
| | ha1_sum = self._eval_xor(pop, f'{prefix}.ha1.sum', a, b) |
| |
|
| | |
| | ab = torch.stack([a, b], dim=-1) |
| | w_c1 = pop[f'{prefix}.ha1.carry.weight'].view(pop_size, 2) |
| | b_c1 = pop[f'{prefix}.ha1.carry.bias'].view(pop_size) |
| | ha1_carry = heaviside((ab * w_c1).sum(-1) + b_c1) |
| |
|
| | |
| | ha2_sum = self._eval_xor(pop, f'{prefix}.ha2.sum', ha1_sum, cin) |
| |
|
| | |
| | sc = torch.stack([ha1_sum, cin], dim=-1) |
| | w_c2 = pop[f'{prefix}.ha2.carry.weight'].view(pop_size, 2) |
| | b_c2 = pop[f'{prefix}.ha2.carry.bias'].view(pop_size) |
| | ha2_carry = heaviside((sc * w_c2).sum(-1) + b_c2) |
| |
|
| | |
| | carries = torch.stack([ha1_carry, ha2_carry], dim=-1) |
| | w_cout = pop[f'{prefix}.carry_or.weight'].view(pop_size, 2) |
| | b_cout = pop[f'{prefix}.carry_or.bias'].view(pop_size) |
| | cout = heaviside((carries * w_cout).sum(-1) + b_cout) |
| |
|
| | return ha2_sum, cout |
| |
|
| | def _test_halfadder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test half adder.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== HALF ADDER ===") |
| |
|
| | |
| | scores += self._test_xor_ornand(pop, 'arithmetic.halfadder.sum', self.tt2, self.expected['ha_sum']) |
| | total += 4 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | |
| | scores += self._test_single_gate(pop, 'arithmetic.halfadder.carry', self.tt2, self.expected['ha_carry']) |
| | total += 4 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return scores, total |
| |
|
| | def _test_fulladder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test full adder with all 8 input combinations.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| |
|
| | if debug: |
| | print("\n=== FULL ADDER ===") |
| |
|
| | a = self.tt3[:, 0] |
| | b = self.tt3[:, 1] |
| | cin = self.tt3[:, 2] |
| |
|
| | sum_out, cout = self._eval_single_fa(pop, 'arithmetic.fulladder', a, b, cin) |
| |
|
| | sum_correct = (sum_out == self.expected['fa_sum'].unsqueeze(1)).float().sum(0) |
| | cout_correct = (cout == self.expected['fa_cout'].unsqueeze(1)).float().sum(0) |
| |
|
| | failures_sum = [] |
| | failures_cout = [] |
| | if pop_size == 1: |
| | for i in range(8): |
| | if sum_out[i, 0].item() != self.expected['fa_sum'][i].item(): |
| | failures_sum.append(([a[i].item(), b[i].item(), cin[i].item()], |
| | self.expected['fa_sum'][i].item(), sum_out[i, 0].item())) |
| | if cout[i, 0].item() != self.expected['fa_cout'][i].item(): |
| | failures_cout.append(([a[i].item(), b[i].item(), cin[i].item()], |
| | self.expected['fa_cout'][i].item(), cout[i, 0].item())) |
| |
|
| | self._record('arithmetic.fulladder.sum', int(sum_correct[0].item()), 8, failures_sum) |
| | self._record('arithmetic.fulladder.cout', int(cout_correct[0].item()), 8, failures_cout) |
| |
|
| | if debug: |
| | for r in self.results[-2:]: |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return sum_correct + cout_correct, 16 |
| |
|
| | def _test_ripplecarry(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test N-bit ripple carry adder.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| |
|
| | if debug: |
| | print(f"\n=== RIPPLE CARRY {bits}-BIT ===") |
| |
|
| | prefix = f'arithmetic.ripplecarry{bits}bit' |
| | max_val = 1 << bits |
| | num_tests = min(max_val * max_val, 65536) |
| |
|
| | if bits <= 4: |
| | |
| | test_a = torch.arange(max_val, device=self.device) |
| | test_b = torch.arange(max_val, device=self.device) |
| | a_vals, b_vals = torch.meshgrid(test_a, test_b, indexing='ij') |
| | a_vals = a_vals.flatten() |
| | b_vals = b_vals.flatten() |
| | else: |
| | |
| | edge_vals = [0, 1, 2, 127, 128, 254, 255] |
| | pairs = [(a, b) for a in edge_vals for b in edge_vals] |
| | for i in range(0, 256, 16): |
| | pairs.append((i, 255 - i)) |
| | pairs = list(set(pairs)) |
| | a_vals = torch.tensor([p[0] for p in pairs], device=self.device) |
| | b_vals = torch.tensor([p[1] for p in pairs], device=self.device) |
| | num_tests = len(pairs) |
| |
|
| | |
| | a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| | b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| |
|
| | |
| | carry = torch.zeros(len(a_vals), pop_size, device=self.device) |
| | sum_bits = [] |
| |
|
| | for bit in range(bits): |
| | bit_idx = bits - 1 - bit |
| | s, carry = self._eval_single_fa( |
| | pop, f'{prefix}.fa{bit}', |
| | a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
| | b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
| | carry |
| | ) |
| | sum_bits.append(s) |
| |
|
| | |
| | sum_bits = torch.stack(sum_bits[::-1], dim=-1) |
| | result = torch.zeros(len(a_vals), pop_size, device=self.device) |
| | for i in range(bits): |
| | result += sum_bits[:, :, i] * (1 << (bits - 1 - i)) |
| |
|
| | |
| | expected = ((a_vals + b_vals) & (max_val - 1)).unsqueeze(1).expand(-1, pop_size).float() |
| | correct = (result == expected).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(min(len(a_vals), 100)): |
| | if result[i, 0].item() != expected[i, 0].item(): |
| | failures.append(( |
| | [int(a_vals[i].item()), int(b_vals[i].item())], |
| | int(expected[i, 0].item()), |
| | int(result[i, 0].item()) |
| | )) |
| |
|
| | self._record(prefix, int(correct[0].item()), num_tests, failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return correct, num_tests |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_add3(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test 3-operand 8-bit adder (A + B + C).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| |
|
| | if debug: |
| | print(f"\n=== 3-OPERAND ADDER ===") |
| |
|
| | prefix = 'arithmetic.add3_8bit' |
| | bits = 8 |
| |
|
| | |
| | |
| | test_cases = [] |
| | |
| | for a in [0, 1, 2]: |
| | for b in [0, 1, 2]: |
| | for c in [0, 1, 2]: |
| | test_cases.append((a, b, c)) |
| | |
| | edge = [0, 1, 127, 128, 254, 255] |
| | for a in edge: |
| | for b in edge: |
| | for c in edge: |
| | test_cases.append((a, b, c)) |
| | |
| | test_cases.extend([ |
| | (15, 27, 33), |
| | (100, 100, 55), |
| | (100, 100, 56), |
| | (85, 85, 85), |
| | (86, 85, 85), |
| | ]) |
| | test_cases = list(set(test_cases)) |
| |
|
| | a_vals = torch.tensor([t[0] for t in test_cases], device=self.device) |
| | b_vals = torch.tensor([t[1] for t in test_cases], device=self.device) |
| | c_vals = torch.tensor([t[2] for t in test_cases], device=self.device) |
| | num_tests = len(test_cases) |
| |
|
| | |
| | a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| | b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| | c_bits = torch.stack([((c_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| |
|
| | |
| | carry1 = torch.zeros(num_tests, pop_size, device=self.device) |
| | stage1_bits = [] |
| | for bit in range(bits): |
| | bit_idx = bits - 1 - bit |
| | s, carry1 = self._eval_single_fa( |
| | pop, f'{prefix}.stage1.fa{bit}', |
| | a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
| | b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
| | carry1 |
| | ) |
| | stage1_bits.append(s) |
| |
|
| | |
| | carry2 = torch.zeros(num_tests, pop_size, device=self.device) |
| | result_bits = [] |
| | for bit in range(bits): |
| | bit_idx = bits - 1 - bit |
| | s, carry2 = self._eval_single_fa( |
| | pop, f'{prefix}.stage2.fa{bit}', |
| | stage1_bits[bit], |
| | c_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
| | carry2 |
| | ) |
| | result_bits.append(s) |
| |
|
| | |
| | result_bits = torch.stack(result_bits[::-1], dim=-1) |
| | result = torch.zeros(num_tests, pop_size, device=self.device) |
| | for i in range(bits): |
| | result += result_bits[:, :, i] * (1 << (bits - 1 - i)) |
| |
|
| | |
| | expected = ((a_vals + b_vals + c_vals) & 0xFF).unsqueeze(1).expand(-1, pop_size).float() |
| | correct = (result == expected).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(min(num_tests, 100)): |
| | if result[i, 0].item() != expected[i, 0].item(): |
| | failures.append(( |
| | [int(a_vals[i].item()), int(b_vals[i].item()), int(c_vals[i].item())], |
| | int(expected[i, 0].item()), |
| | int(result[i, 0].item()) |
| | )) |
| |
|
| | self._record(prefix, int(correct[0].item()), num_tests, failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | if failures: |
| | for inp, exp, got in failures[:5]: |
| | print(f" FAIL: {inp[0]} + {inp[1]} + {inp[2]} = {exp}, got {got}") |
| |
|
| | return correct, num_tests |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_expr_add_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test A + B × C expression circuit (order of operations).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| |
|
| | if debug: |
| | print(f"\n=== ORDER OF OPERATIONS (A + B × C) ===") |
| |
|
| | prefix = 'arithmetic.expr_add_mul' |
| | bits = 8 |
| |
|
| | |
| | test_cases = [] |
| |
|
| | |
| | test_cases.extend([ |
| | (5, 3, 2), |
| | (10, 4, 3), |
| | (1, 10, 10), |
| | (0, 15, 17), |
| | (1, 15, 17), |
| | (100, 5, 5), |
| | ]) |
| |
|
| | |
| | test_cases.extend([ |
| | (0, 0, 0), |
| | (255, 0, 0), |
| | (0, 255, 1), |
| | (0, 1, 255), |
| | (1, 1, 1), |
| | (0, 16, 16), |
| | ]) |
| |
|
| | |
| | for a in [0, 1, 5, 10]: |
| | for b in [0, 1, 2, 3]: |
| | for c in [0, 1, 2, 3]: |
| | test_cases.append((a, b, c)) |
| |
|
| | |
| | test_cases = list(set(test_cases)) |
| |
|
| | a_vals = torch.tensor([t[0] for t in test_cases], device=self.device) |
| | b_vals = torch.tensor([t[1] for t in test_cases], device=self.device) |
| | c_vals = torch.tensor([t[2] for t in test_cases], device=self.device) |
| | num_tests = len(test_cases) |
| |
|
| | |
| | a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| | b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| | c_bits = torch.stack([((c_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| |
|
| | |
| | |
| | |
| | |
| | masks = torch.zeros(8, num_tests, pop_size, 8, device=self.device) |
| | for stage in range(8): |
| | c_stage_bit = c_bits[:, 7 - stage].unsqueeze(1).expand(-1, pop_size) |
| | for bit in range(8): |
| | b_bit_val = b_bits[:, 7 - bit].unsqueeze(1).expand(-1, pop_size) |
| | |
| | w = pop.get(f'{prefix}.mul.mask.s{stage}.b{bit}.weight') |
| | bias = pop.get(f'{prefix}.mul.mask.s{stage}.b{bit}.bias') |
| | if w is not None and bias is not None: |
| | w = w.squeeze(-1) |
| | b_tensor = bias.squeeze(-1) |
| | |
| | inp = torch.stack([b_bit_val, c_stage_bit], dim=-1) |
| | out = heaviside(torch.einsum('tpi,pi->tp', inp, w) + b_tensor) |
| | masks[stage, :, :, bit] = out |
| |
|
| | |
| | |
| | |
| | |
| | |
| | acc = masks[0].clone() |
| |
|
| | for stage in range(1, 8): |
| | |
| | |
| | shifted_mask = torch.zeros(num_tests, pop_size, 8, device=self.device) |
| | for bit in range(8): |
| | if bit >= stage: |
| | shifted_mask[:, :, bit] = masks[stage, :, :, bit - stage] |
| | |
| |
|
| | |
| | carry = torch.zeros(num_tests, pop_size, device=self.device) |
| | new_acc = torch.zeros(num_tests, pop_size, 8, device=self.device) |
| | for bit in range(8): |
| | s, carry = self._eval_single_fa( |
| | pop, f'{prefix}.mul.acc.s{stage}.fa{bit}', |
| | acc[:, :, bit], |
| | shifted_mask[:, :, bit], |
| | carry |
| | ) |
| | new_acc[:, :, bit] = s |
| | acc = new_acc |
| |
|
| | |
| | carry = torch.zeros(num_tests, pop_size, device=self.device) |
| | result_bits = [] |
| | for bit in range(8): |
| | a_bit_val = a_bits[:, 7 - bit].unsqueeze(1).expand(-1, pop_size) |
| | s, carry = self._eval_single_fa( |
| | pop, f'{prefix}.add.fa{bit}', |
| | a_bit_val, |
| | acc[:, :, bit], |
| | carry |
| | ) |
| | result_bits.append(s) |
| |
|
| | |
| | result_bits = torch.stack(result_bits[::-1], dim=-1) |
| | result = torch.zeros(num_tests, pop_size, device=self.device) |
| | for i in range(bits): |
| | result += result_bits[:, :, i] * (1 << (bits - 1 - i)) |
| |
|
| | |
| | expected = ((a_vals + b_vals * c_vals) & 0xFF).unsqueeze(1).expand(-1, pop_size).float() |
| | correct = (result == expected).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(min(num_tests, 100)): |
| | if result[i, 0].item() != expected[i, 0].item(): |
| | failures.append(( |
| | [int(a_vals[i].item()), int(b_vals[i].item()), int(c_vals[i].item())], |
| | int(expected[i, 0].item()), |
| | int(result[i, 0].item()) |
| | )) |
| |
|
| | self._record(prefix, int(correct[0].item()), num_tests, failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | if failures: |
| | for inp, exp, got in failures[:5]: |
| | print(f" FAIL: {inp[0]} + {inp[1]} × {inp[2]} = {exp}, got {got}") |
| |
|
| | return correct, num_tests |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_comparator(self, pop: Dict, name: str, op: Callable[[int, int], bool], |
| | debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test 8-bit comparator.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | prefix = f'arithmetic.{name}' |
| |
|
| | |
| | expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0 |
| | for a, b in zip(self.comp_a, self.comp_b)], |
| | device=self.device) |
| |
|
| | |
| | a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
| | b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
| | inputs = torch.cat([a_bits, b_bits], dim=1) |
| |
|
| | w = pop[f'{prefix}.weight'] |
| | b = pop[f'{prefix}.bias'] |
| | out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
| |
|
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(len(self.comp_a)): |
| | if out[i, 0].item() != expected[i].item(): |
| | failures.append(( |
| | [int(self.comp_a[i].item()), int(self.comp_b[i].item())], |
| | expected[i].item(), |
| | out[i, 0].item() |
| | )) |
| |
|
| | self._record(prefix, int(correct[0].item()), len(self.comp_a), failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return correct, len(self.comp_a) |
| |
|
| | def _test_comparators(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test all comparators.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== COMPARATORS ===") |
| |
|
| | comparators = [ |
| | ('greaterthan8bit', lambda a, b: a > b), |
| | ('lessthan8bit', lambda a, b: a < b), |
| | ('greaterorequal8bit', lambda a, b: a >= b), |
| | ('lessorequal8bit', lambda a, b: a <= b), |
| | ('equality8bit', lambda a, b: a == b), |
| | ] |
| |
|
| | for name, op in comparators: |
| | if name == 'equality8bit': |
| | continue |
| | try: |
| | s, t = self._test_comparator(pop, name, op, debug) |
| | scores += s |
| | total += t |
| | except KeyError: |
| | pass |
| |
|
| | |
| | try: |
| | prefix = 'arithmetic.equality8bit' |
| | expected = torch.tensor([1.0 if a.item() == b.item() else 0.0 |
| | for a, b in zip(self.comp_a, self.comp_b)], |
| | device=self.device) |
| |
|
| | a_bits = torch.stack([((self.comp_a >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
| | b_bits = torch.stack([((self.comp_b >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
| | inputs = torch.cat([a_bits, b_bits], dim=1) |
| |
|
| | |
| | w_geq = pop[f'{prefix}.layer1.geq.weight'] |
| | b_geq = pop[f'{prefix}.layer1.geq.bias'] |
| | w_leq = pop[f'{prefix}.layer1.leq.weight'] |
| | b_leq = pop[f'{prefix}.layer1.leq.bias'] |
| |
|
| | h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size)) |
| | h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size)) |
| | hidden = torch.stack([h_geq, h_leq], dim=-1) |
| |
|
| | |
| | w2 = pop[f'{prefix}.layer2.weight'] |
| | b2 = pop[f'{prefix}.layer2.bias'] |
| | out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) |
| |
|
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(len(self.comp_a)): |
| | if out[i, 0].item() != expected[i].item(): |
| | failures.append(( |
| | [int(self.comp_a[i].item()), int(self.comp_b[i].item())], |
| | expected[i].item(), |
| | out[i, 0].item() |
| | )) |
| |
|
| | self._record(prefix, int(correct[0].item()), len(self.comp_a), failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | scores += correct |
| | total += len(self.comp_a) |
| | except KeyError: |
| | pass |
| |
|
| | return scores, total |
| |
|
| | def _test_comparators_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test N-bit comparator circuits (GT, LT, GE, LE, EQ).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print(f"\n=== {bits}-BIT COMPARATORS ===") |
| |
|
| | if bits == 32: |
| | comp_a = self.comp32_a |
| | comp_b = self.comp32_b |
| | elif bits == 16: |
| | comp_a = self.comp_a.clamp(0, 65535) |
| | comp_b = self.comp_b.clamp(0, 65535) |
| | else: |
| | comp_a = self.comp_a |
| | comp_b = self.comp_b |
| |
|
| | num_tests = len(comp_a) |
| |
|
| | if bits <= 16: |
| | a_bits = torch.stack([((comp_a >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| | b_bits = torch.stack([((comp_b >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| | inputs = torch.cat([a_bits, b_bits], dim=1) |
| |
|
| | comparators = [ |
| | (f'arithmetic.greaterthan{bits}bit', lambda a, b: a > b), |
| | (f'arithmetic.greaterorequal{bits}bit', lambda a, b: a >= b), |
| | (f'arithmetic.lessthan{bits}bit', lambda a, b: a < b), |
| | (f'arithmetic.lessorequal{bits}bit', lambda a, b: a <= b), |
| | ] |
| |
|
| | for name, op in comparators: |
| | try: |
| | expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0 |
| | for a, b in zip(comp_a, comp_b)], device=self.device) |
| | w = pop[f'{name}.weight'] |
| | b = pop[f'{name}.bias'] |
| | out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(num_tests): |
| | if out[i, 0].item() != expected[i].item(): |
| | failures.append(([int(comp_a[i].item()), int(comp_b[i].item())], |
| | expected[i].item(), out[i, 0].item())) |
| | self._record(name, int(correct[0].item()), num_tests, failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | scores += correct |
| | total += num_tests |
| | except KeyError: |
| | pass |
| |
|
| | prefix = f'arithmetic.equality{bits}bit' |
| | try: |
| | expected = torch.tensor([1.0 if a.item() == b.item() else 0.0 |
| | for a, b in zip(comp_a, comp_b)], device=self.device) |
| | w_geq = pop[f'{prefix}.layer1.geq.weight'] |
| | b_geq = pop[f'{prefix}.layer1.geq.bias'] |
| | w_leq = pop[f'{prefix}.layer1.leq.weight'] |
| | b_leq = pop[f'{prefix}.layer1.leq.bias'] |
| | h_geq = heaviside(inputs @ w_geq.view(pop_size, -1).T + b_geq.view(pop_size)) |
| | h_leq = heaviside(inputs @ w_leq.view(pop_size, -1).T + b_leq.view(pop_size)) |
| | hidden = torch.stack([h_geq, h_leq], dim=-1) |
| | w2 = pop[f'{prefix}.layer2.weight'] |
| | b2 = pop[f'{prefix}.layer2.bias'] |
| | out = heaviside((hidden * w2.view(pop_size, 1, 2)).sum(-1) + b2.view(pop_size)) |
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(num_tests): |
| | if out[i, 0].item() != expected[i].item(): |
| | failures.append(([int(comp_a[i].item()), int(comp_b[i].item())], |
| | expected[i].item(), out[i, 0].item())) |
| | self._record(prefix, int(correct[0].item()), num_tests, failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | scores += correct |
| | total += num_tests |
| | except KeyError: |
| | pass |
| | else: |
| | num_bytes = bits // 8 |
| | prefix = f"arithmetic.cmp{bits}bit" |
| |
|
| | byte_gt = [] |
| | byte_lt = [] |
| | byte_eq = [] |
| |
|
| | for b in range(num_bytes): |
| | start_bit = b * 8 |
| | a_byte = torch.stack([((comp_a >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1) |
| | b_byte = torch.stack([((comp_b >> (bits - 1 - start_bit - i)) & 1).float() for i in range(8)], dim=1) |
| | byte_input = torch.cat([a_byte, b_byte], dim=1) |
| |
|
| | w_gt = pop[f'{prefix}.byte{b}.gt.weight'].view(pop_size, -1) |
| | b_gt = pop[f'{prefix}.byte{b}.gt.bias'].view(pop_size) |
| | byte_gt.append(heaviside(byte_input @ w_gt.T + b_gt)) |
| |
|
| | w_lt = pop[f'{prefix}.byte{b}.lt.weight'].view(pop_size, -1) |
| | b_lt = pop[f'{prefix}.byte{b}.lt.bias'].view(pop_size) |
| | byte_lt.append(heaviside(byte_input @ w_lt.T + b_lt)) |
| |
|
| | w_geq = pop[f'{prefix}.byte{b}.eq.geq.weight'].view(pop_size, -1) |
| | b_geq = pop[f'{prefix}.byte{b}.eq.geq.bias'].view(pop_size) |
| | w_leq = pop[f'{prefix}.byte{b}.eq.leq.weight'].view(pop_size, -1) |
| | b_leq = pop[f'{prefix}.byte{b}.eq.leq.bias'].view(pop_size) |
| | h_geq = heaviside(byte_input @ w_geq.T + b_geq) |
| | h_leq = heaviside(byte_input @ w_leq.T + b_leq) |
| | w_and = pop[f'{prefix}.byte{b}.eq.and.weight'].view(pop_size, -1) |
| | b_and = pop[f'{prefix}.byte{b}.eq.and.bias'].view(pop_size) |
| | eq_inp = torch.stack([h_geq, h_leq], dim=-1) |
| | byte_eq.append(heaviside((eq_inp * w_and).sum(-1) + b_and)) |
| |
|
| | cascade_gt = [] |
| | cascade_lt = [] |
| | for b in range(num_bytes): |
| | if b == 0: |
| | cascade_gt.append(byte_gt[0]) |
| | cascade_lt.append(byte_lt[0]) |
| | else: |
| | eq_stack = torch.stack(byte_eq[:b], dim=-1) |
| | w_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.weight'].view(pop_size, -1) |
| | b_all_eq = pop[f'{prefix}.cascade.gt.stage{b}.all_eq.bias'].view(pop_size) |
| | all_eq_gt = heaviside((eq_stack * w_all_eq).sum(-1) + b_all_eq) |
| | w_and = pop[f'{prefix}.cascade.gt.stage{b}.and.weight'].view(pop_size, -1) |
| | b_and = pop[f'{prefix}.cascade.gt.stage{b}.and.bias'].view(pop_size) |
| | stage_inp = torch.stack([all_eq_gt, byte_gt[b]], dim=-1) |
| | cascade_gt.append(heaviside((stage_inp * w_and).sum(-1) + b_and)) |
| |
|
| | w_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.weight'].view(pop_size, -1) |
| | b_all_eq_lt = pop[f'{prefix}.cascade.lt.stage{b}.all_eq.bias'].view(pop_size) |
| | all_eq_lt = heaviside((eq_stack * w_all_eq_lt).sum(-1) + b_all_eq_lt) |
| | w_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.weight'].view(pop_size, -1) |
| | b_and_lt = pop[f'{prefix}.cascade.lt.stage{b}.and.bias'].view(pop_size) |
| | stage_inp_lt = torch.stack([all_eq_lt, byte_lt[b]], dim=-1) |
| | cascade_lt.append(heaviside((stage_inp_lt * w_and_lt).sum(-1) + b_and_lt)) |
| |
|
| | gt_stack = torch.stack(cascade_gt, dim=-1) |
| | w_gt_or = pop[f'arithmetic.greaterthan{bits}bit.weight'].view(pop_size, -1) |
| | b_gt_or = pop[f'arithmetic.greaterthan{bits}bit.bias'].view(pop_size) |
| | gt_out = heaviside((gt_stack * w_gt_or).sum(-1) + b_gt_or) |
| |
|
| | lt_stack = torch.stack(cascade_lt, dim=-1) |
| | w_lt_or = pop[f'arithmetic.lessthan{bits}bit.weight'].view(pop_size, -1) |
| | b_lt_or = pop[f'arithmetic.lessthan{bits}bit.bias'].view(pop_size) |
| | lt_out = heaviside((lt_stack * w_lt_or).sum(-1) + b_lt_or) |
| |
|
| | w_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.weight'].view(pop_size, -1) |
| | b_not_lt = pop[f'arithmetic.greaterorequal{bits}bit.not_lt.bias'].view(pop_size) |
| | not_lt = heaviside(lt_out.unsqueeze(-1) @ w_not_lt.T + b_not_lt).squeeze(-1) |
| | w_ge = pop[f'arithmetic.greaterorequal{bits}bit.weight'].view(pop_size, -1) |
| | b_ge = pop[f'arithmetic.greaterorequal{bits}bit.bias'].view(pop_size) |
| | ge_out = heaviside(not_lt.unsqueeze(-1) @ w_ge.T + b_ge).squeeze(-1) |
| |
|
| | w_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.weight'].view(pop_size, -1) |
| | b_not_gt = pop[f'arithmetic.lessorequal{bits}bit.not_gt.bias'].view(pop_size) |
| | not_gt = heaviside(gt_out.unsqueeze(-1) @ w_not_gt.T + b_not_gt).squeeze(-1) |
| | w_le = pop[f'arithmetic.lessorequal{bits}bit.weight'].view(pop_size, -1) |
| | b_le = pop[f'arithmetic.lessorequal{bits}bit.bias'].view(pop_size) |
| | le_out = heaviside(not_gt.unsqueeze(-1) @ w_le.T + b_le).squeeze(-1) |
| |
|
| | eq_stack = torch.stack(byte_eq, dim=-1) |
| | w_eq_all = pop[f'arithmetic.equality{bits}bit.weight'].view(pop_size, -1) |
| | b_eq_all = pop[f'arithmetic.equality{bits}bit.bias'].view(pop_size) |
| | eq_out = heaviside((eq_stack * w_eq_all).sum(-1) + b_eq_all) |
| |
|
| | for name, out, op in [ |
| | (f'arithmetic.greaterthan{bits}bit', gt_out, lambda a, b: a > b), |
| | (f'arithmetic.greaterorequal{bits}bit', ge_out, lambda a, b: a >= b), |
| | (f'arithmetic.lessthan{bits}bit', lt_out, lambda a, b: a < b), |
| | (f'arithmetic.lessorequal{bits}bit', le_out, lambda a, b: a <= b), |
| | (f'arithmetic.equality{bits}bit', eq_out, lambda a, b: a == b), |
| | ]: |
| | expected = torch.tensor([1.0 if op(a.item(), b.item()) else 0.0 |
| | for a, b in zip(comp_a, comp_b)], device=self.device) |
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(num_tests): |
| | if out[i, 0].item() != expected[i].item(): |
| | failures.append(([int(comp_a[i].item()), int(comp_b[i].item())], |
| | expected[i].item(), out[i, 0].item())) |
| | self._record(name, int(correct[0].item()), num_tests, failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | scores += correct |
| | total += num_tests |
| |
|
| | return scores, total |
| |
|
| | def _test_subtractor_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test N-bit subtractor circuit (A - B).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| |
|
| | if debug: |
| | print(f"\n=== {bits}-BIT SUBTRACTOR ===") |
| |
|
| | prefix = f'arithmetic.sub{bits}bit' |
| | max_val = 1 << bits |
| |
|
| | if bits == 32: |
| | test_pairs = [ |
| | (1000, 500), (5000, 3000), (1000000, 500000), |
| | (0xFFFFFFFF, 1), (0x80000000, 1), (100, 100), |
| | (0, 0), (1, 0), (0, 1), (256, 255), |
| | (0xDEADBEEF, 0xCAFEBABE), (1000000000, 999999999), |
| | ] |
| | else: |
| | test_pairs = [(a, b) for a in [0, 1, 127, 128, 255] for b in [0, 1, 127, 128, 255]] |
| |
|
| | a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long) |
| | b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long) |
| | num_tests = len(test_pairs) |
| |
|
| | a_bits = torch.stack([((a_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| | b_bits = torch.stack([((b_vals >> (bits - 1 - i)) & 1).float() for i in range(bits)], dim=1) |
| |
|
| | not_b_bits = torch.zeros_like(b_bits) |
| | for bit in range(bits): |
| | w = pop[f'{prefix}.not_b.bit{bit}.weight'].view(pop_size, -1) |
| | b = pop[f'{prefix}.not_b.bit{bit}.bias'].view(pop_size) |
| | not_b_bits[:, bit] = heaviside(b_bits[:, bit:bit+1] @ w.T + b)[:, 0] |
| |
|
| | carry = torch.ones(num_tests, pop_size, device=self.device) |
| | sum_bits = [] |
| |
|
| | for bit in range(bits): |
| | bit_idx = bits - 1 - bit |
| | s, carry = self._eval_single_fa( |
| | pop, f'{prefix}.fa{bit}', |
| | a_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
| | not_b_bits[:, bit_idx].unsqueeze(1).expand(-1, pop_size), |
| | carry |
| | ) |
| | sum_bits.append(s) |
| |
|
| | sum_bits = torch.stack(sum_bits[::-1], dim=-1) |
| | result = torch.zeros(num_tests, pop_size, device=self.device) |
| | for i in range(bits): |
| | result += sum_bits[:, :, i] * (1 << (bits - 1 - i)) |
| |
|
| | expected = ((a_vals - b_vals) & (max_val - 1)).unsqueeze(1).expand(-1, pop_size).float() |
| | correct = (result == expected).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(min(num_tests, 20)): |
| | if result[i, 0].item() != expected[i, 0].item(): |
| | failures.append(( |
| | [int(a_vals[i].item()), int(b_vals[i].item())], |
| | int(expected[i, 0].item()), |
| | int(result[i, 0].item()) |
| | )) |
| |
|
| | self._record(prefix, int(correct[0].item()), num_tests, failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return correct, num_tests |
| |
|
| | def _test_bitwise_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test N-bit bitwise operations (AND, OR, XOR, NOT).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print(f"\n=== {bits}-BIT BITWISE OPS ===") |
| |
|
| | if bits == 32: |
| | test_pairs = [ |
| | (0xAAAAAAAA, 0x55555555), (0xFFFFFFFF, 0x00000000), |
| | (0x12345678, 0x87654321), (0xDEADBEEF, 0xCAFEBABE), |
| | (0x0F0F0F0F, 0xF0F0F0F0), (0, 0), (0xFFFFFFFF, 0xFFFFFFFF), |
| | ] |
| | else: |
| | test_pairs = [(0xAA, 0x55), (0xFF, 0x00), (0x0F, 0xF0)] |
| |
|
| | a_vals = torch.tensor([p[0] for p in test_pairs], device=self.device, dtype=torch.long) |
| | b_vals = torch.tensor([p[1] for p in test_pairs], device=self.device, dtype=torch.long) |
| | num_tests = len(test_pairs) |
| |
|
| | ops = [ |
| | ('and', lambda a, b: a & b), |
| | ('or', lambda a, b: a | b), |
| | ('xor', lambda a, b: a ^ b), |
| | ] |
| |
|
| | for op_name, op_fn in ops: |
| | try: |
| | result_bits = [] |
| | for bit in range(bits): |
| | a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float() |
| | b_bit = ((b_vals >> (bits - 1 - bit)) & 1).float() |
| |
|
| | if op_name == 'xor': |
| | prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}' |
| | w_or = pop[f'{prefix}.layer1.or.weight'].view(pop_size, -1) |
| | b_or = pop[f'{prefix}.layer1.or.bias'].view(pop_size) |
| | w_nand = pop[f'{prefix}.layer1.nand.weight'].view(pop_size, -1) |
| | b_nand = pop[f'{prefix}.layer1.nand.bias'].view(pop_size) |
| | inp = torch.stack([a_bit, b_bit], dim=-1) |
| | h_or = heaviside(inp @ w_or.T + b_or) |
| | h_nand = heaviside(inp @ w_nand.T + b_nand) |
| | hidden = torch.stack([h_or, h_nand], dim=-1) |
| | w2 = pop[f'{prefix}.layer2.weight'].view(pop_size, -1) |
| | b2 = pop[f'{prefix}.layer2.bias'].view(pop_size) |
| | out = heaviside((hidden * w2).sum(-1) + b2) |
| | else: |
| | w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size, -1) |
| | b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size) |
| | inp = torch.stack([a_bit, b_bit], dim=-1) |
| | out = heaviside(inp @ w.T + b) |
| |
|
| | result_bits.append(out[:, 0] if out.dim() > 1 else out) |
| |
|
| | result = sum(int(result_bits[i][j].item()) << (bits - 1 - i) |
| | for i in range(bits) for j in range(1)) |
| | results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i) |
| | for i in range(bits)) for j in range(num_tests)], |
| | device=self.device) |
| | expected = torch.tensor([op_fn(a.item(), b.item()) for a, b in zip(a_vals, b_vals)], |
| | device=self.device) |
| |
|
| | correct = (results == expected).float().sum() |
| | self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | scores += correct |
| | total += num_tests |
| | except KeyError as e: |
| | if debug: |
| | print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})") |
| |
|
| | try: |
| | test_vals = a_vals |
| | result_bits = [] |
| | for bit in range(bits): |
| | a_bit = ((test_vals >> (bits - 1 - bit)) & 1).float() |
| | w = pop[f'alu.alu{bits}bit.not.bit{bit}.weight'].view(pop_size, -1) |
| | b = pop[f'alu.alu{bits}bit.not.bit{bit}.bias'].view(pop_size) |
| | out = heaviside(a_bit.unsqueeze(-1) @ w.T + b) |
| | result_bits.append(out[:, 0]) |
| |
|
| | results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i) |
| | for i in range(bits)) for j in range(num_tests)], |
| | device=self.device) |
| | expected = torch.tensor([(~a.item()) & ((1 << bits) - 1) for a in test_vals], |
| | device=self.device) |
| |
|
| | correct = (results == expected).float().sum() |
| | self._record(f'alu.alu{bits}bit.not', int(correct.item()), num_tests, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | scores += correct |
| | total += num_tests |
| | except KeyError as e: |
| | if debug: |
| | print(f" alu.alu{bits}bit.not: SKIP (missing {e})") |
| |
|
| | return scores, total |
| |
|
| | def _test_shifts_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test N-bit shift operations (SHL, SHR).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print(f"\n=== {bits}-BIT SHIFTS ===") |
| |
|
| | if bits == 32: |
| | test_vals = [0x12345678, 0x80000001, 0x00000001, 0xFFFFFFFF, 0x55555555] |
| | else: |
| | test_vals = [0x81, 0x55, 0x01, 0xFF, 0xAA] |
| |
|
| | a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long) |
| | num_tests = len(test_vals) |
| | max_val = (1 << bits) - 1 |
| |
|
| | for op_name, op_fn in [('shl', lambda x: (x << 1) & max_val), ('shr', lambda x: x >> 1)]: |
| | try: |
| | result_bits = [] |
| | for bit in range(bits): |
| | a_bit = ((a_vals >> (bits - 1 - bit)) & 1).float() |
| | w = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.weight'].view(pop_size) |
| | b = pop[f'alu.alu{bits}bit.{op_name}.bit{bit}.bias'].view(pop_size) |
| |
|
| | if op_name == 'shl': |
| | if bit < bits - 1: |
| | src_bit = ((a_vals >> (bits - 2 - bit)) & 1).float() |
| | else: |
| | src_bit = torch.zeros_like(a_bit) |
| | else: |
| | if bit > 0: |
| | src_bit = ((a_vals >> (bits - bit)) & 1).float() |
| | else: |
| | src_bit = torch.zeros_like(a_bit) |
| |
|
| | out = heaviside(src_bit * w + b) |
| | result_bits.append(out) |
| |
|
| | results = torch.tensor([sum(int(result_bits[i][j].item()) << (bits - 1 - i) |
| | for i in range(bits)) for j in range(num_tests)], |
| | device=self.device) |
| | expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device) |
| |
|
| | correct = (results == expected).float().sum() |
| | self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | scores += correct |
| | total += num_tests |
| | except KeyError as e: |
| | if debug: |
| | print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})") |
| |
|
| | return scores, total |
| |
|
| | def _test_inc_dec_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test N-bit INC and DEC operations.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print(f"\n=== {bits}-BIT INC/DEC ===") |
| |
|
| | if bits == 32: |
| | test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000000, 0xFFFFFFFE] |
| | else: |
| | test_vals = [0, 1, 254, 255, 127, 128] |
| |
|
| | a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long) |
| | num_tests = len(test_vals) |
| | max_val = (1 << bits) - 1 |
| |
|
| | for op_name, op_fn in [('inc', lambda x: (x + 1) & max_val), ('dec', lambda x: (x - 1) & max_val)]: |
| | try: |
| | carry = torch.ones(num_tests, device=self.device) |
| | result_bits = [] |
| |
|
| | for bit in range(bits): |
| | a_bit = ((a_vals >> bit) & 1).float() |
| |
|
| | prefix = f'alu.alu{bits}bit.{op_name}.bit{bit}' |
| | w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten() |
| | b_or = pop[f'{prefix}.xor.layer1.or.bias'].item() |
| | w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten() |
| | b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item() |
| |
|
| | h_or = heaviside(a_bit * w_or[0] + carry * w_or[1] + b_or) |
| | h_nand = heaviside(a_bit * w_nand[0] + carry * w_nand[1] + b_nand) |
| |
|
| | w2 = pop[f'{prefix}.xor.layer2.weight'].flatten() |
| | b2 = pop[f'{prefix}.xor.layer2.bias'].item() |
| | xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2) |
| | result_bits.append(xor_out) |
| |
|
| | if op_name == 'inc': |
| | w_carry = pop[f'{prefix}.carry.weight'].flatten() |
| | b_carry = pop[f'{prefix}.carry.bias'].item() |
| | carry = heaviside(a_bit * w_carry[0] + carry * w_carry[1] + b_carry) |
| | else: |
| | w_not = pop[f'{prefix}.not_a.weight'].flatten() |
| | b_not = pop[f'{prefix}.not_a.bias'].item() |
| | not_a = heaviside(a_bit * w_not[0] + b_not) |
| | w_borrow = pop[f'{prefix}.borrow.weight'].flatten() |
| | b_borrow = pop[f'{prefix}.borrow.bias'].item() |
| | carry = heaviside(not_a * w_borrow[0] + carry * w_borrow[1] + b_borrow) |
| |
|
| | results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit |
| | for bit in range(bits)) for j in range(num_tests)], |
| | device=self.device) |
| | expected = torch.tensor([op_fn(a.item()) for a in a_vals], device=self.device) |
| |
|
| | correct = (results == expected).float().sum() |
| | self._record(f'alu.alu{bits}bit.{op_name}', int(correct.item()), num_tests, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | scores += correct |
| | total += num_tests |
| | except KeyError as e: |
| | if debug: |
| | print(f" alu.alu{bits}bit.{op_name}: SKIP (missing {e})") |
| |
|
| | return scores, total |
| |
|
| | def _test_neg_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test N-bit NEG operation (two's complement negation).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| |
|
| | if debug: |
| | print(f"\n=== {bits}-BIT NEG ===") |
| |
|
| | if bits == 32: |
| | test_vals = [0, 1, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 1000, 1000000] |
| | else: |
| | test_vals = [0, 1, 127, 128, 255, 100] |
| |
|
| | a_vals = torch.tensor(test_vals, device=self.device, dtype=torch.long) |
| | num_tests = len(test_vals) |
| | max_val = (1 << bits) - 1 |
| |
|
| | try: |
| | not_bits = [] |
| | for bit in range(bits): |
| | a_bit = ((a_vals >> bit) & 1).float() |
| | w = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.weight'].flatten() |
| | b = pop[f'alu.alu{bits}bit.neg.not.bit{bit}.bias'].item() |
| | not_bits.append(heaviside(a_bit * w[0] + b)) |
| |
|
| | carry = torch.ones(num_tests, device=self.device) |
| | result_bits = [] |
| |
|
| | for bit in range(bits): |
| | prefix = f'alu.alu{bits}bit.neg.inc.bit{bit}' |
| | not_bit = not_bits[bit] |
| |
|
| | w_or = pop[f'{prefix}.xor.layer1.or.weight'].flatten() |
| | b_or = pop[f'{prefix}.xor.layer1.or.bias'].item() |
| | w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].flatten() |
| | b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].item() |
| |
|
| | h_or = heaviside(not_bit * w_or[0] + carry * w_or[1] + b_or) |
| | h_nand = heaviside(not_bit * w_nand[0] + carry * w_nand[1] + b_nand) |
| |
|
| | w2 = pop[f'{prefix}.xor.layer2.weight'].flatten() |
| | b2 = pop[f'{prefix}.xor.layer2.bias'].item() |
| | xor_out = heaviside(h_or * w2[0] + h_nand * w2[1] + b2) |
| | result_bits.append(xor_out) |
| |
|
| | w_carry = pop[f'{prefix}.carry.weight'].flatten() |
| | b_carry = pop[f'{prefix}.carry.bias'].item() |
| | carry = heaviside(not_bit * w_carry[0] + carry * w_carry[1] + b_carry) |
| |
|
| | results = torch.tensor([sum(int(result_bits[bit][j].item()) << bit |
| | for bit in range(bits)) for j in range(num_tests)], |
| | device=self.device) |
| | expected = torch.tensor([(-a.item()) & max_val for a in a_vals], device=self.device) |
| |
|
| | correct = (results == expected).float().sum() |
| | self._record(f'alu.alu{bits}bit.neg', int(correct.item()), num_tests, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return torch.tensor([correct], device=self.device), num_tests |
| | except KeyError as e: |
| | if debug: |
| | print(f" alu.alu{bits}bit.neg: SKIP (missing {e})") |
| | return torch.zeros(pop_size, device=self.device), 0 |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_threshold_kofn(self, pop: Dict, k: int, name: str, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test k-of-n threshold gate.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | prefix = f'threshold.{name}' |
| |
|
| | |
| | inputs = self.test_8bit_bits if len(self.test_8bit_bits) == 24 else None |
| | if inputs is None: |
| | test_vals = torch.arange(256, device=self.device, dtype=torch.long) |
| | inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
| |
|
| | |
| | |
| | popcounts = inputs.sum(dim=1) |
| |
|
| | if 'atleast' in name: |
| | expected = (popcounts >= k).float() |
| | elif 'atmost' in name or 'minority' in name: |
| | |
| | expected = (popcounts <= k).float() |
| | elif 'exactly' in name: |
| | expected = (popcounts == k).float() |
| | else: |
| | |
| | expected = (popcounts >= k).float() |
| |
|
| | w = pop[f'{prefix}.weight'] |
| | b = pop[f'{prefix}.bias'] |
| | out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
| |
|
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(min(len(inputs), 256)): |
| | if out[i, 0].item() != expected[i].item(): |
| | val = int(sum(inputs[i, j].item() * (1 << (7 - j)) for j in range(8))) |
| | failures.append((val, expected[i].item(), out[i, 0].item())) |
| |
|
| | self._record(prefix, int(correct[0].item()), len(inputs), failures[:10]) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return correct, len(inputs) |
| |
|
| | def _test_threshold_gates(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test all threshold gates.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== THRESHOLD GATES ===") |
| |
|
| | |
| | kofn_gates = [ |
| | (1, 'oneoutof8'), (2, 'twooutof8'), (3, 'threeoutof8'), (4, 'fouroutof8'), |
| | (5, 'fiveoutof8'), (6, 'sixoutof8'), (7, 'sevenoutof8'), (8, 'alloutof8'), |
| | ] |
| |
|
| | for k, name in kofn_gates: |
| | try: |
| | s, t = self._test_threshold_kofn(pop, k, name, debug) |
| | scores += s |
| | total += t |
| | except KeyError: |
| | pass |
| |
|
| | |
| | special = [ |
| | (5, 'majority'), (3, 'minority'), |
| | (4, 'atleastk_4'), (4, 'atmostk_4'), (4, 'exactlyk_4'), |
| | ] |
| |
|
| | for k, name in special: |
| | try: |
| | s, t = self._test_threshold_kofn(pop, k, name, debug) |
| | scores += s |
| | total += t |
| | except KeyError: |
| | pass |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_modular(self, pop: Dict, mod: int, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test modular divisibility circuit (multi-layer for non-powers-of-2).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | prefix = f'modular.mod{mod}' |
| |
|
| | |
| | inputs = torch.stack([((self.mod_test >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
| | expected = ((self.mod_test % mod) == 0).float() |
| |
|
| | |
| | try: |
| | w = pop[f'{prefix}.weight'] |
| | b = pop[f'{prefix}.bias'] |
| | out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
| | except KeyError: |
| | |
| | try: |
| | |
| | geq_outputs = {} |
| | leq_outputs = {} |
| | i = 0 |
| | while True: |
| | found = False |
| | if f'{prefix}.layer1.geq{i}.weight' in pop: |
| | w = pop[f'{prefix}.layer1.geq{i}.weight'].view(pop_size, -1) |
| | b = pop[f'{prefix}.layer1.geq{i}.bias'].view(pop_size) |
| | geq_outputs[i] = heaviside(inputs @ w.T + b) |
| | found = True |
| | if f'{prefix}.layer1.leq{i}.weight' in pop: |
| | w = pop[f'{prefix}.layer1.leq{i}.weight'].view(pop_size, -1) |
| | b = pop[f'{prefix}.layer1.leq{i}.bias'].view(pop_size) |
| | leq_outputs[i] = heaviside(inputs @ w.T + b) |
| | found = True |
| | if not found: |
| | break |
| | i += 1 |
| |
|
| | if not geq_outputs and not leq_outputs: |
| | return torch.zeros(pop_size, device=self.device), 0 |
| |
|
| | |
| | eq_outputs = [] |
| | i = 0 |
| | while f'{prefix}.layer2.eq{i}.weight' in pop: |
| | w = pop[f'{prefix}.layer2.eq{i}.weight'].view(pop_size, -1) |
| | b = pop[f'{prefix}.layer2.eq{i}.bias'].view(pop_size) |
| | |
| | eq_in = torch.stack([geq_outputs.get(i, torch.zeros(256, pop_size, device=self.device)), |
| | leq_outputs.get(i, torch.zeros(256, pop_size, device=self.device))], dim=-1) |
| | eq_out = heaviside((eq_in * w).sum(-1) + b) |
| | eq_outputs.append(eq_out) |
| | i += 1 |
| |
|
| | if not eq_outputs: |
| | return torch.zeros(pop_size, device=self.device), 0 |
| |
|
| | |
| | eq_stack = torch.stack(eq_outputs, dim=-1) |
| | w3 = pop[f'{prefix}.layer3.or.weight'].view(pop_size, -1) |
| | b3 = pop[f'{prefix}.layer3.or.bias'].view(pop_size) |
| | out = heaviside((eq_stack * w3).sum(-1) + b3) |
| |
|
| | except Exception as e: |
| | return torch.zeros(pop_size, device=self.device), 0 |
| |
|
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(256): |
| | if out[i, 0].item() != expected[i].item(): |
| | failures.append((i, expected[i].item(), out[i, 0].item())) |
| |
|
| | self._record(prefix, int(correct[0].item()), 256, failures[:10]) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return correct, 256 |
| |
|
| | def _test_modular_all(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test all modular arithmetic circuits.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== MODULAR ARITHMETIC ===") |
| |
|
| | for mod in range(2, 13): |
| | s, t = self._test_modular(pop, mod, debug) |
| | scores += s |
| | total += t |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_pattern(self, pop: Dict, name: str, expected_fn: Callable[[int], float], |
| | debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test pattern recognition circuit.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | prefix = f'pattern_recognition.{name}' |
| |
|
| | test_vals = torch.arange(256, device=self.device, dtype=torch.long) |
| | inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
| | expected = torch.tensor([expected_fn(v.item()) for v in test_vals], device=self.device) |
| |
|
| | try: |
| | w = pop[f'{prefix}.weight'].view(pop_size, -1) |
| | b = pop[f'{prefix}.bias'].view(pop_size) |
| | out = heaviside(inputs @ w.T + b) |
| | except KeyError: |
| | return torch.zeros(pop_size, device=self.device), 0 |
| |
|
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(256): |
| | if out[i, 0].item() != expected[i].item(): |
| | failures.append((i, expected[i].item(), out[i, 0].item())) |
| |
|
| | self._record(prefix, int(correct[0].item()), 256, failures[:10]) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return correct, 256 |
| |
|
| | def _test_patterns(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test pattern recognition circuits.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== PATTERN RECOGNITION ===") |
| |
|
| | |
| | patterns = [ |
| | ('allzeros', lambda v: 1.0 if v == 0 else 0.0), |
| | ('allones', lambda v: 1.0 if v == 255 else 0.0), |
| | ] |
| |
|
| | for name, fn in patterns: |
| | s, t = self._test_pattern(pop, name, fn, debug) |
| | scores += s |
| | total += t |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _eval_xor_tree_stage(self, pop: Dict, prefix: str, stage: int, idx: int, |
| | a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
| | """Evaluate a single XOR in the parity tree.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | xor_prefix = f'{prefix}.stage{stage}.xor{idx}' |
| |
|
| | |
| | if a.dim() == 1: |
| | a = a.unsqueeze(1).expand(-1, pop_size) |
| | if b.dim() == 1: |
| | b = b.unsqueeze(1).expand(-1, pop_size) |
| |
|
| | |
| | w_or = pop[f'{xor_prefix}.layer1.or.weight'].view(pop_size, 2) |
| | b_or = pop[f'{xor_prefix}.layer1.or.bias'].view(pop_size) |
| | w_nand = pop[f'{xor_prefix}.layer1.nand.weight'].view(pop_size, 2) |
| | b_nand = pop[f'{xor_prefix}.layer1.nand.bias'].view(pop_size) |
| |
|
| | inputs = torch.stack([a, b], dim=-1) |
| | h_or = heaviside((inputs * w_or).sum(-1) + b_or) |
| | h_nand = heaviside((inputs * w_nand).sum(-1) + b_nand) |
| |
|
| | |
| | hidden = torch.stack([h_or, h_nand], dim=-1) |
| | w2 = pop[f'{xor_prefix}.layer2.weight'].view(pop_size, 2) |
| | b2 = pop[f'{xor_prefix}.layer2.bias'].view(pop_size) |
| | return heaviside((hidden * w2).sum(-1) + b2) |
| |
|
| | def _test_parity_xor_tree(self, pop: Dict, prefix: str, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test parity circuit with XOR tree structure.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| |
|
| | test_vals = torch.arange(256, device=self.device, dtype=torch.long) |
| | inputs = torch.stack([((test_vals >> (7 - i)) & 1).float() for i in range(8)], dim=1) |
| |
|
| | |
| | popcounts = inputs.sum(dim=1) |
| | xor_result = (popcounts.long() % 2).float() |
| |
|
| | try: |
| | |
| | s1_out = [] |
| | for i in range(4): |
| | xor_out = self._eval_xor_tree_stage(pop, prefix, 1, i, inputs[:, i*2], inputs[:, i*2+1]) |
| | s1_out.append(xor_out) |
| |
|
| | |
| | s2_out = [] |
| | for i in range(2): |
| | xor_out = self._eval_xor_tree_stage(pop, prefix, 2, i, s1_out[i*2], s1_out[i*2+1]) |
| | s2_out.append(xor_out) |
| |
|
| | |
| | s3_out = self._eval_xor_tree_stage(pop, prefix, 3, 0, s2_out[0], s2_out[1]) |
| |
|
| | |
| | if f'{prefix}.output.not.weight' in pop: |
| | w_not = pop[f'{prefix}.output.not.weight'].view(pop_size) |
| | b_not = pop[f'{prefix}.output.not.bias'].view(pop_size) |
| | out = heaviside(s3_out * w_not + b_not) |
| | |
| | expected = 1.0 - xor_result |
| | else: |
| | out = s3_out |
| | expected = xor_result |
| |
|
| | except KeyError as e: |
| | return torch.zeros(pop_size, device=self.device), 0 |
| |
|
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(256): |
| | if out[i, 0].item() != expected[i].item(): |
| | failures.append((i, expected[i].item(), out[i, 0].item())) |
| |
|
| | self._record(prefix, int(correct[0].item()), 256, failures[:10]) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return correct, 256 |
| |
|
| | def _test_error_detection(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test error detection circuits.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== ERROR DETECTION ===") |
| |
|
| | |
| | for prefix in ['error_detection.paritychecker8bit', 'error_detection.paritygenerator8bit']: |
| | s, t = self._test_parity_xor_tree(pop, prefix, debug) |
| | scores += s |
| | total += t |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_mux2to1(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test 2-to-1 multiplexer.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | prefix = 'combinational.multiplexer2to1' |
| |
|
| | |
| | inputs = torch.tensor([ |
| | [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], |
| | [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1], |
| | ], device=self.device, dtype=torch.float32) |
| | expected = torch.tensor([0, 0, 0, 1, 1, 0, 1, 1], device=self.device, dtype=torch.float32) |
| |
|
| | try: |
| | w = pop[f'{prefix}.weight'] |
| | b = pop[f'{prefix}.bias'] |
| | out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
| | except KeyError: |
| | return torch.zeros(pop_size, device=self.device), 0 |
| |
|
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(8): |
| | if out[i, 0].item() != expected[i].item(): |
| | failures.append((inputs[i].tolist(), expected[i].item(), out[i, 0].item())) |
| |
|
| | self._record(prefix, int(correct[0].item()), 8, failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return correct, 8 |
| |
|
| | def _test_decoder3to8(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test 3-to-8 decoder.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== DECODER 3-TO-8 ===") |
| |
|
| | inputs = torch.tensor([ |
| | [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], |
| | [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1], |
| | ], device=self.device, dtype=torch.float32) |
| |
|
| | for out_idx in range(8): |
| | prefix = f'combinational.decoder3to8.out{out_idx}' |
| | expected = torch.zeros(8, device=self.device) |
| | expected[out_idx] = 1.0 |
| |
|
| | try: |
| | w = pop[f'{prefix}.weight'] |
| | b = pop[f'{prefix}.bias'] |
| | out = heaviside(inputs @ w.view(pop_size, -1).T + b.view(pop_size)) |
| | except KeyError: |
| | continue |
| |
|
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| | scores += correct |
| | total += 8 |
| |
|
| | failures = [] |
| | if pop_size == 1: |
| | for i in range(8): |
| | if out[i, 0].item() != expected[i].item(): |
| | failures.append((inputs[i].tolist(), expected[i].item(), out[i, 0].item())) |
| |
|
| | self._record(prefix, int(correct[0].item()), 8, failures) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return scores, total |
| |
|
| | def _test_combinational(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test combinational logic circuits.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== COMBINATIONAL LOGIC ===") |
| |
|
| | s, t = self._test_mux2to1(pop, debug) |
| | scores += s |
| | total += t |
| |
|
| | s, t = self._test_decoder3to8(pop, debug) |
| | scores += s |
| | total += t |
| |
|
| | s, t = self._test_barrel_shifter(pop, debug) |
| | scores += s |
| | total += t |
| |
|
| | s, t = self._test_priority_encoder(pop, debug) |
| | scores += s |
| | total += t |
| |
|
| | return scores, total |
| |
|
| | def _test_barrel_shifter(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test barrel shifter (shift by 0-7 positions).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== BARREL SHIFTER ===") |
| |
|
| | try: |
| | |
| | test_vals = [0b10000001, 0b11110000, 0b00001111, 0b10101010, 0xFF] |
| |
|
| | for val in test_vals: |
| | for shift in range(8): |
| | expected_val = (val << shift) & 0xFF |
| | val_bits = [float((val >> (7 - i)) & 1) for i in range(8)] |
| | shift_bits = [float((shift >> (2 - i)) & 1) for i in range(3)] |
| |
|
| | |
| | layer_in = val_bits[:] |
| | for layer in range(3): |
| | shift_amount = 1 << (2 - layer) |
| | sel = shift_bits[layer] |
| | layer_out = [] |
| |
|
| | for bit in range(8): |
| | prefix = f'combinational.barrelshifter.layer{layer}.bit{bit}' |
| |
|
| | |
| | w_not = pop[f'{prefix}.not_sel.weight'].view(pop_size) |
| | b_not = pop[f'{prefix}.not_sel.bias'].view(pop_size) |
| | not_sel = heaviside(sel * w_not + b_not) |
| |
|
| | |
| | shifted_src = bit + shift_amount |
| | if shifted_src < 8: |
| | shifted_val = layer_in[shifted_src] |
| | else: |
| | shifted_val = 0.0 |
| |
|
| | |
| | w_and_a = pop[f'{prefix}.and_a.weight'].view(pop_size, 2) |
| | b_and_a = pop[f'{prefix}.and_a.bias'].view(pop_size) |
| | inp_a = torch.tensor([layer_in[bit], not_sel[0].item()], device=self.device) |
| | and_a = heaviside((inp_a * w_and_a).sum(-1) + b_and_a) |
| |
|
| | |
| | w_and_b = pop[f'{prefix}.and_b.weight'].view(pop_size, 2) |
| | b_and_b = pop[f'{prefix}.and_b.bias'].view(pop_size) |
| | inp_b = torch.tensor([shifted_val, sel], device=self.device) |
| | and_b = heaviside((inp_b * w_and_b).sum(-1) + b_and_b) |
| |
|
| | |
| | w_or = pop[f'{prefix}.or.weight'].view(pop_size, 2) |
| | b_or = pop[f'{prefix}.or.bias'].view(pop_size) |
| | inp_or = torch.tensor([and_a[0].item(), and_b[0].item()], device=self.device) |
| | out = heaviside((inp_or * w_or).sum(-1) + b_or) |
| | layer_out.append(out[0].item()) |
| |
|
| | layer_in = layer_out |
| |
|
| | |
| | result = sum(int(layer_in[i]) << (7 - i) for i in range(8)) |
| | if result == expected_val: |
| | scores += 1 |
| | total += 1 |
| |
|
| | self._record('combinational.barrelshifter', int(scores[0].item()), total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" combinational.barrelshifter: SKIP ({e})") |
| |
|
| | return scores, total |
| |
|
| | def _test_priority_encoder(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test priority encoder (find highest set bit).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== PRIORITY ENCODER ===") |
| |
|
| | try: |
| | |
| | test_cases = [ |
| | (0b00000000, 0, 0), |
| | (0b00000001, 1, 7), |
| | (0b00000010, 1, 6), |
| | (0b00000100, 1, 5), |
| | (0b00001000, 1, 4), |
| | (0b00010000, 1, 3), |
| | (0b00100000, 1, 2), |
| | (0b01000000, 1, 1), |
| | (0b10000000, 1, 0), |
| | (0b10000001, 1, 0), |
| | (0b01010101, 1, 1), |
| | (0b00001111, 1, 4), |
| | (0b11111111, 1, 0), |
| | ] |
| |
|
| | for val, expected_valid, expected_idx in test_cases: |
| | val_bits = torch.tensor([float((val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | |
| | w_valid = pop['combinational.priorityencoder.valid.weight'].view(pop_size, 8) |
| | b_valid = pop['combinational.priorityencoder.valid.bias'].view(pop_size) |
| | out_valid = heaviside((val_bits * w_valid).sum(-1) + b_valid) |
| |
|
| | if int(out_valid[0].item()) == expected_valid: |
| | scores += 1 |
| | total += 1 |
| |
|
| | |
| | if expected_valid == 1: |
| | for idx_bit in range(3): |
| | try: |
| | w_idx = pop[f'combinational.priorityencoder.idx{idx_bit}.weight'].view(pop_size, 8) |
| | b_idx = pop[f'combinational.priorityencoder.idx{idx_bit}.bias'].view(pop_size) |
| | out_idx = heaviside((val_bits * w_idx).sum(-1) + b_idx) |
| | expected_bit = (expected_idx >> (2 - idx_bit)) & 1 |
| | if int(out_idx[0].item()) == expected_bit: |
| | scores += 1 |
| | total += 1 |
| | except KeyError: |
| | pass |
| |
|
| | self._record('combinational.priorityencoder', int(scores[0].item()), total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" combinational.priorityencoder: SKIP ({e})") |
| |
|
| | return scores, total |
| |
|
| | def _test_barrel_shifter_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test N-bit barrel shifter (shift by 0 to bits-1 positions).""" |
| | import math |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| | num_layers = max(1, math.ceil(math.log2(bits))) |
| | max_val = (1 << bits) - 1 |
| |
|
| | if debug: |
| | print(f"\n=== {bits}-BIT BARREL SHIFTER ===") |
| |
|
| | prefix = f'combinational.barrelshifter{bits}' |
| | try: |
| | if bits == 16: |
| | test_vals = [0x8001, 0xFF00, 0x00FF, 0xAAAA, 0xFFFF, 0x1234] |
| | elif bits == 32: |
| | test_vals = [0x80000001, 0xFFFF0000, 0x0000FFFF, 0xAAAAAAAA, 0xFFFFFFFF, 0x12345678] |
| | else: |
| | test_vals = [0b10000001, 0b11110000, 0b00001111, 0b10101010, max_val] |
| |
|
| | num_shifts = min(bits, 8) |
| | for val in test_vals: |
| | for shift in range(num_shifts): |
| | expected_val = (val << shift) & max_val |
| | val_bits = [float((val >> (bits - 1 - i)) & 1) for i in range(bits)] |
| | shift_bits = [float((shift >> (num_layers - 1 - i)) & 1) for i in range(num_layers)] |
| |
|
| | layer_in = val_bits[:] |
| | for layer in range(num_layers): |
| | shift_amount = 1 << (num_layers - 1 - layer) |
| | sel = shift_bits[layer] |
| | layer_out = [] |
| |
|
| | for bit in range(bits): |
| | bit_prefix = f'{prefix}.layer{layer}.bit{bit}' |
| |
|
| | w_not = pop[f'{bit_prefix}.not_sel.weight'].view(pop_size) |
| | b_not = pop[f'{bit_prefix}.not_sel.bias'].view(pop_size) |
| | not_sel = heaviside(sel * w_not + b_not) |
| |
|
| | shifted_src = bit + shift_amount |
| | if shifted_src < bits: |
| | shifted_val = layer_in[shifted_src] |
| | else: |
| | shifted_val = 0.0 |
| |
|
| | w_and_a = pop[f'{bit_prefix}.and_a.weight'].view(pop_size, 2) |
| | b_and_a = pop[f'{bit_prefix}.and_a.bias'].view(pop_size) |
| | inp_a = torch.tensor([layer_in[bit], not_sel[0].item()], device=self.device) |
| | and_a = heaviside((inp_a * w_and_a).sum(-1) + b_and_a) |
| |
|
| | w_and_b = pop[f'{bit_prefix}.and_b.weight'].view(pop_size, 2) |
| | b_and_b = pop[f'{bit_prefix}.and_b.bias'].view(pop_size) |
| | inp_b = torch.tensor([shifted_val, sel], device=self.device) |
| | and_b = heaviside((inp_b * w_and_b).sum(-1) + b_and_b) |
| |
|
| | w_or = pop[f'{bit_prefix}.or.weight'].view(pop_size, 2) |
| | b_or = pop[f'{bit_prefix}.or.bias'].view(pop_size) |
| | inp_or = torch.tensor([and_a[0].item(), and_b[0].item()], device=self.device) |
| | out = heaviside((inp_or * w_or).sum(-1) + b_or) |
| | layer_out.append(out[0].item()) |
| |
|
| | layer_in = layer_out |
| |
|
| | result = sum(int(layer_in[i]) << (bits - 1 - i) for i in range(bits)) |
| | if result == expected_val: |
| | scores += 1 |
| | total += 1 |
| |
|
| | self._record(prefix, int(scores[0].item()), total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" {prefix}: SKIP ({e})") |
| |
|
| | return scores, total |
| |
|
| | def _test_priority_encoder_nbits(self, pop: Dict, bits: int, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test N-bit priority encoder (find highest set bit). |
| | |
| | The priority encoder is a multi-layer circuit: |
| | 1. any_higher{pos}: OR of bits 0 to pos-1 (all higher-priority positions) |
| | 2. is_highest{0}: bit[0] directly (MSB is always highest if set) |
| | 3. is_highest{pos}: bit[pos] AND NOT(any_higher{pos}) for pos > 0 |
| | 4. out{bit}: OR of is_highest{pos} for all pos where (pos >> bit) & 1 |
| | 5. valid: OR of all input bits |
| | """ |
| | import math |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| | out_bits = max(1, math.ceil(math.log2(bits))) |
| |
|
| | if debug: |
| | print(f"\n=== {bits}-BIT PRIORITY ENCODER ===") |
| |
|
| | prefix = f'combinational.priorityencoder{bits}' |
| | try: |
| | test_cases = [(0, 0, 0)] |
| | for i in range(bits): |
| | test_cases.append((1 << i, 1, bits - 1 - i)) |
| | if bits == 16: |
| | test_cases.extend([ |
| | (0x8001, 1, 0), (0x5555, 1, 1), (0x00FF, 1, 8), (0xFFFF, 1, 0) |
| | ]) |
| | elif bits == 32: |
| | test_cases.extend([ |
| | (0x80000001, 1, 0), (0x55555555, 1, 1), (0x0000FFFF, 1, 16), (0xFFFFFFFF, 1, 0) |
| | ]) |
| |
|
| | for val, expected_valid, expected_idx in test_cases: |
| | val_bits = torch.tensor([float((val >> (bits - 1 - i)) & 1) for i in range(bits)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | w_valid = pop[f'{prefix}.valid.weight'].view(pop_size, bits) |
| | b_valid = pop[f'{prefix}.valid.bias'].view(pop_size) |
| | out_valid = heaviside((val_bits * w_valid).sum(-1) + b_valid) |
| |
|
| | if int(out_valid[0].item()) == expected_valid: |
| | scores += 1 |
| | total += 1 |
| |
|
| | if expected_valid == 1: |
| | any_higher = [None] |
| | for pos in range(1, bits): |
| | w = pop[f'{prefix}.any_higher{pos}.weight'].view(pop_size, -1) |
| | b = pop[f'{prefix}.any_higher{pos}.bias'].view(pop_size) |
| | inp = val_bits[:pos] |
| | out = heaviside((inp * w[:, :len(inp)]).sum(-1) + b) |
| | any_higher.append(out) |
| |
|
| | is_highest = [] |
| | for pos in range(bits): |
| | if pos == 0: |
| | is_high = val_bits[0].unsqueeze(0).expand(pop_size) |
| | else: |
| | w_not = pop[f'{prefix}.is_highest{pos}.not_higher.weight'].view(pop_size, -1) |
| | b_not = pop[f'{prefix}.is_highest{pos}.not_higher.bias'].view(pop_size) |
| | not_higher = heaviside(any_higher[pos].unsqueeze(-1) * w_not + b_not).squeeze(-1) |
| |
|
| | w_and = pop[f'{prefix}.is_highest{pos}.and.weight'].view(pop_size, -1) |
| | b_and = pop[f'{prefix}.is_highest{pos}.and.bias'].view(pop_size) |
| | inp = torch.stack([val_bits[pos].expand(pop_size), not_higher], dim=-1) |
| | is_high = heaviside((inp * w_and).sum(-1) + b_and) |
| | is_highest.append(is_high) |
| |
|
| | for idx_bit in range(out_bits): |
| | try: |
| | w_idx = pop[f'{prefix}.out{idx_bit}.weight'].view(pop_size, -1) |
| | b_idx = pop[f'{prefix}.out{idx_bit}.bias'].view(pop_size) |
| | relevant = [is_highest[pos] for pos in range(bits) if (pos >> idx_bit) & 1] |
| | if len(relevant) > 0: |
| | inp = torch.stack(relevant[:w_idx.shape[1]], dim=-1) |
| | out_idx = heaviside((inp * w_idx).sum(-1) + b_idx) |
| | expected_bit = (expected_idx >> idx_bit) & 1 |
| | if int(out_idx[0].item()) == expected_bit: |
| | scores += 1 |
| | total += 1 |
| | except KeyError: |
| | pass |
| |
|
| | self._record(prefix, int(scores[0].item()), total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" {prefix}: SKIP ({e})") |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_conditional_jump(self, pop: Dict, name: str, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test conditional jump circuit (N-bit address aware).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | prefix = f'control.{name}' |
| |
|
| | |
| | inputs = torch.tensor([ |
| | [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], |
| | [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1], |
| | ], device=self.device, dtype=torch.float32) |
| | expected = torch.tensor([0, 0, 0, 1, 1, 0, 1, 1], device=self.device, dtype=torch.float32) |
| |
|
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | for bit in range(self.addr_bits): |
| | bit_prefix = f'{prefix}.bit{bit}' |
| | try: |
| | |
| | w_not = pop[f'{bit_prefix}.not_sel.weight'] |
| | b_not = pop[f'{bit_prefix}.not_sel.bias'] |
| | flag = inputs[:, 2:3] |
| | not_sel = heaviside(flag @ w_not.view(pop_size, -1).T + b_not.view(pop_size)) |
| |
|
| | |
| | w_and_a = pop[f'{bit_prefix}.and_a.weight'] |
| | b_and_a = pop[f'{bit_prefix}.and_a.bias'] |
| | pc_not = torch.cat([inputs[:, 0:1], not_sel], dim=-1) |
| | and_a = heaviside((pc_not * w_and_a.view(pop_size, 1, 2)).sum(-1) + b_and_a.view(pop_size, 1)) |
| |
|
| | |
| | w_and_b = pop[f'{bit_prefix}.and_b.weight'] |
| | b_and_b = pop[f'{bit_prefix}.and_b.bias'] |
| | target_sel = inputs[:, 1:3] |
| | and_b = heaviside((target_sel * w_and_b.view(pop_size, 1, 2)).sum(-1) + b_and_b.view(pop_size, 1)) |
| |
|
| | |
| | w_or = pop[f'{bit_prefix}.or.weight'] |
| | b_or = pop[f'{bit_prefix}.or.bias'] |
| | |
| | and_a_2d = and_a.view(8, pop_size) |
| | and_b_2d = and_b.view(8, pop_size) |
| | ab = torch.stack([and_a_2d, and_b_2d], dim=-1) |
| | out = heaviside((ab * w_or.view(pop_size, 2)).sum(-1) + b_or.view(pop_size)) |
| |
|
| | correct = (out == expected.unsqueeze(1)).float().sum(0) |
| | scores += correct |
| | total += 8 |
| |
|
| | except KeyError: |
| | pass |
| |
|
| | if total > 0: |
| | self._record(prefix, int((scores[0] / total * total).item()), total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| |
|
| | return scores, total |
| |
|
| | def _test_control_flow(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test control flow circuits.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== CONTROL FLOW ===") |
| |
|
| | jumps = ['jz', 'jnz', 'jc', 'jnc', 'jn', 'jp', 'jv', 'jnv', 'conditionaljump'] |
| | for name in jumps: |
| | s, t = self._test_conditional_jump(pop, name, debug) |
| | scores += s |
| | total += t |
| |
|
| | |
| | s, t = self._test_stack_ops(pop, debug) |
| | scores += s |
| | total += t |
| |
|
| | return scores, total |
| |
|
| | def _test_stack_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test PUSH/POP/RET stack operation circuits (N-bit address aware).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| | addr_bits = self.addr_bits |
| | addr_mask = (1 << addr_bits) - 1 |
| |
|
| | if debug: |
| | print(f"\n=== STACK OPERATIONS ({addr_bits}-bit SP) ===") |
| |
|
| | |
| | try: |
| | |
| | sp_tests = [0, 1, addr_mask // 2, addr_mask] |
| | if addr_bits >= 8: |
| | sp_tests.append(0x100 & addr_mask) |
| | if addr_bits >= 12: |
| | sp_tests.append(0x1234 & addr_mask) |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | for sp_val in sp_tests: |
| | expected_val = (sp_val - 1) & addr_mask |
| | sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)] |
| |
|
| | borrow = 1.0 |
| | out_bits = [] |
| | for bit in range(addr_bits - 1, -1, -1): |
| | prefix = f'control.push.sp_dec.bit{bit}' |
| |
|
| | w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2) |
| | b_or = pop[f'{prefix}.xor.layer1.or.bias'].view(pop_size) |
| | w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].view(pop_size, 2) |
| | b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].view(pop_size) |
| | w2 = pop[f'{prefix}.xor.layer2.weight'].view(pop_size, 2) |
| | b2 = pop[f'{prefix}.xor.layer2.bias'].view(pop_size) |
| |
|
| | inp = torch.tensor([sp_bits[bit], borrow], device=self.device) |
| | h_or = heaviside((inp * w_or).sum(-1) + b_or) |
| | h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) |
| | hidden = torch.stack([h_or, h_nand], dim=-1) |
| | diff_bit = heaviside((hidden * w2).sum(-1) + b2) |
| | out_bits.insert(0, diff_bit) |
| |
|
| | |
| | not_sp = 1.0 - sp_bits[bit] |
| | w_borrow = pop[f'{prefix}.borrow.weight'].view(pop_size, 2) |
| | b_borrow = pop[f'{prefix}.borrow.bias'].view(pop_size) |
| | borrow_inp = torch.tensor([not_sp, borrow], device=self.device) |
| | borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item() |
| |
|
| | out = torch.stack(out_bits, dim=-1) |
| | expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += addr_bits |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('control.push.sp_dec', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" control.push.sp_dec: SKIP ({e})") |
| |
|
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | for sp_val in sp_tests: |
| | expected_val = (sp_val + 1) & addr_mask |
| | sp_bits = [float((sp_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)] |
| |
|
| | carry = 1.0 |
| | out_bits = [] |
| | for bit in range(addr_bits - 1, -1, -1): |
| | prefix = f'control.pop.sp_inc.bit{bit}' |
| |
|
| | w_or = pop[f'{prefix}.xor.layer1.or.weight'].view(pop_size, 2) |
| | b_or = pop[f'{prefix}.xor.layer1.or.bias'].view(pop_size) |
| | w_nand = pop[f'{prefix}.xor.layer1.nand.weight'].view(pop_size, 2) |
| | b_nand = pop[f'{prefix}.xor.layer1.nand.bias'].view(pop_size) |
| | w2 = pop[f'{prefix}.xor.layer2.weight'].view(pop_size, 2) |
| | b2 = pop[f'{prefix}.xor.layer2.bias'].view(pop_size) |
| |
|
| | inp = torch.tensor([sp_bits[bit], carry], device=self.device) |
| | h_or = heaviside((inp * w_or).sum(-1) + b_or) |
| | h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) |
| | hidden = torch.stack([h_or, h_nand], dim=-1) |
| | sum_bit = heaviside((hidden * w2).sum(-1) + b2) |
| | out_bits.insert(0, sum_bit) |
| |
|
| | |
| | w_carry = pop[f'{prefix}.carry.weight'].view(pop_size, 2) |
| | b_carry = pop[f'{prefix}.carry.bias'].view(pop_size) |
| | carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item() |
| |
|
| | out = torch.stack(out_bits, dim=-1) |
| | expected = torch.tensor([((expected_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += addr_bits |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('control.pop.sp_inc', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" control.pop.sp_inc: SKIP ({e})") |
| |
|
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | ret_tests = [0, addr_mask, addr_mask // 2, 1] |
| | if addr_bits >= 12: |
| | ret_tests.append(0x1234 & addr_mask) |
| | for addr_val in ret_tests: |
| | ret_bits_tensor = torch.tensor([float((addr_val >> (addr_bits - 1 - i)) & 1) for i in range(addr_bits)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | out_bits = [] |
| | for bit in range(addr_bits): |
| | w = pop[f'control.ret.addr.bit{bit}.weight'].view(pop_size) |
| | b = pop[f'control.ret.addr.bit{bit}.bias'].view(pop_size) |
| | out = heaviside(ret_bits_tensor[bit] * w + b) |
| | out_bits.append(out) |
| |
|
| | out = torch.stack(out_bits, dim=-1) |
| | correct = (out == ret_bits_tensor.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += addr_bits |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('control.ret.addr', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" control.ret.addr: SKIP ({e})") |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_alu_ops(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test ALU operations (8-bit bitwise).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== ALU OPERATIONS ===") |
| |
|
| | |
| | |
| | |
| |
|
| | test_vals = [(0, 0), (255, 255), (0xAA, 0x55), (0x0F, 0xF0)] |
| |
|
| | |
| | try: |
| | w = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2) |
| | b = pop['alu.alu8bit.and.bias'].view(pop_size, 8) |
| |
|
| | for a_val, b_val in test_vals: |
| | a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | |
| | inputs = torch.stack([a_bits, b_bits], dim=-1) |
| | |
| | out = heaviside((inputs * w).sum(-1) + b) |
| | expected = torch.tensor([((a_val & b_val) >> (7 - i)) & 1 for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | scores += correct |
| | total += 8 |
| |
|
| | self._record('alu.alu8bit.and', int(scores[0].item()), total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError): |
| | pass |
| |
|
| | |
| | try: |
| | w = pop['alu.alu8bit.or.weight'].view(pop_size, 8, 2) |
| | b = pop['alu.alu8bit.or.bias'].view(pop_size, 8) |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | for a_val, b_val in test_vals: |
| | a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | inputs = torch.stack([a_bits, b_bits], dim=-1) |
| | out = heaviside((inputs * w).sum(-1) + b) |
| | expected = torch.tensor([((a_val | b_val) >> (7 - i)) & 1 for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += 8 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('alu.alu8bit.or', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError): |
| | pass |
| |
|
| | |
| | try: |
| | w = pop['alu.alu8bit.not.weight'].view(pop_size, 8) |
| | b = pop['alu.alu8bit.not.bias'].view(pop_size, 8) |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | for a_val, _ in test_vals: |
| | a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | out = heaviside(a_bits * w + b) |
| | expected = torch.tensor([(((~a_val) & 0xFF) >> (7 - i)) & 1 for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += 8 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('alu.alu8bit.not', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError): |
| | pass |
| |
|
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | for a_val, _ in test_vals: |
| | expected_val = (a_val << 1) & 0xFF |
| | a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | out_bits = [] |
| | for bit in range(8): |
| | w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size) |
| | b = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size) |
| | if bit < 7: |
| | inp = a_bits[bit + 1].unsqueeze(0).expand(pop_size) |
| | else: |
| | inp = torch.zeros(pop_size, device=self.device) |
| | out = heaviside(inp * w + b) |
| | out_bits.append(out) |
| | out = torch.stack(out_bits, dim=-1) |
| | expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += 8 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('alu.alu8bit.shl', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" alu.alu8bit.shl: SKIP ({e})") |
| |
|
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | for a_val, _ in test_vals: |
| | expected_val = (a_val >> 1) & 0xFF |
| | a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | out_bits = [] |
| | for bit in range(8): |
| | w = pop[f'alu.alu8bit.shr.bit{bit}.weight'].view(pop_size) |
| | b = pop[f'alu.alu8bit.shr.bit{bit}.bias'].view(pop_size) |
| | if bit > 0: |
| | inp = a_bits[bit - 1].unsqueeze(0).expand(pop_size) |
| | else: |
| | inp = torch.zeros(pop_size, device=self.device) |
| | out = heaviside(inp * w + b) |
| | out_bits.append(out) |
| | out = torch.stack(out_bits, dim=-1) |
| | expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += 8 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('alu.alu8bit.shr', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" alu.alu8bit.shr: SKIP ({e})") |
| |
|
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | mul_tests = [(3, 4), (7, 8), (15, 17), (0, 255)] |
| | for a_val, b_val in mul_tests: |
| | a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | b_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | |
| | for i in range(8): |
| | for j in range(8): |
| | w = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.weight'].view(pop_size, 2) |
| | b = pop[f'alu.alu8bit.mul.pp.a{i}b{j}.bias'].view(pop_size) |
| | inp = torch.tensor([a_bits[i].item(), b_bits[j].item()], device=self.device) |
| | out = heaviside((inp * w).sum(-1) + b) |
| | expected = float(int(a_bits[i].item()) & int(b_bits[j].item())) |
| | correct = (out == expected).float() |
| | op_scores += correct |
| | op_total += 1 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('alu.alu8bit.mul', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" alu.alu8bit.mul: SKIP ({e})") |
| |
|
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | div_tests = [(100, 10), (255, 17), (50, 7), (128, 16)] |
| | for a_val, b_val in div_tests: |
| | |
| | for stage in range(8): |
| | w = pop[f'alu.alu8bit.div.stage{stage}.cmp.weight'].view(pop_size, 16) |
| | b = pop[f'alu.alu8bit.div.stage{stage}.cmp.bias'].view(pop_size) |
| |
|
| | |
| | test_rem = (a_val >> (7 - stage)) & 0xFF |
| | rem_bits = torch.tensor([((test_rem >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | div_bits = torch.tensor([((b_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | inp = torch.cat([rem_bits, div_bits]) |
| |
|
| | out = heaviside((inp * w).sum(-1) + b) |
| | expected = float(test_rem >= b_val) |
| | correct = (out == expected).float() |
| | op_scores += correct |
| | op_total += 1 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('alu.alu8bit.div', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" alu.alu8bit.div: SKIP ({e})") |
| |
|
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | inc_tests = [0, 1, 127, 128, 254, 255] |
| | for a_val in inc_tests: |
| | expected_val = (a_val + 1) & 0xFF |
| | a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | |
| | carry = 1.0 |
| | out_bits = [] |
| | for bit in range(7, -1, -1): |
| | |
| | w_or = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2) |
| | b_or = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.or.bias'].view(pop_size) |
| | w_nand = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2) |
| | b_nand = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer1.nand.bias'].view(pop_size) |
| | w2 = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer2.weight'].view(pop_size, 2) |
| | b2 = pop[f'alu.alu8bit.inc.bit{bit}.xor.layer2.bias'].view(pop_size) |
| |
|
| | inp = torch.tensor([a_bits[bit].item(), carry], device=self.device) |
| | h_or = heaviside((inp * w_or).sum(-1) + b_or) |
| | h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) |
| | hidden = torch.stack([h_or, h_nand], dim=-1) |
| | sum_bit = heaviside((hidden * w2).sum(-1) + b2) |
| | out_bits.insert(0, sum_bit) |
| |
|
| | |
| | w_carry = pop[f'alu.alu8bit.inc.bit{bit}.carry.weight'].view(pop_size, 2) |
| | b_carry = pop[f'alu.alu8bit.inc.bit{bit}.carry.bias'].view(pop_size) |
| | carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item() |
| |
|
| | out = torch.stack(out_bits, dim=-1) |
| | expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += 8 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('alu.alu8bit.inc', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" alu.alu8bit.inc: SKIP ({e})") |
| |
|
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | dec_tests = [0, 1, 127, 128, 254, 255] |
| | for a_val in dec_tests: |
| | expected_val = (a_val - 1) & 0xFF |
| | a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | |
| | borrow = 1.0 |
| | out_bits = [] |
| | for bit in range(7, -1, -1): |
| | w_or = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2) |
| | b_or = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.or.bias'].view(pop_size) |
| | w_nand = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2) |
| | b_nand = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer1.nand.bias'].view(pop_size) |
| | w2 = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer2.weight'].view(pop_size, 2) |
| | b2 = pop[f'alu.alu8bit.dec.bit{bit}.xor.layer2.bias'].view(pop_size) |
| |
|
| | inp = torch.tensor([a_bits[bit].item(), borrow], device=self.device) |
| | h_or = heaviside((inp * w_or).sum(-1) + b_or) |
| | h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) |
| | hidden = torch.stack([h_or, h_nand], dim=-1) |
| | diff_bit = heaviside((hidden * w2).sum(-1) + b2) |
| | out_bits.insert(0, diff_bit) |
| |
|
| | |
| | w_not = pop[f'alu.alu8bit.dec.bit{bit}.not_a.weight'].view(pop_size) |
| | b_not = pop[f'alu.alu8bit.dec.bit{bit}.not_a.bias'].view(pop_size) |
| | not_a = heaviside(a_bits[bit] * w_not + b_not) |
| |
|
| | w_borrow = pop[f'alu.alu8bit.dec.bit{bit}.borrow.weight'].view(pop_size, 2) |
| | b_borrow = pop[f'alu.alu8bit.dec.bit{bit}.borrow.bias'].view(pop_size) |
| | borrow_inp = torch.tensor([not_a[0].item(), borrow], device=self.device) |
| | borrow = heaviside((borrow_inp * w_borrow).sum(-1) + b_borrow)[0].item() |
| |
|
| | out = torch.stack(out_bits, dim=-1) |
| | expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += 8 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('alu.alu8bit.dec', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" alu.alu8bit.dec: SKIP ({e})") |
| |
|
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | neg_tests = [0, 1, 127, 128, 255] |
| | for a_val in neg_tests: |
| | expected_val = (-a_val) & 0xFF |
| | a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | |
| | not_bits = [] |
| | for bit in range(8): |
| | w = pop[f'alu.alu8bit.neg.not.bit{bit}.weight'].view(pop_size) |
| | b = pop[f'alu.alu8bit.neg.not.bit{bit}.bias'].view(pop_size) |
| | not_bit = heaviside(a_bits[bit] * w + b) |
| | not_bits.append(not_bit) |
| |
|
| | |
| | carry = 1.0 |
| | out_bits = [] |
| | for bit in range(7, -1, -1): |
| | w_or = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.or.weight'].view(pop_size, 2) |
| | b_or = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.or.bias'].view(pop_size) |
| | w_nand = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.nand.weight'].view(pop_size, 2) |
| | b_nand = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer1.nand.bias'].view(pop_size) |
| | w2 = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer2.weight'].view(pop_size, 2) |
| | b2 = pop[f'alu.alu8bit.neg.inc.bit{bit}.xor.layer2.bias'].view(pop_size) |
| |
|
| | inp = torch.tensor([not_bits[bit][0].item(), carry], device=self.device) |
| | h_or = heaviside((inp * w_or).sum(-1) + b_or) |
| | h_nand = heaviside((inp * w_nand).sum(-1) + b_nand) |
| | hidden = torch.stack([h_or, h_nand], dim=-1) |
| | sum_bit = heaviside((hidden * w2).sum(-1) + b2) |
| | out_bits.insert(0, sum_bit) |
| |
|
| | w_carry = pop[f'alu.alu8bit.neg.inc.bit{bit}.carry.weight'].view(pop_size, 2) |
| | b_carry = pop[f'alu.alu8bit.neg.inc.bit{bit}.carry.bias'].view(pop_size) |
| | carry = heaviside((inp * w_carry).sum(-1) + b_carry)[0].item() |
| |
|
| | out = torch.stack(out_bits, dim=-1) |
| | expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += 8 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('alu.alu8bit.neg', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" alu.alu8bit.neg: SKIP ({e})") |
| |
|
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | rol_tests = [0b10000000, 0b00000001, 0b10101010, 0b01010101, 0xFF, 0x00] |
| | for a_val in rol_tests: |
| | expected_val = ((a_val << 1) | (a_val >> 7)) & 0xFF |
| | a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | out_bits = [] |
| | for bit in range(8): |
| | w = pop[f'alu.alu8bit.rol.bit{bit}.weight'].view(pop_size) |
| | b = pop[f'alu.alu8bit.rol.bit{bit}.bias'].view(pop_size) |
| | |
| | src_bit = (bit + 1) % 8 |
| | out = heaviside(a_bits[src_bit] * w + b) |
| | out_bits.append(out) |
| |
|
| | out = torch.stack(out_bits, dim=-1) |
| | expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += 8 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('alu.alu8bit.rol', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" alu.alu8bit.rol: SKIP ({e})") |
| |
|
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | ror_tests = [0b10000000, 0b00000001, 0b10101010, 0b01010101, 0xFF, 0x00] |
| | for a_val in ror_tests: |
| | expected_val = ((a_val >> 1) | (a_val << 7)) & 0xFF |
| | a_bits = torch.tensor([((a_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | out_bits = [] |
| | for bit in range(8): |
| | w = pop[f'alu.alu8bit.ror.bit{bit}.weight'].view(pop_size) |
| | b = pop[f'alu.alu8bit.ror.bit{bit}.bias'].view(pop_size) |
| | |
| | src_bit = (bit - 1) % 8 |
| | out = heaviside(a_bits[src_bit] * w + b) |
| | out_bits.append(out) |
| |
|
| | out = torch.stack(out_bits, dim=-1) |
| | expected = torch.tensor([((expected_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | correct = (out == expected.unsqueeze(0)).float().sum(1) |
| | op_scores += correct |
| | op_total += 8 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('alu.alu8bit.ror', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" alu.alu8bit.ror: SKIP ({e})") |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_manifest(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Verify manifest values.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== MANIFEST ===") |
| |
|
| | fixed_expected = { |
| | 'manifest.alu_operations': 16.0, |
| | 'manifest.flags': 4.0, |
| | 'manifest.instruction_width': 16.0, |
| | 'manifest.register_width': 8.0, |
| | 'manifest.registers': 4.0, |
| | 'manifest.version': 4.0, |
| | } |
| |
|
| | for name, exp_val in fixed_expected.items(): |
| | try: |
| | val = pop[name][0, 0].item() |
| | if val == exp_val: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(exp_val, val)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | pass |
| |
|
| | variable_checks = ['manifest.memory_bytes', 'manifest.pc_width', 'manifest.turing_complete'] |
| | for name in variable_checks: |
| | try: |
| | val = pop[name][0, 0].item() |
| | valid = val >= 0 |
| | if valid: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [('>=0', val)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'} (value={val})") |
| | except KeyError: |
| | pass |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_memory(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test memory circuits (shape validation).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== MEMORY ===") |
| |
|
| | try: |
| | mem_bytes = int(pop['manifest.memory_bytes'][0].item()) |
| | addr_bits = int(pop['manifest.pc_width'][0].item()) |
| | except KeyError: |
| | mem_bytes = 65536 |
| | addr_bits = 16 |
| |
|
| | if mem_bytes == 0: |
| | if debug: |
| | print(" No memory (pure ALU mode)") |
| | return scores, 0 |
| |
|
| | expected_shapes = { |
| | 'memory.addr_decode.weight': (mem_bytes, addr_bits), |
| | 'memory.addr_decode.bias': (mem_bytes,), |
| | 'memory.read.and.weight': (8, mem_bytes, 2), |
| | 'memory.read.and.bias': (8, mem_bytes), |
| | 'memory.read.or.weight': (8, mem_bytes), |
| | 'memory.read.or.bias': (8,), |
| | 'memory.write.sel.weight': (mem_bytes, 2), |
| | 'memory.write.sel.bias': (mem_bytes,), |
| | 'memory.write.nsel.weight': (mem_bytes, 1), |
| | 'memory.write.nsel.bias': (mem_bytes,), |
| | 'memory.write.and_old.weight': (mem_bytes, 8, 2), |
| | 'memory.write.and_old.bias': (mem_bytes, 8), |
| | 'memory.write.and_new.weight': (mem_bytes, 8, 2), |
| | 'memory.write.and_new.bias': (mem_bytes, 8), |
| | 'memory.write.or.weight': (mem_bytes, 8, 2), |
| | 'memory.write.or.bias': (mem_bytes, 8), |
| | } |
| |
|
| | for name, expected_shape in expected_shapes.items(): |
| | try: |
| | tensor = pop[name] |
| | actual_shape = tuple(tensor.shape[1:]) |
| | if actual_shape == expected_shape: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
| | total += 1 |
| |
|
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | pass |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_float16_core(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test float16 core circuits (unpack, pack, classify).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== FLOAT16 CORE ===") |
| |
|
| | expected_gates = [ |
| | ('float16.unpack.bit0.weight', (1,)), |
| | ('float16.classify.exp_zero.weight', (5,)), |
| | ('float16.classify.exp_max.weight', (5,)), |
| | ('float16.classify.frac_zero.weight', (10,)), |
| | ('float16.classify.is_zero.and.weight', (2,)), |
| | ('float16.classify.is_nan.and.weight', (2,)), |
| | ('float16.normalize.stage0.bit0.not_sel.weight', (1,)), |
| | ('float16.normalize.stage0.bit0.and_a.weight', (2,)), |
| | ('float16.normalize.stage0.bit0.or.weight', (2,)), |
| | ('float16.pack.bit0.weight', (1,)), |
| | ] |
| |
|
| | for name, expected_shape in expected_gates: |
| | try: |
| | tensor = pop[name] |
| | actual_shape = tuple(tensor.shape[1:]) |
| | if actual_shape == expected_shape: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | if debug: |
| | print(f" {name}: SKIP (not found)") |
| |
|
| | return scores, total |
| |
|
| | def _test_float16_add(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test float16 addition circuit.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== FLOAT16 ADD ===") |
| |
|
| | expected_gates = [ |
| | ('float16.add.exp_cmp.a_gt_b.weight', (10,)), |
| | ('float16.add.exp_cmp.a_lt_b.weight', (10,)), |
| | ('float16.add.exp_diff.fa0.ha1.sum.layer1.or.weight', (2,)), |
| | ('float16.add.align.stage0.bit0.not_sel.weight', (1,)), |
| | ('float16.add.sign_xor.layer1.or.weight', (2,)), |
| | ('float16.add.mant_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
| | ('float16.add.mant_sub.not_b.bit0.weight', (1,)), |
| | ('float16.add.mant_select.bit0.not_sel.weight', (1,)), |
| | ] |
| |
|
| | for name, expected_shape in expected_gates: |
| | try: |
| | tensor = pop[name] |
| | actual_shape = tuple(tensor.shape[1:]) |
| | if actual_shape == expected_shape: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | if debug: |
| | print(f" {name}: SKIP (not found)") |
| |
|
| | return scores, total |
| |
|
| | def _test_float16_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test float16 multiplication circuit.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== FLOAT16 MUL ===") |
| |
|
| | expected_gates = [ |
| | ('float16.mul.sign_xor.layer1.or.weight', (2,)), |
| | ('float16.mul.exp_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
| | ('float16.mul.bias_sub.not_bias.bit0.weight', (1,)), |
| | ('float16.mul.mant_mul.pp.a0b0.weight', (2,)), |
| | ('float16.mul.mant_mul.acc.s0.fa0.ha1.sum.layer1.or.weight', (2,)), |
| | ] |
| |
|
| | for name, expected_shape in expected_gates: |
| | try: |
| | tensor = pop[name] |
| | actual_shape = tuple(tensor.shape[1:]) |
| | if actual_shape == expected_shape: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | if debug: |
| | print(f" {name}: SKIP (not found)") |
| |
|
| | return scores, total |
| |
|
| | def _test_float16_div(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test float16 division circuit.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== FLOAT16 DIV ===") |
| |
|
| | expected_gates = [ |
| | ('float16.div.sign_xor.layer1.or.weight', (2,)), |
| | ('float16.div.exp_sub.not_b.bit0.weight', (1,)), |
| | ('float16.div.bias_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
| | ('float16.div.mant_div.stage0.cmp.weight', (22,)), |
| | ('float16.div.mant_div.stage0.sub.not_d.bit0.weight', (1,)), |
| | ('float16.div.mant_div.stage0.mux.bit0.not_sel.weight', (1,)), |
| | ] |
| |
|
| | for name, expected_shape in expected_gates: |
| | try: |
| | tensor = pop[name] |
| | actual_shape = tuple(tensor.shape[1:]) |
| | if actual_shape == expected_shape: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | if debug: |
| | print(f" {name}: SKIP (not found)") |
| |
|
| | return scores, total |
| |
|
| | def _test_float16_cmp(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test float16 comparison circuits.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== FLOAT16 CMP ===") |
| |
|
| | expected_gates = [ |
| | ('float16.cmp.a.exp_max.weight', (5,)), |
| | ('float16.cmp.a.frac_nz.weight', (10,)), |
| | ('float16.cmp.a.is_nan.weight', (2,)), |
| | ('float16.cmp.either_nan.weight', (2,)), |
| | ('float16.cmp.sign_xor.layer1.or.weight', (2,)), |
| | ('float16.cmp.both_zero.weight', (2,)), |
| | ('float16.cmp.mag_a_gt_b.weight', (30,)), |
| | ('float16.cmp.eq.result.weight', (2,)), |
| | ('float16.cmp.lt.result.weight', (3,)), |
| | ('float16.cmp.gt.result.weight', (3,)), |
| | ] |
| |
|
| | for name, expected_shape in expected_gates: |
| | try: |
| | tensor = pop[name] |
| | actual_shape = tuple(tensor.shape[1:]) |
| | if actual_shape == expected_shape: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | if debug: |
| | print(f" {name}: SKIP (not found)") |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_float32_core(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test float32 core circuits (unpack, pack, classify).""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== FLOAT32 CORE ===") |
| |
|
| | expected_gates = [ |
| | ('float32.unpack.bit0.weight', (1,)), |
| | ('float32.classify.exp_zero.weight', (8,)), |
| | ('float32.classify.exp_max.weight', (8,)), |
| | ('float32.classify.frac_zero.weight', (23,)), |
| | ('float32.classify.is_zero.and.weight', (2,)), |
| | ('float32.classify.is_nan.and.weight', (2,)), |
| | ('float32.normalize.stage0.bit0.not_sel.weight', (1,)), |
| | ('float32.pack.bit0.weight', (1,)), |
| | ] |
| |
|
| | for name, expected_shape in expected_gates: |
| | try: |
| | tensor = pop[name] |
| | actual_shape = tuple(tensor.shape[1:]) |
| | if actual_shape == expected_shape: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | if debug: |
| | print(f" {name}: SKIP (not found)") |
| |
|
| | return scores, total |
| |
|
| | def _test_float32_add(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test float32 addition circuit.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== FLOAT32 ADD ===") |
| |
|
| | expected_gates = [ |
| | ('float32.add.exp_cmp.a_gt_b.weight', (16,)), |
| | ('float32.add.exp_diff.fa0.ha1.sum.layer1.or.weight', (2,)), |
| | ('float32.add.align.stage0.bit0.not_sel.weight', (1,)), |
| | ('float32.add.sign_xor.layer1.or.weight', (2,)), |
| | ('float32.add.mant_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
| | ('float32.add.mant_sub.not_b.bit0.weight', (1,)), |
| | ('float32.add.mant_select.bit0.not_sel.weight', (1,)), |
| | ] |
| |
|
| | for name, expected_shape in expected_gates: |
| | try: |
| | tensor = pop[name] |
| | actual_shape = tuple(tensor.shape[1:]) |
| | if actual_shape == expected_shape: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | if debug: |
| | print(f" {name}: SKIP (not found)") |
| |
|
| | return scores, total |
| |
|
| | def _test_float32_mul(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test float32 multiplication circuit.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== FLOAT32 MUL ===") |
| |
|
| | expected_gates = [ |
| | ('float32.mul.sign_xor.layer1.or.weight', (2,)), |
| | ('float32.mul.exp_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
| | ('float32.mul.bias_sub.not_bias.bit0.weight', (1,)), |
| | ('float32.mul.mant_mul.pp.a0b0.weight', (2,)), |
| | ('float32.mul.mant_mul.acc.s0.fa0.ha1.sum.layer1.or.weight', (2,)), |
| | ] |
| |
|
| | for name, expected_shape in expected_gates: |
| | try: |
| | tensor = pop[name] |
| | actual_shape = tuple(tensor.shape[1:]) |
| | if actual_shape == expected_shape: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | if debug: |
| | print(f" {name}: SKIP (not found)") |
| |
|
| | return scores, total |
| |
|
| | def _test_float32_div(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test float32 division circuit.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== FLOAT32 DIV ===") |
| |
|
| | expected_gates = [ |
| | ('float32.div.sign_xor.layer1.or.weight', (2,)), |
| | ('float32.div.exp_sub.not_b.bit0.weight', (1,)), |
| | ('float32.div.bias_add.fa0.ha1.sum.layer1.or.weight', (2,)), |
| | ('float32.div.mant_div.stage0.cmp.weight', (48,)), |
| | ('float32.div.mant_div.stage0.sub.not_d.bit0.weight', (1,)), |
| | ('float32.div.mant_div.stage0.mux.bit0.not_sel.weight', (1,)), |
| | ] |
| |
|
| | for name, expected_shape in expected_gates: |
| | try: |
| | tensor = pop[name] |
| | actual_shape = tuple(tensor.shape[1:]) |
| | if actual_shape == expected_shape: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | if debug: |
| | print(f" {name}: SKIP (not found)") |
| |
|
| | return scores, total |
| |
|
| | def _test_float32_cmp(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test float32 comparison circuits.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== FLOAT32 CMP ===") |
| |
|
| | expected_gates = [ |
| | ('float32.cmp.a.exp_max.weight', (8,)), |
| | ('float32.cmp.a.frac_nz.weight', (23,)), |
| | ('float32.cmp.a.is_nan.weight', (2,)), |
| | ('float32.cmp.either_nan.weight', (2,)), |
| | ('float32.cmp.sign_xor.layer1.or.weight', (2,)), |
| | ('float32.cmp.both_zero.weight', (2,)), |
| | ('float32.cmp.mag_a_gt_b.weight', (62,)), |
| | ('float32.cmp.eq.result.weight', (2,)), |
| | ('float32.cmp.lt.result.weight', (3,)), |
| | ('float32.cmp.gt.result.weight', (3,)), |
| | ] |
| |
|
| | for name, expected_shape in expected_gates: |
| | try: |
| | tensor = pop[name] |
| | actual_shape = tuple(tensor.shape[1:]) |
| | if actual_shape == expected_shape: |
| | scores += 1 |
| | self._record(name, 1, 1, []) |
| | else: |
| | self._record(name, 0, 1, [(expected_shape, actual_shape)]) |
| | total += 1 |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except KeyError: |
| | if debug: |
| | print(f" {name}: SKIP (not found)") |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def _test_integration(self, pop: Dict, debug: bool) -> Tuple[torch.Tensor, int]: |
| | """Test complex operations that chain multiple circuit families.""" |
| | pop_size = next(iter(pop.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total = 0 |
| |
|
| | if debug: |
| | print("\n=== INTEGRATION TESTS ===") |
| |
|
| | |
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | tests = [(10, 20, 25), (100, 50, 200), (255, 1, 0), (0, 0, 1)] |
| | for a, b, c in tests: |
| | sum_val = (a + b) & 0xFF |
| | expected = float(sum_val > c) |
| |
|
| | |
| | sum_bits = torch.tensor([((sum_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | c_bits = torch.tensor([((c >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | |
| | w = pop['arithmetic.greaterthan8bit.weight'].view(pop_size, 16) |
| | bias = pop['arithmetic.greaterthan8bit.bias'].view(pop_size) |
| | inp = torch.cat([sum_bits, c_bits]) |
| | out = heaviside((inp * w).sum(-1) + bias) |
| | correct = (out == expected).float() |
| | op_scores += correct |
| | op_total += 1 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('integration.add_then_compare', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" integration.add_then_compare: SKIP ({e})") |
| |
|
| | |
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | tests = [(3, 5), (4, 6), (7, 11), (9, 9)] |
| | for a, b in tests: |
| | product = (a * b) & 0xFF |
| | expected_mod3 = product % 3 |
| |
|
| | |
| | prod_bits = torch.tensor([((product >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | |
| | w1 = pop['modular.mod3.layer1.weight'].view(pop_size, 8) |
| | b1 = pop['modular.mod3.layer1.bias'].view(pop_size) |
| | h1 = heaviside((prod_bits * w1).sum(-1) + b1) |
| |
|
| | w2 = pop['modular.mod3.layer2.weight'].view(pop_size, 8) |
| | b2 = pop['modular.mod3.layer2.bias'].view(pop_size) |
| | h2 = heaviside((prod_bits * w2).sum(-1) + b2) |
| |
|
| | |
| | op_scores += 1 |
| | op_total += 1 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('integration.mul_then_mod', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" integration.mul_then_mod: SKIP ({e})") |
| |
|
| | |
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | tests = [(0b10101010, 0b11110000), (0b00001111, 0b01010101), (0xFF, 0x0F)] |
| | for a, b in tests: |
| | shifted_a = (a << 1) & 0xFF |
| | expected = shifted_a & b |
| |
|
| | a_bits = torch.tensor([((a >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | |
| | shifted_bits = [] |
| | for bit in range(8): |
| | w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size) |
| | bias = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size) |
| | if bit < 7: |
| | inp = a_bits[bit + 1] |
| | else: |
| | inp = torch.tensor(0.0, device=self.device) |
| | out = heaviside(inp * w + bias) |
| | shifted_bits.append(out) |
| |
|
| | |
| | and_bits = [] |
| | w_and = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2) |
| | b_and = pop['alu.alu8bit.and.bias'].view(pop_size, 8) |
| | for bit in range(8): |
| | inp = torch.tensor([shifted_bits[bit][0].item(), b_bits[bit].item()], |
| | device=self.device) |
| | out = heaviside((inp * w_and[:, bit]).sum(-1) + b_and[:, bit]) |
| | and_bits.append(out) |
| |
|
| | out_val = sum(int(and_bits[i][0].item()) << (7 - i) for i in range(8)) |
| | correct = (out_val == expected) |
| | op_scores += float(correct) |
| | op_total += 1 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('integration.shift_then_and', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" integration.shift_then_and: SKIP ({e})") |
| |
|
| | |
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | tests = [(50, 30), (30, 50), (100, 100), (0, 1)] |
| | for a, b in tests: |
| | diff = (a - b) & 0xFF |
| | is_negative = a < b |
| | expected = (-diff & 0xFF) if is_negative else diff |
| |
|
| | |
| | |
| | a_bits = torch.tensor([((a >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | b_bits = torch.tensor([((b >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | |
| | w = pop['arithmetic.lessthan8bit.weight'].view(pop_size, 16) |
| | bias = pop['arithmetic.lessthan8bit.bias'].view(pop_size) |
| | inp = torch.cat([a_bits, b_bits]) |
| | lt_out = heaviside((inp * w).sum(-1) + bias) |
| |
|
| | correct = (lt_out[0].item() == float(is_negative)) |
| | op_scores += float(correct) |
| | op_total += 1 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('integration.sub_then_conditional', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" integration.sub_then_conditional: SKIP ({e})") |
| |
|
| | |
| | |
| | try: |
| | op_scores = torch.zeros(pop_size, device=self.device) |
| | op_total = 0 |
| |
|
| | tests = [(10, 20), (50, 50), (127, 1), (0, 0)] |
| | for a, b in tests: |
| | sum_val = (a + b) & 0xFF |
| | doubled = (sum_val << 1) & 0xFF |
| | expected = doubled & 0xF0 |
| |
|
| | sum_bits = torch.tensor([((sum_val >> (7 - i)) & 1) for i in range(8)], |
| | device=self.device, dtype=torch.float32) |
| | mask_bits = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0], |
| | device=self.device, dtype=torch.float32) |
| |
|
| | |
| | shifted_bits = [] |
| | for bit in range(8): |
| | w = pop[f'alu.alu8bit.shl.bit{bit}.weight'].view(pop_size) |
| | bias = pop[f'alu.alu8bit.shl.bit{bit}.bias'].view(pop_size) |
| | if bit < 7: |
| | inp = sum_bits[bit + 1] |
| | else: |
| | inp = torch.tensor(0.0, device=self.device) |
| | out = heaviside(inp * w + bias) |
| | shifted_bits.append(out) |
| |
|
| | |
| | w_and = pop['alu.alu8bit.and.weight'].view(pop_size, 8, 2) |
| | b_and = pop['alu.alu8bit.and.bias'].view(pop_size, 8) |
| | result_bits = [] |
| | for bit in range(8): |
| | inp = torch.tensor([shifted_bits[bit][0].item(), mask_bits[bit].item()], |
| | device=self.device) |
| | out = heaviside((inp * w_and[:, bit]).sum(-1) + b_and[:, bit]) |
| | result_bits.append(out) |
| |
|
| | out_val = sum(int(result_bits[i][0].item()) << (7 - i) for i in range(8)) |
| | correct = (out_val == expected) |
| | op_scores += float(correct) |
| | op_total += 1 |
| |
|
| | scores += op_scores |
| | total += op_total |
| | self._record('integration.complex_expr', int(op_scores[0].item()), op_total, []) |
| | if debug: |
| | r = self.results[-1] |
| | print(f" {r.name}: {r.passed}/{r.total} {'PASS' if r.success else 'FAIL'}") |
| | except (KeyError, RuntimeError) as e: |
| | if debug: |
| | print(f" integration.complex_expr: SKIP ({e})") |
| |
|
| | return scores, total |
| |
|
| | |
| | |
| | |
| |
|
| | def evaluate(self, population: Dict[str, torch.Tensor], debug: bool = False) -> torch.Tensor: |
| | """ |
| | Evaluate population fitness with per-circuit reporting. |
| | |
| | Args: |
| | population: Dict of tensors, each with shape [pop_size, ...] |
| | debug: If True, print per-circuit results |
| | |
| | Returns: |
| | Tensor of fitness scores [pop_size], normalized to [0, 1] |
| | """ |
| | self.results = [] |
| | self.category_scores = {} |
| |
|
| | pop_size = next(iter(population.values())).shape[0] |
| | scores = torch.zeros(pop_size, device=self.device) |
| | total_tests = 0 |
| |
|
| | |
| | s, t = self._test_boolean_gates(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['boolean'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_halfadder(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['halfadder'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_fulladder(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['fulladder'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | for bits in [2, 4, 8]: |
| | s, t = self._test_ripplecarry(population, bits, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | for bits in [16, 32]: |
| | if f'arithmetic.ripplecarry{bits}bit.fa0.ha1.sum.layer1.or.weight' in population: |
| | if debug: |
| | print(f"\n{'=' * 60}") |
| | print(f" {bits}-BIT CIRCUITS") |
| | print(f"{'=' * 60}") |
| |
|
| | s, t = self._test_ripplecarry(population, bits, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores[f'ripplecarry{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | s, t = self._test_comparators_nbits(population, bits, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores[f'comparators{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if f'arithmetic.sub{bits}bit.not_b.bit0.weight' in population: |
| | s, t = self._test_subtractor_nbits(population, bits, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores[f'subtractor{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if f'alu.alu{bits}bit.and.bit0.weight' in population: |
| | s, t = self._test_bitwise_nbits(population, bits, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores[f'bitwise{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if f'alu.alu{bits}bit.shl.bit0.weight' in population: |
| | s, t = self._test_shifts_nbits(population, bits, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores[f'shifts{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if f'alu.alu{bits}bit.inc.bit0.xor.layer1.or.weight' in population: |
| | s, t = self._test_inc_dec_nbits(population, bits, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores[f'incdec{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if f'alu.alu{bits}bit.neg.not.bit0.weight' in population: |
| | s, t = self._test_neg_nbits(population, bits, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores[f'neg{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if f'combinational.barrelshifter{bits}.layer0.bit0.not_sel.weight' in population: |
| | s, t = self._test_barrel_shifter_nbits(population, bits, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores[f'barrelshifter{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if f'combinational.priorityencoder{bits}.valid.weight' in population: |
| | s, t = self._test_priority_encoder_nbits(population, bits, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores[f'priorityencoder{bits}'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_add3(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['add3'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_expr_add_mul(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['expr_add_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_comparators(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['comparators'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_threshold_gates(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['threshold'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_modular_all(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['modular'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_patterns(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['patterns'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_error_detection(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['error_detection'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_combinational(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['combinational'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_control_flow(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['control'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_alu_ops(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['alu'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_manifest(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['manifest'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | s, t = self._test_memory(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['memory'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | if 'float16.unpack.bit0.weight' in population: |
| | if debug: |
| | print(f"\n{'=' * 60}") |
| | print(f" FLOAT16 CIRCUITS") |
| | print(f"{'=' * 60}") |
| |
|
| | s, t = self._test_float16_core(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['float16_core'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if 'float16.add.exp_cmp.a_gt_b.weight' in population: |
| | s, t = self._test_float16_add(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['float16_add'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if 'float16.mul.sign_xor.layer1.or.weight' in population: |
| | s, t = self._test_float16_mul(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['float16_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if 'float16.div.sign_xor.layer1.or.weight' in population: |
| | s, t = self._test_float16_div(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['float16_div'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if 'float16.cmp.a.exp_max.weight' in population: |
| | s, t = self._test_float16_cmp(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['float16_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | |
| | if 'float32.unpack.bit0.weight' in population: |
| | if debug: |
| | print(f"\n{'=' * 60}") |
| | print(f" FLOAT32 CIRCUITS") |
| | print(f"{'=' * 60}") |
| |
|
| | s, t = self._test_float32_core(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['float32_core'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if 'float32.add.exp_cmp.a_gt_b.weight' in population: |
| | s, t = self._test_float32_add(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['float32_add'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if 'float32.mul.sign_xor.layer1.or.weight' in population: |
| | s, t = self._test_float32_mul(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['float32_mul'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if 'float32.div.sign_xor.layer1.or.weight' in population: |
| | s, t = self._test_float32_div(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['float32_div'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | if 'float32.cmp.a.exp_max.weight' in population: |
| | s, t = self._test_float32_cmp(population, debug) |
| | scores += s |
| | total_tests += t |
| | self.category_scores['float32_cmp'] = (s[0].item() if pop_size == 1 else s.mean().item(), t) |
| |
|
| | self.total_tests = total_tests |
| |
|
| | if debug: |
| | print("\n" + "=" * 60) |
| | print("CATEGORY SUMMARY") |
| | print("=" * 60) |
| | for cat, (got, expected) in sorted(self.category_scores.items()): |
| | pct = 100 * got / expected if expected > 0 else 0 |
| | status = "PASS" if got == expected else "FAIL" |
| | print(f" {cat:20} {int(got):6}/{expected:6} ({pct:6.2f}%) [{status}]") |
| |
|
| | print("\n" + "=" * 60) |
| | print("CIRCUIT FAILURES") |
| | print("=" * 60) |
| | failed = [r for r in self.results if not r.success] |
| | if failed: |
| | for r in failed[:20]: |
| | print(f" {r.name}: {r.passed}/{r.total}") |
| | if r.failures: |
| | print(f" First failure: {r.failures[0]}") |
| | if len(failed) > 20: |
| | print(f" ... and {len(failed) - 20} more") |
| | else: |
| | print(" None!") |
| |
|
| | return scores / total_tests if total_tests > 0 else scores |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='Unified Evaluation Suite for 8-bit Threshold Computer') |
| | parser.add_argument('--model', type=str, default=MODEL_PATH, help='Path to safetensors model') |
| | parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu') |
| | parser.add_argument('--pop_size', type=int, default=1, help='Population size for batched evaluation') |
| | parser.add_argument('--quiet', action='store_true', help='Suppress detailed output') |
| | parser.add_argument('--cpu-test', action='store_true', help='Run CPU smoke test (LOAD, ADD, STORE, HALT)') |
| | args = parser.parse_args() |
| |
|
| | if args.cpu_test: |
| | return run_smoke_test() |
| |
|
| | print("=" * 70) |
| | print(" UNIFIED EVALUATION SUITE") |
| | print("=" * 70) |
| |
|
| | print(f"\nLoading model from {args.model}...") |
| | model = load_model(args.model) |
| | print(f" Loaded {len(model)} tensors, {sum(t.numel() for t in model.values()):,} params") |
| |
|
| | print(f"\nInitializing evaluator on {args.device}...") |
| | evaluator = BatchedFitnessEvaluator(device=args.device, model_path=args.model) |
| |
|
| | print(f"\nCreating population (size {args.pop_size})...") |
| | population = create_population(model, pop_size=args.pop_size, device=args.device) |
| |
|
| | print("\nRunning evaluation...") |
| | if args.device == 'cuda': |
| | torch.cuda.synchronize() |
| | start = time.perf_counter() |
| |
|
| | fitness = evaluator.evaluate(population, debug=not args.quiet) |
| |
|
| | if args.device == 'cuda': |
| | torch.cuda.synchronize() |
| | elapsed = time.perf_counter() - start |
| |
|
| | print("\n" + "=" * 70) |
| | print("RESULTS") |
| | print("=" * 70) |
| |
|
| | if args.pop_size == 1: |
| | print(f" Fitness: {fitness[0].item():.6f}") |
| | else: |
| | print(f" Mean Fitness: {fitness.mean().item():.6f}") |
| | print(f" Min Fitness: {fitness.min().item():.6f}") |
| | print(f" Max Fitness: {fitness.max().item():.6f}") |
| |
|
| | print(f" Total tests: {evaluator.total_tests}") |
| | print(f" Time: {elapsed * 1000:.2f} ms") |
| |
|
| | if args.pop_size > 1: |
| | print(f" Throughput: {args.pop_size / elapsed:.0f} evals/sec") |
| | perfect = (fitness >= 0.9999).sum().item() |
| | print(f" Perfect (>=99.99%): {perfect}/{args.pop_size}") |
| |
|
| | if fitness[0].item() >= 0.9999: |
| | print("\n STATUS: PASS") |
| | return 0 |
| | else: |
| | failed_count = int((1 - fitness[0].item()) * evaluator.total_tests) |
| | print(f"\n STATUS: FAIL ({failed_count} tests failed)") |
| | return 1 |
| |
|
| |
|
| | if __name__ == '__main__': |
| | exit(main()) |
| |
|