2800 lines
115 KiB
Python
2800 lines
115 KiB
Python
"""
|
|
OpenCode SDK Client Wrapper for 7000%AUTO
|
|
Uses OpenCode SDK (opencode-ai) for AI agent interactions.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any, AsyncIterator, List, Callable, Awaitable
|
|
|
|
import httpx
|
|
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Agent system prompts cache
|
|
_AGENT_PROMPTS: Dict[str, str] = {}
|
|
|
|
# OpenCode configuration cache
|
|
_OPENCODE_CONFIG: Optional[Dict[str, Any]] = None
|
|
|
|
# Pre-compiled regex patterns for performance
|
|
_FRONTMATTER_PATTERN = re.compile(r'^---\s*\n.*?\n---\s*\n', re.DOTALL)
|
|
_VALID_AGENT_NAME_PATTERN = re.compile(r'^[a-zA-Z0-9_-]+$')
|
|
_SSE_DATA_PREFIX = 'data:'
|
|
_SSE_DONE_MARKER = '[DONE]'
|
|
|
|
|
|
def _load_opencode_config() -> Dict[str, Any]:
|
|
"""Load opencode.json configuration"""
|
|
global _OPENCODE_CONFIG
|
|
if _OPENCODE_CONFIG is not None:
|
|
return _OPENCODE_CONFIG
|
|
|
|
config_path = Path("opencode.json")
|
|
if config_path.exists():
|
|
try:
|
|
_OPENCODE_CONFIG = json.loads(config_path.read_text(encoding="utf-8"))
|
|
return _OPENCODE_CONFIG
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load opencode.json: {e}")
|
|
|
|
# Fallback configuration - should normally be generated by main.py
|
|
# If we reach here without proper config, something is wrong
|
|
if not settings.OPENCODE_SDK or not settings.OPENCODE_MODEL:
|
|
logger.error("OpenCode settings not configured. Set OPENCODE_SDK, OPENCODE_MODEL, etc.")
|
|
raise RuntimeError("Missing required OpenCode environment variables")
|
|
|
|
# Derive provider name from SDK package (e.g. @ai-sdk/anthropic -> anthropic)
|
|
sdk_parts = settings.OPENCODE_SDK.split("/")
|
|
provider_name = sdk_parts[-1]
|
|
|
|
_OPENCODE_CONFIG = {
|
|
"provider": {
|
|
provider_name: {
|
|
"npm": settings.OPENCODE_SDK,
|
|
"name": provider_name.title(),
|
|
"options": {
|
|
"baseURL": settings.OPENCODE_API_BASE,
|
|
"apiKey": "{env:OPENCODE_API_KEY}"
|
|
},
|
|
"models": {
|
|
settings.OPENCODE_MODEL: {
|
|
"name": settings.OPENCODE_MODEL,
|
|
"options": {
|
|
"max_tokens": settings.OPENCODE_MAX_TOKENS
|
|
}
|
|
}
|
|
}
|
|
}
|
|
},
|
|
"model": f"{provider_name}/{settings.OPENCODE_MODEL}"
|
|
}
|
|
return _OPENCODE_CONFIG
|
|
|
|
|
|
def _remove_yaml_frontmatter(content: str) -> str:
|
|
"""Remove YAML frontmatter (---...---) from markdown content"""
|
|
return _FRONTMATTER_PATTERN.sub('', content).strip()
|
|
|
|
|
|
def _load_agent_prompt(agent_name: str) -> str:
|
|
"""Load agent system prompt from .opencode/agent/{agent_name}.md"""
|
|
if agent_name in _AGENT_PROMPTS:
|
|
return _AGENT_PROMPTS[agent_name]
|
|
|
|
# Validate agent name to prevent path traversal
|
|
if not _VALID_AGENT_NAME_PATTERN.match(agent_name):
|
|
logger.warning(f"Invalid agent name format: {agent_name}")
|
|
fallback = f"You are {agent_name}, an AI assistant. Complete the task given to you."
|
|
_AGENT_PROMPTS[agent_name] = fallback
|
|
return fallback
|
|
|
|
agent_path = Path(f".opencode/agent/{agent_name}.md")
|
|
if agent_path.exists():
|
|
content = agent_path.read_text(encoding="utf-8")
|
|
# Remove YAML frontmatter if present
|
|
content = _remove_yaml_frontmatter(content)
|
|
_AGENT_PROMPTS[agent_name] = content
|
|
return content
|
|
|
|
# Fallback generic prompt
|
|
fallback = f"You are {agent_name}, an AI assistant. Complete the task given to you."
|
|
_AGENT_PROMPTS[agent_name] = fallback
|
|
return fallback
|
|
|
|
|
|
class OpenCodeError(Exception):
|
|
"""Base exception for OpenCode client errors"""
|
|
pass
|
|
|
|
|
|
class _MessageAlreadySentError(Exception):
|
|
"""Internal exception indicating message was sent but streaming failed.
|
|
|
|
This is used to signal to the caller that fallback should NOT send
|
|
another message, but instead poll for the existing message's completion.
|
|
"""
|
|
def __init__(self, message: str, accumulated_content: List[str] = None):
|
|
super().__init__(message)
|
|
self.accumulated_content = accumulated_content or []
|
|
|
|
|
|
def _safe_get(obj: Any, key: str, default: Any = None) -> Any:
|
|
"""Get a value from either a dict or an object attribute."""
|
|
if obj is None:
|
|
return default
|
|
if isinstance(obj, dict):
|
|
return obj.get(key, default)
|
|
return getattr(obj, key, default)
|
|
|
|
|
|
class OpenCodeClient:
|
|
"""
|
|
Client for AI agent interactions using OpenCode SDK.
|
|
|
|
This client wraps the opencode-ai SDK to provide session management
|
|
and message handling for the 7000%AUTO agent pipeline.
|
|
|
|
Agents are specified via the 'mode' parameter in session.chat(),
|
|
which corresponds to agents defined in opencode.json under the 'agent' key.
|
|
The 'system' parameter provides a fallback prompt from .opencode/agent/*.md files.
|
|
Each agent in opencode.json has its own model, prompt, tools, and permissions.
|
|
"""
|
|
|
|
def __init__(self, base_url: Optional[str] = None):
|
|
self.base_url = base_url or settings.OPENCODE_SERVER_URL
|
|
self._client: Optional[Any] = None
|
|
self._sessions: Dict[str, Dict[str, Any]] = {}
|
|
|
|
# Load configuration from opencode.json
|
|
config = _load_opencode_config()
|
|
|
|
# Derive provider from SDK package name (e.g. @ai-sdk/anthropic -> anthropic)
|
|
# This matches the provider key generated in main.py's generate_opencode_config()
|
|
if not settings.OPENCODE_SDK or not settings.OPENCODE_MODEL:
|
|
raise RuntimeError("Missing required OpenCode environment variables (OPENCODE_SDK, OPENCODE_MODEL)")
|
|
|
|
sdk_parts = settings.OPENCODE_SDK.split("/")
|
|
self.provider_id = sdk_parts[-1]
|
|
self.model_id = settings.OPENCODE_MODEL
|
|
|
|
logger.info(f"OpenCode client using provider={self.provider_id}, model={self.model_id}")
|
|
|
|
async def _get_client(self):
|
|
"""Get or create AsyncOpencode client"""
|
|
if self._client is None:
|
|
try:
|
|
from opencode_ai import AsyncOpencode
|
|
|
|
# Set generous timeout for AI model responses (300 seconds)
|
|
# AI models can take a long time to generate responses
|
|
timeout = httpx.Timeout(300.0, connect=30.0)
|
|
|
|
client_kwargs = {
|
|
"timeout": timeout,
|
|
}
|
|
if self.base_url:
|
|
client_kwargs["base_url"] = self.base_url
|
|
|
|
self._client = AsyncOpencode(**client_kwargs)
|
|
logger.info(f"OpenCode client initialized (base_url: {self.base_url or 'default'}, timeout: 300s)")
|
|
except ImportError:
|
|
raise OpenCodeError("opencode-ai package not installed. Run: pip install opencode-ai")
|
|
except Exception as e:
|
|
raise OpenCodeError(f"Failed to initialize OpenCode client: {e}")
|
|
|
|
return self._client
|
|
|
|
async def _close_client(self):
|
|
"""Close the OpenCode client"""
|
|
if self._client is not None:
|
|
try:
|
|
await self._client.close()
|
|
except Exception as e:
|
|
logger.warning(f"Error closing OpenCode client: {e}")
|
|
finally:
|
|
self._client = None
|
|
|
|
async def create_session(self, agent_name: str) -> str:
|
|
"""
|
|
Create a new session for an agent.
|
|
|
|
Args:
|
|
agent_name: Name of the agent (e.g., "ideator", "planner")
|
|
|
|
Returns:
|
|
Session ID string
|
|
"""
|
|
try:
|
|
client = await self._get_client()
|
|
|
|
# Create session via OpenCode SDK
|
|
# Pass extra_body={} to send empty JSON body (server expects JSON even if empty)
|
|
session = await client.session.create(extra_body={})
|
|
session_id = session.id
|
|
|
|
# Store session metadata including agent name and prompt
|
|
system_prompt = _load_agent_prompt(agent_name)
|
|
self._sessions[session_id] = {
|
|
"agent": agent_name,
|
|
"system_prompt": system_prompt,
|
|
"opencode_session": session,
|
|
}
|
|
|
|
logger.info(f"Created session {session_id} for agent {agent_name}")
|
|
return session_id
|
|
|
|
except OpenCodeError:
|
|
raise
|
|
except Exception as e:
|
|
server_url = self.base_url or "default"
|
|
logger.error(f"Failed to create session for agent {agent_name}: {e}")
|
|
raise OpenCodeError(
|
|
f"Failed to create session (is OpenCode server running at {server_url}?): {e}"
|
|
)
|
|
|
|
async def send_message(
|
|
self,
|
|
session_id: str,
|
|
message: str,
|
|
timeout_seconds: int = 120,
|
|
output_callback: Optional[Callable[[str], Awaitable[None]]] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Send a message and get response from OpenCode.
|
|
|
|
Args:
|
|
session_id: Session ID from create_session
|
|
message: User message to send
|
|
timeout_seconds: Maximum time to wait for agent completion (default 120s)
|
|
output_callback: Optional async callback to receive streaming output chunks
|
|
|
|
Returns:
|
|
Dict with "content" (raw response) and "parsed" (extracted JSON)
|
|
"""
|
|
if session_id not in self._sessions:
|
|
raise OpenCodeError(f"Session {session_id} not found")
|
|
|
|
session_data = self._sessions[session_id]
|
|
agent_name = session_data["agent"]
|
|
|
|
try:
|
|
# If output_callback is provided, use true streaming for real-time output
|
|
if output_callback:
|
|
return await self._send_message_streaming(
|
|
session_id, message, output_callback
|
|
)
|
|
|
|
# Otherwise, use polling-based approach
|
|
client = await self._get_client()
|
|
|
|
logger.info(f"Sending message to session {session_id} (agent: {agent_name})")
|
|
|
|
# Build message parts
|
|
parts: List[Dict[str, Any]] = [
|
|
{"type": "text", "text": message}
|
|
]
|
|
|
|
# Enable all MCP tools for the agent
|
|
# This allows agents to use search, github, x_api, database tools
|
|
tools: Dict[str, bool] = {"*": True}
|
|
|
|
# Send chat message via OpenCode SDK
|
|
# - mode: specifies agent/mode to use (maps to agents in opencode.json)
|
|
# - system: provides fallback system prompt if mode isn't recognized
|
|
# - tools: enables MCP server tools defined in opencode.json
|
|
# - max_tokens is configured in opencode.json model options
|
|
# OpenCode server loads agent config from opencode.json based on mode
|
|
response = await client.session.chat(
|
|
session_id,
|
|
model_id=self.model_id,
|
|
provider_id=self.provider_id,
|
|
parts=parts,
|
|
mode=agent_name, # Specify agent mode from opencode.json
|
|
system=session_data["system_prompt"], # Fallback system prompt
|
|
tools=tools,
|
|
)
|
|
|
|
# Check for errors in the response
|
|
if hasattr(response, 'error') and response.error:
|
|
error_msg = str(response.error)
|
|
logger.error(f"OpenCode response error: {error_msg}")
|
|
raise OpenCodeError(f"Agent error: {error_msg}")
|
|
|
|
# session.chat() returns immediately - agent may still be running
|
|
# Check if response is complete by looking at time.completed
|
|
# If not complete, poll until agent finishes
|
|
await self._wait_for_completion(
|
|
client, session_id, response, timeout_seconds, None
|
|
)
|
|
|
|
# Now fetch the actual message content
|
|
content = await self._fetch_message_content(client, session_id)
|
|
|
|
logger.info(f"Received response for session {session_id} ({len(content)} chars)")
|
|
|
|
return {"content": content}
|
|
|
|
except OpenCodeError:
|
|
raise
|
|
except Exception as e:
|
|
server_url = self.base_url or "default"
|
|
logger.error(f"Failed to send message to session {session_id}: {e}")
|
|
raise OpenCodeError(
|
|
f"Failed to send message (is OpenCode server running at {server_url}?): {e}"
|
|
)
|
|
|
|
async def _send_message_streaming(
|
|
self,
|
|
session_id: str,
|
|
message: str,
|
|
output_callback: Callable[[str], Awaitable[None]]
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Send a message with real-time streaming output.
|
|
|
|
Uses OpenCode SDK event-based streaming:
|
|
1. Start event.list() SSE stream to receive real-time events
|
|
2. Call session.chat() to send the message
|
|
3. Handle events filtered by session_id:
|
|
- message.part.updated: Real-time delta text chunks
|
|
- session.idle: Completion detection
|
|
- session.error: Error handling
|
|
4. Return final content when session.idle is received
|
|
|
|
Note: message.updated events are intentionally ignored to avoid
|
|
content duplication with message.part.updated events.
|
|
|
|
Falls back to with_streaming_response, then polling if event streaming fails.
|
|
|
|
Args:
|
|
session_id: Session ID from create_session
|
|
message: User message to send
|
|
output_callback: Async callback to receive streaming output chunks
|
|
|
|
Returns:
|
|
Dict with "content" (full response)
|
|
"""
|
|
session_data = self._sessions[session_id]
|
|
agent_name = session_data["agent"]
|
|
|
|
logger.info(f"Session {session_id} ({agent_name}): Starting streaming message")
|
|
|
|
client = await self._get_client()
|
|
|
|
# Try event.subscribe() based streaming first
|
|
# This method handles its own fallback to polling if message was already sent
|
|
try:
|
|
return await self._send_message_with_event_stream(
|
|
client, session_id, message, output_callback
|
|
)
|
|
except OpenCodeError:
|
|
raise
|
|
except _MessageAlreadySentError as e:
|
|
# Message was sent but event streaming failed - poll for completion
|
|
# Do NOT fall back to methods that would send another message
|
|
logger.warning(f"Session {session_id}: Event stream failed after message sent, polling for completion")
|
|
return await self._poll_for_existing_message(
|
|
client, session_id, output_callback, list(e.accumulated_content)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: event.list() approach failed before message sent ({e}), trying with_streaming_response")
|
|
|
|
# Fallback to with_streaming_response approach (only if message NOT yet sent)
|
|
try:
|
|
return await self._send_message_with_streaming_response(
|
|
client, session_id, message, output_callback
|
|
)
|
|
except OpenCodeError:
|
|
raise
|
|
except _MessageAlreadySentError as e:
|
|
# Message was sent but streaming failed - poll for completion
|
|
logger.warning(f"Session {session_id}: with_streaming_response failed after message sent, polling for completion")
|
|
return await self._poll_for_existing_message(
|
|
client, session_id, output_callback, list(e.accumulated_content)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: with_streaming_response failed before message sent ({e}), falling back to polling")
|
|
|
|
# Final fallback: polling-based streaming (sends new message)
|
|
return await self._send_message_polling(
|
|
session_id, message, output_callback
|
|
)
|
|
|
|
async def _send_message_with_event_stream(
|
|
self,
|
|
client: Any,
|
|
session_id: str,
|
|
message: str,
|
|
output_callback: Callable[[str], Awaitable[None]]
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Send a message using event.list() for real-time streaming.
|
|
|
|
This approach:
|
|
1. Starts event.list() SSE stream BEFORE sending message (to not miss events)
|
|
2. Sends message via session.chat()
|
|
3. Processes events filtered by session_id
|
|
4. Returns when session.idle is received
|
|
|
|
If message is sent but streaming fails, raises _MessageAlreadySentError
|
|
to signal the caller should poll instead of sending another message.
|
|
|
|
Args:
|
|
client: OpenCode client instance
|
|
session_id: Session ID
|
|
message: User message to send
|
|
output_callback: Callback for streaming output
|
|
|
|
Returns:
|
|
Dict with "content" (full response)
|
|
|
|
Raises:
|
|
OpenCodeError: For API/auth errors
|
|
_MessageAlreadySentError: If message sent but streaming failed
|
|
"""
|
|
session_data = self._sessions[session_id]
|
|
agent_name = session_data["agent"]
|
|
|
|
logger.info(f"Session {session_id}: Starting event.list() based streaming")
|
|
|
|
# Build message parts
|
|
parts: List[Dict[str, Any]] = [
|
|
{"type": "text", "text": message}
|
|
]
|
|
tools: Dict[str, bool] = {"*": True}
|
|
|
|
accumulated_content: List[str] = []
|
|
message_completed = asyncio.Event()
|
|
message_sent = False
|
|
event_error: List[Exception] = []
|
|
events_stream: Any = None
|
|
last_sent_content_length: List[int] = [0] # Track for delta-style streaming
|
|
assistant_message_id: List[Optional[str]] = [None] # Track assistant message ID to filter events
|
|
|
|
async def process_events():
|
|
"""Background task to process SSE events."""
|
|
nonlocal events_stream
|
|
|
|
# Track sent content per part ID to avoid duplicates
|
|
# When the same part is updated multiple times, only send new content
|
|
part_content_sent: Dict[str, int] = {}
|
|
|
|
try:
|
|
# Subscribe to SSE event stream using event.list()
|
|
events_stream = await client.event.list()
|
|
chunk_count = 0
|
|
|
|
async for event in events_stream:
|
|
# Check if we should stop
|
|
if message_completed.is_set():
|
|
break
|
|
|
|
# Log raw event for debugging (first few events)
|
|
if chunk_count < 10:
|
|
event_debug = self._get_event_debug_info(event)
|
|
logger.debug(f"Session {session_id}: Raw event {chunk_count}: {event_debug}")
|
|
|
|
# Get event properties
|
|
properties = _safe_get(event, 'properties', {})
|
|
if not isinstance(properties, dict):
|
|
if hasattr(properties, '__dict__'):
|
|
properties = properties.__dict__
|
|
elif hasattr(properties, 'model_dump'):
|
|
try:
|
|
properties = properties.model_dump()
|
|
except Exception:
|
|
properties = {}
|
|
|
|
# Filter by session ID
|
|
event_session_id = (
|
|
_safe_get(properties, 'sessionID') or
|
|
_safe_get(properties, 'sessionId') or
|
|
_safe_get(properties, 'session_id')
|
|
)
|
|
|
|
if event_session_id and event_session_id != session_id:
|
|
continue
|
|
|
|
# Get event type
|
|
event_type = _safe_get(event, 'type', '')
|
|
|
|
# Handle message.updated FIRST to detect assistant message ID
|
|
# This must be processed before message.part.updated so we know which parts to filter
|
|
if event_type == 'message.updated':
|
|
# Extract message info to detect assistant message
|
|
info = _safe_get(properties, 'info')
|
|
if info:
|
|
msg_role = _safe_get(info, 'role')
|
|
msg_id = _safe_get(info, 'id')
|
|
|
|
# If this is the assistant message, save its ID for filtering part events
|
|
if msg_role == 'assistant' and msg_id and assistant_message_id[0] is None:
|
|
assistant_message_id[0] = msg_id
|
|
logger.info(f"Session {session_id}: Detected assistant message ID from event: {msg_id}")
|
|
|
|
# Also extract text content (existing logic)
|
|
full_text = self._extract_text_from_event(event, event_type)
|
|
if full_text:
|
|
# Delta-style: only send new content
|
|
if len(full_text) > last_sent_content_length[0]:
|
|
new_text = full_text[last_sent_content_length[0]:]
|
|
last_sent_content_length[0] = len(full_text)
|
|
chunk_count += 1
|
|
accumulated_content.append(new_text)
|
|
logger.info(f"Session {session_id}: message.updated delta chunk {chunk_count} ({len(new_text)} new chars, total {len(full_text)})")
|
|
try:
|
|
await output_callback(new_text)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Output callback error: {e}")
|
|
continue
|
|
|
|
# Handle message.part.updated - real-time delta text
|
|
if event_type == 'message.part.updated':
|
|
# Extract part and message ID for deduplication and filtering
|
|
part = _safe_get(properties, 'part')
|
|
part_id = _safe_get(part, 'id') if part else None
|
|
part_message_id = _safe_get(part, 'message_id') if part else None
|
|
|
|
# Only process events for the assistant message (skip user message events)
|
|
# If assistant_message_id is not yet set, skip ALL events until we know the ID
|
|
if assistant_message_id[0] is None:
|
|
logger.debug(f"Session {session_id}: Skipping event - assistant message ID not yet set")
|
|
continue
|
|
|
|
# Skip events that don't match our assistant message
|
|
if part_message_id and part_message_id != assistant_message_id[0]:
|
|
logger.debug(f"Session {session_id}: Skipping event for non-assistant message {part_message_id[:20] if part_message_id else 'N/A'}")
|
|
continue
|
|
|
|
text = self._extract_text_from_event(event, event_type)
|
|
if text:
|
|
# Delta-style transmission: only send new content
|
|
if part_id:
|
|
already_sent = part_content_sent.get(part_id, 0)
|
|
if len(text) > already_sent:
|
|
new_text = text[already_sent:]
|
|
part_content_sent[part_id] = len(text)
|
|
|
|
chunk_count += 1
|
|
accumulated_content.append(new_text)
|
|
logger.info(f"Session {session_id}: Stream chunk {chunk_count} ({len(new_text)} new chars, part {part_id[:20] if part_id else 'N/A'})")
|
|
try:
|
|
await output_callback(new_text)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Output callback error: {e}")
|
|
# else: no new content, skip duplicate
|
|
else:
|
|
# No part ID, send full text (fallback)
|
|
chunk_count += 1
|
|
accumulated_content.append(text)
|
|
logger.info(f"Session {session_id}: Stream chunk {chunk_count} ({len(text)} chars, no part ID)")
|
|
try:
|
|
await output_callback(text)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Output callback error: {e}")
|
|
continue
|
|
|
|
# NOTE: message.updated is now handled BEFORE message.part.updated above
|
|
# to detect assistant message ID before filtering part events
|
|
|
|
# Handle session.updated - may contain message content
|
|
if event_type == 'session.updated':
|
|
full_text = self._extract_text_from_session_updated(event, properties)
|
|
if full_text:
|
|
# Delta-style: only send new content
|
|
if len(full_text) > last_sent_content_length[0]:
|
|
new_text = full_text[last_sent_content_length[0]:]
|
|
last_sent_content_length[0] = len(full_text)
|
|
chunk_count += 1
|
|
accumulated_content.append(new_text)
|
|
logger.info(f"Session {session_id}: session.updated delta chunk {chunk_count} ({len(new_text)} new chars)")
|
|
try:
|
|
await output_callback(new_text)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Output callback error: {e}")
|
|
continue
|
|
|
|
# Handle session.diff - diff/streaming updates
|
|
if event_type == 'session.diff':
|
|
text = self._extract_text_from_session_diff_event(event, properties)
|
|
if text:
|
|
chunk_count += 1
|
|
accumulated_content.append(text)
|
|
logger.info(f"Session {session_id}: session.diff chunk {chunk_count} ({len(text)} chars)")
|
|
try:
|
|
await output_callback(text)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Output callback error: {e}")
|
|
continue
|
|
|
|
# Handle session.idle - completion
|
|
if event_type == 'session.idle':
|
|
logger.info(f"Session {session_id}: session.idle received - agent completed")
|
|
message_completed.set()
|
|
break
|
|
|
|
# Handle session.error - error
|
|
if event_type == 'session.error':
|
|
error_msg = _safe_get(properties, 'error') or _safe_get(properties, 'message') or 'Unknown error'
|
|
logger.error(f"Session {session_id}: session.error received: {error_msg}")
|
|
event_error.append(OpenCodeError(f"Session error: {error_msg}"))
|
|
message_completed.set()
|
|
break
|
|
|
|
# Log ALL event types received (helps debug which events are actually sent)
|
|
if event_type:
|
|
if event_type not in ('server.connected', 'session.status', 'lsp.client.diagnostics', 'file.edited', 'file.watcher.updated', 'server.heartbeat'):
|
|
# For important events, log more details
|
|
props_keys = list(properties.keys()) if isinstance(properties, dict) else 'N/A'
|
|
logger.info(f"Session {session_id}: Event '{event_type}' received, properties keys: {props_keys}")
|
|
|
|
except asyncio.CancelledError:
|
|
logger.debug(f"Session {session_id}: Event processing cancelled")
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Event processing error: {e}")
|
|
event_error.append(e)
|
|
message_completed.set()
|
|
|
|
# Start event processing task BEFORE sending message
|
|
event_task = asyncio.create_task(process_events())
|
|
|
|
try:
|
|
# Give event stream a moment to connect
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Send the message
|
|
logger.info(f"Session {session_id}: Sending message via session.chat()")
|
|
response = await client.session.chat(
|
|
session_id,
|
|
model_id=self.model_id,
|
|
provider_id=self.provider_id,
|
|
parts=parts,
|
|
mode=agent_name,
|
|
system=session_data["system_prompt"],
|
|
tools=tools,
|
|
)
|
|
|
|
message_sent = True
|
|
|
|
# Extract assistant message ID from response for filtering events
|
|
response_info = _safe_get(response, 'info', response)
|
|
extracted_msg_id = _safe_get(response_info, 'id') or _safe_get(response, 'id')
|
|
if extracted_msg_id:
|
|
assistant_message_id[0] = extracted_msg_id
|
|
logger.info(f"Session {session_id}: Assistant message ID: {extracted_msg_id}")
|
|
|
|
# Check for immediate error in response
|
|
if hasattr(response, 'error') and response.error:
|
|
error_msg = str(response.error)
|
|
logger.error(f"Session {session_id}: Chat error: {error_msg}")
|
|
raise OpenCodeError(f"Chat error: {error_msg}")
|
|
|
|
# Wait for completion with timeout
|
|
try:
|
|
await asyncio.wait_for(message_completed.wait(), timeout=300)
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Session {session_id}: Event stream timeout")
|
|
|
|
# Check for errors from event processing
|
|
if event_error:
|
|
raise event_error[0]
|
|
|
|
# Get final content
|
|
content = ''.join(accumulated_content)
|
|
|
|
# If no content from streaming, fetch via messages API
|
|
if not content:
|
|
logger.info(f"Session {session_id}: No content from event stream, fetching via messages API")
|
|
content = await self._fetch_message_content(client, session_id)
|
|
if content and output_callback:
|
|
try:
|
|
await output_callback(content)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Final callback error: {e}")
|
|
|
|
logger.info(f"Session {session_id}: Event streaming complete ({len(content)} chars)")
|
|
return {"content": content}
|
|
|
|
except OpenCodeError:
|
|
raise
|
|
except _MessageAlreadySentError:
|
|
raise
|
|
except Exception as e:
|
|
if message_sent:
|
|
raise _MessageAlreadySentError(
|
|
f"Event streaming failed after message sent: {e}",
|
|
accumulated_content
|
|
)
|
|
raise
|
|
finally:
|
|
# Clean up event task
|
|
if not event_task.done():
|
|
event_task.cancel()
|
|
try:
|
|
await event_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
def _extract_text_from_stream_event(self, event: Any) -> Optional[str]:
|
|
"""
|
|
Extract text content from a streaming event.
|
|
|
|
Handles various event formats from stream=True responses:
|
|
- Direct text/content fields
|
|
- Delta format (OpenAI style)
|
|
- Parts array format
|
|
- Message content format
|
|
|
|
Args:
|
|
event: Streaming event object
|
|
|
|
Returns:
|
|
Extracted text or None
|
|
"""
|
|
if event is None:
|
|
return None
|
|
|
|
# String - return directly
|
|
if isinstance(event, str):
|
|
return event if event.strip() else None
|
|
|
|
# Try to convert to dict if it's a Pydantic model
|
|
event_dict = event
|
|
if hasattr(event, 'model_dump'):
|
|
try:
|
|
event_dict = event.model_dump()
|
|
except Exception:
|
|
pass
|
|
elif hasattr(event, '__dict__'):
|
|
event_dict = event.__dict__
|
|
|
|
if isinstance(event_dict, dict):
|
|
# Direct content/text fields
|
|
for field in ['content', 'text', 'delta', 'message']:
|
|
val = event_dict.get(field)
|
|
if isinstance(val, str) and val.strip():
|
|
return val
|
|
elif isinstance(val, dict):
|
|
# Nested content (e.g., delta.content)
|
|
text = val.get('content') or val.get('text')
|
|
if isinstance(text, str) and text.strip():
|
|
return text
|
|
|
|
# Parts array format
|
|
parts = event_dict.get('parts')
|
|
if isinstance(parts, list):
|
|
texts = []
|
|
for part in parts:
|
|
if isinstance(part, dict) and part.get('type') == 'text':
|
|
t = part.get('text', '')
|
|
if t:
|
|
texts.append(t)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# Choices format (OpenAI style)
|
|
choices = event_dict.get('choices')
|
|
if isinstance(choices, list) and choices:
|
|
first = choices[0]
|
|
if isinstance(first, dict):
|
|
delta = first.get('delta', {})
|
|
if isinstance(delta, dict):
|
|
content = delta.get('content')
|
|
if isinstance(content, str):
|
|
return content
|
|
|
|
# Properties format (OpenCode events)
|
|
props = event_dict.get('properties', {})
|
|
if isinstance(props, dict):
|
|
# Try delta in properties
|
|
delta = props.get('delta', {})
|
|
if isinstance(delta, dict):
|
|
text = delta.get('text') or delta.get('content')
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Try part in properties
|
|
part = props.get('part', {})
|
|
if isinstance(part, dict) and part.get('type') == 'text':
|
|
text = part.get('text')
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Try object attributes directly
|
|
for attr in ['content', 'text', 'delta']:
|
|
val = getattr(event, attr, None)
|
|
if isinstance(val, str) and val.strip():
|
|
return val
|
|
|
|
return None
|
|
|
|
def _extract_text_from_event(self, event: Any, event_type: str = "") -> Optional[str]:
|
|
"""
|
|
Extract text content from an SSE event.
|
|
|
|
Handles different event types:
|
|
- message.part.updated: Extract delta/incremental text
|
|
- message.updated: Extract full message content
|
|
|
|
Args:
|
|
event: Event object from event.list()
|
|
event_type: The event type string for specialized handling
|
|
|
|
Returns:
|
|
Extracted text or None
|
|
"""
|
|
if event is None:
|
|
return None
|
|
|
|
properties = _safe_get(event, 'properties', {})
|
|
|
|
# Handle message.part.updated - delta/streaming content
|
|
if event_type == 'message.part.updated':
|
|
# Try delta object first (most common for streaming)
|
|
delta = _safe_get(properties, 'delta')
|
|
if isinstance(delta, dict):
|
|
text = _safe_get(delta, 'text') or _safe_get(delta, 'content')
|
|
if text and isinstance(text, str):
|
|
return text
|
|
|
|
# Try direct text in properties
|
|
text = _safe_get(properties, 'text')
|
|
if text and isinstance(text, str):
|
|
return text
|
|
|
|
# Try part object (can be dict or Pydantic model)
|
|
part = _safe_get(properties, 'part')
|
|
if part is not None:
|
|
# _safe_get handles both dict and object attributes
|
|
# Include 'reasoning' type which also contains text content
|
|
part_type = _safe_get(part, 'type')
|
|
if part_type in ('text', 'reasoning', None):
|
|
text = _safe_get(part, 'text') or _safe_get(part, 'content')
|
|
if text and isinstance(text, str):
|
|
logger.debug(f"Extracted text from part (type={part_type}): {len(text)} chars")
|
|
return text
|
|
|
|
return None
|
|
|
|
# Handle message.updated - full message content
|
|
if event_type == 'message.updated':
|
|
# Try parts array in properties
|
|
parts = _safe_get(properties, 'parts')
|
|
if parts and hasattr(parts, '__iter__'):
|
|
texts = []
|
|
for part in parts:
|
|
# Handle both dict and Pydantic model parts
|
|
part_type = _safe_get(part, 'type')
|
|
if part_type in ('text', 'reasoning'): # Include reasoning type
|
|
t = _safe_get(part, 'text')
|
|
if t:
|
|
texts.append(t)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# Try message object
|
|
message = _safe_get(properties, 'message')
|
|
if message is not None:
|
|
# Check message parts
|
|
msg_parts = _safe_get(message, 'parts')
|
|
if msg_parts and hasattr(msg_parts, '__iter__'):
|
|
texts = []
|
|
for part in msg_parts:
|
|
part_type = _safe_get(part, 'type')
|
|
if part_type in ('text', 'reasoning'):
|
|
t = _safe_get(part, 'text')
|
|
if t:
|
|
texts.append(t)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# Direct content
|
|
text = _safe_get(message, 'content') or _safe_get(message, 'text')
|
|
if text and isinstance(text, str):
|
|
return text
|
|
|
|
# Direct content in properties
|
|
text = _safe_get(properties, 'content') or _safe_get(properties, 'text')
|
|
if text and isinstance(text, str):
|
|
return text
|
|
|
|
return None
|
|
|
|
# Generic fallback for other event types
|
|
# Try various property locations for text content
|
|
text = _safe_get(properties, 'text') or _safe_get(properties, 'content')
|
|
if text and isinstance(text, str):
|
|
return text
|
|
|
|
# Delta content (streaming format)
|
|
delta = _safe_get(properties, 'delta')
|
|
if delta is not None:
|
|
text = _safe_get(delta, 'content') or _safe_get(delta, 'text')
|
|
if text and isinstance(text, str):
|
|
return text
|
|
|
|
# Parts array format
|
|
parts = _safe_get(properties, 'parts')
|
|
if parts and hasattr(parts, '__iter__'):
|
|
texts = []
|
|
for part in parts:
|
|
part_type = _safe_get(part, 'type')
|
|
if part_type in ('text', 'reasoning'):
|
|
t = _safe_get(part, 'text')
|
|
if t:
|
|
texts.append(t)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# Message object with content
|
|
message = _safe_get(properties, 'message')
|
|
if message is not None:
|
|
text = _safe_get(message, 'content') or _safe_get(message, 'text')
|
|
if text and isinstance(text, str):
|
|
return text
|
|
msg_parts = _safe_get(message, 'parts')
|
|
if msg_parts and hasattr(msg_parts, '__iter__'):
|
|
texts = []
|
|
for part in msg_parts:
|
|
part_type = _safe_get(part, 'type')
|
|
if part_type in ('text', 'reasoning'):
|
|
t = _safe_get(part, 'text')
|
|
if t:
|
|
texts.append(t)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
return None
|
|
|
|
def _get_event_debug_info(self, event: Any) -> str:
|
|
"""
|
|
Get detailed debug information about an event object.
|
|
|
|
Args:
|
|
event: Event object from SSE stream
|
|
|
|
Returns:
|
|
Debug string describing the event structure
|
|
"""
|
|
info_parts = []
|
|
|
|
try:
|
|
# Get type
|
|
event_type = type(event).__name__
|
|
info_parts.append(f"type={event_type}")
|
|
|
|
# Try model_dump() for Pydantic models
|
|
if hasattr(event, 'model_dump'):
|
|
try:
|
|
dump = event.model_dump()
|
|
# Truncate if too long
|
|
dump_str = str(dump)
|
|
if len(dump_str) > 500:
|
|
dump_str = dump_str[:500] + '...'
|
|
info_parts.append(f"model_dump={dump_str}")
|
|
except Exception as e:
|
|
info_parts.append(f"model_dump_error={e}")
|
|
|
|
# Get all attributes
|
|
if hasattr(event, '__dict__'):
|
|
attrs = list(event.__dict__.keys())
|
|
info_parts.append(f"attrs={attrs}")
|
|
|
|
# Try common attribute names
|
|
for attr in ['type', 'properties', 'data', 'diff', 'content', 'value', 'message', 'parts', 'delta', 'id', 'session_id']:
|
|
val = _safe_get(event, attr)
|
|
if val is not None:
|
|
val_str = str(val)
|
|
if len(val_str) > 100:
|
|
val_str = val_str[:100] + '...'
|
|
info_parts.append(f"{attr}={val_str}")
|
|
|
|
# If it's a dict, show keys
|
|
if isinstance(event, dict):
|
|
info_parts.append(f"dict_keys={list(event.keys())}")
|
|
|
|
except Exception as e:
|
|
info_parts.append(f"debug_error={e}")
|
|
|
|
return "; ".join(info_parts)
|
|
|
|
def _extract_text_from_session_updated(self, event: Any, properties: Dict[str, Any]) -> Optional[str]:
|
|
"""
|
|
Extract text content from a session.updated event.
|
|
|
|
session.updated events contain the current session state including messages.
|
|
We look for the last assistant message and extract its text content.
|
|
|
|
Args:
|
|
event: The full event object
|
|
properties: Properties dict from the event
|
|
|
|
Returns:
|
|
Full text content from the last assistant message, or None
|
|
"""
|
|
if not properties:
|
|
# Try to get properties from event
|
|
if hasattr(event, 'properties'):
|
|
properties = event.properties
|
|
elif hasattr(event, 'model_dump'):
|
|
try:
|
|
dump = event.model_dump()
|
|
properties = dump.get('properties', {})
|
|
except Exception:
|
|
pass
|
|
|
|
if not properties:
|
|
return None
|
|
|
|
# Convert to dict if needed
|
|
props_dict = properties
|
|
if not isinstance(properties, dict):
|
|
if hasattr(properties, 'model_dump'):
|
|
try:
|
|
props_dict = properties.model_dump()
|
|
except Exception:
|
|
pass
|
|
elif hasattr(properties, '__dict__'):
|
|
props_dict = properties.__dict__
|
|
|
|
if not isinstance(props_dict, dict):
|
|
return None
|
|
|
|
# Try to find messages in the session data
|
|
# The structure might be: properties.messages or properties.session.messages
|
|
messages = None
|
|
|
|
# Direct messages array
|
|
messages = props_dict.get('messages')
|
|
|
|
# Try session.messages
|
|
if not messages:
|
|
session = props_dict.get('session')
|
|
if isinstance(session, dict):
|
|
messages = session.get('messages')
|
|
elif hasattr(session, 'messages'):
|
|
messages = session.messages
|
|
|
|
# Try data.messages
|
|
if not messages:
|
|
data = props_dict.get('data')
|
|
if isinstance(data, dict):
|
|
messages = data.get('messages')
|
|
|
|
if not messages or not isinstance(messages, list):
|
|
# Log what we have for debugging
|
|
logger.info(f"session.updated: no messages found. Keys: {list(props_dict.keys()) if isinstance(props_dict, dict) else 'N/A'}")
|
|
|
|
# Try to extract from 'info' object - this might contain message content
|
|
info = props_dict.get('info') if isinstance(props_dict, dict) else None
|
|
if info:
|
|
text = self._extract_text_from_info_object(info)
|
|
if text:
|
|
logger.info(f"session.updated: extracted text from info object ({len(text)} chars)")
|
|
return text
|
|
else:
|
|
# Log info structure for debugging
|
|
info_keys = list(info.keys()) if isinstance(info, dict) else 'not a dict'
|
|
logger.debug(f"session.updated: info object keys: {info_keys}")
|
|
if isinstance(info, dict):
|
|
# Log a sample of the info structure (truncated)
|
|
info_sample = str(info)[:500]
|
|
logger.debug(f"session.updated: info sample: {info_sample}")
|
|
|
|
return None
|
|
|
|
# Find the last assistant message
|
|
last_assistant_msg = None
|
|
for msg in reversed(messages):
|
|
msg_dict = msg
|
|
if not isinstance(msg, dict):
|
|
if hasattr(msg, 'model_dump'):
|
|
try:
|
|
msg_dict = msg.model_dump()
|
|
except Exception:
|
|
continue
|
|
elif hasattr(msg, '__dict__'):
|
|
msg_dict = msg.__dict__
|
|
else:
|
|
continue
|
|
|
|
# Check role
|
|
role = msg_dict.get('role')
|
|
if not role:
|
|
info = msg_dict.get('info', {})
|
|
if isinstance(info, dict):
|
|
role = info.get('role')
|
|
|
|
if role == 'assistant':
|
|
last_assistant_msg = msg_dict
|
|
break
|
|
|
|
if not last_assistant_msg:
|
|
return None
|
|
|
|
# Extract text from parts
|
|
parts = last_assistant_msg.get('parts', [])
|
|
if not parts:
|
|
# Try content field
|
|
content = last_assistant_msg.get('content')
|
|
if isinstance(content, str):
|
|
return content
|
|
return None
|
|
|
|
texts = []
|
|
for part in parts:
|
|
part_dict = part
|
|
if not isinstance(part, dict):
|
|
if hasattr(part, 'model_dump'):
|
|
try:
|
|
part_dict = part.model_dump()
|
|
except Exception:
|
|
continue
|
|
elif hasattr(part, '__dict__'):
|
|
part_dict = part.__dict__
|
|
else:
|
|
continue
|
|
|
|
if part_dict.get('type') == 'text':
|
|
text = part_dict.get('text', '')
|
|
if text:
|
|
texts.append(text)
|
|
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
return None
|
|
|
|
def _extract_text_from_info_object(self, info: Any) -> Optional[str]:
|
|
"""
|
|
Extract text content from an 'info' object in session.updated events.
|
|
|
|
The info object might contain message content in various formats:
|
|
- Direct text/content fields
|
|
- Nested parts array
|
|
- Message-like structures
|
|
|
|
Args:
|
|
info: Info object from session.updated event properties
|
|
|
|
Returns:
|
|
Extracted text or None
|
|
"""
|
|
if info is None:
|
|
return None
|
|
|
|
# Convert to dict if needed
|
|
info_dict = info
|
|
if not isinstance(info, dict):
|
|
if hasattr(info, 'model_dump'):
|
|
try:
|
|
info_dict = info.model_dump()
|
|
except Exception:
|
|
pass
|
|
elif hasattr(info, '__dict__'):
|
|
info_dict = info.__dict__
|
|
|
|
if not isinstance(info_dict, dict):
|
|
# If it's a string, return it
|
|
if isinstance(info, str):
|
|
return info if info.strip() else None
|
|
return None
|
|
|
|
# Try direct text/content fields
|
|
for field in ['text', 'content', 'message', 'value', 'output']:
|
|
value = info_dict.get(field)
|
|
if isinstance(value, str) and value.strip():
|
|
return value
|
|
|
|
# Try parts array (OpenCode message format)
|
|
parts = info_dict.get('parts')
|
|
if isinstance(parts, list):
|
|
texts = []
|
|
for part in parts:
|
|
if isinstance(part, dict):
|
|
if part.get('type') == 'text':
|
|
t = part.get('text')
|
|
if t:
|
|
texts.append(t)
|
|
# Also try content field in part
|
|
elif 'content' in part:
|
|
t = part.get('content')
|
|
if isinstance(t, str) and t:
|
|
texts.append(t)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# Try messages array inside info
|
|
messages = info_dict.get('messages')
|
|
if isinstance(messages, list):
|
|
# Find last assistant message
|
|
for msg in reversed(messages):
|
|
msg_dict = msg if isinstance(msg, dict) else (msg.__dict__ if hasattr(msg, '__dict__') else {})
|
|
role = msg_dict.get('role')
|
|
if role == 'assistant':
|
|
# Extract from message parts
|
|
msg_parts = msg_dict.get('parts', [])
|
|
texts = []
|
|
for part in msg_parts:
|
|
if isinstance(part, dict) and part.get('type') == 'text':
|
|
t = part.get('text')
|
|
if t:
|
|
texts.append(t)
|
|
if texts:
|
|
return ''.join(texts)
|
|
# Try direct content
|
|
content = msg_dict.get('content')
|
|
if isinstance(content, str):
|
|
return content
|
|
|
|
# Try delta format (streaming)
|
|
delta = info_dict.get('delta')
|
|
if isinstance(delta, dict):
|
|
text = delta.get('text') or delta.get('content')
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Try assistant field (might contain assistant message)
|
|
assistant = info_dict.get('assistant')
|
|
if assistant:
|
|
if isinstance(assistant, str):
|
|
return assistant
|
|
elif isinstance(assistant, dict):
|
|
text = assistant.get('text') or assistant.get('content')
|
|
if isinstance(text, str):
|
|
return text
|
|
# Check parts in assistant
|
|
asst_parts = assistant.get('parts', [])
|
|
texts = []
|
|
for part in asst_parts:
|
|
if isinstance(part, dict) and part.get('type') == 'text':
|
|
t = part.get('text')
|
|
if t:
|
|
texts.append(t)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# Try response field
|
|
response = info_dict.get('response')
|
|
if response:
|
|
if isinstance(response, str):
|
|
return response
|
|
elif isinstance(response, dict):
|
|
text = response.get('text') or response.get('content')
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Try output field
|
|
output = info_dict.get('output')
|
|
if output:
|
|
if isinstance(output, str):
|
|
return output
|
|
elif isinstance(output, dict):
|
|
text = output.get('text') or output.get('content')
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Recursively check nested objects for common content fields
|
|
for key in ['data', 'result', 'payload', 'body']:
|
|
nested = info_dict.get(key)
|
|
if isinstance(nested, dict):
|
|
result = self._extract_text_from_info_object(nested)
|
|
if result:
|
|
return result
|
|
|
|
return None
|
|
|
|
def _extract_text_from_session_diff_event(self, event: Any, properties: Dict[str, Any]) -> Optional[str]:
|
|
"""
|
|
Extract text content from a session.diff event.
|
|
|
|
Tries multiple sources:
|
|
1. Properties dict (already extracted)
|
|
2. Direct event attributes (data, diff, content, etc.)
|
|
3. model_dump() if available
|
|
4. __dict__ if available
|
|
|
|
Args:
|
|
event: The full event object
|
|
properties: Properties dict from the event (may be empty)
|
|
|
|
Returns:
|
|
Extracted text or None
|
|
"""
|
|
# First, try the properties dict
|
|
text = self._extract_text_from_session_diff(properties)
|
|
if text:
|
|
return text
|
|
|
|
# Try to get data from event object directly
|
|
# Common field names where diff data might be
|
|
for field in ['diff', 'data', 'content', 'value', 'delta', 'message', 'text', 'parts']:
|
|
val = _safe_get(event, field)
|
|
if val is not None:
|
|
extracted = self._try_extract_text_from_value(val)
|
|
if extracted:
|
|
logger.debug(f"Extracted text from event.{field}")
|
|
return extracted
|
|
|
|
# Try model_dump() for Pydantic models
|
|
if hasattr(event, 'model_dump'):
|
|
try:
|
|
dump = event.model_dump()
|
|
if isinstance(dump, dict):
|
|
# Look for text in dump (excluding properties which we already checked)
|
|
for field in ['diff', 'data', 'content', 'value', 'delta', 'message', 'text', 'parts']:
|
|
if field in dump and dump[field] is not None:
|
|
extracted = self._try_extract_text_from_value(dump[field])
|
|
if extracted:
|
|
logger.debug(f"Extracted text from model_dump.{field}")
|
|
return extracted
|
|
except Exception as e:
|
|
logger.debug(f"model_dump extraction failed: {e}")
|
|
|
|
# Try __dict__ for regular objects
|
|
if hasattr(event, '__dict__'):
|
|
event_dict = event.__dict__
|
|
for field in ['diff', 'data', 'content', 'value', 'delta', 'message', 'text', 'parts']:
|
|
if field in event_dict and event_dict[field] is not None:
|
|
extracted = self._try_extract_text_from_value(event_dict[field])
|
|
if extracted:
|
|
logger.debug(f"Extracted text from __dict__.{field}")
|
|
return extracted
|
|
|
|
return None
|
|
|
|
def _try_extract_text_from_value(self, value: Any) -> Optional[str]:
|
|
"""
|
|
Try to extract text from a value of unknown structure.
|
|
|
|
Args:
|
|
value: Value to extract text from (could be str, dict, list, etc.)
|
|
|
|
Returns:
|
|
Extracted text or None
|
|
"""
|
|
if value is None:
|
|
return None
|
|
|
|
# Direct string
|
|
if isinstance(value, str):
|
|
return value if value.strip() else None
|
|
|
|
# Dict - look for text fields
|
|
if isinstance(value, dict):
|
|
# Direct text fields
|
|
for field in ['text', 'content', 'value', 'delta', 'message']:
|
|
if field in value and isinstance(value[field], str):
|
|
return value[field]
|
|
|
|
# Nested delta
|
|
delta = value.get('delta')
|
|
if isinstance(delta, dict):
|
|
text = delta.get('text') or delta.get('content')
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Parts array
|
|
parts = value.get('parts')
|
|
if isinstance(parts, list):
|
|
texts = []
|
|
for part in parts:
|
|
if isinstance(part, dict) and part.get('type') == 'text':
|
|
t = part.get('text')
|
|
if t:
|
|
texts.append(t)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# JSON Patch style operations
|
|
for ops_field in ['diff', 'operations', 'patches', 'changes']:
|
|
ops = value.get(ops_field)
|
|
if isinstance(ops, list):
|
|
texts = []
|
|
for op in ops:
|
|
if isinstance(op, dict):
|
|
v = op.get('value') or op.get('text') or op.get('content')
|
|
if isinstance(v, str):
|
|
texts.append(v)
|
|
elif isinstance(op, str):
|
|
texts.append(op)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# List - try to extract from each item
|
|
if isinstance(value, list):
|
|
texts = []
|
|
for item in value:
|
|
if isinstance(item, str):
|
|
texts.append(item)
|
|
elif isinstance(item, dict):
|
|
# JSON Patch style
|
|
v = item.get('value') or item.get('text') or item.get('content')
|
|
if isinstance(v, str):
|
|
texts.append(v)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
return None
|
|
|
|
def _extract_text_from_session_diff(self, properties: Dict[str, Any]) -> Optional[str]:
|
|
"""
|
|
Extract text content from a session.diff event's properties.
|
|
|
|
session.diff events contain incremental changes to the session state.
|
|
The text content may be in various formats depending on the server.
|
|
|
|
Args:
|
|
properties: Properties dict from the session.diff event
|
|
|
|
Returns:
|
|
Extracted text or None
|
|
"""
|
|
if not properties:
|
|
return None
|
|
|
|
# Convert to dict if it's an object
|
|
if not isinstance(properties, dict):
|
|
if hasattr(properties, '__dict__'):
|
|
properties = properties.__dict__
|
|
elif hasattr(properties, 'model_dump'):
|
|
try:
|
|
properties = properties.model_dump()
|
|
except Exception:
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
# Try 'diff' field - may contain text changes
|
|
diff = _safe_get(properties, 'diff')
|
|
if diff:
|
|
# diff could be a list of operations
|
|
if isinstance(diff, list):
|
|
texts = []
|
|
for op in diff:
|
|
if isinstance(op, dict):
|
|
# JSON Patch style: {"op": "add", "path": "...", "value": "text"}
|
|
value = op.get('value')
|
|
if isinstance(value, str):
|
|
texts.append(value)
|
|
# Or direct text in op
|
|
text = op.get('text') or op.get('content')
|
|
if isinstance(text, str):
|
|
texts.append(text)
|
|
if texts:
|
|
return ''.join(texts)
|
|
# diff could be a string directly
|
|
elif isinstance(diff, str):
|
|
return diff
|
|
# diff could be a dict with content
|
|
elif isinstance(diff, dict):
|
|
text = _safe_get(diff, 'text') or _safe_get(diff, 'content') or _safe_get(diff, 'value')
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Try 'operations' field (common diff format)
|
|
operations = _safe_get(properties, 'operations')
|
|
if isinstance(operations, list):
|
|
texts = []
|
|
for op in operations:
|
|
if isinstance(op, dict):
|
|
value = op.get('value') or op.get('text') or op.get('content')
|
|
if isinstance(value, str):
|
|
texts.append(value)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# Try 'patches' field (JSON Patch style)
|
|
patches = _safe_get(properties, 'patches')
|
|
if isinstance(patches, list):
|
|
texts = []
|
|
for patch in patches:
|
|
if isinstance(patch, dict):
|
|
value = patch.get('value')
|
|
if isinstance(value, str):
|
|
texts.append(value)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# Try 'changes' field
|
|
changes = _safe_get(properties, 'changes')
|
|
if isinstance(changes, list):
|
|
texts = []
|
|
for change in changes:
|
|
if isinstance(change, dict):
|
|
text = change.get('text') or change.get('content') or change.get('value')
|
|
if isinstance(text, str):
|
|
texts.append(text)
|
|
elif isinstance(change, str):
|
|
texts.append(change)
|
|
if texts:
|
|
return ''.join(texts)
|
|
elif isinstance(changes, str):
|
|
return changes
|
|
|
|
# Try direct content fields
|
|
for field in ['text', 'content', 'delta', 'value', 'message']:
|
|
value = _safe_get(properties, field)
|
|
if isinstance(value, str):
|
|
return value
|
|
elif isinstance(value, dict):
|
|
text = _safe_get(value, 'text') or _safe_get(value, 'content')
|
|
if isinstance(text, str):
|
|
return text
|
|
|
|
# Try 'parts' array (OpenCode message format)
|
|
parts = _safe_get(properties, 'parts')
|
|
if isinstance(parts, list):
|
|
texts = []
|
|
for part in parts:
|
|
if isinstance(part, dict) and part.get('type') == 'text':
|
|
text = part.get('text')
|
|
if text:
|
|
texts.append(text)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
return None
|
|
|
|
def _is_completion_event(self, event_type: str, properties: Dict[str, Any]) -> bool:
|
|
"""
|
|
Check if an event indicates message completion.
|
|
|
|
Primary completion event: session.idle
|
|
|
|
Args:
|
|
event_type: Event type string
|
|
properties: Event properties dict
|
|
|
|
Returns:
|
|
True if this is a completion event
|
|
"""
|
|
# Primary: session.idle indicates agent finished responding
|
|
if event_type == 'session.idle':
|
|
return True
|
|
|
|
# Secondary: check other completion patterns
|
|
completion_types = [
|
|
'message.complete',
|
|
'message.completed',
|
|
'message.done',
|
|
'message.finish',
|
|
'session.complete',
|
|
'session.completed',
|
|
'response.complete',
|
|
'response.done',
|
|
]
|
|
|
|
if event_type in completion_types:
|
|
return True
|
|
|
|
# Check properties for completion indicators
|
|
status = _safe_get(properties, 'status')
|
|
if status in ('completed', 'complete', 'done', 'finished', 'idle'):
|
|
return True
|
|
|
|
finish = _safe_get(properties, 'finish') or _safe_get(properties, 'finish_reason')
|
|
if finish in ('stop', 'end_turn', 'completed'):
|
|
return True
|
|
|
|
# Check for time.completed in message info
|
|
time_info = _safe_get(properties, 'time')
|
|
if time_info and _safe_get(time_info, 'completed'):
|
|
return True
|
|
|
|
return False
|
|
|
|
def _is_error_event(self, event_type: str, properties: Dict[str, Any]) -> bool:
|
|
"""
|
|
Check if an event indicates an error.
|
|
|
|
Primary error event: session.error
|
|
|
|
Args:
|
|
event_type: Event type string
|
|
properties: Event properties dict
|
|
|
|
Returns:
|
|
True if this is an error event
|
|
"""
|
|
# Primary: session.error
|
|
if event_type == 'session.error':
|
|
return True
|
|
|
|
# Secondary: other error patterns
|
|
error_types = ['error', 'session.failed', 'message.error']
|
|
if event_type in error_types:
|
|
return True
|
|
|
|
# Check properties for error indicators
|
|
if _safe_get(properties, 'error'):
|
|
return True
|
|
|
|
status = _safe_get(properties, 'status')
|
|
if status in ('error', 'failed'):
|
|
return True
|
|
|
|
return False
|
|
|
|
async def _send_message_with_streaming_response(
|
|
self,
|
|
client: Any,
|
|
session_id: str,
|
|
message: str,
|
|
output_callback: Callable[[str], Awaitable[None]]
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Send a message using with_streaming_response for SSE streaming.
|
|
|
|
This is the fallback approach when event.subscribe() is not available.
|
|
|
|
If message is sent but streaming fails, raises _MessageAlreadySentError
|
|
to signal the caller should poll instead of sending another message.
|
|
|
|
Args:
|
|
client: OpenCode client instance
|
|
session_id: Session ID
|
|
message: User message to send
|
|
output_callback: Callback for streaming output
|
|
|
|
Returns:
|
|
Dict with "content" (full response)
|
|
|
|
Raises:
|
|
OpenCodeError: For API/auth errors
|
|
_MessageAlreadySentError: If message sent but streaming failed
|
|
"""
|
|
session_data = self._sessions[session_id]
|
|
agent_name = session_data["agent"]
|
|
|
|
logger.info(f"Session {session_id}: Using with_streaming_response fallback")
|
|
|
|
# Build message parts
|
|
parts: List[Dict[str, Any]] = [
|
|
{"type": "text", "text": message}
|
|
]
|
|
tools: Dict[str, bool] = {"*": True}
|
|
|
|
accumulated_content: List[str] = []
|
|
message_sent = False # Track if message was sent (entering context = sent)
|
|
|
|
try:
|
|
async with client.session.with_streaming_response.chat(
|
|
session_id,
|
|
model_id=self.model_id,
|
|
provider_id=self.provider_id,
|
|
parts=parts,
|
|
mode=agent_name,
|
|
system=session_data["system_prompt"],
|
|
tools=tools,
|
|
) as response:
|
|
# Once we enter context, message has been sent
|
|
message_sent = True
|
|
chunk_count = 0
|
|
|
|
async for raw_line in response.iter_lines():
|
|
if not raw_line:
|
|
continue
|
|
|
|
chunk_count += 1
|
|
|
|
# Try to parse as SSE
|
|
text = self._parse_sse_line(raw_line)
|
|
|
|
if text:
|
|
accumulated_content.append(text)
|
|
logger.info(f"Session {session_id}: Stream chunk {chunk_count} ({len(text)} chars)")
|
|
try:
|
|
await output_callback(text)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Output callback error: {e}")
|
|
|
|
content = ''.join(accumulated_content)
|
|
logger.info(f"Session {session_id}: with_streaming_response complete ({chunk_count} chunks, {len(content)} chars)")
|
|
|
|
# If we got no content from streaming, fetch it via messages API
|
|
if not content:
|
|
logger.info(f"Session {session_id}: No content from streaming, fetching final content")
|
|
content = await self._fetch_message_content(client, session_id)
|
|
if content and output_callback:
|
|
try:
|
|
await output_callback(content)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Output callback error on final content: {e}")
|
|
|
|
return {"content": content}
|
|
|
|
except OpenCodeError:
|
|
raise
|
|
except Exception as e:
|
|
# If message was sent but streaming failed, signal to poll instead
|
|
if message_sent:
|
|
raise _MessageAlreadySentError(
|
|
f"with_streaming_response failed after message sent: {e}",
|
|
accumulated_content
|
|
)
|
|
# Message not sent - let caller try another approach
|
|
raise
|
|
|
|
async def _poll_for_existing_message(
|
|
self,
|
|
client: Any,
|
|
session_id: str,
|
|
output_callback: Optional[Callable[[str], Awaitable[None]]],
|
|
already_received: List[str]
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Poll for an existing message's completion without sending a new message.
|
|
|
|
Used when streaming fails after the message was already sent - we need to
|
|
get the rest of the response without duplicating the message.
|
|
|
|
Args:
|
|
client: OpenCode client instance
|
|
session_id: Session ID
|
|
output_callback: Optional callback for new content chunks
|
|
already_received: Content chunks already received via streaming
|
|
|
|
Returns:
|
|
Dict with "content" (full response)
|
|
"""
|
|
logger.info(f"Session {session_id}: Polling for existing message completion")
|
|
|
|
# Track what we've already sent to the callback
|
|
already_sent_length = sum(len(chunk) for chunk in already_received)
|
|
|
|
# Poll for completion with faster interval
|
|
poll_interval = 0.3
|
|
max_polls = int(120 / poll_interval) # 120 second timeout
|
|
|
|
for poll_count in range(max_polls):
|
|
try:
|
|
# Fetch current message content
|
|
current_content = await self._fetch_message_content(client, session_id)
|
|
|
|
# Send new content to callback
|
|
if output_callback and current_content and len(current_content) > already_sent_length:
|
|
new_content = current_content[already_sent_length:]
|
|
already_sent_length = len(current_content)
|
|
logger.info(f"Session {session_id}: Polling - {len(new_content)} new chars via callback")
|
|
try:
|
|
await output_callback(new_content)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Output callback error: {e}")
|
|
|
|
# Check if message is complete by looking at latest message status
|
|
messages_response = await client.session.messages(session_id)
|
|
if messages_response:
|
|
for msg in reversed(messages_response):
|
|
info = _safe_get(msg, 'info')
|
|
if info and _safe_get(info, 'role') == 'assistant':
|
|
if self._is_message_completed(info, f"Session {session_id} poll {poll_count + 1}"):
|
|
logger.info(f"Session {session_id}: Message completed after {poll_count + 1} polls")
|
|
final_content = await self._fetch_message_content(client, session_id)
|
|
# Send any remaining content to callback
|
|
if output_callback and final_content and len(final_content) > already_sent_length:
|
|
try:
|
|
await output_callback(final_content[already_sent_length:])
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Final output callback error: {e}")
|
|
return {"content": final_content or ''.join(already_received)}
|
|
break
|
|
|
|
if (poll_count + 1) % 30 == 0:
|
|
logger.info(f"Session {session_id}: Still polling... ({poll_count + 1} polls)")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Poll error: {e}")
|
|
|
|
# Sleep at the end of the loop for better responsiveness on first poll
|
|
await asyncio.sleep(poll_interval)
|
|
|
|
# Timeout - return whatever we have
|
|
logger.warning(f"Session {session_id}: Polling timeout, returning accumulated content")
|
|
final_content = await self._fetch_message_content(client, session_id)
|
|
return {"content": final_content or ''.join(already_received)}
|
|
|
|
async def _send_message_polling(
|
|
self,
|
|
session_id: str,
|
|
message: str,
|
|
output_callback: Callable[[str], Awaitable[None]]
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Send a message with polling-based streaming output (fallback).
|
|
|
|
Uses polling via _wait_for_completion which checks for new content
|
|
at regular intervals.
|
|
|
|
Args:
|
|
session_id: Session ID from create_session
|
|
message: User message to send
|
|
output_callback: Async callback to receive streaming output chunks
|
|
|
|
Returns:
|
|
Dict with "content" (full response)
|
|
"""
|
|
if session_id not in self._sessions:
|
|
raise OpenCodeError(f"Session {session_id} not found")
|
|
|
|
session_data = self._sessions[session_id]
|
|
agent_name = session_data["agent"]
|
|
|
|
logger.info(f"Session {session_id} ({agent_name}): Starting polling-based streaming")
|
|
|
|
client = await self._get_client()
|
|
|
|
# Build message parts
|
|
parts: List[Dict[str, Any]] = [
|
|
{"type": "text", "text": message}
|
|
]
|
|
|
|
# Enable all MCP tools
|
|
tools: Dict[str, bool] = {"*": True}
|
|
|
|
# Send chat message - this returns immediately, agent runs async
|
|
response = await client.session.chat(
|
|
session_id,
|
|
model_id=self.model_id,
|
|
provider_id=self.provider_id,
|
|
parts=parts,
|
|
mode=agent_name,
|
|
system=session_data["system_prompt"],
|
|
tools=tools,
|
|
)
|
|
|
|
# Check for errors in the response
|
|
if hasattr(response, 'error') and response.error:
|
|
error_msg = str(response.error)
|
|
logger.error(f"OpenCode response error: {error_msg}")
|
|
raise OpenCodeError(f"Agent error: {error_msg}")
|
|
|
|
# Poll for completion, streaming content via callback
|
|
await self._wait_for_completion(
|
|
client, session_id, response, timeout_seconds=120, output_callback=output_callback
|
|
)
|
|
|
|
# Fetch the final complete content
|
|
content = await self._fetch_message_content(client, session_id)
|
|
|
|
logger.info(f"Session {session_id}: Polling streaming complete ({len(content) if content else 0} chars)")
|
|
|
|
return {"content": content}
|
|
|
|
def _is_message_completed(self, info: Any, context: str = "") -> bool:
|
|
"""
|
|
Check if a message is completed using multiple strategies.
|
|
Handles both dict and object info representations.
|
|
|
|
Checks in order:
|
|
1. time.completed is not None (primary - matches actual API response)
|
|
2. status == "completed"
|
|
3. completed_at timestamp exists
|
|
|
|
Args:
|
|
info: Message info object or dict
|
|
context: Context string for debug logging
|
|
|
|
Returns:
|
|
True if message is completed, False otherwise
|
|
"""
|
|
if info is None:
|
|
return False
|
|
|
|
# Debug log: dump available fields
|
|
if logger.isEnabledFor(logging.DEBUG):
|
|
info_fields = {}
|
|
for attr in ['id', 'role', 'status', 'completed_at', 'created_at', 'time', 'error', 'finish']:
|
|
info_fields[attr] = _safe_get(info, attr, '<not found>')
|
|
logger.debug(f"{context} - info fields: {info_fields}")
|
|
|
|
# Strategy 1: Check time.completed (primary - matches actual API response)
|
|
time_info = _safe_get(info, 'time')
|
|
if time_info:
|
|
completed = _safe_get(time_info, 'completed')
|
|
if completed is not None:
|
|
logger.debug(f"{context} - time.completed={completed}")
|
|
return True
|
|
|
|
# Strategy 2: Check finish field (e.g. finish='stop' means done)
|
|
finish = _safe_get(info, 'finish')
|
|
if finish == 'stop':
|
|
logger.debug(f"{context} - finish={finish}")
|
|
return True
|
|
|
|
# Strategy 3: Check status field
|
|
status = _safe_get(info, 'status')
|
|
if status is not None:
|
|
logger.debug(f"{context} - status={status}")
|
|
if status == "completed":
|
|
return True
|
|
if status in ("in_progress", "incomplete"):
|
|
return False
|
|
|
|
# Strategy 4: Check completed_at timestamp
|
|
completed_at = _safe_get(info, 'completed_at')
|
|
if completed_at is not None:
|
|
logger.debug(f"{context} - completed_at={completed_at}")
|
|
return True
|
|
|
|
return False
|
|
|
|
async def _wait_for_completion(
|
|
self,
|
|
client: Any,
|
|
session_id: str,
|
|
initial_response: Any,
|
|
timeout_seconds: int,
|
|
output_callback: Optional[Callable[[str], Awaitable[None]]] = None
|
|
) -> None:
|
|
"""
|
|
Wait for agent to complete processing while streaming output.
|
|
|
|
The session.chat() method returns immediately with an AssistantMessage.
|
|
We need to poll session.messages() until the message is complete.
|
|
During polling, we stream any new output via the callback.
|
|
|
|
Completion is determined by:
|
|
1. status == "completed"
|
|
2. completed_at timestamp exists
|
|
3. time.completed exists (legacy fallback)
|
|
|
|
Args:
|
|
client: OpenCode client instance
|
|
session_id: Session ID
|
|
initial_response: Initial AssistantMessage from session.chat()
|
|
timeout_seconds: Maximum time to wait for completion
|
|
output_callback: Optional async callback to receive streaming output
|
|
"""
|
|
# Extract message info - could be nested under .info (dict or object)
|
|
initial_info = _safe_get(initial_response, 'info', initial_response)
|
|
|
|
# Get the message ID from info (where the real data lives)
|
|
message_id = _safe_get(initial_info, 'id') or _safe_get(initial_response, 'id')
|
|
|
|
# Check for error in initial response
|
|
self._check_response_for_error(initial_response, "initial response")
|
|
|
|
# Debug log initial response structure
|
|
if logger.isEnabledFor(logging.DEBUG):
|
|
logger.debug(f"Session {session_id}: Initial response type={type(initial_response).__name__}")
|
|
if hasattr(initial_response, 'model_dump'):
|
|
try:
|
|
dump = initial_response.model_dump()
|
|
logger.debug(f"Session {session_id}: Initial response dump={dump}")
|
|
except Exception:
|
|
pass
|
|
|
|
# Check if initial response is already complete
|
|
if self._is_message_completed(initial_info, f"Session {session_id} initial"):
|
|
logger.info(f"Session {session_id}: Agent already completed in initial response (message: {message_id})")
|
|
# Still need to stream final content via callback
|
|
if output_callback:
|
|
try:
|
|
final_content = await self._fetch_message_content(client, session_id)
|
|
logger.info(f"Session {session_id}: Fetched final content for callback ({len(final_content) if final_content else 0} chars)")
|
|
if final_content:
|
|
logger.info(f"Session {session_id}: Calling output_callback with final content")
|
|
await output_callback(final_content)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Output callback error on completed message: {e}")
|
|
return
|
|
|
|
# Poll for completion with faster interval for better responsiveness
|
|
poll_interval = 0.3 # seconds (reduced from 1.0 for faster streaming updates)
|
|
max_polls = int(timeout_seconds / poll_interval)
|
|
|
|
# Track last seen content for streaming
|
|
last_content_length = 0
|
|
|
|
logger.info(f"Session {session_id}: Agent still running (message: {message_id}), polling for completion...")
|
|
|
|
for poll_count in range(max_polls):
|
|
await asyncio.sleep(poll_interval)
|
|
|
|
try:
|
|
# Fetch messages to check completion status
|
|
messages_response = await client.session.messages(session_id)
|
|
|
|
if not messages_response:
|
|
logger.debug(f"Session {session_id}: No messages in response (poll {poll_count + 1})")
|
|
continue
|
|
|
|
# Find the specific message by ID, or fall back to last assistant message
|
|
target_message = None
|
|
|
|
for msg in reversed(messages_response):
|
|
info = _safe_get(msg, 'info')
|
|
if not info:
|
|
continue
|
|
|
|
# Match by message ID if available
|
|
if message_id:
|
|
msg_id = _safe_get(info, 'id')
|
|
if msg_id == message_id:
|
|
target_message = msg
|
|
break
|
|
else:
|
|
# Fallback: match last assistant message
|
|
role = _safe_get(info, 'role')
|
|
if role == 'assistant':
|
|
target_message = msg
|
|
break
|
|
|
|
if target_message:
|
|
# Stream new content via callback
|
|
if output_callback:
|
|
current_content = self._extract_message_text(target_message)
|
|
if len(current_content) > last_content_length:
|
|
new_content = current_content[last_content_length:]
|
|
last_content_length = len(current_content)
|
|
logger.info(f"Session {session_id}: Polling stream - {len(new_content)} new chars via callback")
|
|
try:
|
|
await output_callback(new_content)
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Output callback error: {e}")
|
|
|
|
info = _safe_get(target_message, 'info')
|
|
if info:
|
|
# Check for errors in the message
|
|
self._check_response_for_error(info, f"message {message_id}")
|
|
|
|
# Check completion using multiple strategies
|
|
context = f"Session {session_id} poll {poll_count + 1}"
|
|
if self._is_message_completed(info, context):
|
|
logger.info(f"Session {session_id}: Agent completed after {poll_count + 1}s (message: {message_id})")
|
|
return
|
|
else:
|
|
logger.debug(f"Session {session_id}: No matching assistant message found (poll {poll_count + 1})")
|
|
|
|
# Log progress every 10 polls
|
|
if (poll_count + 1) % 10 == 0:
|
|
logger.info(f"Session {session_id}: Still waiting... ({poll_count + 1}s elapsed)")
|
|
|
|
except OpenCodeError:
|
|
raise
|
|
except Exception as e:
|
|
logger.warning(f"Session {session_id}: Error polling for completion: {e}")
|
|
|
|
# Timeout reached
|
|
logger.warning(f"Session {session_id}: Timeout after {timeout_seconds}s waiting for agent completion")
|
|
raise OpenCodeError(f"Agent timed out after {timeout_seconds} seconds")
|
|
|
|
def _extract_message_text(self, message: Any) -> str:
|
|
"""
|
|
Extract text content from a message object.
|
|
|
|
Args:
|
|
message: Message object from session.messages()
|
|
|
|
Returns:
|
|
Text content extracted from the message
|
|
"""
|
|
texts = []
|
|
parts = _safe_get(message, 'parts') or []
|
|
for part in parts:
|
|
if _safe_get(part, 'type') == 'text':
|
|
text = _safe_get(part, 'text', '')
|
|
if text:
|
|
texts.append(text)
|
|
|
|
if texts:
|
|
return '\n'.join(texts)
|
|
|
|
# Fallback: try model_dump
|
|
if hasattr(message, 'model_dump'):
|
|
try:
|
|
dump = message.model_dump()
|
|
parts_data = dump.get('parts', [])
|
|
for part_data in parts_data:
|
|
if isinstance(part_data, dict) and part_data.get('type') == 'text':
|
|
text = part_data.get('text', '')
|
|
if text:
|
|
texts.append(text)
|
|
if texts:
|
|
return '\n'.join(texts)
|
|
except Exception:
|
|
pass
|
|
|
|
return ''
|
|
|
|
def _check_response_for_error(self, response: Any, context: str = "") -> None:
|
|
"""
|
|
Check response object for error information and raise OpenCodeError if found.
|
|
|
|
Args:
|
|
response: Response object to check (could be message info or full response)
|
|
context: Context string for error messages
|
|
|
|
Raises:
|
|
OpenCodeError: If an error is detected in the response
|
|
"""
|
|
if response is None:
|
|
return
|
|
|
|
# Check for 'error' attribute directly (handles both dict and object)
|
|
error = _safe_get(response, 'error')
|
|
if error:
|
|
error_message = self._extract_error_message(error)
|
|
if error_message:
|
|
logger.error(f"API error detected in {context}: {error_message}")
|
|
raise OpenCodeError(f"API Error: {error_message}")
|
|
|
|
# Check via model_dump if available
|
|
if hasattr(response, 'model_dump'):
|
|
try:
|
|
dump = response.model_dump()
|
|
if isinstance(dump, dict) and 'error' in dump and dump['error']:
|
|
error_data = dump['error']
|
|
error_message = self._extract_error_message_from_dict(error_data)
|
|
if error_message:
|
|
logger.error(f"API error detected in {context} (from dump): {error_message}")
|
|
raise OpenCodeError(f"API Error: {error_message}")
|
|
except Exception as e:
|
|
logger.debug(f"Could not check model_dump for errors: {e}")
|
|
|
|
def _extract_error_message(self, error: Any) -> Optional[str]:
|
|
"""
|
|
Extract a human-readable error message from an error object or dict.
|
|
|
|
Args:
|
|
error: Error object or dict (could be ProviderAuthError, APIError, etc.)
|
|
|
|
Returns:
|
|
Error message string or None
|
|
"""
|
|
if error is None:
|
|
return None
|
|
|
|
if isinstance(error, dict):
|
|
return self._extract_error_message_from_dict(error)
|
|
|
|
# Try to get error name/type
|
|
error_name = getattr(error, 'name', None) or type(error).__name__
|
|
|
|
# Try to get data.message or data.error
|
|
data = _safe_get(error, 'data')
|
|
if data:
|
|
message = _safe_get(data, 'message') or _safe_get(data, 'error')
|
|
if message:
|
|
return f"{error_name}: {message}"
|
|
|
|
# Try direct message attribute
|
|
message = _safe_get(error, 'message')
|
|
if message:
|
|
return f"{error_name}: {message}"
|
|
|
|
# Check for common auth error patterns
|
|
if 'Auth' in error_name or 'auth' in str(error).lower():
|
|
return f"{error_name}: Authentication failed - check your OPENCODE_API_KEY"
|
|
|
|
# Fallback to string representation
|
|
error_str = str(error)
|
|
if error_str and error_str != str(type(error)):
|
|
if len(error_str) > 200:
|
|
error_str = error_str[:200] + "..."
|
|
return f"{error_name}: {error_str}"
|
|
|
|
return error_name
|
|
|
|
def _extract_error_message_from_dict(self, error_data: Any) -> Optional[str]:
|
|
"""
|
|
Extract error message from a dict representation of an error.
|
|
|
|
Args:
|
|
error_data: Error data as dict
|
|
|
|
Returns:
|
|
Error message string or None
|
|
"""
|
|
if not isinstance(error_data, dict):
|
|
return str(error_data) if error_data else None
|
|
|
|
error_name = error_data.get('name', 'Error')
|
|
|
|
# Check for nested data
|
|
data = error_data.get('data', {})
|
|
if isinstance(data, dict):
|
|
message = data.get('message') or data.get('error')
|
|
if message:
|
|
return f"{error_name}: {message}"
|
|
|
|
# Direct message
|
|
message = error_data.get('message') or error_data.get('error')
|
|
if message:
|
|
return f"{error_name}: {message}"
|
|
|
|
# Check for auth errors
|
|
if 'Auth' in error_name or 'ProviderAuth' in error_name:
|
|
return f"{error_name}: Authentication failed - check your OPENCODE_API_KEY"
|
|
|
|
return error_name if error_name != 'Error' else None
|
|
|
|
async def _fetch_message_content(self, client: Any, session_id: str) -> str:
|
|
"""
|
|
Fetch actual message content from session messages.
|
|
|
|
The session.chat() method returns AssistantMessage which only contains metadata.
|
|
To get actual text content, we need to call session.messages() and extract
|
|
TextPart content from the last assistant message.
|
|
|
|
Args:
|
|
client: OpenCode client instance
|
|
session_id: Session ID
|
|
|
|
Returns:
|
|
Extracted text content from the last assistant message
|
|
|
|
Raises:
|
|
OpenCodeError: If an API error is detected in the message
|
|
"""
|
|
try:
|
|
# Fetch all messages for the session
|
|
messages_response = await client.session.messages(session_id)
|
|
|
|
if not messages_response:
|
|
logger.warning(f"No messages found for session {session_id}")
|
|
return ""
|
|
|
|
# Find the last assistant message
|
|
last_assistant_message = None
|
|
for msg in messages_response:
|
|
info = _safe_get(msg, 'info')
|
|
if info and _safe_get(info, 'role') == 'assistant':
|
|
last_assistant_message = msg
|
|
|
|
if not last_assistant_message:
|
|
logger.warning(f"No assistant message found for session {session_id}")
|
|
return ""
|
|
|
|
# Check for errors
|
|
info = _safe_get(last_assistant_message, 'info')
|
|
if info:
|
|
self._check_response_for_error(info, f"session {session_id}")
|
|
|
|
# Extract text from parts
|
|
texts = []
|
|
parts = _safe_get(last_assistant_message, 'parts') or []
|
|
for part in parts:
|
|
if _safe_get(part, 'type') == 'text':
|
|
text = _safe_get(part, 'text', '')
|
|
if text:
|
|
texts.append(text)
|
|
|
|
if texts:
|
|
return '\n'.join(texts)
|
|
|
|
# Fallback: try to extract from dict representation
|
|
if hasattr(last_assistant_message, 'model_dump'):
|
|
dump = last_assistant_message.model_dump()
|
|
parts_data = dump.get('parts', [])
|
|
for part_data in parts_data:
|
|
if isinstance(part_data, dict) and part_data.get('type') == 'text':
|
|
text = part_data.get('text', '')
|
|
if text:
|
|
texts.append(text)
|
|
if texts:
|
|
return '\n'.join(texts)
|
|
|
|
logger.warning(f"Session {session_id}: No text content found in assistant message")
|
|
return ""
|
|
|
|
except OpenCodeError:
|
|
raise
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch message content: {e}")
|
|
return ""
|
|
|
|
def _extract_response_content(self, response: Any) -> str:
|
|
"""
|
|
Extract text content from OpenCode AssistantMessage response.
|
|
|
|
The response structure may vary, so we try multiple approaches.
|
|
Based on SDK investigation, AssistantMessage has:
|
|
- id, cost, mode, api_model_id, path, provider_id, role, session_id,
|
|
system, time, tokens, error, summary
|
|
- The text content is typically in path.parts
|
|
"""
|
|
# First, check path.parts which is the primary location for content
|
|
if hasattr(response, 'path'):
|
|
path = response.path
|
|
if hasattr(path, 'parts'):
|
|
content = self._extract_parts_content(path.parts)
|
|
if content:
|
|
logger.debug("Extracted content from response.path.parts")
|
|
return content
|
|
|
|
# Try direct parts attribute
|
|
if hasattr(response, 'parts'):
|
|
parts = response.parts
|
|
if isinstance(parts, list):
|
|
texts = []
|
|
for part in parts:
|
|
if isinstance(part, dict) and part.get('type') == 'text':
|
|
texts.append(part.get('text', ''))
|
|
elif hasattr(part, 'text'):
|
|
texts.append(str(part.text))
|
|
if texts:
|
|
logger.debug("Extracted content from response.parts")
|
|
return '\n'.join(texts)
|
|
|
|
# Try content attribute
|
|
if hasattr(response, 'content'):
|
|
content = response.content
|
|
if content:
|
|
logger.debug("Extracted content from response.content")
|
|
return str(content)
|
|
|
|
# Try text attribute
|
|
if hasattr(response, 'text'):
|
|
text = response.text
|
|
if text:
|
|
logger.debug("Extracted content from response.text")
|
|
return str(text)
|
|
|
|
# Fallback: convert response to string via model_dump
|
|
if hasattr(response, 'model_dump'):
|
|
logger.debug("Falling back to model_dump for content extraction")
|
|
dump = response.model_dump()
|
|
# Try to extract text from the dumped structure
|
|
# Add null check for dump['path'] to prevent NoneType error
|
|
if 'path' in dump and dump['path'] is not None and 'parts' in dump['path']:
|
|
parts = dump['path']['parts']
|
|
if parts:
|
|
texts = [p.get('text', '') for p in parts if p.get('type') == 'text']
|
|
if texts:
|
|
return '\n'.join(texts)
|
|
return json.dumps(dump, indent=2)
|
|
|
|
logger.warning("Could not extract structured content, using str()")
|
|
return str(response)
|
|
|
|
def _extract_parts_content(self, parts: Any) -> str:
|
|
"""Extract text content from message parts"""
|
|
if not parts:
|
|
return ""
|
|
|
|
texts = []
|
|
for part in parts:
|
|
if isinstance(part, dict):
|
|
if part.get('type') == 'text':
|
|
texts.append(part.get('text', ''))
|
|
elif hasattr(part, 'text'):
|
|
texts.append(str(part.text))
|
|
elif hasattr(part, 'content'):
|
|
texts.append(str(part.content))
|
|
|
|
return '\n'.join(texts)
|
|
|
|
def _parse_sse_line(self, sse_line: str) -> Optional[str]:
|
|
"""
|
|
Parse a single SSE (Server-Sent Events) line and extract text content.
|
|
|
|
SSE line format is 'data: {json}' or 'data: [DONE]'.
|
|
The JSON payload may contain text in various formats:
|
|
- {"content": "text"}
|
|
- {"text": "text"}
|
|
- {"delta": {"content": "text"}}
|
|
- {"choices": [{"delta": {"content": "text"}}]}
|
|
- {"parts": [{"type": "text", "text": "text"}]}
|
|
|
|
Args:
|
|
sse_line: Single SSE line (from iter_lines())
|
|
|
|
Returns:
|
|
Extracted text content or None if no text found
|
|
"""
|
|
if not sse_line:
|
|
return None
|
|
|
|
line = sse_line.strip()
|
|
|
|
# Skip empty lines and non-data lines (e.g., event:, id:, retry:)
|
|
if not line or not line.startswith(_SSE_DATA_PREFIX):
|
|
return None
|
|
|
|
# Extract the data after 'data:'
|
|
data_str = line[len(_SSE_DATA_PREFIX):].strip()
|
|
|
|
# Skip [DONE] marker
|
|
if data_str == _SSE_DONE_MARKER:
|
|
return None
|
|
|
|
# Skip empty data
|
|
if not data_str:
|
|
return None
|
|
|
|
# Try to parse as JSON
|
|
try:
|
|
data = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
# Not JSON, might be raw text - log and skip
|
|
logger.debug(f"SSE data is not JSON: {data_str[:100]}")
|
|
return None
|
|
|
|
# Extract text from various JSON structures
|
|
return self._extract_text_from_sse_json(data)
|
|
|
|
def _extract_text_from_sse_json(self, data: Any) -> Optional[str]:
|
|
"""
|
|
Extract text content from SSE JSON payload.
|
|
|
|
Handles various response formats from different AI providers.
|
|
|
|
Args:
|
|
data: Parsed JSON data from SSE
|
|
|
|
Returns:
|
|
Extracted text or None
|
|
"""
|
|
if not isinstance(data, dict):
|
|
return None
|
|
|
|
# Direct content field (common format)
|
|
if 'content' in data and isinstance(data['content'], str):
|
|
return data['content']
|
|
|
|
# Direct text field
|
|
if 'text' in data and isinstance(data['text'], str):
|
|
return data['text']
|
|
|
|
# Delta format (OpenAI streaming style)
|
|
delta = data.get('delta')
|
|
if isinstance(delta, dict):
|
|
if 'content' in delta and isinstance(delta['content'], str):
|
|
return delta['content']
|
|
if 'text' in delta and isinstance(delta['text'], str):
|
|
return delta['text']
|
|
|
|
# Choices array format (OpenAI chat completions)
|
|
choices = data.get('choices')
|
|
if isinstance(choices, list) and choices:
|
|
first_choice = choices[0]
|
|
if isinstance(first_choice, dict):
|
|
choice_delta = first_choice.get('delta')
|
|
if isinstance(choice_delta, dict):
|
|
content = choice_delta.get('content')
|
|
if isinstance(content, str):
|
|
return content
|
|
|
|
# Parts array format (OpenCode style)
|
|
parts = data.get('parts')
|
|
if isinstance(parts, list):
|
|
texts = []
|
|
for part in parts:
|
|
if isinstance(part, dict) and part.get('type') == 'text':
|
|
text = part.get('text')
|
|
if isinstance(text, str):
|
|
texts.append(text)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# Message content (nested format)
|
|
message = data.get('message')
|
|
if isinstance(message, dict):
|
|
content = message.get('content')
|
|
if isinstance(content, str):
|
|
return content
|
|
|
|
return None
|
|
|
|
def _extract_text_from_chunk(self, chunk: Any) -> Optional[str]:
|
|
"""
|
|
Extract text content from a streaming chunk.
|
|
|
|
OpenCode SDK returns Part objects during streaming with various formats:
|
|
- Dict with type="text" and text field
|
|
- Object with .type and .text attributes
|
|
- Parts array format
|
|
- Delta format (OpenAI style)
|
|
|
|
Args:
|
|
chunk: Streaming chunk from session.chat()
|
|
|
|
Returns:
|
|
Extracted text or None
|
|
"""
|
|
if chunk is None:
|
|
return None
|
|
|
|
# String chunk - return as-is
|
|
if isinstance(chunk, str):
|
|
return chunk if chunk.strip() else None
|
|
|
|
# Dict format
|
|
if isinstance(chunk, dict):
|
|
# Direct text part: {"type": "text", "text": "..."}
|
|
if chunk.get('type') == 'text' and 'text' in chunk:
|
|
return chunk['text']
|
|
|
|
# Direct text field
|
|
if 'text' in chunk and isinstance(chunk['text'], str):
|
|
return chunk['text']
|
|
|
|
# Direct content field
|
|
if 'content' in chunk and isinstance(chunk['content'], str):
|
|
return chunk['content']
|
|
|
|
# Parts array: {"parts": [{"type": "text", "text": "..."}]}
|
|
parts = chunk.get('parts')
|
|
if isinstance(parts, list):
|
|
texts = []
|
|
for part in parts:
|
|
if isinstance(part, dict) and part.get('type') == 'text':
|
|
text = part.get('text')
|
|
if text:
|
|
texts.append(text)
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# Delta format: {"delta": {"content": "..."}}
|
|
delta = chunk.get('delta')
|
|
if isinstance(delta, dict):
|
|
content = delta.get('content') or delta.get('text')
|
|
if isinstance(content, str):
|
|
return content
|
|
|
|
# Choices format: {"choices": [{"delta": {"content": "..."}}]}
|
|
choices = chunk.get('choices')
|
|
if isinstance(choices, list) and choices:
|
|
first = choices[0]
|
|
if isinstance(first, dict):
|
|
delta = first.get('delta', {})
|
|
if isinstance(delta, dict):
|
|
content = delta.get('content')
|
|
if isinstance(content, str):
|
|
return content
|
|
|
|
return None
|
|
|
|
# Object with attributes (Part object from SDK)
|
|
# Check for type="text" with text attribute
|
|
chunk_type = getattr(chunk, 'type', None)
|
|
if chunk_type == 'text':
|
|
text = getattr(chunk, 'text', None)
|
|
if text:
|
|
return str(text)
|
|
|
|
# Try direct text attribute
|
|
text = getattr(chunk, 'text', None)
|
|
if text:
|
|
return str(text)
|
|
|
|
# Try content attribute
|
|
content = getattr(chunk, 'content', None)
|
|
if content:
|
|
return str(content)
|
|
|
|
# Try parts attribute
|
|
parts = getattr(chunk, 'parts', None)
|
|
if parts and hasattr(parts, '__iter__'):
|
|
texts = []
|
|
for part in parts:
|
|
part_type = _safe_get(part, 'type')
|
|
if part_type == 'text':
|
|
text = _safe_get(part, 'text')
|
|
if text:
|
|
texts.append(str(text))
|
|
if texts:
|
|
return ''.join(texts)
|
|
|
|
# Try model_dump for Pydantic models
|
|
if hasattr(chunk, 'model_dump'):
|
|
try:
|
|
return self._extract_text_from_chunk(chunk.model_dump())
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
async def stream_response(self, session_id: str, message: str) -> AsyncIterator[str]:
|
|
"""
|
|
Stream response from agent.
|
|
|
|
Tries two approaches:
|
|
1. Direct async iteration on session.chat() - native SDK streaming
|
|
2. Fallback to with_streaming_response with SSE parsing
|
|
|
|
Args:
|
|
session_id: Session ID from create_session
|
|
message: User message to send
|
|
|
|
Yields:
|
|
Response text chunks
|
|
"""
|
|
if session_id not in self._sessions:
|
|
raise OpenCodeError(f"Session {session_id} not found")
|
|
|
|
session_data = self._sessions[session_id]
|
|
agent_name = session_data["agent"]
|
|
|
|
try:
|
|
client = await self._get_client()
|
|
|
|
logger.info(f"[STREAM_TRACE] Session {session_id} ({agent_name}): Starting stream_response")
|
|
|
|
# Build message parts
|
|
parts: List[Dict[str, Any]] = [
|
|
{"type": "text", "text": message}
|
|
]
|
|
|
|
# Enable all MCP tools
|
|
tools: Dict[str, bool] = {"*": True}
|
|
|
|
chunk_count = 0
|
|
text_chunk_count = 0
|
|
|
|
# APPROACH 1: Try direct async iteration on session.chat()
|
|
# OpenCode SDK may support: async for chunk in client.session.chat(...)
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Trying direct async iteration on session.chat()...")
|
|
|
|
approach1_yielded_data = False
|
|
|
|
try:
|
|
chat_result = client.session.chat(
|
|
session_id,
|
|
model_id=self.model_id,
|
|
provider_id=self.provider_id,
|
|
parts=parts,
|
|
mode=session_data["agent"],
|
|
system=session_data["system_prompt"],
|
|
tools=tools,
|
|
)
|
|
|
|
# Check if result is async iterable (streaming supported)
|
|
if hasattr(chat_result, '__aiter__'):
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: session.chat() returned async iterable, streaming directly...")
|
|
|
|
async for chunk in chat_result:
|
|
chunk_count += 1
|
|
|
|
# Log chunk info for debugging
|
|
chunk_type = type(chunk).__name__
|
|
chunk_preview = str(chunk)[:200] if chunk else 'None'
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Chunk {chunk_count} (type={chunk_type}): {chunk_preview}")
|
|
|
|
# Extract text from chunk
|
|
text = self._extract_text_from_chunk(chunk)
|
|
|
|
if text:
|
|
text_chunk_count += 1
|
|
approach1_yielded_data = True
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Extracted text ({len(text)} chars): {text[:100]}...")
|
|
yield text
|
|
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Direct streaming COMPLETE via APPROACH 1 ({chunk_count} chunks, {text_chunk_count} with text)")
|
|
return
|
|
|
|
# If not async iterable, it might be awaitable (single response)
|
|
elif hasattr(chat_result, '__await__'):
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: session.chat() returned awaitable, awaiting...")
|
|
response = await chat_result
|
|
|
|
# Check if response is async iterable
|
|
if hasattr(response, '__aiter__'):
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Awaited response is async iterable, streaming...")
|
|
async for chunk in response:
|
|
chunk_count += 1
|
|
text = self._extract_text_from_chunk(chunk)
|
|
if text:
|
|
text_chunk_count += 1
|
|
approach1_yielded_data = True
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Extracted text ({len(text)} chars)")
|
|
yield text
|
|
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Awaited streaming COMPLETE via APPROACH 1 ({chunk_count} chunks, {text_chunk_count} with text)")
|
|
return
|
|
else:
|
|
# Single response object - extract text
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Got single response via APPROACH 1, extracting text...")
|
|
text = self._extract_text_from_chunk(response)
|
|
if text:
|
|
approach1_yielded_data = True
|
|
yield text
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Single response COMPLETE via APPROACH 1")
|
|
return
|
|
else:
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: session.chat() returned non-iterable: {type(chat_result).__name__}")
|
|
|
|
except TypeError as e:
|
|
# TypeError usually means it's not async iterable
|
|
if approach1_yielded_data:
|
|
logger.warning(f"[STREAM_TRACE] Session {session_id}: APPROACH 1 failed after yielding {text_chunk_count} chunks (TypeError: {e}), NOT retrying to avoid duplicates")
|
|
return
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Direct iteration failed (TypeError: {e}), trying SSE fallback...")
|
|
except Exception as e:
|
|
if approach1_yielded_data:
|
|
logger.warning(f"[STREAM_TRACE] Session {session_id}: APPROACH 1 failed after yielding {text_chunk_count} chunks ({type(e).__name__}: {e}), NOT retrying to avoid duplicates")
|
|
return
|
|
logger.warning(f"[STREAM_TRACE] Session {session_id}: Direct iteration failed ({type(e).__name__}: {e}), trying SSE fallback...")
|
|
|
|
# APPROACH 2: Fallback to with_streaming_response with SSE parsing
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Using with_streaming_response fallback...")
|
|
|
|
chunk_count = 0
|
|
text_chunk_count = 0
|
|
raw_yield_count = 0
|
|
|
|
async with client.session.with_streaming_response.chat(
|
|
session_id,
|
|
model_id=self.model_id,
|
|
provider_id=self.provider_id,
|
|
parts=parts,
|
|
mode=session_data["agent"],
|
|
system=session_data["system_prompt"],
|
|
tools=tools,
|
|
) as response:
|
|
# Log response status for debugging
|
|
status_code = getattr(response, 'status_code', None)
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: SSE streaming context opened (status_code={status_code}), iterating lines...")
|
|
|
|
async for raw_line in response.iter_lines():
|
|
chunk_count += 1
|
|
|
|
if not raw_line:
|
|
continue
|
|
|
|
line_preview = raw_line[:200] if len(raw_line) > 200 else raw_line
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: SSE Line {chunk_count}: {line_preview}")
|
|
|
|
# Try to parse as SSE
|
|
text = self._parse_sse_line(raw_line)
|
|
|
|
if text:
|
|
text_chunk_count += 1
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Parsed SSE text ({len(text)} chars): {text[:100]}...")
|
|
yield text
|
|
else:
|
|
# Fallback: yield raw content if it looks meaningful
|
|
stripped_line = raw_line.strip()
|
|
if stripped_line and not stripped_line.startswith(('event:', 'id:', 'retry:', ':')):
|
|
if not stripped_line.startswith('data:') or len(stripped_line) > 6:
|
|
raw_yield_count += 1
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: Yielding raw line: {stripped_line[:100]}")
|
|
yield stripped_line
|
|
|
|
logger.info(f"[STREAM_TRACE] Session {session_id}: SSE streaming COMPLETE via APPROACH 2 ({chunk_count} lines, {text_chunk_count} parsed, {raw_yield_count} raw)")
|
|
|
|
except OpenCodeError:
|
|
logger.error(f"[STREAM_TRACE] Session {session_id}: OpenCodeError in stream_response")
|
|
raise
|
|
except Exception as e:
|
|
server_url = self.base_url or "default"
|
|
logger.error(f"[STREAM_TRACE] Session {session_id}: Exception in stream_response: {e}")
|
|
raise OpenCodeError(
|
|
f"Failed to stream response (is OpenCode server running at {server_url}?): {e}"
|
|
)
|
|
|
|
async def close_session(self, session_id: str):
|
|
"""
|
|
Close a session.
|
|
|
|
Args:
|
|
session_id: Session ID to close
|
|
"""
|
|
if session_id in self._sessions:
|
|
try:
|
|
client = await self._get_client()
|
|
await client.session.delete(session_id)
|
|
except Exception as e:
|
|
logger.warning(f"Error deleting session {session_id}: {e}")
|
|
finally:
|
|
del self._sessions[session_id]
|
|
logger.info(f"Closed session {session_id}")
|
|
|
|
async def close(self):
|
|
"""Close the client and all sessions"""
|
|
# Close all sessions
|
|
session_ids = list(self._sessions.keys())
|
|
for session_id in session_ids:
|
|
await self.close_session(session_id)
|
|
|
|
# Close the client
|
|
await self._close_client()
|