tmp
This commit is contained in:
11
week9/src/main/java/com/learn/Week9Application.java
Normal file
11
week9/src/main/java/com/learn/Week9Application.java
Normal file
@@ -0,0 +1,11 @@
|
||||
package com.learn;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
|
||||
@SpringBootApplication
|
||||
public class Week9Application {
|
||||
public static void main(String[] args) {
|
||||
SpringApplication.run(Week9Application.class, args);
|
||||
}
|
||||
}
|
||||
80
week9/src/main/java/com/learn/config/AIConfig.java
Normal file
80
week9/src/main/java/com/learn/config/AIConfig.java
Normal file
@@ -0,0 +1,80 @@
|
||||
package com.learn.config;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
|
||||
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
/**
|
||||
* Spring AI 配置类(手动装配,便于理解各组件的关系)
|
||||
*
|
||||
* 架构:
|
||||
* ChatModel (OpenAiChatModel) ← 底层 HTTP 通信,由 application.yml 自动配置
|
||||
* ↓ 包装
|
||||
* ChatClient ← 高级 Fluent API,我们手动创建
|
||||
* ↓ 拦截
|
||||
* MessageChatMemoryAdvisor ← 自动注入对话历史(多轮对话)
|
||||
* ↓ 写入 / 读取
|
||||
* ChatMemory (MessageWindowChatMemory) ← 滑动窗口记忆(默认保留最近 20 条)
|
||||
* ↓ 存储
|
||||
* ChatMemoryRepository (InMemoryChatMemoryRepository) ← ConcurrentHashMap 存储
|
||||
*
|
||||
* Spring AI 1.0.6 的架构变化:
|
||||
* - 旧版 InMemoryChatMemory 被拆分为两个角色:
|
||||
* ChatMemory(记忆策略)+ ChatMemoryRepository(存储实现)
|
||||
* - MessageChatMemoryAdvisor 构造函数私有化,必须用 Builder 创建
|
||||
* - 会话 ID 常量从 AbstractChatMemoryAdvisor 迁移到 ChatMemory 接口
|
||||
*/
|
||||
@Configuration
|
||||
public class AIConfig {
|
||||
|
||||
/**
|
||||
* 对话记忆存储(内存实现,重启丢失)。
|
||||
*
|
||||
* InMemoryChatMemoryRepository = ConcurrentHashMap 包装
|
||||
* MessageWindowChatMemory = 滑动窗口策略(只保留最近 N 条消息)
|
||||
*
|
||||
* 生产环境可替换为 JdbcChatMemoryRepository 实现持久化。
|
||||
*/
|
||||
@Bean
|
||||
public ChatMemory chatMemory() {
|
||||
var repository = new InMemoryChatMemoryRepository();
|
||||
return MessageWindowChatMemory.builder()
|
||||
.chatMemoryRepository(repository)
|
||||
.maxMessages(20) // 每个会话最多保留 20 条历史消息
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* ChatClient —— Spring AI 的推荐入口
|
||||
*
|
||||
* 构建步骤:
|
||||
* 1. ChatClient.builder(chatModel) —— 绑定底层的 ChatModel 实现
|
||||
* 2. .defaultSystem(...) —— 设置系统级角色 Prompt
|
||||
* 3. .defaultAdvisors(...) —— 注册 Advisor 拦截器链(如记忆管理)
|
||||
* 4. .build() —— 创建不可变实例
|
||||
*
|
||||
* 关键点说明:
|
||||
* - ChatClient 是线程安全的,整个应用共用一个单例
|
||||
* - .defaultSystem() 对所有通过此 ChatClient 的请求生效
|
||||
* - MessageChatMemoryAdvisor 会在每次请求前自动注入历史消息
|
||||
* - 单次请求可通过 .system() 覆盖默认 System Prompt
|
||||
*/
|
||||
@Bean
|
||||
public ChatClient chatClient(ChatModel chatModel, ChatMemory chatMemory) {
|
||||
return ChatClient.builder(chatModel)
|
||||
.defaultSystem("""
|
||||
你是一个友好的 AI 助手,名字叫"小智"。
|
||||
请用中文回答用户的问题。
|
||||
如果你不知道答案,诚实地告诉用户,不要编造信息。
|
||||
""")
|
||||
.defaultAdvisors(
|
||||
MessageChatMemoryAdvisor.builder(chatMemory).build()
|
||||
)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
93
week9/src/main/java/com/learn/controller/ChatController.java
Normal file
93
week9/src/main/java/com/learn/controller/ChatController.java
Normal file
@@ -0,0 +1,93 @@
|
||||
package com.learn.controller;
|
||||
|
||||
import com.learn.dto.ApiResponse;
|
||||
import com.learn.dto.ChatRequest;
|
||||
import com.learn.service.ChatService;
|
||||
import jakarta.validation.Valid;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.codec.ServerSentEvent;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/chat")
|
||||
public class ChatController {
|
||||
|
||||
private final ChatService chatService;
|
||||
|
||||
public ChatController(ChatService chatService) {
|
||||
this.chatService = chatService;
|
||||
}
|
||||
|
||||
// ==================== Day 5: 同步对话 ====================
|
||||
|
||||
/**
|
||||
* 同步聊天 —— 发送完整消息,等待完整回复。
|
||||
*
|
||||
* 请求: POST /api/chat
|
||||
* Body: {"message": "你好", "conversationId": "会话ID(可选)"}
|
||||
* 响应: {"code": 200, "data": {"reply": "你好!有什么可以帮你的?", "conversationId": "xxx"}}
|
||||
*/
|
||||
@PostMapping
|
||||
public ApiResponse<Map<String, String>> chat(@Valid @RequestBody ChatRequest request) {
|
||||
String conversationId = request.getConversationId() != null
|
||||
? request.getConversationId() : "default";
|
||||
String reply = chatService.chat(request.getMessage(), conversationId);
|
||||
return ApiResponse.success(Map.of(
|
||||
"reply", reply,
|
||||
"conversationId", conversationId
|
||||
));
|
||||
}
|
||||
|
||||
// ==================== Day 6: 流式对话 (SSE) ====================
|
||||
|
||||
/**
|
||||
* 流式聊天 —— 逐 token 返回,打字机效果。
|
||||
*
|
||||
* 请求: POST /api/chat/stream
|
||||
* Body: 同 /api/chat
|
||||
* 响应: text/event-stream
|
||||
* data: 你
|
||||
* data: 好
|
||||
* data: !
|
||||
* event: done
|
||||
* data: [DONE]
|
||||
*/
|
||||
@PostMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
public Flux<ServerSentEvent<String>> chatStream(@Valid @RequestBody ChatRequest request) {
|
||||
return chatService.chatStream(request.getMessage(), request.getConversationId())
|
||||
.map(token -> ServerSentEvent.<String>builder()
|
||||
.data(token)
|
||||
.build())
|
||||
.concatWith(Mono.just(ServerSentEvent.<String>builder()
|
||||
.event("done")
|
||||
.data("[DONE]")
|
||||
.build()));
|
||||
}
|
||||
|
||||
// ==================== Day 7: 历史管理 ====================
|
||||
|
||||
/**
|
||||
* 查询某个会话的历史消息。
|
||||
*
|
||||
* GET /api/chat/history?conversationId=xxx
|
||||
*/
|
||||
@GetMapping("/history")
|
||||
public ApiResponse<?> getHistory(@RequestParam(defaultValue = "default") String conversationId) {
|
||||
return ApiResponse.success(chatService.getHistory(conversationId));
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除某个会话的历史。
|
||||
*
|
||||
* DELETE /api/chat/history?conversationId=xxx
|
||||
*/
|
||||
@DeleteMapping("/history")
|
||||
public ApiResponse<?> clearHistory(@RequestParam(defaultValue = "default") String conversationId) {
|
||||
chatService.clearHistory(conversationId);
|
||||
return ApiResponse.success("已清除会话: " + conversationId);
|
||||
}
|
||||
}
|
||||
61
week9/src/main/java/com/learn/dto/ApiResponse.java
Normal file
61
week9/src/main/java/com/learn/dto/ApiResponse.java
Normal file
@@ -0,0 +1,61 @@
|
||||
package com.learn.dto;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class ApiResponse<T> {
|
||||
|
||||
private int code;
|
||||
private String message;
|
||||
private T data;
|
||||
private Map<String, Object> extra;
|
||||
|
||||
public ApiResponse() {}
|
||||
|
||||
public ApiResponse(int code, String message, T data) {
|
||||
this.code = code;
|
||||
this.message = message;
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
public static <T> ApiResponse<T> success(T data) {
|
||||
return new ApiResponse<>(200, "success", data);
|
||||
}
|
||||
|
||||
public static <T> ApiResponse<T> success(String message, T data) {
|
||||
return new ApiResponse<>(200, message, data);
|
||||
}
|
||||
|
||||
public static <T> ApiResponse<T> created(T data) {
|
||||
return new ApiResponse<>(201, "created", data);
|
||||
}
|
||||
|
||||
public static <T> ApiResponse<T> error(int code, String message) {
|
||||
return new ApiResponse<>(code, message, null);
|
||||
}
|
||||
|
||||
public static <T> ApiResponse<T> notFound(String message) {
|
||||
return new ApiResponse<>(404, message, null);
|
||||
}
|
||||
|
||||
public static <T> ApiResponse<T> badRequest(String message) {
|
||||
return new ApiResponse<>(400, message, null);
|
||||
}
|
||||
|
||||
public ApiResponse<T> put(String key, Object value) {
|
||||
if (this.extra == null) {
|
||||
this.extra = new HashMap<>();
|
||||
}
|
||||
this.extra.put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public int getCode() { return code; }
|
||||
public void setCode(int code) { this.code = code; }
|
||||
public String getMessage() { return message; }
|
||||
public void setMessage(String message) { this.message = message; }
|
||||
public T getData() { return data; }
|
||||
public void setData(T data) { this.data = data; }
|
||||
public Map<String, Object> getExtra() { return extra; }
|
||||
public void setExtra(Map<String, Object> extra) { this.extra = extra; }
|
||||
}
|
||||
18
week9/src/main/java/com/learn/dto/ChatHistoryVo.java
Normal file
18
week9/src/main/java/com/learn/dto/ChatHistoryVo.java
Normal file
@@ -0,0 +1,18 @@
|
||||
package com.learn.dto;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class ChatHistoryVo {
|
||||
|
||||
private String conversationId;
|
||||
private int messageCount;
|
||||
private List<Map<String, Object>> messages;
|
||||
|
||||
public String getConversationId() { return conversationId; }
|
||||
public void setConversationId(String conversationId) { this.conversationId = conversationId; }
|
||||
public int getMessageCount() { return messageCount; }
|
||||
public void setMessageCount(int messageCount) { this.messageCount = messageCount; }
|
||||
public List<Map<String, Object>> getMessages() { return messages; }
|
||||
public void setMessages(List<Map<String, Object>> messages) { this.messages = messages; }
|
||||
}
|
||||
16
week9/src/main/java/com/learn/dto/ChatRequest.java
Normal file
16
week9/src/main/java/com/learn/dto/ChatRequest.java
Normal file
@@ -0,0 +1,16 @@
|
||||
package com.learn.dto;
|
||||
|
||||
import jakarta.validation.constraints.NotBlank;
|
||||
|
||||
public class ChatRequest {
|
||||
|
||||
@NotBlank(message = "消息不能为空")
|
||||
private String message;
|
||||
|
||||
private String conversationId;
|
||||
|
||||
public String getMessage() { return message; }
|
||||
public void setMessage(String message) { this.message = message; }
|
||||
public String getConversationId() { return conversationId; }
|
||||
public void setConversationId(String conversationId) { this.conversationId = conversationId; }
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.learn.exception;
|
||||
|
||||
import com.learn.dto.ApiResponse;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.bind.MethodArgumentNotValidException;
|
||||
import org.springframework.web.bind.annotation.ExceptionHandler;
|
||||
import org.springframework.web.bind.annotation.RestControllerAdvice;
|
||||
|
||||
@RestControllerAdvice
|
||||
public class GlobalExceptionHandler {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(GlobalExceptionHandler.class);
|
||||
|
||||
@ExceptionHandler(MethodArgumentNotValidException.class)
|
||||
public ResponseEntity<ApiResponse<?>> handleValidation(MethodArgumentNotValidException ex) {
|
||||
var errors = new StringBuilder();
|
||||
for (var error : ex.getBindingResult().getFieldErrors()) {
|
||||
if (!errors.isEmpty()) errors.append("; ");
|
||||
errors.append(error.getField()).append(": ").append(error.getDefaultMessage());
|
||||
}
|
||||
return ResponseEntity.badRequest()
|
||||
.body(ApiResponse.badRequest("参数校验失败: " + errors));
|
||||
}
|
||||
|
||||
@ExceptionHandler(IllegalArgumentException.class)
|
||||
public ResponseEntity<ApiResponse<?>> handleIllegalArgument(IllegalArgumentException ex) {
|
||||
return ResponseEntity.badRequest()
|
||||
.body(ApiResponse.badRequest(ex.getMessage()));
|
||||
}
|
||||
|
||||
@ExceptionHandler(Exception.class)
|
||||
public ResponseEntity<ApiResponse<?>> handleAll(Exception ex) {
|
||||
log.error("Unexpected error", ex);
|
||||
return ResponseEntity.status(500)
|
||||
.body(ApiResponse.error(500, "服务器内部错误: " + ex.getMessage()));
|
||||
}
|
||||
}
|
||||
118
week9/src/main/java/com/learn/service/ChatService.java
Normal file
118
week9/src/main/java/com/learn/service/ChatService.java
Normal file
@@ -0,0 +1,118 @@
|
||||
package com.learn.service;
|
||||
|
||||
import com.learn.dto.ChatHistoryVo;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.stereotype.Service;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATION_ID;
|
||||
|
||||
@Service
|
||||
public class ChatService {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(ChatService.class);
|
||||
|
||||
private final ChatClient chatClient;
|
||||
private final ChatMemory chatMemory;
|
||||
|
||||
public ChatService(ChatClient chatClient, ChatMemory chatMemory) {
|
||||
this.chatClient = chatClient;
|
||||
this.chatMemory = chatMemory;
|
||||
}
|
||||
|
||||
// ==================== Day 5: 同步对话 ====================
|
||||
|
||||
/**
|
||||
* 一问一答,等待完整回复后返回。
|
||||
*/
|
||||
public String chat(String message) {
|
||||
return chatClient.prompt()
|
||||
.user(message)
|
||||
.call()
|
||||
.content();
|
||||
}
|
||||
|
||||
// ==================== Day 6: 流式对话 ====================
|
||||
|
||||
/**
|
||||
* 流式对话,逐 token 返回,实现"打字机效果"。
|
||||
*
|
||||
* @param message 用户消息
|
||||
* @param conversationId 会话 ID(默认 "default")
|
||||
* @return Flux<String> 逐 token 流
|
||||
*/
|
||||
public Flux<String> chatStream(String message, String conversationId) {
|
||||
String convId = defaultIfNull(conversationId);
|
||||
log.debug("Streaming chat for conversation: {}", convId);
|
||||
return chatClient.prompt()
|
||||
.user(message)
|
||||
.advisors(a -> a.param(CONVERSATION_ID, convId))
|
||||
.stream()
|
||||
.content();
|
||||
}
|
||||
|
||||
// ==================== Day 7: 多轮对话 ====================
|
||||
|
||||
/**
|
||||
* 带上下文的多轮对话(非流式版)。
|
||||
*/
|
||||
public String chat(String message, String conversationId) {
|
||||
String convId = defaultIfNull(conversationId);
|
||||
log.debug("Chat with memory, conversation: {}", convId);
|
||||
return chatClient.prompt()
|
||||
.user(message)
|
||||
.advisors(a -> a.param(CONVERSATION_ID, convId))
|
||||
.call()
|
||||
.content();
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询某个会话的历史消息。
|
||||
*/
|
||||
public ChatHistoryVo getHistory(String conversationId) {
|
||||
String convId = defaultIfNull(conversationId);
|
||||
var messages = chatMemory.get(convId);
|
||||
var list = new ArrayList<Map<String, Object>>();
|
||||
for (var msg : messages) {
|
||||
var item = new HashMap<String, Object>();
|
||||
item.put("role", msg.getMessageType().name().toLowerCase());
|
||||
item.put("content", msg.getText());
|
||||
list.add(item);
|
||||
}
|
||||
ChatHistoryVo vo = new ChatHistoryVo();
|
||||
vo.setConversationId(convId);
|
||||
vo.setMessageCount(list.size());
|
||||
vo.setMessages(list);
|
||||
return vo;
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除某个会话的所有历史。
|
||||
*/
|
||||
public void clearHistory(String conversationId) {
|
||||
String convId = defaultIfNull(conversationId);
|
||||
chatMemory.clear(convId);
|
||||
log.debug("Cleared history for conversation: {}", convId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有会话 ID 列表。
|
||||
*/
|
||||
public List<String> listConversations() {
|
||||
// InMemoryChatMemory 没有直接列出所有 key 的方法
|
||||
// 这里用一个简化方案:客户端自行管理会话 ID
|
||||
return List.of();
|
||||
}
|
||||
|
||||
private String defaultIfNull(String id) {
|
||||
return (id == null || id.isBlank()) ? "default" : id;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user