"""Go AST parser using tree-sitter. Extracts imports and function calls.""" import tree_sitter_go as tsgo from tree_sitter import Language, Parser from pathlib import Path from dataclasses import dataclass, field GO_LANGUAGE = Language(tsgo.language()) # stdlib packages to filter out GO_STDLIB = { "fmt", "os", "io", "log", "net", "http", "context", "sync", "time", "strings", "strconv", "bytes", "errors", "sort", "math", "path", "encoding", "crypto", "reflect", "testing", "flag", "regexp", "bufio", "archive", "compress", "container", "database", "debug", "embed", "go", "hash", "html", "image", "index", "internal", "mime", "plugin", "runtime", "syscall", "text", "unicode", "unsafe", "encoding/json", "encoding/xml", "encoding/base64", "encoding/binary", "encoding/csv", "encoding/gob", "encoding/hex", "encoding/pem", "net/http", "net/url", "net/http/httptest", "io/ioutil", "io/fs", "os/exec", "os/signal", "path/filepath", "sync/atomic", "crypto/tls", "crypto/rand", "crypto/sha256", "crypto/hmac", "log/slog", "testing/fstest", } @dataclass class FileInfo: path: str content: str imports: list[str] = field(default_factory=list) functions: list[str] = field(default_factory=list) def parse_go_file(filepath: str, content: str, repo_module: str) -> FileInfo: """Parse a Go file and extract imports and exported functions.""" parser = Parser(GO_LANGUAGE) tree = parser.parse(content.encode()) root = tree.root_node info = FileInfo(path=filepath, content=content) for node in _find_nodes(root, "import_declaration"): for spec in _find_nodes(node, "import_spec"): path_node = spec.child_by_field_name("path") if path_node: import_path = path_node.text.decode().strip('"') info.imports.append(import_path) for node in _find_nodes(root, "function_declaration"): name_node = node.child_by_field_name("name") if name_node: info.functions.append(name_node.text.decode()) for node in _find_nodes(root, "method_declaration"): name_node = node.child_by_field_name("name") if name_node: info.functions.append(name_node.text.decode()) return info def filter_imports(imports: list[str], repo_module: str) -> list[str]: """Keep only first-party imports (same module) and significant third-party.""" result = [] for imp in imports: top = imp.split("/")[0] if imp in GO_STDLIB or top in GO_STDLIB: continue if imp.startswith(repo_module): result.append(imp) elif "." in top: result.append(imp) return result def get_repo_module(repo_path: str) -> str: """Read the module path from go.mod.""" gomod = Path(repo_path) / "go.mod" if gomod.exists(): for line in gomod.read_text().splitlines(): if line.startswith("module "): return line.split("module ", 1)[1].strip() return "" def resolve_import_to_file(import_path: str, repo_module: str, go_files: dict[str, str]) -> str | None: """Try to resolve an import path to a directory in the repo.""" if not import_path.startswith(repo_module): return None rel_dir = import_path[len(repo_module):].lstrip("/") for fpath in go_files: fdir = str(Path(fpath).parent) if fdir == rel_dir or fdir.endswith("/" + rel_dir): return rel_dir return None def _find_nodes(node, type_name: str): """Recursively find all nodes of a given type.""" if node.type == type_name: yield node for child in node.children: yield from _find_nodes(child, type_name)