0%

大模型聊天应用中会话取消功能的实现

在基于大模型的聊天应用中,会话取消或者中断是一个基本功能,下面是基于 asyncio.Event() 实现的会话取消/中断功能。

在大模型聊天应用中,会话取消功能的实现有多种方式:第一种是仅前端取消会话,后端不做处理;第二种是基于 asyncio.Event() 实现取消会话功能。下面看一个基于 asyncio.Event() 实现取消会话功能的例子。

创建一个会话应用,chat 接口和 cancel_chat 接口都需要使用异步接口,接口中对大模型的调用也需要使用异步。如果 chat 接口中含有同步的调用,chat 接口就会阻塞 cancel_chat 接口,就无法取消会话。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# chat.py
import json
import asyncio
import logging
import traceback
from http import HTTPStatus

import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse

from llm import LLMClient
from memory import RedisClient
from session import SessionManager, SessionStatus, SessionCacheTime
from model import ChatCompletionRequest, ChatCancelRequest, ChatCancelResponse, ChatStatusResponse


logger = logging.getLogger(__name__)


BASE_URL = 'xxx'
API_KEY = 'xxx'
HISTORY_MSG_LIMIT = 10


app = FastAPI()
session_manager = SessionManager()


def _gen_session_response(session_id: str,
curr_session: dict,
status: SessionStatus,
message: str | None = None) -> dict:
curr_session['status'] = status
if status != SessionStatus.RUNNING:
message = None

session_msg = {"session_id": session_id, "status": status, "message": message}

return session_msg


@app.post("/api/v1/chat/completion")
async def chat_completion_stream(chat: ChatCompletionRequest):
# 创建会话
session_id = chat.session_id
if not session_id or session_id not in session_manager.sessions:
session_id = session_manager.create_session(chat.session_id)

curr_session = session_manager.get_session(session_id)

# 从缓存中获取message历史记录
redis_client = RedisClient()
key_session_messages = f"key_session_messages_{session_id}"
messages = await redis_client.get(key_session_messages)
if messages:
try:
messages = json.loads(messages)
except json.JSONDecodeError:
messages = curr_session.get('messages', [])
if messages is None:
messages = []

messages.append({"role": "user", "content": chat.prompt})
messages = messages[-HISTORY_MSG_LIMIT:]
curr_session['messages'] = messages

# 大模型client
llm_client = LLMClient(base_url=BASE_URL, api_key=API_KEY)
chat_args = {
"model": chat.model,
"messages": messages
}

async def generate():
try:
full_content = ''

session_msg = _gen_session_response(session_id, curr_session, SessionStatus.RUNNING)
# yield f"data: {session_msg}\n\n"

# 异步调用大模型
async for resp in await llm_client.chat_stream(**chat_args):
# 判断当前会话是否被取消
if curr_session['cancel_event'].is_set():
session_msg = _gen_session_response(session_id, curr_session, SessionStatus.CANCELLED)
yield f"data: {session_msg}\n\n"
break

if resp.status_code == HTTPStatus.OK and resp.output.choices:
choices = resp.output.choices
if not choices:
continue

curr_content = choices[0].message.content
full_content += curr_content
session_msg = _gen_session_response(session_id, curr_session, SessionStatus.RUNNING, curr_content)
yield f"data: {session_msg}\n\n"

if not curr_session['cancel_event'].is_set():
# 会话没有取消,正常返回
curr_session['messages'].append({"role": "assistant", "content": full_content})
await redis_client.set(key_session_messages, json.dumps(curr_session['messages'], ensure_ascii=False),
int(SessionCacheTime.MSG_CACHE_TIME))

session_msg = _gen_session_response(session_id, curr_session, SessionStatus.COMPLETED)
yield f"data: {session_msg}\n\n"
except asyncio.CancelledError:
session_msg = _gen_session_response(session_id, curr_session, SessionStatus.CANCELLED)
yield f"data: {session_msg}\n\n"
except Exception as e:
logger.error(traceback.format_exc())
session_msg = _gen_session_response(session_id, curr_session, SessionStatus.ERROR)
yield f"data: {session_msg}\n\n"

return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
)


@app.post("/api/v1/chat/cancel", response_model=ChatCancelResponse)
async def cancel_chat(chat: ChatCancelRequest):
"""取消会话
"""
session_id = chat.session_id
curr_session = session_manager.sessions.get(session_id)
if not curr_session:
logger.info(f"会话 {session_id} 不存在")
raise HTTPException(
status_code=404,
detail=f"会话 {session_id} 不存在"
)

if curr_session['status'] != SessionStatus.RUNNING:
logger.info(f"会话 {session_id} 当前状态是: {curr_session['status']}, 不支持取消")
raise HTTPException(
status_code=400,
detail=f"会话 {session_id} 当前状态是: {curr_session['status']}, 不支持取消"
)

logger.info(f"会话 {session_id} 正在取消")
await session_manager.cancel_session(session_id)

return ChatCancelResponse(
session_id=session_id,
status=SessionStatus.CANCELLED,
message=f"会话 {session_id} 已经取消"
)


@app.get("/api/v1/chat/status/{session_id}", response_model=ChatStatusResponse)
async def chat_status(session_id: str):
"""查询会话状态
"""
curr_session = session_manager.sessions.get(session_id)
if not curr_session:
logger.info(f"会话 {session_id} 不存在")
raise HTTPException(
status_code=404,
detail=f"会话 {session_id} 不存在"
)

return ChatStatusResponse(
session_id=session_id,
status=curr_session['status']
)


if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=8011,
log_level="info"
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# session.py
import asyncio
import uuid
from typing import Dict
from datetime import datetime
import logging
from enum import Enum


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class SessionStatus(str, Enum):
CREATED = 'created'
RUNNING = 'running'
ERROR = 'error'
CANCELLED = 'cancelled'
COMPLETED = 'completed'


class SessionCacheTime(int, Enum):
MSG_CACHE_TIME = 7 * 24 * 3600


class SessionManager:
"""会话管理器
"""

def __init__(self):
self.sessions: Dict[str, dict] = {}

def create_session(self, session_id: str | None = None) -> str:
"""创建会话
"""
# session id
if session_id is None:
session_id = str(uuid.uuid4())

# 保存会话信息
self.sessions[session_id] = {
"status": SessionStatus.CREATED,
"start_time": datetime.now().isoformat(),
"end_time": None,
"messages": [],
"cancel_event": asyncio.Event()
}

logger.info(f"会话 {session_id} 创建成功")
return session_id

def get_session(self, session_id: str) -> dict:
"""获取会话信息
"""
return self.sessions.get(session_id)

def update_session_status(self, session_id: str, status: SessionStatus):
"""更新会话状态
"""
if session_id in self.sessions:
self.sessions[session_id]["status"] = status
if status in ["completed", "cancelled", "error"]:
self.sessions[session_id]["end_time"] = datetime.now().isoformat()
logger.info(f"会话 {session_id} 状态更新为: {status}")


async def cancel_session(self, session_id: str) -> bool:
"""取消会话
"""
session = self.get_session(session_id)
if not session:
return False

# 取消会话
session["cancel_event"].set()
session['status'] = SessionStatus.CANCELLED
session['end_time'] = datetime.now().isoformat()
return True
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# llm.py
import logging
import traceback

from dashscope.aigc.generation import AioGeneration


logger = logging.getLogger(__name__)


class LLMClient:

def __init__(self, base_url, api_key):
self.base_url: str = base_url
self.api_key: str = api_key

async def chat_stream(self, model, messages):
try:
# 使用异步的方式调用大模型,不会阻塞接口
response = await AioGeneration.call(
api_key=self.api_key,
model=model,
stream=True,
messages=messages,
result_format="message",
incremental_output=True
)
return response
except Exception as e:
logger.error(traceback.format_exc())
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# memory.py
import redis.asyncio as redis


class RedisClient:

def __init__(self, host='localhost', port=6380, db=0):
self.pool = redis.ConnectionPool(
host=host,
port=port,
db=db,
max_connections=100
)
self.client = redis.Redis(connection_pool=self.pool)

async def set(self, key, val, expire=24 * 3600):
await self.client.set(key, val, ex=expire)

async def get(self, key):
return await self.client.get(key)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# model.py
from typing import Optional

from pydantic import BaseModel, Field

from session import SessionStatus


class ChatCompletionRequest(BaseModel):
session_id: Optional[str] = Field(default=None, description="会话ID")
prompt: str = Field(..., description="用户消息")
model: str = Field(default="qwen-max", description="模型名称")
stream: bool = Field(default=False, description="是否流式返回")


class ChatCompletionResponse(BaseModel):
session_id: str
message: str
model: str
completed: bool


class ChatCancelRequest(BaseModel):
session_id: str = Field(..., description="会话ID")


class ChatCancelResponse(BaseModel):
"""取消响应模型"""
session_id: str
status: SessionStatus
message: str


class ChatStatusResponse(BaseModel):
session_id: str
status: SessionStatus