0%

使用Python实现MCP Server和MCP Client

下面是使用Python实现的一个基于sse协议的mcp server和mcp client.

mcp server端代码

server_sse.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import json
import anyio
import logging
import traceback
import asyncio
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator

import uvicorn
import mcp.types as types
from mcp.server import Server
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.routing import Route


logger = logging.getLogger(__file__)


@asynccontextmanager
async def lifespan(server: Server) -> AsyncIterator[dict]:
"""初始化配置和释放资源
"""
try:
yield {
"client_mysql": 'xxx',
"client_redis": 'xxx'
}
finally:
pass


server = Server("demo", lifespan=lifespan)
sse = SseServerTransport("/messages/")


async def handle_sse(request):
"""处理sse请求
"""
try:
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server.run(streams[0], streams[1], server.create_initialization_options())
except asyncio.exceptions.CancelledError:
pass
except anyio.WouldBlock:
pass
except Exception as e:
logger.error(traceback.format_exc())

async def handle_messages(request):
try:
await sse.handle_post_message(request.scope, request.receive, request._send)
except Exception as e:
logger.error(traceback.format_exc())


class Tools:
# 定义一些工具
async def get_weather(self, city: str, date: str) -> dict:
"""获取城市在某天的天气信息

Args:
city: 需要获取天气的城市
date: 天气的日期
"""
return {
"city": city,
"date": date,
"sunny": True,
"temp": "20"
}

# 定义工具
async def book_hotel(self, username: str) -> dict:
"""帮用户预订酒店

Args:
username: 用户名
"""
return {
"username": username,
"book_status": "success"
}


@server.list_tools()
async def list_tools():
ctx = server.request_context
client_mysql = ctx.lifespan_context["client_mysql"]
print(f"=====client mysql: {client_mysql}")

return [
types.Tool(
name="get_weather",
description="获取某个城市在某天的天气情况",
inputSchema={
"type": "object",
"required": ["city", "date"],
"properties": {
"city": {
"type": "string",
"description": "城市",
},
"date": {
"type": "string",
"description": "日期,格式: yyyy-mm-dd",
}
},
},
),
types.Tool(
name="book_hotel",
description="给用户预订酒店",
inputSchema={
"type": "object",
"required": ["username"],
"properties": {
"username": {
"type": "string",
"description": "用户名",
}
},
},
),
]


@server.call_tool()
async def call_tool(name, args):
try:
print(f"=====func name: {name}, request args: {args}")

tools = Tools()
func_name = getattr(tools, name)
if func_name:
result = await func_name(**args)
result_text = json.dumps({"code": 0, "msg": "success", "data": result}, ensure_ascii=False)
else:
result_text = json.dumps({"code": -1, "msg": "function not found", "data": {}}, ensure_ascii=False)

except Exception as e:
logger.error(traceback.format_exc())
result_text = json.dumps({"code": -1, "msg": str(e), "data": {}}, ensure_ascii=False)

return [types.TextContent(type='text', text=result_text)]


@server.list_resources()
async def list_resources():
"""定义一些资源
"""
return [
types.Resource(
uri='file:///doc01/file01.log',
name="文档01",
description="文档01",
mimeType="text/plain",
size=10 * 1024 * 1024
),
types.Resource(
uri='file:///doc01/file02.pdf',
name="文档02",
description="文档02",
mimeType="application/pdf",
size=10 * 1024 * 1024
),
# types.Resource(
# uri='file:///doc01/file03.png',
# name="文档03",
# description="文档03",
# mimeType="image/png",
# size=10 * 1024 * 1024
# ),
# types.Resource(
# uri='file:///doc01/file04.json',
# name="文档04",
# description="文档04",
# mimeType="application/json",
# size=10 * 1024 * 1024
# ),
# types.Resource(
# uri='file:///doc01/file05.docx',
# name="文档05",
# description="文档05",
# mimeType="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
# size=10 * 1024 * 1024
# )
]


@server.read_resource()
async def read_resource(uri):
str_uri = str(uri)
if str_uri == "file:///doc01/file01.log":
return "file01 log"
if str_uri == "file:///doc01/file02.pdf":
return "file02 pdf"

raise ValueError("Resource not found")


if __name__ == "__main__":
starlette_app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Route("/messages/", endpoint=handle_messages, methods=['POST']),
],
)
uvicorn.run(starlette_app, host="0.0.0.0", port=8099)

上面的代码中,在mcp server中定义了两个tool,实现了调用工具的方法;也定义了两个资源,并实现了对资源的访问。
运行 python server_sse.py 后就可以启动mcp server的服务了。

requirements.txt:

1
mcp[cli]

mcp client端代码

client_sse.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
mport sys
import asyncio
import logging
import traceback
from typing import Optional, Union
from contextlib import AsyncExitStack
from dataclasses import dataclass, field

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client

from llm import LLM_RESOURCES, LLMDeepseek, LLMAliYun, LLMHuoShan, LLMZhipu, LLMOpenAI


logger = logging.getLogger(__name__)


CHAT_MAX_ROUND = 5


@dataclass
class MCPClient:

llm: Union[LLMDeepseek, LLMAliYun, LLMHuoShan, LLMZhipu, LLMOpenAI] = None

# init=False 不在 __init__ 中初始化
session: Optional[ClientSession] = field(init=False, default=None)
exit_stack: AsyncExitStack = field(init=False, default=AsyncExitStack())
tools: list = field(init=False, default=None)
prompts: list = field(init=False, default=None)
resources: list = field(init=False, default=None)
resource_templates: list = field(init=False, default=None)
server_url: list = field(init=False, default="http://localhost:3000")

async def connect_to_server_by_stdio(self, server_script_path: str):
"""通过stdio的方式通信

Args:
server_script_path: mcp server python脚本
"""
is_python = server_script_path.endswith('.py')
is_js = server_script_path.endswith('.js')
if not (is_python or is_js):
raise ValueError("Server script must be a .py or .js file")

command = "python" if is_python else "node"
server_params = StdioServerParameters(
command=command,
args=[server_script_path],
env=None
)

stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))

await self.session.initialize()

async def connect_to_server_by_sse(self, server_url: str):
"""通过sse的方式通信
"""
# async with sse_client(server_url) as streams:
# async with ClientSession(streams[0], streams[1]) as self.session:
# await self.session.initialize()
# self.session = session

# 创建 SSE 客户端连接上下文管理器
self._streams_context = sse_client(url=server_url)
# 异步初始化 SSE 连接,获取数据流对象
streams = await self._streams_context.__aenter__()

# 使用数据流创建 MCP 客户端会话上下文
self._session_context = ClientSession(*streams)
# 初始化客户端会话对象
self.session: ClientSession = await self._session_context.__aenter__()

# 执行 MCP 协议初始化握手
await self.session.initialize()

async def get_server_resources(self):
"""获取mcp服务器上的资源
"""
list_tools = await self.session.list_tools()
if list_tools and list_tools.tools:
await self.handle_mcp_server_tools(list_tools.tools)
print(f"mcp server tools: {self.tools}")

list_resources = await self.session.list_resources()
if list_resources and list_resources.resources:
await self.handle_mcp_server_resources(list_resources.resources)
print(f"mcp server resources: {self.resources}")

async def process_query_by_sse(self, query, resource_uris=None):
messages = [
{"role": "user", "content": query}
]

chat_messages = {
"messages": messages,
"stream": False
}

if self.tools:
chat_messages['tools'] = self.tools

if resource_uris:
resource_list = []
if isinstance(resource_uris, str):
resource_uris = [resource_uris]
if isinstance(resource_uris, list):
for uri in resource_uris:
resource = await self.session.read_resource(uri)
if resource:
contents = resource.contents
for content in contents:
resource_list.append(content.text)

if resource_list:
# 将resource添加到messages中
pass

curr_round, response = 1, None
while curr_round <= CHAT_MAX_ROUND:
response = self.llm.chat_completions(chat_messages)
if not response:
continue

result = await self.llm.handle_completions(self.session, response, chat_messages)
if not isinstance(result, dict):
return result
else:
chat_messages = result

curr_round += 1

return response

async def process_query_by_sse_stream(self, query, resource_uris=None):
messages = [
{"role": "user", "content": query}
]

chat_messages = {
"messages": messages,
"stream": True
}

if self.tools:
chat_messages['tools'] = self.tools

if resource_uris:
resource_list = []
if isinstance(resource_uris, str):
resource_uris = [resource_uris]
if isinstance(resource_uris, list):
for uri in resource_uris:
resource = await self.session.read_resource(uri)
if resource:
contents = resource.contents
for content in contents:
resource_list.append(content.text)

if resource_list:
# 将resource添加到messages中
pass

curr_round, continue_loop = 1, True
while curr_round <= CHAT_MAX_ROUND and continue_loop:
response = self.llm.chat_completions(chat_messages)
if not response:
continue

async for result in self.llm.handle_completions_stream(self.session, response, chat_messages):
if not isinstance(result, dict):
continue_loop = False
yield result
else:
chat_messages = result

curr_round += 1

async def handle_mcp_server_tools(self, resource_tools):
if self.tools:
return

server_tools = []
if resource_tools:
for tool in resource_tools:
server_tools.append({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema
}
})

self.tools = server_tools

async def handle_mcp_server_resources(self, resources):
if self.resources:
return

server_resources = []
if resources:
for resource in resources:
server_resources.append({
"uri": resource.uri,
"name": resource.name,
"description": resource.description,
"mimeType": resource.mimeType,
"size": resource.size,
"annotations": resource.annotations
})
self.resources = server_resources

async def handle_mcp_server_resource_templates(self, resource_templates):
resources = []
if resource_templates:
for temp in resource_templates:
uri = temp.uriTemplate
name = temp.name

resources.append({
"uri": uri,
"name": name
})

self.resource_templates = resources

async def chat_loop(self, stream=True):
print("\nMCP Client Started!")
print("Type your queries or 'quit' to exit.")

# 获取mcp server的资源
await self.get_server_resources()

resource_uris = []
if self.resources:
for resource in self.resources:
resource_uris.append(resource.get('uri'))

while True:
query = input("\nQuery: ").strip()
if not query:
continue
if query.lower() == 'quit':
break

try:
if stream is True:
async for response in self.process_query_by_sse_stream(query, resource_uris):
if not isinstance(response, str):
if not response.choices:
continue

choice = response.choices[0]
print(f"{choice.delta.content}")
else:
response = await self.process_query_by_sse(query, resource_uris)
print(response.choices[0].message.content)

except Exception as e:
logger.error(traceback.format_exc())

async def cleanup(self):
await self.exit_stack.aclose()


async def main():
if len(sys.argv) < 3:
print("Usage: python client_sse.py <sse server url> <llm>")
sys.exit(1)

llm = sys.argv[2]
if llm not in LLM_RESOURCES.keys():
print("llm must be one of [deepseek, aliyun, volcengine, zhipu, openai]")
sys.exit(1)

try:
llm_instance = LLM_RESOURCES.get(llm)()
client = MCPClient(llm=llm_instance)
except Exception as e:
logger.info("初始化mcp client失败")
logger.error(traceback.format_exc())
sys.exit(1)

try:
await client.connect_to_server_by_sse(sys.argv[1])
await client.chat_loop()
except Exception as e:
logger.info("连接mcp server失败,或者会话异常")
logger.error(traceback.format_exc())
finally:
await client.cleanup()


if __name__ == "__main__":
asyncio.run(main())

llm.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import os
import json
from dataclasses import dataclass

from openai import OpenAI
from volcenginesdkarkruntime import Ark
from zhipuai import ZhipuAI
from dotenv import load_dotenv

from model import FinishReason


load_dotenv()


@dataclass
class LLMBase:

client = None
default_model = None

def chat_completions(self, chat_messages):
if 'model' not in chat_messages:
chat_messages['model'] = self.default_model

response = self.client.chat.completions.create(**chat_messages)
return response

async def handle_completions(self, client_session, response, chat_messages):
"""非流式会话响应处理
"""
choice = response.choices[0]
finish_reason = choice.finish_reason
if finish_reason == FinishReason.STOP:
return response

if finish_reason == FinishReason.TOOL_CALLS:

tool_calls_message_info = choice.message
if isinstance(self, (LLMHuoShan, LLMZhipu)):
tool_calls_message_info = choice.message.dict()

chat_messages.get('messages').extend([
tool_calls_message_info
])

tool_calls = choice.message.tool_calls
for tool in tool_calls:
tool_info = {
"role": 'tool',
"tool_call_id": tool.id
}

tool_name = tool.function.name
tool_args = json.loads(tool.function.arguments) if tool.function.arguments else {}

# 调用工具
resp = await client_session.call_tool(tool_name, tool_args)
func_call_content = resp.content[0] if resp and resp.content else None
if func_call_content:
text_data = json.loads(func_call_content.text)
call_result = text_data.get('data', None)
call_result = json.dumps(call_result) if call_result else ''
else:
call_result = None
tool_info['content'] = call_result

chat_messages.get('messages').extend([
tool_info
])

return chat_messages

return response

async def handle_completions_stream(self, client_session, response, chat_messages):
if isinstance(self, (LLMDeepseek, LLMAliYun, LLMOpenAI)):
async for result in self.handle_completions_stream_t1(client_session, response, chat_messages):
yield result

if isinstance(self, (LLMHuoShan, LLMZhipu)):
async for result in self.handle_completions_stream_t2(client_session, response, chat_messages):
yield result

async def handle_completions_stream_t1(self, client_session, response, chat_messages):
"""流式会话响应处理
"""
finish_reason = ''
tool_infos = {}
for resp in response:
choice = resp.choices[0]
finish_reason = choice.finish_reason
tool_calls = choice.delta.tool_calls
content = choice.delta.content

if all([finish_reason is None, tool_calls is None, content in ['', None]]):
continue

if finish_reason == FinishReason.TOOL_CALLS:
break

if not tool_calls:
# 非函数调用
if finish_reason is None or finish_reason == FinishReason.STOP:
yield resp
else:
# 获取函数信息
tool = tool_calls[0]
tool_idx = tool.index
tool_id = tool.id
arguments = tool.function.arguments
if tool_idx not in tool_infos:
tool_infos[tool_idx] = {
'tool_id': tool_id,
'tool_name': tool.function.name,
'arguments': arguments
}
else:
if arguments:
# 流式输出拼接函数调用的参数
tool_infos[tool_idx]['arguments'] += arguments

if finish_reason == FinishReason.TOOL_CALLS and tool_infos:
chat_messages = await self._handle_function_call_stream_t1(client_session, tool_infos, chat_messages)
yield chat_messages

@staticmethod
async def _handle_function_call_stream_t1(client_session, tool_infos, chat_messages):
tool_call_list, tool_call_results = [], []
for idx, val in tool_infos.items():
tool_id = val.get('tool_id')
tool_name = val.get('tool_name')
arguments = val.get('arguments')
func_args = json.loads(arguments) if arguments else {}

tool_call_list.append(
{
"id": tool_id,
"function": {
"arguments": arguments,
"name": tool_name,
},
"type": 'function',
"index": idx
}
)

# 调用工具
resp = await client_session.call_tool(tool_name, func_args)
func_call_content = resp.content[0] if resp and resp.content else None
if func_call_content:
text_data = json.loads(func_call_content.text)
call_result = text_data.get('data', None)
call_result = json.dumps(call_result) if call_result else ''

tool_call_results.extend(
[
{
"role": 'tool',
"tool_call_id": tool_id,
"content": call_result,
},

]
)

chat_messages.get('messages').extend([{
"role": 'assistant',
"content": '',
"tool_calls": tool_call_list
}])
chat_messages.get('messages').extend(tool_call_results)

return chat_messages

async def handle_completions_stream_t2(self, client_session, response, chat_messages):
"""流式会话响应处理
"""
for resp in response:
choice = resp.choices[0]
finish_reason = choice.finish_reason
if finish_reason is None or finish_reason == FinishReason.STOP:
yield resp

if finish_reason == FinishReason.TOOL_CALLS and choice.delta.tool_calls:
chat_messages = await self._handle_function_call_stream_t2(client_session, resp, chat_messages)

yield chat_messages

@staticmethod
async def _handle_function_call_stream_t2(client_session, response, chat_messages):
if response and chat_messages:
choice = response.choices[0]
dump_content = choice.delta
tool_calls = dump_content.tool_calls
for tool in tool_calls:
tool_id = tool.id
func_name = tool.function.name
func_args = json.loads(tool.function.arguments)

resp = await client_session.call_tool(func_name, func_args)
func_call_content = resp.content[0] if resp and resp.content else None
if func_call_content:
text_data = json.loads(func_call_content.text)
call_result = text_data.get('data', None)
call_result = json.dumps(call_result) if call_result else ''

chat_messages.get('messages').extend(
[
dump_content.model_dump(),
{
"role": 'tool',
"tool_call_id": tool_id,
"content": call_result,
"name": func_name
}
]
)
else:
return chat_messages
return chat_messages


@dataclass
class LLMDeepseek(LLMBase):

client = OpenAI(
base_url=os.getenv("DEEPSEEK_BASE_URL"),
api_key=os.getenv("DEEPSEEK_API_KEY"),
)

default_model = os.getenv("DEEPSEEK_DEFAULT_MODEL")


@dataclass
class LLMAliYun(LLMBase):

client = OpenAI(
base_url=os.getenv("ALI_YUN_BASE_URL"),
api_key=os.getenv("ALI_YUN_API_KEY"),
)

default_model = os.getenv("ALI_YUN_DEFAULT_MODEL")


@dataclass
class LLMHuoShan(LLMBase):

client = Ark(
api_key=os.getenv("HUO_SHAN_API_KEY")
)

default_model = os.getenv("HUO_SHAN_DEFAULT_MODEL")


class LLMZhipu(LLMBase):

client = ZhipuAI(api_key=os.getenv("ZHIPU_API_KEY"))

default_model = os.getenv("ZHIPU_DEFAULT_MODEL")


@dataclass
class LLMOpenAI(LLMBase):
client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
)

default_model = os.getenv("OPENAI_DEFAULT_MODEL")


LLM_RESOURCES = {
"deepseek": LLMDeepseek,
"aliyun": LLMAliYun,
"volcengine": LLMHuoShan,
"zhipu": LLMZhipu,
"openai": LLMOpenAI
}

llm.py中实现了Deepseek、阿里云、豆包、智谱四个大模型厂商的模型调用。

model.py:

1
2
3
4
5
6
7
8
from enum import Enum


class FinishReason(str, Enum):
STOP = 'stop'
TOOL_CALLS = 'tool_calls'
CONTENT_FILTER = 'content_filter'
LENGTH = 'length'

.env的配置内容为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
DEEPSEEK_BASE_URL=https://api.deepseek.com
DEEPSEEK_API_KEY=xxx
DEEPSEEK_DEFAULT_MODEL=deepseek-chat

ALI_YUN_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
ALI_YUN_API_KEY=xxx
ALI_YUN_DEFAULT_MODEL=qwen-max

HUO_SHAN_API_KEY=xxx
HUO_SHAN_DEFAULT_MODEL=xxx

ZHIPU_API_KEY=xxx
ZHIPU_DEFAULT_MODEL=glm-4-plus

OPENAI_API_KEY=xxx
OPENAI_DEFAULT_MODEL=gpt-4o-mini

requirements.txt:

1
2
3
4
5
6
mcp
openai
python-dotenv
volcengine-python-sdk[ark]
zhipuai
httpx[socks]

代码编写完成后,执行 python client_sse.py http://localhost:8099/sse deepseek
就可以连接mcp server,并使用 deepseek 大模型进行对话了,或者把 deepseek 换成 aliyun、huoshan、zhipu。