Files
gc-plan/week10/src/main/java/com/learn/controller/ChatController.java
2026-04-30 16:08:39 +08:00

167 lines
6.7 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package com.learn.controller;
import com.learn.dto.ApiResponse;
import com.learn.dto.ChatRequest;
import com.learn.dto.RAGQueryRequest;
import com.learn.service.AgentService;
import com.learn.service.ChatService;
import com.learn.service.PromptTemplateService;
import com.learn.service.RAGService;
import jakarta.validation.Valid;
import org.springframework.ai.document.Document;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.List;
import java.util.Map;
@RestController
public class ChatController {
private final ChatService chatService;
private final AgentService agentService;
private final RAGService ragService;
private final PromptTemplateService promptService;
public ChatController(ChatService chatService, AgentService agentService,
RAGService ragService, PromptTemplateService promptService) {
this.chatService = chatService;
this.agentService = agentService;
this.ragService = ragService;
this.promptService = promptService;
}
// ==================== Week 9 复用: 基础对话 ====================
@PostMapping("/api/chat")
public ApiResponse<Map<String, String>> chat(@Valid @RequestBody ChatRequest request) {
String convId = request.getConversationId() != null ? request.getConversationId() : "default";
String reply = chatService.chat(request.getMessage(), convId);
return ApiResponse.success(Map.of("reply", reply, "conversationId", convId));
}
@PostMapping(value = "/api/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ServerSentEvent<String>> chatStream(@Valid @RequestBody ChatRequest request) {
return chatService.chatStream(request.getMessage(), request.getConversationId())
.map(token -> ServerSentEvent.<String>builder().data(token).build())
.concatWith(Mono.just(ServerSentEvent.<String>builder()
.event("done").data("[DONE]").build()));
}
@GetMapping("/api/chat/history")
public ApiResponse<?> getHistory(@RequestParam(defaultValue = "default") String conversationId) {
return ApiResponse.success(chatService.getHistory(conversationId));
}
@DeleteMapping("/api/chat/history")
public ApiResponse<?> clearHistory(@RequestParam(defaultValue = "default") String conversationId) {
chatService.clearHistory(conversationId);
return ApiResponse.success("已清除会话: " + conversationId);
}
// ==================== Day 3-4: Agent 任务Function Calling ====================
/**
* 同步 Agent 执行 —— 模型自动选择并调用工具。
*
* POST /api/agent
* Body: {"message": "北京天气如何?", "conversationId": "会话ID可选"}
*/
@PostMapping("/api/agent")
public ApiResponse<Map<String, Object>> agent(@Valid @RequestBody ChatRequest request) {
var result = agentService.execute(request.getMessage(), request.getConversationId());
return ApiResponse.success(result);
}
/**
* 流式 Agent 执行。
*
* POST /api/agent/stream
*/
@PostMapping(value = "/api/agent/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ServerSentEvent<String>> agentStream(@Valid @RequestBody ChatRequest request) {
return agentService.executeStream(request.getMessage(), request.getConversationId())
.map(token -> ServerSentEvent.<String>builder().data(token).build())
.concatWith(Mono.just(ServerSentEvent.<String>builder()
.event("done").data("[DONE]").build()));
}
// ==================== Day 6: RAG 检索增强 ====================
/**
* RAG 检索(只检索不生成,看检索效果)。
*
* POST /api/rag/search
* Body: {"question": "什么是 Spring AI", "topK": 3}
*/
@PostMapping("/api/rag/search")
public ApiResponse<List<Map<String, Object>>> ragSearch(@Valid @RequestBody RAGQueryRequest request) {
return ApiResponse.success(ragService.search(request.getQuestion(), request.getTopK()));
}
/**
* 向知识库添加文档。
*
* POST /api/rag/load
* Body: [{"text": "文档内容", "metadata": {"source": "spring-doc"}}]
*/
@PostMapping("/api/rag/load")
public ApiResponse<Map<String, Object>> ragLoad(@RequestBody List<Map<String, Object>> documents) {
var docs = documents.stream()
.map(m -> {
String text = (String) m.get("text");
@SuppressWarnings("unchecked")
var metadata = (Map<String, Object>) m.getOrDefault("metadata", Map.of());
return new Document(text, new java.util.HashMap<>(metadata));
})
.toList();
int count = ragService.addDocuments(docs);
return ApiResponse.created(Map.of("added", count));
}
// ==================== Day 7: Prompt 模板管理 ====================
/**
* 获取所有模板名称和预览。
*
* GET /api/prompts
*/
@GetMapping("/api/prompts")
public ApiResponse<List<Map<String, String>>> listPrompts() {
return ApiResponse.success(promptService.list());
}
/**
* 获取指定模板的原始内容。
*
* GET /api/prompts/{name}
*/
@GetMapping("/api/prompts/{name}")
public ApiResponse<Map<String, String>> getPrompt(@PathVariable String name) {
String content = promptService.getTemplate(name);
return ApiResponse.success(Map.of("name", name, "content", content));
}
/**
* 使用 Prompt 模板进行对话。
*
* POST /api/chat/with-template
* Body: {"message": "用户消息", "templateName": "java-expert", "params": {"key": "value"}, "conversationId": "xxx"}
*/
@PostMapping("/api/chat/with-template")
public ApiResponse<Map<String, String>> chatWithTemplate(@RequestBody Map<String, Object> body) {
String message = (String) body.get("message");
String templateName = (String) body.get("templateName");
String conversationId = (String) body.getOrDefault("conversationId", "default");
@SuppressWarnings("unchecked")
var params = (Map<String, String>) body.getOrDefault("params", Map.of());
String templateContent = promptService.apply(templateName, params);
String reply = chatService.chatWithTemplate(templateContent, message, conversationId);
return ApiResponse.success(Map.of("reply", reply, "conversationId", conversationId));
}
}