setup
This commit is contained in:
commit
f395d1ad81
2
.env
Normal file
2
.env
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
AZURE_OPENAI_API_KEY="DzAmkSsqjo1IAVoQ8kA03uqH5ULMBEjhco6KjCIcQHk44LWkcwlhJQQJ99BGACfhMk5XJ3w3AAAAACOGlyp4"
|
||||||
|
AZURE_OPENAI_ENDPOINT="https://apillar3-3312-resource.cognitiveservices.azure.com/"
|
||||||
13
.gitignore
vendored
Normal file
13
.gitignore
vendored
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
|
|
||||||
|
|
||||||
|
data/
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
|||||||
|
3.11
|
||||||
15
Dockerfile
Normal file
15
Dockerfile
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
FROM python:3.12-slim-bookworm
|
||||||
|
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||||
|
|
||||||
|
# Copy the project into the image
|
||||||
|
ADD . /app
|
||||||
|
|
||||||
|
# Sync the project into a new environment, using the frozen lockfile
|
||||||
|
WORKDIR /app
|
||||||
|
RUN uv sync
|
||||||
|
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["uv", "run", "mcp_server.py"]
|
||||||
7
README.md
Normal file
7
README.md
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# MCP Demo
|
||||||
|
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
`uv sync` and see the magic happen
|
||||||
|
|
||||||
BIN
data.sqlite
Normal file
BIN
data.sqlite
Normal file
Binary file not shown.
365
db_utils.py
Normal file
365
db_utils.py
Normal file
@ -0,0 +1,365 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Database utility module for accessing and querying the SQLite database.
|
||||||
|
Provides convenient functions for data retrieval, analysis, and database operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import pandas as pd
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DatabaseUtils:
|
||||||
|
"""Utility class for database operations."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str = "data.sqlite"):
|
||||||
|
"""
|
||||||
|
Initialize database utilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path (str): Path to the SQLite database file
|
||||||
|
"""
|
||||||
|
self.db_path = db_path
|
||||||
|
self._validate_database()
|
||||||
|
|
||||||
|
def _validate_database(self):
|
||||||
|
"""Validate that the database exists and is accessible."""
|
||||||
|
if not Path(self.db_path).exists():
|
||||||
|
raise FileNotFoundError(f"Database '{self.db_path}' does not exist")
|
||||||
|
|
||||||
|
def get_connection(self) -> sqlite3.Connection:
|
||||||
|
"""Get a database connection."""
|
||||||
|
return sqlite3.connect(self.db_path)
|
||||||
|
|
||||||
|
def get_tables(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get list of all tables in the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of table names
|
||||||
|
"""
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||||
|
return [row[0] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def get_table_info(self, table_name: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get detailed information about a table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name (str): Name of the table
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Table information including columns, row count, etc.
|
||||||
|
"""
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Get column info
|
||||||
|
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||||
|
columns = cursor.fetchall()
|
||||||
|
|
||||||
|
# Get row count
|
||||||
|
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
|
||||||
|
row_count = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'table_name': table_name,
|
||||||
|
'row_count': row_count,
|
||||||
|
'column_count': len(columns),
|
||||||
|
'columns': [{'name': col[1], 'type': col[2], 'not_null': bool(col[3]),
|
||||||
|
'default_value': col[4], 'primary_key': bool(col[5])} for col in columns],
|
||||||
|
'column_names': [col[1] for col in columns]
|
||||||
|
}
|
||||||
|
|
||||||
|
def query(self, sql: str, params: Optional[Tuple] = None) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Execute a SQL query and return results as a DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql (str): SQL query string
|
||||||
|
params (Optional[Tuple]): Parameters for the query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: Query results
|
||||||
|
"""
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
return pd.read_sql_query(sql, conn, params=params)
|
||||||
|
|
||||||
|
def execute(self, sql: str, params: Optional[Tuple] = None) -> int:
|
||||||
|
"""
|
||||||
|
Execute a SQL statement (INSERT, UPDATE, DELETE).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sql (str): SQL statement
|
||||||
|
params (Optional[Tuple]): Parameters for the statement
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of affected rows
|
||||||
|
"""
|
||||||
|
with self.get_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(sql, params or ())
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount
|
||||||
|
|
||||||
|
def get_table_data(self, table_name: str, limit: Optional[int] = None,
|
||||||
|
columns: Optional[List[str]] = None) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Get data from a specific table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name (str): Name of the table
|
||||||
|
limit (Optional[int]): Limit number of rows returned
|
||||||
|
columns (Optional[List[str]]): Specific columns to select
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: Table data
|
||||||
|
"""
|
||||||
|
col_str = ', '.join(columns) if columns else '*'
|
||||||
|
sql = f"SELECT {col_str} FROM {table_name}"
|
||||||
|
if limit:
|
||||||
|
sql += f" LIMIT {limit}"
|
||||||
|
|
||||||
|
return self.query(sql)
|
||||||
|
|
||||||
|
def search_table(self, table_name: str, column: str, value: Any,
|
||||||
|
comparison: str = '=') -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Search for records in a table based on column value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name (str): Name of the table
|
||||||
|
column (str): Column to search in
|
||||||
|
value (Any): Value to search for
|
||||||
|
comparison (str): Comparison operator (=, >, <, LIKE, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: Matching records
|
||||||
|
"""
|
||||||
|
sql = f"SELECT * FROM {table_name} WHERE {column} {comparison} ?"
|
||||||
|
return self.query(sql, (value,))
|
||||||
|
|
||||||
|
def get_database_summary(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get a comprehensive summary of the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Database summary
|
||||||
|
"""
|
||||||
|
tables = self.get_tables()
|
||||||
|
summary = {
|
||||||
|
'database_path': self.db_path,
|
||||||
|
'total_tables': len(tables),
|
||||||
|
'tables': {}
|
||||||
|
}
|
||||||
|
|
||||||
|
total_rows = 0
|
||||||
|
for table in tables:
|
||||||
|
info = self.get_table_info(table)
|
||||||
|
summary['tables'][table] = info
|
||||||
|
total_rows += info['row_count']
|
||||||
|
|
||||||
|
summary['total_rows'] = total_rows
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def print_database_summary(self):
|
||||||
|
"""Print a formatted database summary."""
|
||||||
|
summary = self.get_database_summary()
|
||||||
|
|
||||||
|
print("="*60)
|
||||||
|
print("DATABASE SUMMARY")
|
||||||
|
print("="*60)
|
||||||
|
print(f"Database: {summary['database_path']}")
|
||||||
|
print(f"Total Tables: {summary['total_tables']}")
|
||||||
|
print(f"Total Rows: {summary['total_rows']:,}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
for table_name, info in summary['tables'].items():
|
||||||
|
print(f"Table: {table_name}")
|
||||||
|
print(f" Rows: {info['row_count']:,}")
|
||||||
|
print(f" Columns: {info['column_count']}")
|
||||||
|
print(f" Column Names: {', '.join(info['column_names'])}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Convenience functions for specific business logic
|
||||||
|
class BusinessAnalytics:
|
||||||
|
"""Business-specific analytics functions."""
|
||||||
|
|
||||||
|
def __init__(self, db_utils: DatabaseUtils):
|
||||||
|
self.db = db_utils
|
||||||
|
|
||||||
|
def get_customer_profile(self, customer_id: int) -> Dict[str, Any]:
|
||||||
|
"""Get complete customer profile including purchases and calls."""
|
||||||
|
# Get customer basic info
|
||||||
|
customer = self.db.query(
|
||||||
|
"SELECT * FROM customer_profile_dataset WHERE customer_id = ?",
|
||||||
|
(customer_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
if customer.empty:
|
||||||
|
return {"error": f"Customer {customer_id} not found"}
|
||||||
|
|
||||||
|
# Get purchase history
|
||||||
|
purchases = self.db.query(
|
||||||
|
"""SELECT p.*, pr.product_name, pr.category, pr.brand
|
||||||
|
FROM purchase_history_dataset_with_status p
|
||||||
|
LEFT JOIN products_dataset_with_descriptions pr ON p.product_id = pr.product_id
|
||||||
|
WHERE p.customer_id = ?
|
||||||
|
ORDER BY p.purchase_date DESC""",
|
||||||
|
(customer_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get call summaries
|
||||||
|
calls = self.db.query(
|
||||||
|
"SELECT * FROM customer_call_summaries WHERE customer_id = ?",
|
||||||
|
(customer_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"customer_info": customer.to_dict('records')[0],
|
||||||
|
"purchases": purchases.to_dict('records'),
|
||||||
|
"calls": calls.to_dict('records'),
|
||||||
|
"total_purchases": len(purchases),
|
||||||
|
"total_spent": purchases['total_amount'].sum() if not purchases.empty else 0,
|
||||||
|
"total_calls": len(calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_product_analytics(self, product_id: Optional[int] = None) -> pd.DataFrame:
|
||||||
|
"""Get product sales analytics."""
|
||||||
|
sql = """
|
||||||
|
SELECT
|
||||||
|
pr.product_id,
|
||||||
|
pr.product_name,
|
||||||
|
pr.category,
|
||||||
|
pr.brand,
|
||||||
|
pr.price_per_unit,
|
||||||
|
COUNT(p.purchase_id) as total_purchases,
|
||||||
|
SUM(p.quantity) as total_quantity_sold,
|
||||||
|
SUM(p.total_amount) as total_revenue,
|
||||||
|
AVG(p.total_amount) as avg_order_value,
|
||||||
|
COUNT(DISTINCT p.customer_id) as unique_customers
|
||||||
|
FROM products_dataset_with_descriptions pr
|
||||||
|
LEFT JOIN purchase_history_dataset_with_status p ON pr.product_id = p.product_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
if product_id:
|
||||||
|
sql += " WHERE pr.product_id = ?"
|
||||||
|
return self.db.query(sql + " GROUP BY pr.product_id", (product_id,))
|
||||||
|
else:
|
||||||
|
return self.db.query(sql + " GROUP BY pr.product_id ORDER BY total_revenue DESC")
|
||||||
|
|
||||||
|
def get_customer_analytics(self) -> pd.DataFrame:
|
||||||
|
"""Get customer analytics summary."""
|
||||||
|
return self.db.query("""
|
||||||
|
SELECT
|
||||||
|
c.customer_id,
|
||||||
|
c.first_name,
|
||||||
|
c.last_name,
|
||||||
|
c.city,
|
||||||
|
c.state,
|
||||||
|
COUNT(p.purchase_id) as total_purchases,
|
||||||
|
SUM(p.total_amount) as total_spent,
|
||||||
|
AVG(p.total_amount) as avg_order_value,
|
||||||
|
MIN(p.purchase_date) as first_purchase,
|
||||||
|
MAX(p.purchase_date) as last_purchase,
|
||||||
|
COUNT(DISTINCT p.product_id) as unique_products_bought,
|
||||||
|
COUNT(cs.customer_id) as total_calls
|
||||||
|
FROM customer_profile_dataset c
|
||||||
|
LEFT JOIN purchase_history_dataset_with_status p ON c.customer_id = p.customer_id
|
||||||
|
LEFT JOIN customer_call_summaries cs ON c.customer_id = cs.customer_id
|
||||||
|
GROUP BY c.customer_id
|
||||||
|
ORDER BY total_spent DESC
|
||||||
|
""")
|
||||||
|
|
||||||
|
def get_sales_by_category(self) -> pd.DataFrame:
|
||||||
|
"""Get sales summary by product category."""
|
||||||
|
return self.db.query("""
|
||||||
|
SELECT
|
||||||
|
pr.category,
|
||||||
|
COUNT(p.purchase_id) as total_orders,
|
||||||
|
SUM(p.quantity) as total_quantity,
|
||||||
|
SUM(p.total_amount) as total_revenue,
|
||||||
|
AVG(p.total_amount) as avg_order_value,
|
||||||
|
COUNT(DISTINCT p.customer_id) as unique_customers,
|
||||||
|
COUNT(DISTINCT pr.product_id) as unique_products
|
||||||
|
FROM products_dataset_with_descriptions pr
|
||||||
|
LEFT JOIN purchase_history_dataset_with_status p ON pr.product_id = p.product_id
|
||||||
|
GROUP BY pr.category
|
||||||
|
ORDER BY total_revenue DESC
|
||||||
|
""")
|
||||||
|
|
||||||
|
def get_top_customers(self, limit: int = 10) -> pd.DataFrame:
|
||||||
|
"""Get top customers by total spent."""
|
||||||
|
return self.db.query(f"""
|
||||||
|
SELECT
|
||||||
|
c.customer_id,
|
||||||
|
c.first_name || ' ' || c.last_name as customer_name,
|
||||||
|
c.email,
|
||||||
|
c.city,
|
||||||
|
c.state,
|
||||||
|
SUM(p.total_amount) as total_spent,
|
||||||
|
COUNT(p.purchase_id) as total_orders
|
||||||
|
FROM customer_profile_dataset c
|
||||||
|
JOIN purchase_history_dataset_with_status p ON c.customer_id = p.customer_id
|
||||||
|
GROUP BY c.customer_id
|
||||||
|
ORDER BY total_spent DESC
|
||||||
|
LIMIT {limit}
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Factory function to create database utilities
|
||||||
|
def create_db_utils(db_path: str = "data.sqlite") -> Tuple[DatabaseUtils, BusinessAnalytics]:
|
||||||
|
"""
|
||||||
|
Create database utilities and business analytics instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path (str): Path to the database
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[DatabaseUtils, BusinessAnalytics]: Database utilities and analytics instances
|
||||||
|
"""
|
||||||
|
db_utils = DatabaseUtils(db_path)
|
||||||
|
analytics = BusinessAnalytics(db_utils)
|
||||||
|
return db_utils, analytics
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example usage
|
||||||
|
try:
|
||||||
|
db, analytics = create_db_utils()
|
||||||
|
|
||||||
|
# Print database summary
|
||||||
|
db.print_database_summary()
|
||||||
|
|
||||||
|
# Example queries
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("SAMPLE ANALYTICS")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# Top 5 customers
|
||||||
|
print("\nTop 5 Customers by Total Spent:")
|
||||||
|
top_customers = analytics.get_top_customers(5)
|
||||||
|
print(top_customers.to_string(index=False))
|
||||||
|
|
||||||
|
# Sales by category
|
||||||
|
print("\nSales by Category:")
|
||||||
|
category_sales = analytics.get_sales_by_category()
|
||||||
|
print(category_sales.to_string(index=False))
|
||||||
|
|
||||||
|
# Customer profile example (if customer 1 exists)
|
||||||
|
print("\nSample Customer Profile (Customer ID 1):")
|
||||||
|
profile = analytics.get_customer_profile(1)
|
||||||
|
if "error" not in profile:
|
||||||
|
print(f"Customer: {profile['customer_info']['first_name']} {profile['customer_info']['last_name']}")
|
||||||
|
print(f"Total Purchases: {profile['total_purchases']}")
|
||||||
|
print(f"Total Spent: ${profile['total_spent']:.2f}")
|
||||||
|
print(f"Total Calls: {profile['total_calls']}")
|
||||||
|
else:
|
||||||
|
print(profile["error"])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
230
mcp_access.py
Normal file
230
mcp_access.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
from mcp.client.stdio import stdio_client
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from openai import AzureOpenAI
|
||||||
|
import json
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from typing import Dict, List
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
model = AzureOpenAI(
|
||||||
|
api_version="2024-12-01-preview",
|
||||||
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Chatbot:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model = model
|
||||||
|
self.sessions = [] # multiple sessions for multiple servers
|
||||||
|
self.available_tools = []
|
||||||
|
self.tool_to_session: Dict[str, ClientSession] = (
|
||||||
|
{}
|
||||||
|
) # maps tool names to their respective sessions
|
||||||
|
self.messages = []
|
||||||
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
|
async def connect_to_server(self, server_config: dict) -> None:
|
||||||
|
|
||||||
|
if server_config['type'] == "uv":
|
||||||
|
server_params = StdioServerParameters(**server_config)
|
||||||
|
|
||||||
|
transport = await self.exit_stack.enter_async_context(
|
||||||
|
stdio_client(server_params)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif server_config['type'] == "sse":
|
||||||
|
transport = await self.exit_stack.enter_async_context(
|
||||||
|
sse_client(
|
||||||
|
url=server_config["url"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported server type: {server_config['type']}")
|
||||||
|
|
||||||
|
read, write = transport
|
||||||
|
|
||||||
|
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||||||
|
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
self.sessions.append(session)
|
||||||
|
|
||||||
|
# Load MCP tools
|
||||||
|
response = await session.list_tools()
|
||||||
|
tools = response.tools
|
||||||
|
|
||||||
|
# Clear previous tools and populate with new ones
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
self.tool_to_session[tool.name] = session
|
||||||
|
self.available_tools.append(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.inputSchema,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Loaded {len(self.available_tools)} tools: "
|
||||||
|
f"{[tool['function']['name'] for tool in self.available_tools]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def connect_to_servers(self):
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open("server_config.json", "r") as file:
|
||||||
|
data = json.load(file)
|
||||||
|
|
||||||
|
servers = data.get("mcp_servers", {})
|
||||||
|
|
||||||
|
for server_name, server_config in servers.items():
|
||||||
|
await self.connect_to_server(server_config)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading server configuration: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def cleanup(self): # new
|
||||||
|
"""Cleanly close all resources using AsyncExitStack."""
|
||||||
|
await self.exit_stack.aclose()
|
||||||
|
|
||||||
|
async def handle_tool_call(self, agent_response: dict):
|
||||||
|
"""Handle tool calls from the agent response."""
|
||||||
|
|
||||||
|
tool_messages = []
|
||||||
|
|
||||||
|
for tool in agent_response.choices[0].message.tool_calls:
|
||||||
|
|
||||||
|
print(f"\nTool call detected: {tool.id}")
|
||||||
|
session = self.tool_to_session.get(tool.function.name)
|
||||||
|
result = await session.call_tool(
|
||||||
|
tool.function.name,
|
||||||
|
arguments=json.loads(tool.function.arguments),
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(result, "content") and result.content:
|
||||||
|
if isinstance(result.content, list):
|
||||||
|
tool_result_content = "\n".join(
|
||||||
|
[
|
||||||
|
(str(item.text) if hasattr(item, "text") else str(item))
|
||||||
|
for item in result.content
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tool_result_content = str(result.content)
|
||||||
|
tool_messages.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool.id,
|
||||||
|
"content": tool_result_content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_messages
|
||||||
|
|
||||||
|
def call_model(self):
|
||||||
|
"""Call the model with the given messages and tools."""
|
||||||
|
return model.chat.completions.create(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=self.messages,
|
||||||
|
max_tokens=2024,
|
||||||
|
tools=self.available_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
# Initialize session and tools
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
user_next = True
|
||||||
|
while True:
|
||||||
|
|
||||||
|
if user_next:
|
||||||
|
# Get user input
|
||||||
|
user_input = input("Enter your question (or 'exit' to quit): ")
|
||||||
|
|
||||||
|
if user_input.lower() == "exit":
|
||||||
|
print("Exiting the agent.")
|
||||||
|
break
|
||||||
|
|
||||||
|
self.messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use asyncio.wait_for to add a timeout
|
||||||
|
agent_response = self.call_model()
|
||||||
|
|
||||||
|
# check if the model calls a tool
|
||||||
|
|
||||||
|
if agent_response.choices[0].message.tool_calls:
|
||||||
|
|
||||||
|
self.messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": agent_response.choices[0].message.content,
|
||||||
|
"tool_calls": agent_response.choices[
|
||||||
|
0
|
||||||
|
].message.tool_calls,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_messages = await self.handle_tool_call(agent_response)
|
||||||
|
|
||||||
|
self.messages.extend(tool_messages)
|
||||||
|
|
||||||
|
# After tool call, the assistant gets the next turn, not the user
|
||||||
|
user_next = False
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If no tool calls, just append the assistant's response
|
||||||
|
else:
|
||||||
|
self.messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": agent_response.choices[0].message.content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"\nAgent: {agent_response.choices[0].message.content}")
|
||||||
|
user_next = True
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("\nAgent response timed out after 60 seconds.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError occurred: {e}")
|
||||||
|
# Remove the last user message if there was an error
|
||||||
|
if self.messages:
|
||||||
|
self.messages.pop()
|
||||||
|
user_next = True # Reset to get user input again after an error
|
||||||
|
finally:
|
||||||
|
# Cleanup session and connections
|
||||||
|
await self.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
|
||||||
|
chat_bot = Chatbot()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await chat_bot.connect_to_servers()
|
||||||
|
await chat_bot.run()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
finally:
|
||||||
|
await chat_bot.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
20
mcp_server.py
Normal file
20
mcp_server.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
from db_utils import DatabaseUtils
|
||||||
|
from datetime import datetime
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['FASTMCP_HOST'] = "0.0.0.0"
|
||||||
|
os.environ['FASTMCP_PORT'] = "8000"
|
||||||
|
|
||||||
|
mcp = FastMCP('Customer')
|
||||||
|
|
||||||
|
db = DatabaseUtils()
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
def get_current_date() -> str:
|
||||||
|
"""Returns the current date in YYYY-MM-DD format."""
|
||||||
|
return datetime.now().strftime('%Y-%m-%d')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mcp.run(transport='sse')
|
||||||
14
pyproject.toml
Normal file
14
pyproject.toml
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
[project]
|
||||||
|
name = "mcp-demo"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"jupyter>=1.1.1",
|
||||||
|
"mcp[cli]>=1.10.1",
|
||||||
|
"openai>=1.93.0",
|
||||||
|
"openpyxl>=3.1.5",
|
||||||
|
"pandas>=2.3.0",
|
||||||
|
"ucimlrepo>=0.0.7",
|
||||||
|
]
|
||||||
18
sample.py
Normal file
18
sample.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
mcp = FastMCP('ExampleServer')
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
def get_current_date() -> str:
|
||||||
|
"""Returns the current date in YYYY-MM-DD format."""
|
||||||
|
return datetime.now().strftime('%Y-%m-%d')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mcp.run(transport='sse')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
11
server_config.json
Normal file
11
server_config.json
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"mcp_servers": {
|
||||||
|
"remote": {
|
||||||
|
"type": "sse",
|
||||||
|
"url": "http://0.0.0.0:8000/sse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
10
server_config_local.json
Normal file
10
server_config_local.json
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
|
||||||
|
"cmd": {
|
||||||
|
"type": "uv",
|
||||||
|
"command": "uv",
|
||||||
|
"args": ["run", "mcp_server.py"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user