| | |
| | |
| | """ |
| | Qwen3-Embedding-0.6B 推理测试代码 |
| | 使用 RKLLM API 进行文本嵌入推理 |
| | """ |
| | import faulthandler |
| | faulthandler.enable() |
| | import os |
| | os.environ["RKLLM_LOG_LEVEL"] = "1" |
| | import numpy as np |
| | import time |
| | from typing import List, Dict, Any |
| | from rkllm_binding import * |
| |
|
| |
|
| | class Qwen3EmbeddingTester: |
| | def __init__(self, model_path: str, library_path: str = "./librkllmrt.so"): |
| | """ |
| | 初始化 Qwen3 嵌入模型测试器 |
| | |
| | Args: |
| | model_path: 模型文件路径(.rkllm 格式) |
| | library_path: RKLLM 库文件路径 |
| | """ |
| | self.model_path = model_path |
| | self.library_path = library_path |
| | self.runtime = None |
| | self.embeddings_buffer = [] |
| | self.current_result = None |
| | |
| | def callback_function(self, result_ptr, userdata_ptr, state_enum): |
| | """ |
| | 推理回调函数 |
| | |
| | Args: |
| | result_ptr: 结果指针 |
| | userdata_ptr: 用户数据指针 |
| | state_enum: 状态枚举 |
| | """ |
| | state = LLMCallState(state_enum) |
| | |
| | if state == LLMCallState.RKLLM_RUN_NORMAL: |
| | result = result_ptr.contents |
| | print(f"result: {result}") |
| | |
| | if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0: |
| | embd_size = result.last_hidden_layer.embd_size |
| | num_tokens = result.last_hidden_layer.num_tokens |
| | |
| | print(f"获取到嵌入向量:维度={embd_size}, 令牌数={num_tokens}") |
| | |
| | |
| | |
| | if num_tokens > 0: |
| | |
| | last_token_embedding = np.array([ |
| | result.last_hidden_layer.hidden_states[(num_tokens-1) * embd_size + i] |
| | for i in range(embd_size) |
| | ]) |
| | |
| | self.current_result = { |
| | 'embedding': last_token_embedding, |
| | 'embd_size': embd_size, |
| | 'num_tokens': num_tokens |
| | } |
| | |
| | print(f"嵌入向量范数: {np.linalg.norm(last_token_embedding):.4f}") |
| | print(f"嵌入向量前10维: {last_token_embedding[:10]}") |
| | |
| | elif state == LLMCallState.RKLLM_RUN_ERROR: |
| | print("推理过程发生错误") |
| | |
| | def init_model(self): |
| | """初始化模型""" |
| | try: |
| | print(f"初始化 RKLLM 运行时,库路径: {self.library_path}") |
| | self.runtime = RKLLMRuntime(self.library_path) |
| | |
| | print("创建默认参数...") |
| | params = self.runtime.create_default_param() |
| | |
| | |
| | params.model_path = self.model_path.encode('utf-8') |
| | params.max_context_len = 1024 |
| | params.max_new_tokens = 1 |
| | params.temperature = 1.0 |
| | params.top_k = 1 |
| | params.top_p = 1.0 |
| | |
| | |
| | params.extend_param.base_domain_id = 1 |
| | params.extend_param.embed_flash = 0 |
| | params.extend_param.enabled_cpus_num = 4 |
| | params.extend_param.enabled_cpus_mask = 0x0F |
| | |
| | print(f"初始化模型: {self.model_path}") |
| | self.runtime.init(params, self.callback_function) |
| | self.runtime.set_chat_template("","","") |
| | print("模型初始化成功!") |
| | |
| | except Exception as e: |
| | print(f"模型初始化失败: {e}") |
| | raise |
| | |
| | def get_detailed_instruct(self, task_description: str, query: str) -> str: |
| | """ |
| | 构建指令提示词(参考 README 中的用法) |
| | |
| | Args: |
| | task_description: 任务描述 |
| | query: 查询文本 |
| | |
| | Returns: |
| | 格式化的指令提示词 |
| | """ |
| | return f'Instruct: {task_description}\nQuery: {query}' |
| | |
| | def encode_text(self, text: str, task_description: str = None) -> np.ndarray: |
| | """ |
| | 编码文本为嵌入向量 |
| | |
| | Args: |
| | text: 要编码的文本 |
| | task_description: 任务描述,如果提供则使用指令提示 |
| | |
| | Returns: |
| | 嵌入向量(numpy数组) |
| | """ |
| | try: |
| | |
| | if task_description: |
| | input_text = self.get_detailed_instruct(task_description, text) |
| | else: |
| | input_text = text |
| | |
| | print(f"编码文本: {input_text[:100]}{'...' if len(input_text) > 100 else ''}") |
| | |
| | |
| | rk_input = RKLLMInput() |
| | rk_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT |
| | c_prompt = input_text.encode('utf-8') |
| | rk_input._union_data.prompt_input = c_prompt |
| | |
| | |
| | infer_params = RKLLMInferParam() |
| | infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER |
| | infer_params.keep_history = 0 |
| | |
| | |
| | self.current_result = None |
| | self.runtime.clear_kv_cache(False) |
| | |
| | |
| | start_time = time.time() |
| | self.runtime.run(rk_input, infer_params) |
| | end_time = time.time() |
| | |
| | print(f"推理耗时: {end_time - start_time:.3f}秒") |
| | |
| | if self.current_result and 'embedding' in self.current_result: |
| | |
| | embedding = self.current_result['embedding'] |
| | normalized_embedding = embedding / np.linalg.norm(embedding) |
| | return normalized_embedding |
| | else: |
| | raise RuntimeError("未能获取到有效的嵌入向量") |
| | |
| | except Exception as e: |
| | print(f"编码文本时发生错误: {e}") |
| | raise |
| | |
| | def compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float: |
| | """ |
| | 计算两个嵌入向量的余弦相似度 |
| | |
| | Args: |
| | emb1: 第一个嵌入向量 |
| | emb2: 第二个嵌入向量 |
| | |
| | Returns: |
| | 余弦相似度值 |
| | """ |
| | return np.dot(emb1, emb2) |
| | |
| | def test_embedding_similarity(self): |
| | """测试嵌入相似度计算""" |
| | print("\n" + "="*50) |
| | print("测试嵌入相似度计算") |
| | print("="*50) |
| | |
| | |
| | task_description = "Given a web search query, retrieve relevant passages that answer the query" |
| | |
| | queries = [ |
| | "What is the capital of China?", |
| | "Explain gravity" |
| | ] |
| | |
| | documents = [ |
| | "The capital of China is Beijing.", |
| | "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." |
| | ] |
| | |
| | |
| | print("\n编码查询文本:") |
| | query_embeddings = [] |
| | for i, query in enumerate(queries): |
| | print(f"\n查询 {i+1}: {query}") |
| | emb = self.encode_text(query, task_description) |
| | query_embeddings.append(emb) |
| | |
| | |
| | print("\n编码文档文本:") |
| | doc_embeddings = [] |
| | for i, doc in enumerate(documents): |
| | print(f"\n文档 {i+1}: {doc}") |
| | emb = self.encode_text(doc) |
| | doc_embeddings.append(emb) |
| | |
| | |
| | print("\n计算相似度矩阵:") |
| | print("查询 vs 文档相似度:") |
| | print("-" * 30) |
| | |
| | similarities = [] |
| | for i, q_emb in enumerate(query_embeddings): |
| | row_similarities = [] |
| | for j, d_emb in enumerate(doc_embeddings): |
| | sim = self.compute_similarity(q_emb, d_emb) |
| | row_similarities.append(sim) |
| | print(f"查询{i+1} vs 文档{j+1}: {sim:.4f}") |
| | similarities.append(row_similarities) |
| | print() |
| | |
| | return similarities |
| | |
| | def test_multilingual_embedding(self): |
| | """测试多语言嵌入能力""" |
| | print("\n" + "="*50) |
| | print("测试多语言嵌入能力") |
| | print("="*50) |
| | |
| | |
| | texts = { |
| | "英语": "Hello, how are you?", |
| | "中文": "你好,你好吗?", |
| | "法语": "Bonjour, comment allez-vous?", |
| | "西班牙语": "Hola, ¿cómo estás?", |
| | "日语": "こんにちは、元気ですか?" |
| | } |
| | |
| | embeddings = {} |
| | print("\n编码多语言文本:") |
| | for lang, text in texts.items(): |
| | print(f"\n{lang}: {text}") |
| | emb = self.encode_text(text) |
| | embeddings[lang] = emb |
| | |
| | |
| | print("\n跨语言相似度:") |
| | print("-" * 30) |
| | |
| | languages = list(texts.keys()) |
| | for i, lang1 in enumerate(languages): |
| | for j, lang2 in enumerate(languages): |
| | if i <= j: |
| | sim = self.compute_similarity(embeddings[lang1], embeddings[lang2]) |
| | print(f"{lang1} vs {lang2}: {sim:.4f}") |
| | |
| | def test_code_embedding(self): |
| | """测试代码嵌入能力""" |
| | print("\n" + "="*50) |
| | print("测试代码嵌入能力") |
| | print("="*50) |
| | |
| | |
| | codes = { |
| | "Python函数": """ |
| | def fibonacci(n): |
| | if n <= 1: |
| | return n |
| | return fibonacci(n-1) + fibonacci(n-2) |
| | """, |
| | "JavaScript函数": """ |
| | function fibonacci(n) { |
| | if (n <= 1) return n; |
| | return fibonacci(n-1) + fibonacci(n-2); |
| | } |
| | """, |
| | "C++函数": """ |
| | int fibonacci(int n) { |
| | if (n <= 1) return n; |
| | return fibonacci(n-1) + fibonacci(n-2); |
| | } |
| | """, |
| | "数组排序": """ |
| | def bubble_sort(arr): |
| | n = len(arr) |
| | for i in range(n): |
| | for j in range(0, n-i-1): |
| | if arr[j] > arr[j+1]: |
| | arr[j], arr[j+1] = arr[j+1], arr[j] |
| | """ |
| | } |
| | |
| | embeddings = {} |
| | print("\n编码代码文本:") |
| | for name, code in codes.items(): |
| | print(f"\n{name}:") |
| | print(code[:100] + "..." if len(code) > 100 else code) |
| | emb = self.encode_text(code) |
| | embeddings[name] = emb |
| | |
| | |
| | print("\n代码相似度:") |
| | print("-" * 30) |
| | |
| | code_names = list(codes.keys()) |
| | for i, name1 in enumerate(code_names): |
| | for j, name2 in enumerate(code_names): |
| | if i <= j: |
| | sim = self.compute_similarity(embeddings[name1], embeddings[name2]) |
| | print(f"{name1} vs {name2}: {sim:.4f}") |
| | |
| | def cleanup(self): |
| | """清理资源""" |
| | if self.runtime: |
| | try: |
| | self.runtime.destroy() |
| | print("模型资源已清理") |
| | except Exception as e: |
| | print(f"清理资源时发生错误: {e}") |
| |
|
| | def main(): |
| | """主函数""" |
| | import argparse |
| | |
| | |
| | parser = argparse.ArgumentParser(description='Qwen3-Embedding-0.6B 推理测试') |
| | parser.add_argument('model_path', help='模型文件路径(.rkllm格式)') |
| | parser.add_argument('--library_path', default="./librkllmrt.so", help='RKLLM库文件路径(默认为./librkllmrt.so)') |
| | args = parser.parse_args() |
| | |
| | |
| | if not os.path.exists(args.model_path): |
| | print(f"错误: 模型文件不存在: {args.model_path}") |
| | print("请确保:") |
| | print("1. 已下载 Qwen3-Embedding-0.6B 模型") |
| | print("2. 已使用 rkllm-convert.py 将模型转换为 .rkllm 格式") |
| | return |
| | |
| | if not os.path.exists(args.library_path): |
| | print(f"错误: RKLLM 库文件不存在: {args.library_path}") |
| | print("请确保 librkllmrt.so 在当前目录或 LD_LIBRARY_PATH 中") |
| | return |
| | |
| | print("Qwen3-Embedding-0.6B 推理测试") |
| | print("=" * 50) |
| | |
| | |
| | tester = Qwen3EmbeddingTester(args.model_path, args.library_path) |
| | |
| | try: |
| | |
| | tester.init_model() |
| | |
| | |
| | print("\n开始运行嵌入测试...") |
| | |
| | |
| | tester.test_embedding_similarity() |
| | |
| | |
| | tester.test_multilingual_embedding() |
| | |
| | |
| | tester.test_code_embedding() |
| | |
| | print("\n" + "="*50) |
| | print("所有测试完成!") |
| | print("="*50) |
| | |
| | except KeyboardInterrupt: |
| | print("\n测试被用户中断") |
| | except Exception as e: |
| | print(f"\n测试过程中发生错误: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | finally: |
| | |
| | tester.cleanup() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|