Skip to content

Commit 277ec54

Browse files
committed
feat: add pgvector support with typed numpy arrays
Add support for PostgreSQL vector type with proper numpy typing: - Map 'vector' type to NDArray[numpy.float32] - Auto-import numpy and numpy.typing.NDArray - Add test case for vector type generation - Add .DS_Store to .gitignore This enables type-safe pgvector usage in generated Python code.
1 parent 6235819 commit 277ec54

File tree

8 files changed

+117
-0
lines changed

8 files changed

+117
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
bin
22

3+
# macOS
4+
.DS_Store
5+
36
# Devenv
47
.envrc
58
.direnv
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Code generated by sqlc. DO NOT EDIT.
2+
# versions:
3+
# sqlc v1.30.0
4+
import dataclasses
5+
import numpy
6+
from numpy.typing import NDArray
7+
from typing import Optional
8+
9+
10+
@dataclasses.dataclass()
11+
class Item:
12+
id: int
13+
embedding: Optional[NDArray[numpy.float32]]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Code generated by sqlc. DO NOT EDIT.
2+
# versions:
3+
# sqlc v1.30.0
4+
# source: query.sql
5+
import numpy
6+
from numpy.typing import NDArray
7+
from typing import Optional
8+
9+
import sqlalchemy
10+
import sqlalchemy.ext.asyncio
11+
12+
from python import models
13+
14+
15+
CREATE_ITEM = """-- name: create_item \\:one
16+
INSERT INTO items (embedding) VALUES (:p1) RETURNING id, embedding
17+
"""
18+
19+
20+
GET_ITEM = """-- name: get_item \\:one
21+
SELECT id, embedding FROM items WHERE id = :p1
22+
"""
23+
24+
25+
class Querier:
26+
def __init__(self, conn: sqlalchemy.engine.Connection):
27+
self._conn = conn
28+
29+
def create_item(self, *, embedding: Optional[NDArray[numpy.float32]]) -> Optional[models.Item]:
30+
row = self._conn.execute(sqlalchemy.text(CREATE_ITEM), {"p1": embedding}).first()
31+
if row is None:
32+
return None
33+
return models.Item(
34+
id=row[0],
35+
embedding=row[1],
36+
)
37+
38+
def get_item(self, *, id: int) -> Optional[models.Item]:
39+
row = self._conn.execute(sqlalchemy.text(GET_ITEM), {"p1": id}).first()
40+
if row is None:
41+
return None
42+
return models.Item(
43+
id=row[0],
44+
embedding=row[1],
45+
)
46+
47+
48+
class AsyncQuerier:
49+
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
50+
self._conn = conn
51+
52+
async def create_item(self, *, embedding: Optional[NDArray[numpy.float32]]) -> Optional[models.Item]:
53+
row = (await self._conn.execute(sqlalchemy.text(CREATE_ITEM), {"p1": embedding})).first()
54+
if row is None:
55+
return None
56+
return models.Item(
57+
id=row[0],
58+
embedding=row[1],
59+
)
60+
61+
async def get_item(self, *, id: int) -> Optional[models.Item]:
62+
row = (await self._conn.execute(sqlalchemy.text(GET_ITEM), {"p1": id})).first()
63+
if row is None:
64+
return None
65+
return models.Item(
66+
id=row[0],
67+
embedding=row[1],
68+
)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
-- name: GetItem :one
2+
SELECT * FROM items WHERE id = $1;
3+
4+
5+
-- name: CreateItem :one
6+
INSERT INTO items (embedding) VALUES ($1) RETURNING *;
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
CREATE TABLE items (
2+
id SERIAL PRIMARY KEY,
3+
embedding vector(3)
4+
);
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
version: '2'
2+
plugins:
3+
- name: py
4+
wasm:
5+
url: file://../../../../bin/sqlc-gen-python.wasm
6+
sha256: "839af1f07c31644548192fc095569e62f1511d72c1c30c1a958ddc9c9429edbc"
7+
sql:
8+
- schema: schema.sql
9+
queries: query.sql
10+
engine: postgresql
11+
codegen:
12+
- plugin: py
13+
out: python
14+
options:
15+
package: python
16+
emit_sync_querier: true
17+
emit_async_querier: true

internal/imports.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,5 +272,9 @@ func stdImports(uses func(name string) bool) map[string]importSpec {
272272
if uses("Any") {
273273
std["typing.Any"] = importSpec{Module: "typing", Name: "Any"}
274274
}
275+
if uses("NDArray[numpy.float32]") {
276+
std["numpy"] = importSpec{Module: "numpy"}
277+
std["numpy.typing.NDArray"] = importSpec{Module: "numpy.typing", Name: "NDArray"}
278+
}
275279
return std
276280
}

internal/postgresql_type.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ func postgresType(req *plugin.GenerateRequest, col *plugin.Column) string {
4242
return "str"
4343
case "ltree", "lquery", "ltxtquery":
4444
return "str"
45+
case "vector":
46+
return "NDArray[numpy.float32]"
4547
default:
4648
for _, schema := range req.Catalog.Schemas {
4749
if schema.Name == "pg_catalog" || schema.Name == "information_schema" {

0 commit comments

Comments
 (0)