231 lines
7.2 KiB
Python
231 lines
7.2 KiB
Python
from mcp import ClientSession, StdioServerParameters
|
|
from mcp.client.stdio import stdio_client
|
|
import asyncio
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from openai import AzureOpenAI
|
|
import json
|
|
from contextlib import AsyncExitStack
|
|
from typing import Dict, List
|
|
from mcp.client.sse import sse_client
|
|
|
|
load_dotenv()
|
|
|
|
|
|
model = AzureOpenAI(
|
|
api_version="2024-12-01-preview",
|
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
|
)
|
|
|
|
|
|
class Chatbot:
|
|
|
|
def __init__(self):
|
|
self.model = model
|
|
self.sessions = [] # multiple sessions for multiple servers
|
|
self.available_tools = []
|
|
self.tool_to_session: Dict[str, ClientSession] = (
|
|
{}
|
|
) # maps tool names to their respective sessions
|
|
self.messages = []
|
|
self.exit_stack = AsyncExitStack()
|
|
|
|
async def connect_to_server(self, server_config: dict) -> None:
|
|
|
|
if server_config['type'] == "uv":
|
|
server_params = StdioServerParameters(**server_config)
|
|
|
|
transport = await self.exit_stack.enter_async_context(
|
|
stdio_client(server_params)
|
|
)
|
|
|
|
elif server_config['type'] == "sse":
|
|
transport = await self.exit_stack.enter_async_context(
|
|
sse_client(
|
|
url=server_config["url"],
|
|
)
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported server type: {server_config['type']}")
|
|
|
|
read, write = transport
|
|
|
|
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
|
|
|
await session.initialize()
|
|
|
|
self.sessions.append(session)
|
|
|
|
# Load MCP tools
|
|
response = await session.list_tools()
|
|
tools = response.tools
|
|
|
|
# Clear previous tools and populate with new ones
|
|
|
|
for tool in tools:
|
|
self.tool_to_session[tool.name] = session
|
|
self.available_tools.append(
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"parameters": tool.inputSchema,
|
|
},
|
|
}
|
|
)
|
|
|
|
print(
|
|
f"Loaded {len(self.available_tools)} tools: "
|
|
f"{[tool['function']['name'] for tool in self.available_tools]}"
|
|
)
|
|
|
|
async def connect_to_servers(self):
|
|
|
|
try:
|
|
with open("server_config.json", "r") as file:
|
|
data = json.load(file)
|
|
|
|
servers = data.get("mcp_servers", {})
|
|
|
|
for server_name, server_config in servers.items():
|
|
await self.connect_to_server(server_config)
|
|
|
|
except Exception as e:
|
|
print(f"Error loading server configuration: {e}")
|
|
raise
|
|
|
|
async def cleanup(self): # new
|
|
"""Cleanly close all resources using AsyncExitStack."""
|
|
await self.exit_stack.aclose()
|
|
|
|
async def handle_tool_call(self, agent_response: dict):
|
|
"""Handle tool calls from the agent response."""
|
|
|
|
tool_messages = []
|
|
|
|
for tool in agent_response.choices[0].message.tool_calls:
|
|
|
|
print(f"\nTool call detected: {tool.id}")
|
|
session = self.tool_to_session.get(tool.function.name)
|
|
result = await session.call_tool(
|
|
tool.function.name,
|
|
arguments=json.loads(tool.function.arguments),
|
|
)
|
|
|
|
if hasattr(result, "content") and result.content:
|
|
if isinstance(result.content, list):
|
|
tool_result_content = "\n".join(
|
|
[
|
|
(str(item.text) if hasattr(item, "text") else str(item))
|
|
for item in result.content
|
|
]
|
|
)
|
|
else:
|
|
tool_result_content = str(result.content)
|
|
tool_messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool.id,
|
|
"content": tool_result_content,
|
|
}
|
|
)
|
|
|
|
return tool_messages
|
|
|
|
def call_model(self):
|
|
"""Call the model with the given messages and tools."""
|
|
return model.chat.completions.create(
|
|
model="gpt-4o",
|
|
messages=self.messages,
|
|
max_tokens=2024,
|
|
tools=self.available_tools,
|
|
)
|
|
|
|
async def run(self):
|
|
# Initialize session and tools
|
|
|
|
try:
|
|
|
|
user_next = True
|
|
while True:
|
|
|
|
if user_next:
|
|
# Get user input
|
|
user_input = input("Enter your question (or 'exit' to quit): ")
|
|
|
|
if user_input.lower() == "exit":
|
|
print("Exiting the agent.")
|
|
break
|
|
|
|
self.messages.append({"role": "user", "content": user_input})
|
|
|
|
try:
|
|
# Use asyncio.wait_for to add a timeout
|
|
agent_response = self.call_model()
|
|
|
|
# check if the model calls a tool
|
|
|
|
if agent_response.choices[0].message.tool_calls:
|
|
|
|
self.messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": agent_response.choices[0].message.content,
|
|
"tool_calls": agent_response.choices[
|
|
0
|
|
].message.tool_calls,
|
|
}
|
|
)
|
|
|
|
tool_messages = await self.handle_tool_call(agent_response)
|
|
|
|
self.messages.extend(tool_messages)
|
|
|
|
# After tool call, the assistant gets the next turn, not the user
|
|
user_next = False
|
|
|
|
continue
|
|
|
|
# If no tool calls, just append the assistant's response
|
|
else:
|
|
self.messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": agent_response.choices[0].message.content,
|
|
}
|
|
)
|
|
print(f"\nAgent: {agent_response.choices[0].message.content}")
|
|
user_next = True
|
|
|
|
except asyncio.TimeoutError:
|
|
print("\nAgent response timed out after 60 seconds.")
|
|
except Exception as e:
|
|
print(f"\nError occurred: {e}")
|
|
# Remove the last user message if there was an error
|
|
if self.messages:
|
|
self.messages.pop()
|
|
user_next = True # Reset to get user input again after an error
|
|
finally:
|
|
# Cleanup session and connections
|
|
await self.cleanup()
|
|
|
|
|
|
async def main():
|
|
|
|
chat_bot = Chatbot()
|
|
|
|
try:
|
|
await chat_bot.connect_to_servers()
|
|
await chat_bot.run()
|
|
except Exception as e:
|
|
print(f"An error occurred: {e}")
|
|
finally:
|
|
await chat_bot.cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|