diff --git a/integrations/aws-strands/python/src/ag_ui_strands/agent.py b/integrations/aws-strands/python/src/ag_ui_strands/agent.py index 6790e4b26..744ca35f3 100644 --- a/integrations/aws-strands/python/src/ag_ui_strands/agent.py +++ b/integrations/aws-strands/python/src/ag_ui_strands/agent.py @@ -79,10 +79,14 @@ async def run(self, input_data: RunAgentInput) -> AsyncIterator[Any]: # Each thread (user session) maintains its own conversation state thread_id = input_data.thread_id or "default" if thread_id not in self._agents_by_thread: + session_manager = None + if self.config.session_manager_provider: + session_manager = self.config.session_manager_provider(input_data) self._agents_by_thread[thread_id] = StrandsAgentCore( model=self._model, system_prompt=self._system_prompt, tools=self._tools, + session_manager=session_manager, **self._agent_kwargs, ) strands_agent = self._agents_by_thread[thread_id] diff --git a/integrations/aws-strands/python/src/ag_ui_strands/config.py b/integrations/aws-strands/python/src/ag_ui_strands/config.py index 6aebff4f9..152a65e08 100644 --- a/integrations/aws-strands/python/src/ag_ui_strands/config.py +++ b/integrations/aws-strands/python/src/ag_ui_strands/config.py @@ -17,6 +17,8 @@ from ag_ui.core import RunAgentInput +from strands.session import SessionManager + StatePayload = Dict[str, Any] @@ -45,6 +47,7 @@ class ToolResultContext(ToolCallContext): StateFromResult = Callable[[ToolResultContext], Awaitable[Optional[StatePayload]] | Optional[StatePayload]] CustomResultHandler = Callable[[ToolResultContext], AsyncIterator[Any]] StateContextBuilder = Callable[[RunAgentInput, str], str] +SessionManagerProvider = Callable[[RunAgentInput], SessionManager] @dataclass @@ -83,6 +86,8 @@ class StrandsAgentConfig: tool_behaviors: Dict[str, ToolBehavior] = field(default_factory=dict) state_context_builder: Optional[StateContextBuilder] = None + session_manager_provider: Optional[SessionManagerProvider] = None + async def maybe_await(value: Any) -> Any: