package com.learn.controller; import com.learn.dto.ApiResponse; import com.learn.dto.ChatRequest; import com.learn.dto.RAGQueryRequest; import com.learn.service.AgentService; import com.learn.service.ChatService; import com.learn.service.PromptTemplateService; import com.learn.service.RAGService; import jakarta.validation.Valid; import org.springframework.ai.document.Document; import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.bind.annotation.*; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import java.util.List; import java.util.Map; @RestController public class ChatController { private final ChatService chatService; private final AgentService agentService; private final RAGService ragService; private final PromptTemplateService promptService; public ChatController(ChatService chatService, AgentService agentService, RAGService ragService, PromptTemplateService promptService) { this.chatService = chatService; this.agentService = agentService; this.ragService = ragService; this.promptService = promptService; } // ==================== Week 9 复用: 基础对话 ==================== @PostMapping("/api/chat") public ApiResponse> chat(@Valid @RequestBody ChatRequest request) { String convId = request.getConversationId() != null ? request.getConversationId() : "default"; String reply = chatService.chat(request.getMessage(), convId); return ApiResponse.success(Map.of("reply", reply, "conversationId", convId)); } @PostMapping(value = "/api/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public Flux> chatStream(@Valid @RequestBody ChatRequest request) { return chatService.chatStream(request.getMessage(), request.getConversationId()) .map(token -> ServerSentEvent.builder().data(token).build()) .concatWith(Mono.just(ServerSentEvent.builder() .event("done").data("[DONE]").build())); } @GetMapping("/api/chat/history") public ApiResponse getHistory(@RequestParam(defaultValue = "default") String conversationId) { return ApiResponse.success(chatService.getHistory(conversationId)); } @DeleteMapping("/api/chat/history") public ApiResponse clearHistory(@RequestParam(defaultValue = "default") String conversationId) { chatService.clearHistory(conversationId); return ApiResponse.success("已清除会话: " + conversationId); } // ==================== Day 3-4: Agent 任务(Function Calling) ==================== /** * 同步 Agent 执行 —— 模型自动选择并调用工具。 * * POST /api/agent * Body: {"message": "北京天气如何?", "conversationId": "会话ID(可选)"} */ @PostMapping("/api/agent") public ApiResponse> agent(@Valid @RequestBody ChatRequest request) { var result = agentService.execute(request.getMessage(), request.getConversationId()); return ApiResponse.success(result); } /** * 流式 Agent 执行。 * * POST /api/agent/stream */ @PostMapping(value = "/api/agent/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public Flux> agentStream(@Valid @RequestBody ChatRequest request) { return agentService.executeStream(request.getMessage(), request.getConversationId()) .map(token -> ServerSentEvent.builder().data(token).build()) .concatWith(Mono.just(ServerSentEvent.builder() .event("done").data("[DONE]").build())); } // ==================== Day 6: RAG 检索增强 ==================== /** * RAG 检索(只检索不生成,看检索效果)。 * * POST /api/rag/search * Body: {"question": "什么是 Spring AI?", "topK": 3} */ @PostMapping("/api/rag/search") public ApiResponse>> ragSearch(@Valid @RequestBody RAGQueryRequest request) { return ApiResponse.success(ragService.search(request.getQuestion(), request.getTopK())); } /** * 向知识库添加文档。 * * POST /api/rag/load * Body: [{"text": "文档内容", "metadata": {"source": "spring-doc"}}] */ @PostMapping("/api/rag/load") public ApiResponse> ragLoad(@RequestBody List> documents) { var docs = documents.stream() .map(m -> { String text = (String) m.get("text"); @SuppressWarnings("unchecked") var metadata = (Map) m.getOrDefault("metadata", Map.of()); return new Document(text, new java.util.HashMap<>(metadata)); }) .toList(); int count = ragService.addDocuments(docs); return ApiResponse.created(Map.of("added", count)); } // ==================== Day 7: Prompt 模板管理 ==================== /** * 获取所有模板名称和预览。 * * GET /api/prompts */ @GetMapping("/api/prompts") public ApiResponse>> listPrompts() { return ApiResponse.success(promptService.list()); } /** * 获取指定模板的原始内容。 * * GET /api/prompts/{name} */ @GetMapping("/api/prompts/{name}") public ApiResponse> getPrompt(@PathVariable String name) { String content = promptService.getTemplate(name); return ApiResponse.success(Map.of("name", name, "content", content)); } /** * 使用 Prompt 模板进行对话。 * * POST /api/chat/with-template * Body: {"message": "用户消息", "templateName": "java-expert", "params": {"key": "value"}, "conversationId": "xxx"} */ @PostMapping("/api/chat/with-template") public ApiResponse> chatWithTemplate(@RequestBody Map body) { String message = (String) body.get("message"); String templateName = (String) body.get("templateName"); String conversationId = (String) body.getOrDefault("conversationId", "default"); @SuppressWarnings("unchecked") var params = (Map) body.getOrDefault("params", Map.of()); String templateContent = promptService.apply(templateName, params); String reply = chatService.chatWithTemplate(templateContent, message, conversationId); return ApiResponse.success(Map.of("reply", reply, "conversationId", conversationId)); } }