Skip to content

Commit 051015c

Browse files
authored
fix bug with ws reasoning (#514)
Signed-off-by: Jessie Frazelle <[email protected]>
1 parent 4329afe commit 051015c

File tree

5 files changed

+77
-56
lines changed

5 files changed

+77
-56
lines changed

generate/type_generators.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -557,48 +557,37 @@ def generate_one_of_type(path: str, name: str, schema: dict, data: dict):
557557
flattened_schema["required"] = required_fields
558558

559559
extra_imports = None
560-
if not alias_imports_needed:
561-
extra_imports = [
562-
"from pydantic import model_serializer, model_validator\n"
563-
]
564-
alias_imports_needed = True
565-
560+
extra_body_lines = []
566561
wrapped_key = outer_name
567-
extra_body_lines = [
568-
' @model_validator(mode="before")',
569-
" @classmethod",
570-
" def _unwrap(cls, data):",
571-
" if isinstance(data, dict) and '"
572-
+ wrapped_key
573-
+ "' in data and isinstance(data['"
574-
+ wrapped_key
575-
+ "'], dict):",
576-
" return data['" + wrapped_key + "']",
577-
" return data",
578-
]
579562

580563
if wrap_entire_payload:
564+
if not alias_imports_needed:
565+
extra_imports = [
566+
"from pydantic import model_serializer, model_validator\n"
567+
]
568+
alias_imports_needed = True
569+
581570
extra_body_lines.extend(
582571
[
583-
' @model_serializer(mode="wrap")',
584-
" def _wrap(self, handler, info):",
585-
" payload = handler(self, info)",
586-
" return {'" + wrapped_key + "': payload}",
572+
' @model_validator(mode="before")',
573+
" @classmethod",
574+
" def _unwrap(cls, data):",
575+
" if isinstance(data, dict) and '"
576+
+ wrapped_key
577+
+ "' in data and isinstance(data['"
578+
+ wrapped_key
579+
+ "'], dict):",
580+
" return data['" + wrapped_key + "']",
581+
" return data",
587582
]
588583
)
589-
else:
584+
590585
extra_body_lines.extend(
591586
[
592587
' @model_serializer(mode="wrap")',
593588
" def _wrap(self, handler, info):",
594589
" payload = handler(self, info)",
595-
" if isinstance(payload, dict) and '"
596-
+ wrapped_key
597-
+ "' in payload:",
598-
" value = payload['" + wrapped_key + "']",
599-
" else:",
600-
" value = payload",
601-
" return {'" + wrapped_key + "': value}",
590+
" return {'" + wrapped_key + "': payload}",
602591
]
603592
)
604593

kittycad/models/ml_copilot_server_message.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -156,30 +156,6 @@ class Reasoning(KittyCadBaseModel):
156156

157157
reasoning: ReasoningMessage
158158

159-
@model_validator(mode="before")
160-
@classmethod
161-
def _unwrap(cls, data):
162-
if (
163-
isinstance(data, dict)
164-
and "reasoning" in data
165-
and isinstance(data["reasoning"], dict)
166-
):
167-
return data["reasoning"]
168-
169-
return data
170-
171-
@model_serializer(mode="wrap")
172-
def _wrap(self, handler, info):
173-
payload = handler(self, info)
174-
175-
if isinstance(payload, dict) and "reasoning" in payload:
176-
value = payload["reasoning"]
177-
178-
else:
179-
value = payload
180-
181-
return {"reasoning": value}
182-
183159

184160
class Replay(KittyCadBaseModel):
185161
"""Replay containing raw bytes for previously-saved messages for a conversation. Includes server messages and client `User` messages.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import json
2+
import uuid
3+
from typing import cast
4+
5+
from websockets.sync.client import ClientConnection as ClientConnectionSync
6+
7+
from kittycad import WebSocketMlReasoningWs
8+
from kittycad.models.ml_copilot_server_message import Reasoning, SessionData
9+
from kittycad.models.reasoning_message import OptionText
10+
11+
12+
class FakeWS:
13+
def __init__(self, messages):
14+
self._messages = iter(messages)
15+
16+
def recv(self, timeout=60): # pragma: no cover - tiny helper
17+
try:
18+
return next(self._messages)
19+
except StopIteration as exc:
20+
raise AssertionError("unexpected recv() after messages exhausted") from exc
21+
22+
23+
def test_ml_reasoning_ws_recv_parses_reasoning_messages():
24+
cache_buster = uuid.uuid4().hex
25+
reasoning_content = (
26+
f":mag: Querying relevant KCL code examples... cache-buster-{cache_buster}"
27+
)
28+
29+
fake_ws = FakeWS(
30+
[
31+
json.dumps({"session_data": {"api_call_id": "abc123"}}),
32+
json.dumps(
33+
{
34+
"reasoning": {
35+
"type": "text",
36+
"content": reasoning_content,
37+
}
38+
}
39+
),
40+
]
41+
)
42+
43+
websocket = cast(
44+
WebSocketMlReasoningWs, WebSocketMlReasoningWs.__new__(WebSocketMlReasoningWs)
45+
)
46+
websocket.ws = cast(ClientConnectionSync, fake_ws)
47+
48+
session_message = websocket.recv()
49+
assert isinstance(session_message.root, SessionData)
50+
assert session_message.root.api_call_id == "abc123"
51+
52+
reasoning_message = websocket.recv()
53+
assert isinstance(reasoning_message.root, Reasoning)
54+
resolved_reasoning = reasoning_message.root.reasoning.root
55+
assert isinstance(resolved_reasoning, OptionText)
56+
assert resolved_reasoning.content == reasoning_content

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "kittycad"
7-
version = "1.2.3"
7+
version = "1.2.4"
88
description = "A client library for accessing KittyCAD"
99
authors = []
1010
readme = "README.md"

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)