diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c49fa4c..60ef93a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,7 +18,17 @@ jobs: cache-dependency-path: setup.py - name: Install dependencies run: | - pip install -e '.[test]' + pip install -e '.[test,docs]' + pip install mypy ruff - name: Run tests run: | - pytest + pytest --cov=github_to_sqlite --cov-branch -q + - name: Run ruff + run: | + ruff check . --format=github --exit-zero + - name: Run mypy + run: | + mypy github_to_sqlite --no-error-summary + - name: Build docs + run: | + sphinx-build -b html docs docs/_build -W diff --git a/PLAN.md b/PLAN.md index 55859f3..cf3ca8a 100644 --- a/PLAN.md +++ b/PLAN.md @@ -1,5 +1,168 @@ # Embeddings feature plan +This document decomposes the work required to generate sentence-transformer embeddings for starred repositories. The work is split into three phases so that core functionality lands first, then documentation tooling, followed by publishing the new docs. + +## Phase 1: Generate and store embeddings + +This phase introduces embeddings for starred repositories. + +### Dependencies +- [x] **Add runtime dependencies** + - [x] Install `sentence-transformers` for embedding inference. + - [x] Install `sqlite-vec` to store and query embedding vectors in SQLite. + - [x] Install `semantic-chunkers` from GitHub to chunk README text using + `semantic_chunkers.chunkers.StatisticalChunker`. + - [x] Install `fd` to locate build definition files across the repository tree. + `find_build_files()` prefers `fd` but falls back to `find` or `os.walk` if + needed. +- [x] **Add development dependencies** + - [x] Include `pytest-cov` for coverage reports. + - [x] Update `setup.py` or `pyproject.toml` accordingly. + +### Database changes +- [x] **Create `repo_embeddings` table** + - [x] Columns: `repo_id` (FK to `repos`), `title_embedding`, `description_embedding`, `readme_embedding`. + - [x] Store embeddings using `sqlite-vec` vec0 virtual tables for efficient vector search. + - [x] Add indexes on `repo_id` for fast lookup. +- [x] **Create `readme_chunk_embeddings` table** + - [x] Columns: `repo_id` (FK to `repos`), `chunk_index`, `chunk_text`, `embedding`. + - [x] Use `sqlite-vec` for the `embedding` column to enable similarity search over + individual README chunks. + - [x] Add a composite index on `repo_id` and `chunk_index`. +- [x] **Create `repo_build_files` table** + - [x] Columns: `repo_id` (FK to `repos`), `file_path`, `metadata` (JSON). + - [x] Store one row per build definition (e.g. `pyproject.toml`, `package.json`). + - [x] The `metadata` column captures the entire parsed contents of the file so that + fields such as package name or author can be queried later. +- [x] **Create `repo_metadata` table** + - [x] Columns: `repo_id` (FK to `repos`), `language`, `directory_tree`. + - [x] Capture the primary programming language and a serialized directory structure + for quick reference. + - [x] **Migration script** + - [x] Provide SQL script or CLI command that creates the table if it does not exist. + - [x] Document migration process in README. + +### Embedding generation +- [x] **Model loading** + - [x] Default to `huggingface.co/Alibaba-NLP/gte-modernbert-base`. + - [x] Allow overriding the model path via CLI option or environment variable. +- [x] **Data collection** + - [x] Fetch starred repositories from GitHub using existing API utilities. + - [x] Retrieve README HTML or markdown for each repo. + - [x] Locate common build files (`pyproject.toml`, `package.json`, + `Cargo.toml`, `Gemfile`) using `fd` when available, otherwise `find` or + `os.walk`. + - [x] Parse each file and store its entire contents as JSON in the + `repo_build_files.metadata` column. Package name and author can then be + derived from this JSON as needed. + - [x] Record the repository's primary programming language and generate a serialized + directory tree for storage in `repo_metadata`. +- [x] **Chunking** + - [x] Use `semantic_chunkers.chunkers.StatisticalChunker` to split README text + into semantically meaningful chunks. See `docs/00-chunkers-intro.ipynb` in + the `semantic-chunkers` repository for usage examples. + - [x] If that library is not available at runtime, fall back to splitting on + blank lines to ensure tests run without optional dependencies. +- [x] **Vector inference** + - [x] Run the model on the repository title, description and each README chunk. + - [x] Batch requests when possible to speed up inference. +- [x] **Storage** + - [x] Save repository-level vectors to `repo_embeddings`. + - [x] Save each chunk's embedding to `readme_chunk_embeddings` along with the + chunk text and index. + - [x] Skip entries that already exist unless `--force` is supplied. + +### CLI integration +- [x] **New command** `starred-embeddings` + - [x] Accept database path and optional model path. + - [x] Iterate through all starred repos and compute embeddings. + - [x] Chunk each README using `StatisticalChunker` and store chunk embeddings. + - [x] Collect build metadata using `find_build_files()` (using `fd`, `find` or + `os.walk` as available) and store the entire parsed JSON in the + `repo_build_files.metadata` column. + - [x] Support `--force` and `--verbose` flags. +- [x] **Error handling** + - [x] Handle missing READMEs gracefully. + - [x] Retry transient network failures. + +### Testing +- [x] **Unit tests** + - [x] Mock GitHub API calls and README fetches. + - [x] Verify embeddings are generated and stored correctly, including per-chunk + embeddings. + - [x] Ensure build metadata is parsed and stored as JSON in `repo_build_files`. + - [x] **Coverage** + - [x] Run `pytest --cov --cov-branch` in CI to ensure branch coverage does not regress. +- [x] **Integration tests** + - [x] Simulate GitHub API responses with `requests_mock` and run the + `starred-embeddings` command end-to-end. + - [x] Confirm embeddings, chunked README data and build metadata are + stored in the database. + +### Documentation +- [x] **README updates** + - [x] Describe the new command and its options. + - [x] Mention default model and how to override it. + - [x] Document how README files are chunked using `semantic-chunkers` before + embedding. + - [x] Explain how build files are detected using `find_build_files()` + (preferring `fd`) and stored for analysis. +- [x] **Changelog entry** + - [x] Summarize the feature and dependencies. + +## Phase 2: Documentation tooling + +- [x] **Introduce RST and Sphinx** + - [x] Add `sphinx` and `sphinx-rtd-theme` to development dependencies. + - [x] Configure a `docs/` directory with Sphinx `conf.py` and initial structure. +- [x] **Convert existing documentation** + - [x] Migrate `README.md` or relevant guides into RST as needed. + - [x] Ensure the embeddings feature is documented in the new docs site. +- [ ] **Automation** + - [ ] Update CI to build documentation and fail on warnings. + +## Phase 3: Publish documentation + +- [ ] **Deployment** + - [ ] Publish the documentation using GitHub Pages or another hosting service. + - [ ] Automate deployment on release so new docs are available immediately. + +## Next task: publish documentation site + +With the documentation building in CI, the next step is to publish it so users +can browse the docs online. + +Steps: + +- [ ] Set up a GitHub Pages workflow that uploads ``docs/_build`` + from the main branch. +- [ ] Trigger the deployment after tests pass on ``main``. + +Completed build steps: + +- [x] Install documentation dependencies in the CI environment. +- [x] Run ``sphinx-build -b html docs docs/_build`` during CI. +- [x] Treat warnings as errors so the build fails on broken docs. + +- [x] Add a `starred-embeddings` Click command in `cli.py`. + - [x] Accept a database path argument. + - [x] Accept `--model` to override the default model. + - [x] Support `--force` and `--verbose` flags. +- [x] Load the sentence-transformers model using the configured name. +- [x] Iterate through starred repositories using existing API helpers. + - [x] Save repository metadata to the database. + - [x] Fetch README content for each repository. + - [x] Use `StatisticalChunker` to split README text. + - [x] Run embeddings for titles, descriptions and README chunks. + - [x] Save vectors to `repo_embeddings` and `readme_chunk_embeddings`. + - [x] Extract build files using `find_build_files()` and store metadata in + `repo_build_files`. + - [x] Capture the primary language and directory tree in `repo_metadata`. +- [x] Write unit tests for the new command using mocks to avoid network calls. + - [x] Ensure coverage passes with `pytest --cov --cov-branch`. + - [x] Add tests for utility helpers like `vector_to_blob`, `parse_build_file`, + `directory_tree` and `_maybe_load_sqlite_vec`. +======= This document decomposes the work required to generate sentence-transformer embeddings for starred repositories. Each bullet point expands into further tasks until reaching granular actionable steps. ## 1. Dependencies @@ -57,3 +220,4 @@ This document decomposes the work required to generate sentence-transformer embe - **Changelog entry** - Summarize the feature and dependencies. + diff --git a/README.md b/README.md index a45bfc0..f75b8fa 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ Save data from GitHub to a SQLite database. - [Scraping dependents for a repository](#scraping-dependents-for-a-repository) - [Fetching emojis](#fetching-emojis) - [Making authenticated API calls](#making-authenticated-api-calls) +- [Running migrations](#running-migrations) +- [Generating embeddings for starred repositories](#generating-embeddings-for-starred-repositories) @@ -258,3 +260,38 @@ Many GitHub APIs are [paginated using the HTTP Link header](https://docs.github. You can outline newline-delimited JSON for each item using `--nl`. This can be useful for streaming items into another tool. $ github-to-sqlite get /users/simonw/repos --nl + +## Running migrations + +Run the `migrate` command to create any optional tables and indexes: + + $ github-to-sqlite migrate github.db + +The command ensures embedding tables exist and sets up FTS, foreign keys and +views using the same logic as the main CLI commands. It will create +`repo_embeddings`, `readme_chunk_embeddings`, `repo_build_files`, and +`repo_metadata` tables, using `sqlite-vec` if available. + +## Build file detection + +Some commands extract metadata from standard build files. The helper prefers the +[`fd`](https://github.com/sharkdp/fd) tool if available, falling back to the +`find` utility or a Python implementation. + +## Generating embeddings for starred repositories + +Use the `starred-embeddings` command to compute embeddings for repositories you +have starred. The command loads the sentence-transformers model configured in +`config.default_model` (currently `Alibaba-NLP/gte-modernbert-base`) unless you +specify `--model`. You can also set the `GITHUB_TO_SQLITE_MODEL` environment +variable to override the default. + +``` +$ github-to-sqlite starred-embeddings github.db --model my/custom-model +``` + +Embeddings for repository titles, descriptions and README chunks are stored in +`repo_embeddings` and `readme_chunk_embeddings`. Build files discovered using +`find_build_files()` are parsed and saved to `repo_build_files`, while basic +language information and a directory listing are recorded in `repo_metadata`. + diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..fdca8f8 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,12 @@ +import os +import sys + +sys.path.insert(0, os.path.abspath('..')) + +project = 'github-to-sqlite' +author = 'Simon Willison' +release = '2.9' + +extensions = ['sphinx.ext.autodoc'] + +html_theme = 'sphinx_rtd_theme' diff --git a/docs/embeddings.rst b/docs/embeddings.rst new file mode 100644 index 0000000..7a07e4e --- /dev/null +++ b/docs/embeddings.rst @@ -0,0 +1,10 @@ +Generating embeddings +===================== + +The ``starred-embeddings`` command computes sentence-transformer embeddings for repositories you have starred. It loads the model configured in ``config.default_model`` (``Alibaba-NLP/gte-modernbert-base`` by default) unless you specify ``--model`` or set the ``GITHUB_TO_SQLITE_MODEL`` environment variable. + +.. code-block:: console + + $ github-to-sqlite starred-embeddings github.db --model my/custom-model + +The command stores repository-level vectors in ``repo_embeddings`` and README chunk vectors in ``readme_chunk_embeddings``. Build files discovered via ``find_build_files()`` are parsed and saved to ``repo_build_files``. Basic language information and the directory listing are recorded in ``repo_metadata``. diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..3279fc2 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,9 @@ +Welcome to github-to-sqlite's documentation! +============================================= + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + embeddings + migrations diff --git a/docs/migrations.rst b/docs/migrations.rst new file mode 100644 index 0000000..ede3d4c --- /dev/null +++ b/docs/migrations.rst @@ -0,0 +1,15 @@ +Migrations and build files +========================== + +Run the ``migrate`` command to create any optional tables and indexes used by the embeddings feature: + +.. code-block:: console + + $ github-to-sqlite migrate github.db + +This sets up the ``repo_embeddings``, ``readme_chunk_embeddings``, ``repo_build_files`` and ``repo_metadata`` tables. The helper prefers the ``sqlite-vec`` extension when available. + +Build file detection +-------------------- + +Some commands look for standard build definitions such as ``pyproject.toml`` or ``package.json``. The ``find_build_files()`` helper uses the ``fd`` command if installed, otherwise falling back to ``find`` or a Python implementation. diff --git a/github_to_sqlite/cli.py b/github_to_sqlite/cli.py index e6a2d88..0891c4e 100644 --- a/github_to_sqlite/cli.py +++ b/github_to_sqlite/cli.py @@ -4,9 +4,11 @@ import pathlib import textwrap import os +import importlib.util import sqlite_utils import time import json +from typing import Any, Optional, cast from github_to_sqlite import utils @@ -140,10 +142,12 @@ def pull_requests(db_path, repo, pull_request_ids, auth, load, orgs, state, sear repos_seen.add(pr_repo_url) utils.save_pull_requests(db, [pull_request], pr_repo) else: + from typing import Iterable, Any + + repos: Iterable[dict[str, Any]] if orgs: repos = itertools.chain.from_iterable( - utils.fetch_all_repos(token=token, org=org) - for org in orgs + utils.fetch_all_repos(token=token, org=org) for org in orgs ) else: repos = [utils.fetch_repo(repo, token)] @@ -306,10 +310,12 @@ def repos(db_path, usernames, auth, repo, load, readme, readme_html): def _repo_readme(db, token, repo_id, full_name, readme, readme_html): if readme: readme = utils.fetch_readme(token, full_name) - db["repos"].update(repo_id, {"readme": readme}, alter=True) + cast(sqlite_utils.db.Table, db["repos"]).update(repo_id, {"readme": readme}, alter=True) if readme_html: readme_html = utils.fetch_readme(token, full_name, html=True) - db["repos"].update(repo_id, {"readme_html": readme_html}, alter=True) + cast(sqlite_utils.db.Table, db["repos"]).update( + repo_id, {"readme_html": readme_html}, alter=True + ) @cli.command() @@ -424,21 +430,24 @@ def commits(db_path, repos, all, auth): db = sqlite_utils.Database(db_path) token = load_token(auth) - def stop_when(commit): + from typing import Callable + + def stop_when(commit: Any) -> bool: try: - db["commits"].get(commit["sha"]) + cast(sqlite_utils.db.Table, db["commits"]).get(commit["sha"]) return True except sqlite_utils.db.NotFoundError: return False + stop_when_func: Optional[Callable[[Any], bool]] = stop_when if all: - stop_when = None + stop_when_func = None for repo in repos: repo_full = utils.fetch_repo(repo, token) utils.save_repo(db, repo_full) - commits = utils.fetch_commits(repo, token, stop_when) + commits = utils.fetch_commits(repo, token, stop_when_func) utils.save_commits(db, commits, repo_full["id"]) time.sleep(1) @@ -467,9 +476,7 @@ def stop_when(commit): ) def scrape_dependents(db_path, repos, auth, verbose): "Scrape dependents for specified repos" - try: - import bs4 - except ImportError: + if importlib.util.find_spec("bs4") is None: raise click.ClickException("Optional dependency bs4 is needed for this command") db = sqlite_utils.Database(db_path) token = load_token(auth) @@ -480,7 +487,7 @@ def scrape_dependents(db_path, repos, auth, verbose): for dependent_repo in utils.scrape_dependents(repo, verbose): # Don't fetch repo details if it's already in our DB - existing = list(db["repos"].rows_where("full_name = ?", [dependent_repo])) + existing = list(cast(sqlite_utils.db.Table, db["repos"]).rows_where("full_name = ?", [dependent_repo])) dependent_id = None if not existing: dependent_full = utils.fetch_repo(dependent_repo, token) @@ -490,12 +497,13 @@ def scrape_dependents(db_path, repos, auth, verbose): else: dependent_id = existing[0]["id"] # Only insert if it isn't already there: - if not db["dependents"].exists() or not list( - db["dependents"].rows_where( + dependents_table = cast(sqlite_utils.db.Table, db["dependents"]) + if not dependents_table.exists() or not list( + dependents_table.rows_where( "repo = ? and dependent = ?", [repo_full["id"], dependent_id] ) ): - db["dependents"].insert( + dependents_table.insert( { "repo": repo_full["id"], "dependent": dependent_id, @@ -534,7 +542,7 @@ def emojis(db_path, auth, fetch): "Fetch GitHub supported emojis" db = sqlite_utils.Database(db_path) token = load_token(auth) - table = db.table("emojis", pk="name") + table = cast(sqlite_utils.db.Table, db.table("emojis", pk="name")) table.upsert_all(utils.fetch_emojis(token)) if fetch: # Ensure table has 'image' column @@ -637,6 +645,48 @@ def workflows(db_path, repos, auth): utils.ensure_db_shape(db) +@cli.command(name="starred-embeddings") +@click.argument( + "db_path", + type=click.Path(file_okay=True, dir_okay=False, allow_dash=False), + required=True, +) +@click.option("--model", help="Model to use for embeddings") +@click.option("--force", is_flag=True, help="Overwrite existing embeddings") +@click.option("--verbose", is_flag=True, help="Show progress information") +@click.option( + "-a", + "--auth", + type=click.Path(file_okay=True, dir_okay=False, allow_dash=True), + default="auth.json", + help="Path to auth.json token file", +) +def starred_embeddings(db_path, model, force, verbose, auth): + """Generate embeddings for repositories starred by the user.""" + token = load_token(auth) + db = sqlite_utils.Database(db_path) + utils.generate_starred_embeddings( + db, + token, + model_name=model, + force=force, + verbose=verbose, + ) + + +@cli.command() +@click.argument( + "db_path", + type=click.Path(file_okay=True, dir_okay=False, allow_dash=False), + required=True, +) +def migrate(db_path): + """Ensure all optional tables, FTS and foreign keys exist.""" + db = sqlite_utils.Database(db_path) + utils.ensure_db_shape(db) + click.echo("Database migrated") + + def load_token(auth): try: token = json.load(open(auth))["github_personal_token"] diff --git a/github_to_sqlite/config.py b/github_to_sqlite/config.py new file mode 100644 index 0000000..e16bf2e --- /dev/null +++ b/github_to_sqlite/config.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class Config(BaseModel): + default_model: str = "Alibaba-NLP/gte-modernbert-base" + onnx_provider: str = "cpu" + max_length: int = 8192 + + +config = Config() diff --git a/github_to_sqlite/sentencizer_chunker.py b/github_to_sqlite/sentencizer_chunker.py new file mode 100644 index 0000000..d644d42 --- /dev/null +++ b/github_to_sqlite/sentencizer_chunker.py @@ -0,0 +1,26 @@ +from typing import List, Sequence, Any + + +class BasicSentencizerChunker: + """Chunk token vectors on a designated period token.""" + + def __init__(self, period_token: str = "."): + self.period_token = period_token + + def chunk( + self, tokens: Sequence[str], vectors: Sequence[Any] + ) -> List[List[Any]]: + if len(tokens) != len(vectors): + raise ValueError("tokens and vectors must be the same length") + chunks: List[List[Any]] = [] + current: List[Any] = [] + for token, vec in zip(tokens, vectors): + current.append(vec) + if token == self.period_token: + chunks.append(current) + current = [] + # drop incomplete final chunk + return chunks + + __call__ = chunk + diff --git a/github_to_sqlite/simple_chunker.py b/github_to_sqlite/simple_chunker.py new file mode 100644 index 0000000..ca83f56 --- /dev/null +++ b/github_to_sqlite/simple_chunker.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, List, Optional, cast + +from nltk.tokenize import sent_tokenize + +from .config import config + +try: # Optional dependencies + from colorama import Fore, Style +except Exception: # pragma: no cover - color output not essential + class _Color(str, Enum): + RED = "" + GREEN = "" + BLUE = "" + MAGENTA = "" + + class _Style(Enum): + RESET_ALL = "" + + Fore = _Color + Style = _Style + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from semantic_router.encoders.base import DenseEncoder +else: + try: + from semantic_router.encoders.base import DenseEncoder + except Exception: # pragma: no cover - optional dependency not installed + @dataclass + class DenseEncoder: + """Fallback encoder used only for typing.""" + + name: str = "default" + + +@dataclass(slots=True) +class Chunk: + """A single chunk of text produced by the chunker.""" + + content: str + token_count: int + is_triggered: bool + triggered_score: float + + +@dataclass(slots=True) +class BaseSplitter: + """Callable object that splits text into sentence strings.""" + + def __call__(self, doc: str) -> List[str]: + raise NotImplementedError("Subclasses must implement this method") + + +@dataclass +class BaseChunker: + """Base class for chunkers.""" + + name: str + splitter: BaseSplitter + encoder: Optional[DenseEncoder] = None + + def __post_init__(self) -> None: + if self.encoder is None: + self.encoder = DenseEncoder() + + def __call__(self, docs: List[str]) -> List[List[Chunk]]: + raise NotImplementedError + + def _split(self, doc: str) -> List[str]: + return self.splitter(doc) + + def _chunk(self, splits: List[Any]) -> List[Chunk]: + raise NotImplementedError + + def print(self, document_splits: List[Chunk]) -> None: + """Display chunks using color if ``colorama`` is installed.""" + + if hasattr(Fore, "RED"): + colors = [Fore.RED, Fore.GREEN, Fore.BLUE, Fore.MAGENTA] + reset = getattr(Style, "RESET_ALL", "") + else: + colors = ["", "", "", ""] + reset = "" + for i, split in enumerate(document_splits): + color = colors[i % len(colors)] + colored_content = f"{color}{split.content}{reset}" + if split.is_triggered: + triggered = f"{split.triggered_score:.2f}" + elif i == len(document_splits) - 1: + triggered = "final split" + else: + triggered = "token limit" + print( + f"Split {i + 1}, tokens {split.token_count}, triggered by: {triggered}" + ) + print(colored_content) + print("-" * 88) + print() + + +@dataclass +class SimpleChunker(BaseChunker): + """Chunk text into groups of ``target_length`` sentences.""" + + target_length: int = field(default=config.max_length) + + def __call__(self, docs: List[str]) -> List[List[Chunk]]: + return [self._chunk(self._split(doc)) for doc in docs] + + def _split(self, doc: str) -> List[str]: + try: + return cast(List[str], sent_tokenize(doc)) + except LookupError: # pragma: no cover - depends on environment + import nltk + + class PunktResource(str, Enum): + PUNKT = "punkt" + PUNKT_TAB = "punkt_tab" + + for resource in PunktResource: + nltk.download(resource.value) + try: + return cast(List[str], sent_tokenize(doc)) + except LookupError: + continue + raise + + def _chunk(self, sentences: List[str]) -> List[Chunk]: + chunks: List[Chunk] = [] + for i in range(0, len(sentences), self.target_length): + piece = sentences[i : i + self.target_length] + if len(piece) < self.target_length: + break + content = " ".join(piece) + chunks.append( + Chunk( + content=content, + token_count=len(piece), + is_triggered=False, + triggered_score=0.0, + ) + ) + return chunks + diff --git a/github_to_sqlite/tokenization.py b/github_to_sqlite/tokenization.py new file mode 100644 index 0000000..c0a2c2a --- /dev/null +++ b/github_to_sqlite/tokenization.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +import os +from tokenizers import Tokenizer + +from .config import config + + +@dataclass(frozen=True, slots=True) +class Token: + id: int + value: str + offsets: tuple[int, int] + + + +def load_tokenizer(model: str | None = None) -> Tokenizer: + """Load a Hugging Face tokenizer. + + Uses the provided model name or the ``GITHUB_TO_SQLITE_MODEL`` environment + variable, falling back to :data:`config.default_model`. + """ + model_name = model or os.environ.get("GITHUB_TO_SQLITE_MODEL", config.default_model) + return Tokenizer.from_pretrained(model_name) diff --git a/github_to_sqlite/utils.py b/github_to_sqlite/utils.py index 833572d..1f826ce 100644 --- a/github_to_sqlite/utils.py +++ b/github_to_sqlite/utils.py @@ -1,10 +1,20 @@ import base64 import sys +import os +import subprocess +import shutil import requests import re import time import urllib.parse import yaml +import sqlite_utils +import pathlib +import sqlite3 +import json +from typing import Optional, cast + +from . import config from urllib3 import Retry @@ -432,7 +442,8 @@ def fetch_tags(repo, token=None): def fetch_commits(repo, token=None, stop_when=None): if stop_when is None: - stop_when = lambda commit: False + def stop_when(commit): + return False headers = make_headers(token) url = "https://api.github.com/repos/{}/commits".format(repo) try: @@ -506,7 +517,7 @@ def paginate(url, headers=None): if isinstance(data, dict) and data.get("message"): print(GitHubError.from_response(response), file=sys.stderr) try: - url = response.links.get("next").get("url") if response.status_code == 200 else url + url = response.links.get("next", {}).get("url") if response.status_code == 200 else url except AttributeError: url = None yield data @@ -716,12 +727,143 @@ def ensure_foreign_keys(db): db[table].add_foreign_key(column, table2, column2) +_SQLITE_VEC_LOADED: bool | None = None + + +def _create_table_if_missing( + db: sqlite_utils.Database, + tables: set[str], + name: str, + columns: dict[str, type], + pk: str | tuple[str, ...], + foreign_keys: list[tuple[str, str, str]] | None = None, +) -> None: + """Create *name* if it doesn't exist in *tables*.""" + if name not in tables: + table = cast(sqlite_utils.db.Table, db[name]) + table.create(columns, pk=pk, foreign_keys=foreign_keys or []) + + +def _create_virtual_table_if_missing(db: sqlite_utils.Database, name: str, sql: str) -> None: + """Create a virtual table using provided SQL if it does not already exist.""" + existing = set(db.table_names()) + if name not in existing: + db.execute(sql) + + +def _maybe_load_sqlite_vec(db): + """Attempt to load sqlite-vec extension, returning True if available.""" + global _SQLITE_VEC_LOADED + if _SQLITE_VEC_LOADED is not None: + return _SQLITE_VEC_LOADED + try: + import sqlite_vec + except ImportError: + _SQLITE_VEC_LOADED = False + return _SQLITE_VEC_LOADED + try: + sqlite_vec.load(db.conn) + except (OSError, sqlite3.DatabaseError, AttributeError): + _SQLITE_VEC_LOADED = False + else: + _SQLITE_VEC_LOADED = True + return _SQLITE_VEC_LOADED + + +def ensure_embedding_tables(db): + """Create tables used for embedding storage if they do not exist.""" + using_vec = _maybe_load_sqlite_vec(db) + + tables = set(db.table_names()) + + if "repo_embeddings" not in tables: + if using_vec: + _create_virtual_table_if_missing( + db, + "repo_embeddings", + """ + create virtual table if not exists repo_embeddings using vec0( + repo_id int primary key, + title_embedding float[768], + description_embedding float[768], + readme_embedding float[768] + ) + """, + ) + else: + _create_table_if_missing( + db, + tables, + "repo_embeddings", + { + "repo_id": int, + "title_embedding": bytes, + "description_embedding": bytes, + "readme_embedding": bytes, + }, + pk="repo_id", + foreign_keys=[("repo_id", "repos", "id")] if "repos" in tables else [], + ) + + if "readme_chunk_embeddings" not in tables: + if using_vec: + _create_virtual_table_if_missing( + db, + "readme_chunk_embeddings", + """ + create virtual table if not exists readme_chunk_embeddings using vec0( + repo_id int, + chunk_index int, + chunk_text text, + embedding float[768] + ) + """, + ) + db.execute( + "create index if not exists readme_chunk_idx on readme_chunk_embeddings(repo_id, chunk_index)" + ) + else: + _create_table_if_missing( + db, + tables, + "readme_chunk_embeddings", + { + "repo_id": int, + "chunk_index": int, + "chunk_text": str, + "embedding": bytes, + }, + pk=("repo_id", "chunk_index"), + foreign_keys=[("repo_id", "repos", "id")] if "repos" in tables else [], + ) + + _create_table_if_missing( + db, + tables, + "repo_build_files", + {"repo_id": int, "file_path": str, "metadata": str}, + pk=("repo_id", "file_path"), + foreign_keys=[("repo_id", "repos", "id")] if "repos" in tables else [], + ) + + _create_table_if_missing( + db, + tables, + "repo_metadata", + {"repo_id": int, "language": str, "directory_tree": str}, + pk="repo_id", + foreign_keys=[("repo_id", "repos", "id")] if "repos" in tables else [], + ) + + def ensure_db_shape(db): "Ensure FTS is configured and expected FKS, views and (soon) indexes are present" # Foreign keys: ensure_foreign_keys(db) db.index_foreign_keys() + ensure_embedding_tables(db) + # FTS: existing_tables = set(db.table_names()) for table, columns in FTS_CONFIG.items(): @@ -732,7 +874,6 @@ def ensure_db_shape(db): db[table].enable_fts(columns, create_triggers=True) # Views: - existing_views = set(db.view_names()) existing_tables = set(db.table_names()) for view, (tables, sql) in VIEWS.items(): # Do all of the tables exist? @@ -745,14 +886,14 @@ def scrape_dependents(repo, verbose=False): # Optional dependency: from bs4 import BeautifulSoup - url = "https://github.com/{}/network/dependents".format(repo) + url: str | None = "https://github.com/{}/network/dependents".format(repo) while url: if verbose: print(url) response = requests.get(url) soup = BeautifulSoup(response.content, "html.parser") repos = [ - a["href"].lstrip("/") + str(a["href"]).lstrip("/") for a in soup.select("a[data-hovercard-type=repository]") ] if verbose: @@ -764,7 +905,10 @@ def scrape_dependents(repo, verbose=False): except IndexError: break if next_link is not None: - url = next_link["href"] + from bs4.element import Tag + + tag = cast(Tag, next_link) + url = cast(Optional[str], tag.get("href")) time.sleep(1) else: url = None @@ -826,6 +970,26 @@ def rewrite_readme_html(html): return html +def chunk_readme(text): + """Return a list of textual chunks for the provided README content. + + Attempts to use ``semantic_chunkers.StatisticalChunker`` if available; + otherwise falls back to splitting on blank lines. This allows tests to run + without the optional dependency installed. + """ + + try: + from semantic_chunkers.chunkers import StatisticalChunker + except ImportError: + pass + else: + chunker = StatisticalChunker() + return list(chunker.chunk(text)) + + # Fallback: split on blank lines + return [p.strip() for p in re.split(r"\n{2,}", text) if p.strip()] + + def fetch_workflows(token, full_name): headers = make_headers(token) url = "https://api.github.com/repos/{}/contents/.github/workflows".format(full_name) @@ -912,3 +1076,211 @@ def save_workflow(db, repo_id, filename, content): pk="id", foreign_keys=["job", "repo"], ) + +# Utility to locate build definition files using fd, find or os.walk + + +BUILD_PATTERNS = ["pyproject.toml", "package.json", "Cargo.toml", "Gemfile"] + + +def _post_process_build_files(found: list[str], base: str) -> list[str]: + """Normalize paths, filter junk and deduplicate while preserving order.""" + unique: list[str] = [] + seen = set() + for item in found: + if "/.git/" in item or "/node_modules/" in item: + continue + norm_path = os.path.normpath(item) + if os.path.isabs(norm_path) or norm_path.startswith(os.path.normpath(base) + os.sep): + norm = os.path.relpath(norm_path, base) + else: + norm = norm_path + if norm not in seen: + unique.append(norm) + seen.add(norm) + return unique + +def find_build_files(path: str) -> list[str]: + """Return a list of build definition files under *path*. + + The helper prefers the ``fd`` command if available, then falls back to + ``find`` and finally to walking the directory tree with ``os.walk``. Paths + are returned relative to *path*. + """ + found: list[str] = [] + + if shutil.which("fd"): + for pattern in BUILD_PATTERNS: + try: + result = subprocess.run( + ["fd", "-HI", "-t", "f", pattern, path], + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError: + continue + found.extend(result.stdout.splitlines()) + elif shutil.which("find"): + for pattern in BUILD_PATTERNS: + try: + result = subprocess.run( + ["find", path, "-name", pattern, "-type", "f"], + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError: + continue + found.extend(result.stdout.splitlines()) + else: + for pattern in BUILD_PATTERNS: + for full in pathlib.Path(path).rglob(pattern): + if full.is_file(): + found.append(str(full)) + return _post_process_build_files(found, path) + + +def vector_to_blob(vec) -> bytes: + """Return a float32 byte string for the provided vector.""" + import numpy as np + + arr = np.asarray(vec, dtype="float32") + return arr.tobytes() + + +def parse_build_file(path: str) -> dict: + """Parse a supported build file and return its contents as a dict.""" + import json + try: + import tomllib + except ImportError: # Python <3.11 + import tomli as tomllib + + try: + if path.endswith(".json"): + with open(path) as fp: + return json.load(fp) + with open(path, "rb") as fp: + return tomllib.load(fp) + except (OSError, json.JSONDecodeError, tomllib.TOMLDecodeError): + return {} + + +def directory_tree(path: str) -> dict: + """Return a simple directory tree representation for *path*.""" + tree = {} + for root, dirs, files in os.walk(path): + rel = os.path.relpath(root, path) + tree[rel] = {"dirs": sorted(dirs), "files": sorted(files)} + return tree + + +def generate_starred_embeddings( + db: sqlite_utils.Database, + token: str, + model_name: str | None = None, + *, + force: bool = False, + verbose: bool = False, +) -> None: + """Generate embeddings for repos starred by the authenticated user.""" + from sentence_transformers import SentenceTransformer + import numpy as np + + ensure_db_shape(db) + using_vec = _maybe_load_sqlite_vec(db) + if verbose: + if using_vec: + print("Using sqlite-vec for embedding storage") + else: + print("sqlite-vec extension not loaded; storing embeddings as BLOBs") + + env_model = os.environ.get("GITHUB_TO_SQLITE_MODEL") + model_name = model_name or env_model or config.config.default_model + embedder = SentenceTransformer(model_name) + + batch_size = 32 + for star in fetch_all_starred(token=token): + repo = star["repo"] + repo_id = save_repo(db, repo) + + repo_embeddings = cast(sqlite_utils.db.Table, db["repo_embeddings"]) + if not force and repo_embeddings.count_where("repo_id = ?", [repo_id]): + if verbose: + print(f"Skipping {repo['full_name']} (already processed)") + continue + + title = repo.get("name") or "" + description = repo.get("description") or "" + readme = fetch_readme(token, repo["full_name"]) or "" + chunks = chunk_readme(readme) + + title_vec, desc_vec = embedder.encode([title, description]) + chunk_vecs = [] + for i in range(0, len(chunks), batch_size): + part = chunks[i : i + batch_size] + chunk_vecs.extend(embedder.encode(list(part))) + + readme_vec = np.mean(chunk_vecs, axis=0) if chunk_vecs else np.zeros_like( + title_vec + ) + + if using_vec: + import sqlite_vec + + title_val = sqlite_vec.serialize_float32(list(title_vec)) + desc_val = sqlite_vec.serialize_float32(list(desc_vec)) + readme_val = sqlite_vec.serialize_float32(list(readme_vec)) + else: + title_val = vector_to_blob(title_vec) + desc_val = vector_to_blob(desc_vec) + readme_val = vector_to_blob(readme_vec) + + repo_embeddings.upsert( + { + "repo_id": repo_id, + "title_embedding": title_val, + "description_embedding": desc_val, + "readme_embedding": readme_val, + }, + pk="repo_id", + ) + + for i, (chunk, vec) in enumerate(zip(chunks, chunk_vecs)): + chunk_val = ( + sqlite_vec.serialize_float32(list(vec)) if using_vec else vector_to_blob(vec) + ) + cast(sqlite_utils.db.Table, db["readme_chunk_embeddings"]).upsert( + { + "repo_id": repo_id, + "chunk_index": i, + "chunk_text": chunk, + "embedding": chunk_val, + }, + pk=("repo_id", "chunk_index"), + ) + + for build_path in find_build_files(repo["full_name"]): + metadata = parse_build_file(os.path.join(repo["full_name"], build_path)) + cast(sqlite_utils.db.Table, db["repo_build_files"]).upsert( + { + "repo_id": repo_id, + "file_path": build_path, + "metadata": json.dumps(metadata), + }, + pk=("repo_id", "file_path"), + ) + + cast(sqlite_utils.db.Table, db["repo_metadata"]).upsert( + { + "repo_id": repo_id, + "language": repo.get("language") or "", + "directory_tree": json.dumps(directory_tree(repo["full_name"])), + }, + pk="repo_id", + ) + + if verbose: + print(f"Processed {repo['full_name']}") + diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..b6ae6b1 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,10 @@ +[mypy] +python_version = 3.10 +ignore_missing_imports = True +show_error_codes = True +pretty = True +sqlite_cache = True +cache_fine_grained = True +check_untyped_defs = True +warn_unused_ignores = True +warn_unused_configs = True diff --git a/setup.py b/setup.py index de72b51..7726550 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,25 @@ def get_long_description(): [console_scripts] github-to-sqlite=github_to_sqlite.cli:cli """, - install_requires=["sqlite-utils>=2.7.2", "requests", "PyYAML"], - extras_require={"test": ["pytest", "requests-mock", "bs4"]}, - tests_require=["github-to-sqlite[test]"], + install_requires=[ + "sqlite-utils>=2.7.2", + "requests", + "PyYAML", + ], + extras_require={ + "test": ["pytest", "pytest-cov", "requests-mock", "bs4", "mypy", "ruff"], + "semantic_chunkers": [ + "semantic-chunkers @ https://github.com/aurelio-labs/semantic-chunkers/archive/refs/tags/v0.1.1.tar.gz" + ], + "vector-search-index": [ + "sentence-transformers[onnx]", + "sqlite-vec", + "nltk", + "onnx", + "pydantic>=2.0", + "tokenizers", + ], + "gpu": ["sentence-transformers[onnx-gpu]"], + "docs": ["sphinx", "sphinx-rtd-theme"], + }, ) diff --git a/tests/test_chunk_readme.py b/tests/test_chunk_readme.py new file mode 100644 index 0000000..36c6628 --- /dev/null +++ b/tests/test_chunk_readme.py @@ -0,0 +1,29 @@ +import sys +from github_to_sqlite import utils + + +def test_chunk_readme_fallback(): + text = """Paragraph one. + +Paragraph two. + +Paragraph three.""" + chunks = utils.chunk_readme(text) + assert chunks == ["Paragraph one.", "Paragraph two.", "Paragraph three."] + + +def test_chunk_readme_with_chunker(monkeypatch): + class DummyChunker: + def chunk(self, text): + return ["chunk1", "chunk2"] + + def dummy_init(): + return DummyChunker() + + monkeypatch.setitem( + sys.modules, + 'semantic_chunkers.chunkers', + type('m', (), {'StatisticalChunker': dummy_init}) + ) + assert utils.chunk_readme('anything') == ['chunk1', 'chunk2'] + diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..ce73f31 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,16 @@ +import importlib +from github_to_sqlite import tokenization + + +def test_load_tokenizer(monkeypatch): + calls = [] + + def fake_from_pretrained(model): + calls.append(model) + return 'tok' + + monkeypatch.setattr('tokenizers.Tokenizer.from_pretrained', fake_from_pretrained) + monkeypatch.setenv('GITHUB_TO_SQLITE_MODEL', 'env-model') + importlib.reload(tokenization) + assert tokenization.load_tokenizer() == 'tok' + assert calls == ['env-model'] diff --git a/tests/test_embedding_tables.py b/tests/test_embedding_tables.py new file mode 100644 index 0000000..c9384cc --- /dev/null +++ b/tests/test_embedding_tables.py @@ -0,0 +1,14 @@ +import sqlite_utils +from github_to_sqlite import utils + + +def test_embedding_tables_created(): + db = sqlite_utils.Database(memory=True) + utils.ensure_db_shape(db) + tables = set(db.table_names()) + assert { + "repo_embeddings", + "readme_chunk_embeddings", + "repo_build_files", + "repo_metadata", + }.issubset(tables) diff --git a/tests/test_find_build_files.py b/tests/test_find_build_files.py new file mode 100644 index 0000000..92cfe4d --- /dev/null +++ b/tests/test_find_build_files.py @@ -0,0 +1,79 @@ +import subprocess +import shutil + + +from github_to_sqlite.utils import find_build_files + + +def test_find_build_files_fd(monkeypatch): + calls = [] + + def fake_which(cmd): + return '/usr/bin/fd' if cmd == 'fd' else None + + def fake_run(args, capture_output, text, check): + pattern = args[4] if len(args) > 4 else args[1] + calls.append((args[0], pattern)) + if pattern == 'package.json': + raise subprocess.CalledProcessError(1, args) + class Res: + def __init__(self): + self.stdout = { + 'pyproject.toml': 'a/pyproject.toml\n', + 'package.json': '', + 'Cargo.toml': 'Cargo.toml\n', + 'Gemfile': '' + }[pattern] + return Res() + + monkeypatch.setattr(shutil, 'which', fake_which) + monkeypatch.setattr(subprocess, 'run', fake_run) + + result = find_build_files('repo') + assert result == ['a/pyproject.toml', 'Cargo.toml'] + assert calls + + +def test_find_fallback(monkeypatch): + calls = [] + + def fake_which(cmd): + if cmd == 'find': + return '/usr/bin/find' + return None + + def fake_run(args, capture_output, text, check): + pattern = args[3] + calls.append((args[0], pattern)) + if pattern == 'package.json': + raise subprocess.CalledProcessError(1, args) + class R: + def __init__(self): + self.stdout = 'repo/' + pattern + '\n' + return R() + + monkeypatch.setattr(shutil, 'which', fake_which) + monkeypatch.setattr(subprocess, 'run', fake_run) + + result = find_build_files('repo') + assert 'package.json' not in result + assert 'pyproject.toml' in result and 'Cargo.toml' in result + assert calls + + +def test_walk_fallback(monkeypatch, tmp_path): + def fake_which(cmd): + return None + + repo_dir = tmp_path / "repo" + repo_dir.mkdir() + (repo_dir / "pyproject.toml").write_text("") + sub = repo_dir / "sub" + sub.mkdir() + (sub / "Cargo.toml").write_text("") + + monkeypatch.setattr(shutil, "which", fake_which) + + result = find_build_files(str(repo_dir)) + assert set(result) == {"pyproject.toml", "sub/Cargo.toml"} + diff --git a/tests/test_migrate_command.py b/tests/test_migrate_command.py new file mode 100644 index 0000000..0e8f2d5 --- /dev/null +++ b/tests/test_migrate_command.py @@ -0,0 +1,18 @@ +from click.testing import CliRunner +import sqlite_utils +from github_to_sqlite import cli + + +def test_migrate_creates_tables(tmpdir): + db_path = str(tmpdir / "test.db") + runner = CliRunner() + result = runner.invoke(cli.cli, ["migrate", db_path]) + assert result.exit_code == 0 + db = sqlite_utils.Database(db_path) + tables = set(db.table_names()) + assert { + "repo_embeddings", + "readme_chunk_embeddings", + "repo_build_files", + "repo_metadata", + }.issubset(tables) diff --git a/tests/test_repos.py b/tests/test_repos.py index 9432992..65fea66 100644 --- a/tests/test_repos.py +++ b/tests/test_repos.py @@ -2,11 +2,9 @@ import pytest import pathlib import sqlite_utils -from sqlite_utils.db import ForeignKey import json from click.testing import CliRunner from github_to_sqlite import cli -import pytest README_HTML = """
  • Filtering tables
  • @@ -45,6 +43,10 @@ def test_repos(mocked, tmpdir): "users", "licenses", "repos", + "repo_embeddings", + "readme_chunk_embeddings", + "repo_build_files", + "repo_metadata", "licenses_fts", "licenses_fts_data", "licenses_fts_idx", diff --git a/tests/test_sentencizer_chunker.py b/tests/test_sentencizer_chunker.py new file mode 100644 index 0000000..ce2883c --- /dev/null +++ b/tests/test_sentencizer_chunker.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest +from github_to_sqlite.sentencizer_chunker import BasicSentencizerChunker + + +def test_sentencizer_chunks_vectors(): + tokens = ["hello", ".", "world", "."] + vecs = [np.array([1]), np.array([2]), np.array([3]), np.array([4])] + chunker = BasicSentencizerChunker() + chunks = chunker(tokens, vecs) + assert len(chunks) == 2 + assert len(chunks[0]) == 2 + assert len(chunks[1]) == 2 + + +def test_sentencizer_drops_incomplete(): + tokens = ["hello", ".", "world"] + vecs = [np.array([1]), np.array([2]), np.array([3])] + chunker = BasicSentencizerChunker() + chunks = chunker(tokens, vecs) + assert len(chunks) == 1 + + +def test_sentencizer_length_mismatch(): + tokens = ["a", "."] + vecs = [np.array([1])] + chunker = BasicSentencizerChunker() + with pytest.raises(ValueError): + chunker(tokens, vecs) diff --git a/tests/test_simple_chunker.py b/tests/test_simple_chunker.py new file mode 100644 index 0000000..eb26c60 --- /dev/null +++ b/tests/test_simple_chunker.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from github_to_sqlite.simple_chunker import SimpleChunker, BaseSplitter + + +@dataclass +class LambdaSplitter(BaseSplitter): + func: callable + + def __call__(self, doc: str): + return self.func(doc) + + +def test_simple_chunker_drops_partial(tmp_path): + text = "Sentence one. Sentence two. Sentence three. Sentence four. Sentence five. Sentence six. Extra." # 7 sentences + chunker = SimpleChunker( + name="test", splitter=LambdaSplitter(func=lambda d: d), target_length=3 + ) + chunks = chunker([text])[0] + # Expect two chunks of exactly 3 sentences each, dropping the last partial chunk + assert len(chunks) == 2 + assert "Sentence three." in chunks[0].content + assert "Sentence six." in chunks[1].content + + +def test_punkt_download(monkeypatch): + calls = [] + + def failing(text): + monkeypatch.setattr('github_to_sqlite.simple_chunker.sent_tokenize', lambda t: ['ok']) + raise LookupError + + monkeypatch.setattr('github_to_sqlite.simple_chunker.sent_tokenize', failing) + monkeypatch.setattr('nltk.download', lambda name: calls.append(name)) + + chunker = SimpleChunker(name='t', splitter=LambdaSplitter(func=lambda d: d), target_length=1) + assert chunker._split('hi') == ['ok'] + assert 'punkt' in calls + diff --git a/tests/test_stargazers.py b/tests/test_stargazers.py index 6de4620..e885ad1 100644 --- a/tests/test_stargazers.py +++ b/tests/test_stargazers.py @@ -3,7 +3,6 @@ import pathlib import pytest import sqlite_utils -from sqlite_utils.db import ForeignKey @pytest.fixture diff --git a/tests/test_starred.py b/tests/test_starred.py index 22f8e65..3278545 100644 --- a/tests/test_starred.py +++ b/tests/test_starred.py @@ -38,6 +38,10 @@ def test_tables(db): "repos_fts", "repos_fts_idx", "repos", + "repo_embeddings", + "readme_chunk_embeddings", + "repo_build_files", + "repo_metadata", "licenses_fts", "users_fts_docsize", "users_fts", diff --git a/tests/test_starred_embeddings_command.py b/tests/test_starred_embeddings_command.py new file mode 100644 index 0000000..b5ef102 --- /dev/null +++ b/tests/test_starred_embeddings_command.py @@ -0,0 +1,123 @@ +import json +import sys +import types +from pathlib import Path + +import sqlite_utils +from click.testing import CliRunner + +from github_to_sqlite import cli, utils + + +def test_starred_embeddings_command(monkeypatch, tmpdir): + starred = json.load(open(Path(__file__).parent / "starred.json")) + repo = starred[0]["repo"] + + repo_dir = Path(tmpdir) / repo["full_name"] + repo_dir.mkdir(parents=True) + (repo_dir / "pyproject.toml").write_text("name = 'pkg'") + + monkeypatch.setattr(utils, "fetch_all_starred", lambda token=None: starred) + monkeypatch.setattr(utils, "fetch_readme", lambda token, full_name: "Readme") + monkeypatch.setattr(utils, "chunk_readme", lambda text: ["chunk1", "chunk2"]) + monkeypatch.setattr(utils, "find_build_files", lambda path: ["pyproject.toml"]) + monkeypatch.setattr(utils, "directory_tree", lambda path: {".": {"dirs": [], "files": ["pyproject.toml"]}}) + monkeypatch.setattr(utils, "parse_build_file", lambda path: {"name": "pkg"}) + + class DummyModel: + def encode(self, texts): + return [[1.0, 0.0]] * len(texts) + + dummy_mod = types.SimpleNamespace(SentenceTransformer=lambda name: DummyModel()) + monkeypatch.setitem(sys.modules, "sentence_transformers", dummy_mod) + monkeypatch.setattr(utils, "_maybe_load_sqlite_vec", lambda db: False) + + db_path = str(Path(tmpdir) / "test.db") + runner = CliRunner() + result = runner.invoke(cli.cli, ["starred-embeddings", db_path]) + assert result.exit_code == 0 + + db = sqlite_utils.Database(db_path) + assert db["repo_embeddings"].count == 1 + assert db["readme_chunk_embeddings"].count == 2 + assert db["repo_build_files"].count == 1 + assert db["repo_metadata"].count == 1 + + +def test_starred_embeddings_command_sqlite_vec(monkeypatch, tmpdir): + starred = json.load(open(Path(__file__).parent / "starred.json")) + repo = starred[0]["repo"] + + repo_dir = Path(tmpdir) / repo["full_name"] + repo_dir.mkdir(parents=True) + (repo_dir / "pyproject.toml").write_text("name = 'pkg'") + + monkeypatch.setattr(utils, "fetch_all_starred", lambda token=None: starred) + monkeypatch.setattr(utils, "fetch_readme", lambda token, full_name: "Readme") + monkeypatch.setattr(utils, "chunk_readme", lambda text: ["chunk1", "chunk2"]) + monkeypatch.setattr(utils, "find_build_files", lambda path: ["pyproject.toml"]) + monkeypatch.setattr(utils, "directory_tree", lambda path: {".": {"dirs": [], "files": ["pyproject.toml"]}}) + monkeypatch.setattr(utils, "parse_build_file", lambda path: {"name": "pkg"}) + + class DummyModel: + def encode(self, texts): + return [[1.0, 0.0]] * len(texts) + + dummy_mod = types.SimpleNamespace(SentenceTransformer=lambda name: DummyModel()) + monkeypatch.setitem(sys.modules, "sentence_transformers", dummy_mod) + + def fake_ensure_embedding_tables(db): + tables = set(db.table_names()) + if "repo_embeddings" not in tables: + db["repo_embeddings"].create({"repo_id": int, "title_embedding": bytes, "description_embedding": bytes, "readme_embedding": bytes}, pk="repo_id") + if "readme_chunk_embeddings" not in tables: + db["readme_chunk_embeddings"].create({"repo_id": int, "chunk_index": int, "chunk_text": str, "embedding": bytes}, pk=("repo_id", "chunk_index")) + if "repo_build_files" not in tables: + db["repo_build_files"].create({"repo_id": int, "file_path": str, "metadata": str}, pk=("repo_id", "file_path")) + if "repo_metadata" not in tables: + db["repo_metadata"].create({"repo_id": int, "language": str, "directory_tree": str}, pk="repo_id") + + monkeypatch.setattr(utils, "_maybe_load_sqlite_vec", lambda db: True) + monkeypatch.setattr(utils, "ensure_embedding_tables", fake_ensure_embedding_tables) + dummy_sqlite_vec = types.SimpleNamespace( + serialize_float32=lambda v: b"".join(float(x).hex().encode() for x in v), + load=lambda conn: None, + ) + monkeypatch.setitem(sys.modules, "sqlite_vec", dummy_sqlite_vec) + + db_path = str(Path(tmpdir) / "test.db") + runner = CliRunner() + result = runner.invoke(cli.cli, ["starred-embeddings", db_path, "--verbose"]) + assert result.exit_code == 0 + assert "Using sqlite-vec for embedding storage" in result.output + +def test_starred_embeddings_env_model(monkeypatch, tmpdir): + starred = json.load(open(Path(__file__).parent / "starred.json")) + repo = starred[0]["repo"] + + repo_dir = Path(tmpdir) / repo["full_name"] + repo_dir.mkdir(parents=True) + (repo_dir / "pyproject.toml").write_text("name = 'pkg'") + + monkeypatch.setattr(utils, "fetch_all_starred", lambda token=None: starred) + monkeypatch.setattr(utils, "fetch_readme", lambda token, full_name: "Readme") + monkeypatch.setattr(utils, "chunk_readme", lambda text: ["chunk1"]) + monkeypatch.setattr(utils, "find_build_files", lambda path: []) + monkeypatch.setattr(utils, "directory_tree", lambda path: {}) + monkeypatch.setattr(utils, "parse_build_file", lambda path: {}) + monkeypatch.setattr(utils, "_maybe_load_sqlite_vec", lambda db: False) + + called = [] + class DummyModel: + def __init__(self, name): + called.append(name) + def encode(self, texts): + return [[0.0]] * len(texts) + dummy_mod = types.SimpleNamespace(SentenceTransformer=DummyModel) + monkeypatch.setitem(sys.modules, "sentence_transformers", dummy_mod) + monkeypatch.setenv("GITHUB_TO_SQLITE_MODEL", "env-model") + + db_path = str(Path(tmpdir) / "test.db") + result = CliRunner().invoke(cli.cli, ["starred-embeddings", db_path]) + assert result.exit_code == 0 + assert called == ["env-model"] diff --git a/tests/test_starred_embeddings_integration.py b/tests/test_starred_embeddings_integration.py new file mode 100644 index 0000000..83fd72c --- /dev/null +++ b/tests/test_starred_embeddings_integration.py @@ -0,0 +1,75 @@ +import json +import base64 +import pathlib +from click.testing import CliRunner +import sqlite_utils +import types +import sys + +from github_to_sqlite import cli, utils + + +def test_starred_embeddings_integration(requests_mock, tmp_path, monkeypatch): + starred = json.load(open(pathlib.Path(__file__).parent / "starred.json")) + repo = starred[0]["repo"] + + # Mock GitHub API + requests_mock.get( + "https://api.github.com/user/starred?per_page=100", + json=starred, + ) + encoded = base64.b64encode(b"Readme").decode("utf-8") + requests_mock.get( + f"https://api.github.com/repos/{repo['full_name']}/readme", + json={"content": encoded}, + ) + + # Stub sentence_transformers + class DummyModel: + def encode(self, texts): + return [[1.0]] * len(texts) + + dummy_mod = types.SimpleNamespace(SentenceTransformer=lambda name: DummyModel()) + monkeypatch.setitem(sys.modules, "sentence_transformers", dummy_mod) + monkeypatch.setattr(utils, "_maybe_load_sqlite_vec", lambda db: False) + + # Prepare filesystem + auth_path = tmp_path / "auth.json" + auth_path.write_text(json.dumps({"github_personal_token": "x"})) + + db_path = tmp_path / "test.db" + + runner = CliRunner() + with runner.isolated_filesystem(temp_dir=tmp_path): + repo_dir = pathlib.Path(repo["full_name"]) + repo_dir.mkdir(parents=True) + (repo_dir / "pyproject.toml").write_text("name='pkg'") + + result = runner.invoke( + cli.cli, + ["starred-embeddings", str(db_path), "-a", str(auth_path), "--verbose"], + ) + assert result.exit_code == 0 + assert "sqlite-vec extension not loaded" in result.output + + db = sqlite_utils.Database(db_path) + assert db["repo_embeddings"].count == 1 + assert db["readme_chunk_embeddings"].count == 1 + assert db["repo_build_files"].count == 1 + assert db["repo_metadata"].count == 1 + + row = db["repo_embeddings"].get(repo["id"]) + assert row["title_embedding"] == b"\x00\x00\x80?" + assert row["description_embedding"] == b"\x00\x00\x80?" + assert row["readme_embedding"] == b"\x00\x00\x80?" + + chunk = db["readme_chunk_embeddings"].get((repo["id"], 0)) + assert chunk["chunk_text"] == "Readme" + assert chunk["embedding"] == b"\x00\x00\x80?" + + build = db["repo_build_files"].get((repo["id"], "pyproject.toml")) + assert build["metadata"] == '{"name": "pkg"}' + + meta = db["repo_metadata"].get(repo["id"]) + assert meta["language"] == repo["language"] + assert meta["directory_tree"] == '{".": {"dirs": [], "files": ["pyproject.toml"]}}' diff --git a/tests/test_utils_functions.py b/tests/test_utils_functions.py new file mode 100644 index 0000000..12b1c03 --- /dev/null +++ b/tests/test_utils_functions.py @@ -0,0 +1,79 @@ +import sys +import sqlite3 + +import sqlite_utils + +from github_to_sqlite import utils + + +def test_vector_to_blob_round_trip(): + import numpy as np + + vec = np.array([1.0, 2.0, 3.5], dtype="float32") + blob = utils.vector_to_blob(vec) + assert isinstance(blob, bytes) + # Byte length should match the float32 representation + assert len(blob) == vec.astype("float32").nbytes + arr = np.frombuffer(blob, dtype="float32") + assert arr.dtype == np.float32 + assert arr.tolist() == [1.0, 2.0, 3.5] + + +def test_parse_build_file_json_and_toml(tmp_path): + json_file = tmp_path / "package.json" + json_file.write_text('{"name": "pkg", "author": "me"}') + toml_file = tmp_path / "pyproject.toml" + toml_file.write_text('name = "pkg"\nauthor = "you"') + + assert utils.parse_build_file(str(json_file)) == {"name": "pkg", "author": "me"} + result = utils.parse_build_file(str(toml_file)) + assert result["name"] == "pkg" + assert result["author"] == "you" + + +def test_directory_tree(tmp_path): + b = tmp_path / "b" + a = tmp_path / "a" + b.mkdir() + a.mkdir() + (a / "2.txt").write_text("data") + (a / "1.txt").write_text("data") + + tree = utils.directory_tree(str(tmp_path)) + # Root should list directories sorted alphabetically + assert tree["."]["dirs"] == ["a", "b"] + assert tree["."]["files"] == [] + # Files should also be sorted + assert tree["a"]["files"] == ["1.txt", "2.txt"] + + +def test_maybe_load_sqlite_vec(monkeypatch): + db = sqlite_utils.Database(memory=True) + + # No sqlite_vec module -> False + sys.modules.pop("sqlite_vec", None) + utils._SQLITE_VEC_LOADED = None + assert utils._maybe_load_sqlite_vec(db) is False + + # Module with load that succeeds -> True + dummy = type("M", (), {"load": lambda conn: None}) + monkeypatch.setitem(sys.modules, "sqlite_vec", dummy) + utils._SQLITE_VEC_LOADED = None + assert utils._maybe_load_sqlite_vec(db) is True + utils._SQLITE_VEC_LOADED = None + +def test_parse_build_file_invalid(tmp_path): + bad = tmp_path / "bad.toml" + bad.write_text("not : valid") + assert utils.parse_build_file(str(bad)) == {} + + +def test_maybe_load_sqlite_vec_failure(monkeypatch): + db = sqlite_utils.Database(memory=True) + class Dummy: + def load(self, conn): + raise sqlite3.DatabaseError("boom") + monkeypatch.setitem(sys.modules, "sqlite_vec", Dummy()) + utils._SQLITE_VEC_LOADED = None + assert utils._maybe_load_sqlite_vec(db) is False + utils._SQLITE_VEC_LOADED = None diff --git a/tests/test_workflows.py b/tests/test_workflows.py index 8ca7d0d..9535e6b 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -3,7 +3,6 @@ import pathlib import pytest import sqlite_utils -from sqlite_utils.db import ForeignKey import textwrap