Skip to content

Commit 121cfe6

Browse files
committed
v1.2.1 changes
- Fixed bug during adding negative feedback to include "response" in DBMS_CLOUD_AI.FEEDBACK procedure - Fixed bug during Profiles.list() to list dummy profiles without attributes - Added test cases for negative and positive feedback
1 parent aa3cc84 commit 121cfe6

File tree

9 files changed

+325
-65
lines changed

9 files changed

+325
-65
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ pytest.env
1919
sample_connect.py
2020
async_pipeline_test.py
2121
parquet.py
22+
local_sample

src/select_ai/async_profile.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@
3030
)
3131
from select_ai.conversation import AsyncConversation
3232
from select_ai.db import async_cursor, async_get_connection
33-
from select_ai.errors import ProfileExistsError, ProfileNotFoundError
33+
from select_ai.errors import (
34+
ProfileAttributesEmptyError,
35+
ProfileExistsError,
36+
ProfileNotFoundError,
37+
)
3438
from select_ai.feedback import (
3539
FeedbackOperation,
3640
FeedbackType,
@@ -70,34 +74,22 @@ async def _init_profile(self):
7074
profile_exists = False
7175
try:
7276
saved_attributes = await self._get_attributes(
77+
profile_name=self.profile_name,
78+
raise_on_empty=True,
79+
)
80+
saved_description = await self._get_profile_description(
7381
profile_name=self.profile_name
7482
)
7583
profile_exists = True
76-
if not self.replace and not self.merge:
77-
if (
78-
self.attributes is not None
79-
or self.description is not None
80-
):
81-
if self.raise_error_if_exists:
82-
raise ProfileExistsError(self.profile_name)
83-
84-
if self.description is None and not self.replace:
85-
self.description = await self._get_profile_description(
86-
profile_name=self.profile_name
87-
)
84+
self._raise_error_if_profile_exists()
85+
except ProfileAttributesEmptyError:
86+
if self.raise_error_on_empty_attributes:
87+
raise
8888
except ProfileNotFoundError:
8989
if self.attributes is None and self.description is None:
9090
raise
9191
else:
92-
if self.attributes is None:
93-
self.attributes = saved_attributes
94-
if self.merge:
95-
self.replace = True
96-
if self.attributes is not None:
97-
self.attributes = dataclass_replace(
98-
saved_attributes,
99-
**self.attributes.dict(exclude_null=True),
100-
)
92+
self._merge_attributes(saved_attributes, saved_description)
10193
if self.replace or not profile_exists:
10294
await self.create(replace=self.replace)
10395
else: # profile name is None:
@@ -132,12 +124,15 @@ async def _get_profile_description(profile_name) -> Union[str, None]:
132124
raise ProfileNotFoundError(profile_name)
133125

134126
@staticmethod
135-
async def _get_attributes(profile_name) -> ProfileAttributes:
127+
async def _get_attributes(
128+
profile_name: str, raise_on_empty: bool = True
129+
) -> Union[ProfileAttributes, None]:
136130
"""Asynchronously gets AI profile attributes from the Database
137131
138132
:param str profile_name: Name of the profile
133+
:param bool raise_on_empty: Raise an error if attributes are empty
139134
:return: select_ai.provider.ProviderAttributes
140-
:raises: ProfileNotFoundError
135+
:raises: select_ai.errors.ProfileAttributesEmptyError
141136
142137
"""
143138
async with async_cursor() as cr:
@@ -149,7 +144,11 @@ async def _get_attributes(profile_name) -> ProfileAttributes:
149144
if attributes:
150145
return await ProfileAttributes.async_create(**dict(attributes))
151146
else:
152-
raise ProfileNotFoundError(profile_name=profile_name)
147+
if raise_on_empty:
148+
raise ProfileAttributesEmptyError(
149+
profile_name=profile_name
150+
)
151+
return None
153152

154153
async def get_attributes(self) -> ProfileAttributes:
155154
"""Asynchronously gets AI profile attributes from the Database
@@ -387,7 +386,9 @@ async def list(
387386
for row in rows:
388387
profile_name = row[0]
389388
yield await cls(
390-
profile_name=profile_name, raise_error_if_exists=False
389+
profile_name=profile_name,
390+
raise_error_if_exists=False,
391+
raise_error_on_empty_attributes=False,
391392
)
392393

393394
async def generate(

src/select_ai/base_profile.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import json
99
from abc import ABC
1010
from dataclasses import dataclass
11+
from dataclasses import replace as dataclass_replace
1112
from typing import List, Mapping, Optional, Tuple
1213

1314
import oracledb
1415

1516
from select_ai._abc import SelectAIDataClass
1617
from select_ai.action import Action
18+
from select_ai.errors import ProfileExistsError
1719
from select_ai.feedback import (
1820
FeedbackOperation,
1921
FeedbackType,
@@ -159,6 +161,10 @@ class BaseProfile(ABC):
159161
if profile exists in the database and replace = False and
160162
merge = False. Default value is True
161163
164+
:param bool raise_error_on_empty_attributes: Raise
165+
ProfileEmptyAttributesError, if profile attributes are empty
166+
in database. Default value is False.
167+
162168
"""
163169

164170
def __init__(
@@ -169,6 +175,7 @@ def __init__(
169175
merge: Optional[bool] = False,
170176
replace: Optional[bool] = False,
171177
raise_error_if_exists: Optional[bool] = True,
178+
raise_error_on_empty_attributes: Optional[bool] = False,
172179
):
173180
"""Initialize a base profile"""
174181
self.profile_name = profile_name
@@ -182,6 +189,34 @@ def __init__(
182189
self.merge = merge
183190
self.replace = replace
184191
self.raise_error_if_exists = raise_error_if_exists
192+
self.raise_error_on_empty_attributes = raise_error_on_empty_attributes
193+
194+
def _raise_error_if_profile_exists(self):
195+
"""
196+
Helper method to raise ProfileExistsError if profile exists
197+
in the database and replace = False and merge = False
198+
"""
199+
if not self.replace and not self.merge:
200+
if self.attributes is not None or self.description is not None:
201+
if self.raise_error_if_exists:
202+
raise ProfileExistsError(self.profile_name)
203+
204+
def _merge_attributes(self, saved_attributes, saved_description):
205+
"""
206+
Helper method to merge user passed attributes with the attributes saved
207+
in the database.
208+
"""
209+
if self.description is None and not self.replace:
210+
self.description = saved_description
211+
if self.attributes is None:
212+
self.attributes = saved_attributes
213+
if self.merge:
214+
self.replace = True
215+
if self.attributes is not None:
216+
self.attributes = dataclass_replace(
217+
saved_attributes,
218+
**self.attributes.dict(exclude_null=True),
219+
)
185220

186221
def __repr__(self):
187222
return (
@@ -206,15 +241,15 @@ def validate_params_for_feedback(
206241
response: Optional[str] = None,
207242
operation: Optional[FeedbackOperation] = FeedbackOperation.ADD,
208243
):
209-
if sql_id and prompt_spec:
210-
raise AttributeError("Either sql_id or prompt_spec must be specified")
211244
if not sql_id and not prompt_spec:
212245
raise AttributeError("Either sql_id or prompt_spec must be specified")
213-
parameters = {
214-
"feedback_type": feedback_type.value,
215-
"feedback_content": feedback_content,
216-
"operation": operation.value,
217-
}
246+
parameters = {"operation": operation.value}
247+
if feedback_content:
248+
parameters["feedback_content"] = feedback_content
249+
if feedback_type:
250+
parameters["feedback_type"] = feedback_type.value
251+
if response:
252+
parameters["response"] = response
218253
if prompt_spec:
219254
prompt, action = prompt_spec
220255
if action not in (Action.RUNSQL, Action.SHOWSQL, Action.EXPLAINSQL):

src/select_ai/errors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ def __str__(self):
5656
)
5757

5858

59+
class ProfileAttributesEmptyError(SelectAIError):
60+
"""Profile attributes empty in the database"""
61+
62+
def __init__(self, profile_name: str):
63+
self.profile_name = profile_name
64+
65+
def __str__(self):
66+
return (
67+
f"Profile {self.profile_name} attributes empty in the database. "
68+
)
69+
70+
5971
class VectorIndexNotFoundError(SelectAIError):
6072
"""VectorIndex not found in the database"""
6173

src/select_ai/profile.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
validate_params_for_summary,
2424
)
2525
from select_ai.db import cursor
26-
from select_ai.errors import ProfileExistsError, ProfileNotFoundError
26+
from select_ai.errors import (
27+
ProfileAttributesEmptyError,
28+
ProfileExistsError,
29+
ProfileNotFoundError,
30+
)
2731
from select_ai.feedback import FeedbackOperation, FeedbackType
2832
from select_ai.provider import Provider
2933
from select_ai.sql import (
@@ -56,34 +60,22 @@ def _init_profile(self) -> None:
5660
profile_exists = False
5761
try:
5862
saved_attributes = self._get_attributes(
63+
profile_name=self.profile_name,
64+
raise_on_empty=True,
65+
)
66+
saved_description = self._get_profile_description(
5967
profile_name=self.profile_name
6068
)
6169
profile_exists = True
62-
if not self.replace and not self.merge:
63-
if (
64-
self.attributes is not None
65-
or self.description is not None
66-
):
67-
if self.raise_error_if_exists:
68-
raise ProfileExistsError(self.profile_name)
69-
70-
if self.description is None and not self.replace:
71-
self.description = self._get_profile_description(
72-
profile_name=self.profile_name
73-
)
70+
self._raise_error_if_profile_exists()
71+
except ProfileAttributesEmptyError:
72+
if self.raise_error_on_empty_attributes:
73+
raise
7474
except ProfileNotFoundError:
7575
if self.attributes is None and self.description is None:
7676
raise
7777
else:
78-
if self.attributes is None:
79-
self.attributes = saved_attributes
80-
if self.merge:
81-
self.replace = True
82-
if self.attributes is not None:
83-
self.attributes = dataclass_replace(
84-
saved_attributes,
85-
**self.attributes.dict(exclude_null=True),
86-
)
78+
self._merge_attributes(saved_attributes, saved_description)
8779
if self.replace or not profile_exists:
8880
self.create(replace=self.replace)
8981
else: # profile name is None
@@ -112,12 +104,15 @@ def _get_profile_description(profile_name) -> Union[str, None]:
112104
raise ProfileNotFoundError(profile_name)
113105

114106
@staticmethod
115-
def _get_attributes(profile_name) -> ProfileAttributes:
107+
def _get_attributes(
108+
profile_name, raise_on_empty: bool = False
109+
) -> Union[ProfileAttributes, None]:
116110
"""Get AI profile attributes from the Database
117111
118112
:param str profile_name: Name of the profile
113+
:param bool raise_on_empty: Raise an error if attributes are empty
119114
:return: select_ai.ProfileAttributes
120-
:raises: ProfileNotFoundError
115+
:raises: select_ai.errors.ProfileAttributesEmptyError
121116
"""
122117
with cursor() as cr:
123118
cr.execute(
@@ -128,7 +123,11 @@ def _get_attributes(profile_name) -> ProfileAttributes:
128123
if attributes:
129124
return ProfileAttributes.create(**dict(attributes))
130125
else:
131-
raise ProfileNotFoundError(profile_name=profile_name)
126+
if raise_on_empty:
127+
raise ProfileAttributesEmptyError(
128+
profile_name=profile_name
129+
)
130+
return None
132131

133132
def get_attributes(self) -> ProfileAttributes:
134133
"""Get AI profile attributes from the Database
@@ -278,7 +277,6 @@ def _save_feedback(
278277
operation=operation,
279278
)
280279
params["profile_name"] = self.profile_name
281-
print(params)
282280
with cursor() as cr:
283281
cr.callproc("DBMS_CLOUD_AI.FEEDBACK", keyword_parameters=params)
284282

@@ -363,7 +361,9 @@ def list(
363361
for row in cr.fetchall():
364362
profile_name = row[0]
365363
yield cls(
366-
profile_name=profile_name, raise_error_if_exists=False
364+
profile_name=profile_name,
365+
raise_error_if_exists=False,
366+
raise_error_on_empty_attributes=False,
367367
)
368368

369369
def generate(
@@ -513,7 +513,6 @@ def summarize(
513513
params=params,
514514
)
515515
parameters["profile_name"] = self.profile_name
516-
print(parameters)
517516
with cursor() as cr:
518517
data = cr.callfunc(
519518
"DBMS_CLOUD_AI.SUMMARIZE",

src/select_ai/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
# http://oss.oracle.com/licenses/upl.
66
# -----------------------------------------------------------------------------
77

8-
__version__ = "1.2.0"
8+
__version__ = "1.2.1"

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,13 @@ def async_connection():
125125
return select_ai.db.async_get_connection()
126126

127127

128-
@pytest.fixture
128+
@pytest.fixture(scope="module")
129129
def cursor():
130130
with select_ai.cursor() as cr:
131131
yield cr
132132

133133

134-
@pytest.fixture
134+
@pytest.fixture(scope="module")
135135
async def async_cursor():
136136
async with select_ai.async_cursor() as cr:
137137
yield cr

0 commit comments

Comments
 (0)