| | """Standalone CodeLoader for loading and processing GitHub repositories.""" |
| |
|
| | import logging |
| | import os |
| | import shutil |
| | from pathlib import Path |
| | from typing import Callable |
| |
|
| | import git |
| | import nbconvert |
| | import nbformat |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class CodeLoader: |
| | """Load and process GitHub repositories for code analysis.""" |
| |
|
| | def __init__( |
| | self, |
| | github_url: str, |
| | max_file_size_mb: float = 1.0, |
| | raw_repo_dir: str | Path = "data/repos-raw", |
| | ): |
| | logger.info( |
| | f"Initializing CodeLoader for {github_url} with max file size " |
| | f"{max_file_size_mb} MB and raw repo dir {raw_repo_dir}" |
| | ) |
| | self.github_url = github_url |
| | self.max_file_size_mb = max_file_size_mb |
| | self.raw_repo_dir = Path(raw_repo_dir) |
| | self.repo_path = self.raw_repo_dir / self.github_url_to_repo_name |
| |
|
| | self.clone_repo() |
| | self.files = self._get_files() |
| |
|
| | @property |
| | def github_url_to_repo_name(self): |
| | """Convert GitHub URL to a safe directory name.""" |
| | base_name = ( |
| | self.github_url.rstrip("/").split("/")[-2] |
| | + "__" |
| | + self.github_url.rstrip("/").split("/")[-1] |
| | ) |
| | |
| | if base_name.endswith(".git"): |
| | base_name = base_name[:-4] |
| | return base_name |
| |
|
| | def clone_repo(self): |
| | """Clone or validate existing repository.""" |
| | if self.repo_path.exists(): |
| | logger.info(f"Repository already exists at {self.repo_path}") |
| |
|
| | |
| | try: |
| | repo = git.Repo(self.repo_path) |
| | |
| | try: |
| | _ = repo.head.commit.hexsha |
| | except (ValueError, git.BadName) as e: |
| | logger.warning( |
| | f"Repository has missing or corrupted commits at " |
| | f"{self.repo_path}, removing and re-cloning. Error: {e}" |
| | ) |
| | shutil.rmtree(self.repo_path) |
| | self.clone_repo() |
| | return |
| |
|
| | logger.info("Repository already exists and is valid") |
| | return |
| |
|
| | except (git.InvalidGitRepositoryError, git.GitCommandError) as e: |
| | logger.warning( |
| | f"Invalid or corrupted git repository at {self.repo_path}, " |
| | f"removing and re-cloning. Error: {e}" |
| | ) |
| | shutil.rmtree(self.repo_path) |
| | self.clone_repo() |
| | return |
| |
|
| | |
| | logger.info(f"Cloning repo {self.github_url} to {self.repo_path}") |
| | self.raw_repo_dir.mkdir(parents=True, exist_ok=True) |
| | repo = git.Repo.clone_from(self.github_url, str(self.repo_path)) |
| |
|
| | |
| | self._cleanup_repo() |
| |
|
| | def _cleanup_repo(self): |
| | """Remove docs/test directories, convert notebooks, and remove large files.""" |
| | |
| | for root, dirs, _ in os.walk(self.repo_path): |
| | |
| | if ".git" in dirs: |
| | dirs.remove(".git") |
| |
|
| | |
| | dirs_to_remove = [ |
| | dir |
| | for dir in dirs |
| | if dir in ["docs", "doc", "test", "tests", "example", "examples"] |
| | ] |
| | for dir in dirs_to_remove: |
| | dir_path = Path(root) / dir |
| | logger.info(f"Removing directory: {dir_path}") |
| | shutil.rmtree(dir_path) |
| | dirs.remove(dir) |
| |
|
| | |
| | for root, dirs, files in os.walk(self.repo_path): |
| | |
| | if ".git" in dirs: |
| | dirs.remove(".git") |
| |
|
| | for file in files: |
| | if file.endswith(".ipynb"): |
| | logger.info(f"Converting Jupyter Notebook {file} to .py") |
| | try: |
| | nb = nbformat.read(Path(root) / file, as_version=4) |
| | |
| | for cell in nb.cells: |
| | if cell.get("cell_type") == "code": |
| | cell["outputs"] = [] |
| | cell["execution_count"] = None |
| |
|
| | |
| | exporter = nbconvert.PythonExporter() |
| | source, _ = exporter.from_notebook_node(nb) |
| | source = ( |
| | "# This file was converted from a jupyter notebook " |
| | f"called {file}. All outputs have been removed.\n{source}" |
| | ) |
| | with open(Path(root) / file.replace(".ipynb", ".py"), "w") as f: |
| | f.write(source) |
| | |
| | os.remove(Path(root) / file) |
| | except Exception as e: |
| | logger.warning(f"Failed to convert notebook {file}: {e}") |
| | raise e |
| |
|
| | |
| | for root, dirs, files in os.walk(self.repo_path): |
| | |
| | if ".git" in dirs: |
| | dirs.remove(".git") |
| |
|
| | for file in files: |
| | file_path = Path(root) / file |
| | try: |
| | file_size = file_path.stat().st_size |
| | except FileNotFoundError as e: |
| | logger.warning(f"Failed to get size of {file_path}: {e}") |
| | continue |
| | if file_size > self.mb_to_bytes(self.max_file_size_mb): |
| | logger.info(f"Removing large file: {file_path}") |
| | os.remove(file_path) |
| |
|
| | def _get_files(self): |
| | """Get all files from the repository.""" |
| | files = {} |
| | for root, _, _files in os.walk(self.repo_path): |
| | for file in _files: |
| | file_path = Path(root) / file |
| | if ".git" in str(file_path): |
| | continue |
| |
|
| | |
| | file_path_key = file_path.relative_to(self.repo_path) |
| |
|
| | try: |
| | with open(file_path, "r", encoding="utf-8", errors="ignore") as f: |
| | content = f.read() |
| | files[str(file_path_key)] = content |
| | except Exception as e: |
| | logger.warning(f"Could not read {file_path}: {e}") |
| |
|
| | |
| | files = dict(sorted(files.items())) |
| | return files |
| |
|
| | @staticmethod |
| | def mb_to_bytes(mb: float) -> int: |
| | """Convert megabytes to bytes.""" |
| | return int(mb * 1024 * 1024) |
| |
|
| | def get_files_by_extension( |
| | self, extensions: list[str] | None = None |
| | ) -> dict[str, str]: |
| | """Get files filtered by extension.""" |
| | if extensions is None: |
| | |
| | extensions = [ |
| | ".c", |
| | ".cc", |
| | ".cpp", |
| | ".cu", |
| | ".h", |
| | ".hpp", |
| | ".java", |
| | ".jl", |
| | ".m", |
| | ".matlab", |
| | ".Makefile", |
| | ".md", |
| | ".pl", |
| | ".ps1", |
| | ".py", |
| | ".r", |
| | ".sh", |
| | "config.txt", |
| | ".rs", |
| | "readme.txt", |
| | "requirements_dev.txt", |
| | "requirements-dev.txt", |
| | "requirements.dev.txt", |
| | "requirements.txt", |
| | ".scala", |
| | ".yaml", |
| | ".yml", |
| | ] |
| | return { |
| | k: v |
| | for k, v in self.files.items() |
| | if k.lower().endswith(tuple(extensions)) |
| | } |
| |
|
| | def get_repo_tree(self): |
| | """Generate a tree representation of the repository.""" |
| | repo_tree = "" |
| | for root, dirs, files in os.walk(self.repo_path): |
| | |
| | if ".git" in dirs: |
| | dirs.remove(".git") |
| |
|
| | level = str(Path(root).relative_to(self.repo_path)).count(os.sep) |
| | indent = "β " * (level - 1) + "βββ " if level > 0 else "" |
| |
|
| | |
| | if level > 0: |
| | repo_tree += f"{indent}{Path(root).name}/\n" |
| |
|
| | sub_indent = "β " * level + "βββ " |
| | for f in files: |
| | repo_tree += f"{sub_indent}{f}\n" |
| | return repo_tree |
| |
|
| | def get_code_prompt( |
| | self, |
| | file_extensions: list[str] | None = None, |
| | token_counter: Callable | None = None, |
| | max_tokens: int | None = None, |
| | code_changes: list[dict[str, str]] | None = None, |
| | ) -> str: |
| | """Generate code prompt with repo tree and file contents.""" |
| | code_prompt = "Repo tree:\n" + self.get_repo_tree() + "\n\n" |
| | tokens = token_counter(code_prompt) if token_counter is not None else 0 |
| | |
| | if token_counter is not None and max_tokens is not None: |
| | logger.info( |
| | f"Building code prompt: repo tree tokens={tokens}, max_tokens={max_tokens}, " |
| | f"remaining for files={max_tokens - tokens}" |
| | ) |
| |
|
| | files_to_replace = {} |
| | if code_changes: |
| | files_to_replace = { |
| | cc["file_name"]: cc["discrepancy_code"] for cc in code_changes |
| | } |
| | logger.debug( |
| | f"Files to replace: {len(files_to_replace)}: {files_to_replace.keys()}" |
| | ) |
| |
|
| | for file_path, file_content in self.get_files_by_extension( |
| | file_extensions |
| | ).items(): |
| | if file_path in files_to_replace: |
| | logger.debug(f"Replacing code for {file_path} with changed code") |
| | file_content = files_to_replace[file_path] |
| | code_file = f"# ---\n# File: {file_path}\n# Content:\n{file_content}\n" |
| | if token_counter is not None: |
| | logger.debug(f"Adding file: {file_path}") |
| | num_tokens = token_counter(code_file) |
| | |
| | if max_tokens and (tokens + num_tokens) > max_tokens: |
| | logger.warning( |
| | f"Truncating. Max tokens reached for {self.github_url}. " |
| | f"Current tokens: {tokens}, File tokens: {num_tokens}, " |
| | f"Max tokens for code is {max_tokens}" |
| | ) |
| | break |
| | tokens += num_tokens |
| | logger.debug( |
| | f"Number of tokens in file: {num_tokens}. " |
| | f"Total number of tokens in code prompt: {tokens}" |
| | ) |
| | code_prompt += code_file |
| | |
| | |
| | if token_counter is not None: |
| | final_code_tokens = token_counter(code_prompt) |
| | logger.info( |
| | f"Code prompt built: {final_code_tokens} tokens " |
| | f"(max was {max_tokens if max_tokens else 'unlimited'})" |
| | ) |
| | |
| | return code_prompt |
| |
|
| |
|
| |
|