167 lines
6.7 KiB
Java
167 lines
6.7 KiB
Java
package com.learn.controller;
|
||
|
||
import com.learn.dto.ApiResponse;
|
||
import com.learn.dto.ChatRequest;
|
||
import com.learn.dto.RAGQueryRequest;
|
||
import com.learn.service.AgentService;
|
||
import com.learn.service.ChatService;
|
||
import com.learn.service.PromptTemplateService;
|
||
import com.learn.service.RAGService;
|
||
import jakarta.validation.Valid;
|
||
import org.springframework.ai.document.Document;
|
||
import org.springframework.http.MediaType;
|
||
import org.springframework.http.codec.ServerSentEvent;
|
||
import org.springframework.web.bind.annotation.*;
|
||
import reactor.core.publisher.Flux;
|
||
import reactor.core.publisher.Mono;
|
||
|
||
import java.util.List;
|
||
import java.util.Map;
|
||
|
||
@RestController
|
||
public class ChatController {
|
||
|
||
private final ChatService chatService;
|
||
private final AgentService agentService;
|
||
private final RAGService ragService;
|
||
private final PromptTemplateService promptService;
|
||
|
||
public ChatController(ChatService chatService, AgentService agentService,
|
||
RAGService ragService, PromptTemplateService promptService) {
|
||
this.chatService = chatService;
|
||
this.agentService = agentService;
|
||
this.ragService = ragService;
|
||
this.promptService = promptService;
|
||
}
|
||
|
||
// ==================== Week 9 复用: 基础对话 ====================
|
||
|
||
@PostMapping("/api/chat")
|
||
public ApiResponse<Map<String, String>> chat(@Valid @RequestBody ChatRequest request) {
|
||
String convId = request.getConversationId() != null ? request.getConversationId() : "default";
|
||
String reply = chatService.chat(request.getMessage(), convId);
|
||
return ApiResponse.success(Map.of("reply", reply, "conversationId", convId));
|
||
}
|
||
|
||
@PostMapping(value = "/api/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||
public Flux<ServerSentEvent<String>> chatStream(@Valid @RequestBody ChatRequest request) {
|
||
return chatService.chatStream(request.getMessage(), request.getConversationId())
|
||
.map(token -> ServerSentEvent.<String>builder().data(token).build())
|
||
.concatWith(Mono.just(ServerSentEvent.<String>builder()
|
||
.event("done").data("[DONE]").build()));
|
||
}
|
||
|
||
@GetMapping("/api/chat/history")
|
||
public ApiResponse<?> getHistory(@RequestParam(defaultValue = "default") String conversationId) {
|
||
return ApiResponse.success(chatService.getHistory(conversationId));
|
||
}
|
||
|
||
@DeleteMapping("/api/chat/history")
|
||
public ApiResponse<?> clearHistory(@RequestParam(defaultValue = "default") String conversationId) {
|
||
chatService.clearHistory(conversationId);
|
||
return ApiResponse.success("已清除会话: " + conversationId);
|
||
}
|
||
|
||
// ==================== Day 3-4: Agent 任务(Function Calling) ====================
|
||
|
||
/**
|
||
* 同步 Agent 执行 —— 模型自动选择并调用工具。
|
||
*
|
||
* POST /api/agent
|
||
* Body: {"message": "北京天气如何?", "conversationId": "会话ID(可选)"}
|
||
*/
|
||
@PostMapping("/api/agent")
|
||
public ApiResponse<Map<String, Object>> agent(@Valid @RequestBody ChatRequest request) {
|
||
var result = agentService.execute(request.getMessage(), request.getConversationId());
|
||
return ApiResponse.success(result);
|
||
}
|
||
|
||
/**
|
||
* 流式 Agent 执行。
|
||
*
|
||
* POST /api/agent/stream
|
||
*/
|
||
@PostMapping(value = "/api/agent/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||
public Flux<ServerSentEvent<String>> agentStream(@Valid @RequestBody ChatRequest request) {
|
||
return agentService.executeStream(request.getMessage(), request.getConversationId())
|
||
.map(token -> ServerSentEvent.<String>builder().data(token).build())
|
||
.concatWith(Mono.just(ServerSentEvent.<String>builder()
|
||
.event("done").data("[DONE]").build()));
|
||
}
|
||
|
||
// ==================== Day 6: RAG 检索增强 ====================
|
||
|
||
/**
|
||
* RAG 检索(只检索不生成,看检索效果)。
|
||
*
|
||
* POST /api/rag/search
|
||
* Body: {"question": "什么是 Spring AI?", "topK": 3}
|
||
*/
|
||
@PostMapping("/api/rag/search")
|
||
public ApiResponse<List<Map<String, Object>>> ragSearch(@Valid @RequestBody RAGQueryRequest request) {
|
||
return ApiResponse.success(ragService.search(request.getQuestion(), request.getTopK()));
|
||
}
|
||
|
||
/**
|
||
* 向知识库添加文档。
|
||
*
|
||
* POST /api/rag/load
|
||
* Body: [{"text": "文档内容", "metadata": {"source": "spring-doc"}}]
|
||
*/
|
||
@PostMapping("/api/rag/load")
|
||
public ApiResponse<Map<String, Object>> ragLoad(@RequestBody List<Map<String, Object>> documents) {
|
||
var docs = documents.stream()
|
||
.map(m -> {
|
||
String text = (String) m.get("text");
|
||
@SuppressWarnings("unchecked")
|
||
var metadata = (Map<String, Object>) m.getOrDefault("metadata", Map.of());
|
||
return new Document(text, new java.util.HashMap<>(metadata));
|
||
})
|
||
.toList();
|
||
int count = ragService.addDocuments(docs);
|
||
return ApiResponse.created(Map.of("added", count));
|
||
}
|
||
|
||
// ==================== Day 7: Prompt 模板管理 ====================
|
||
|
||
/**
|
||
* 获取所有模板名称和预览。
|
||
*
|
||
* GET /api/prompts
|
||
*/
|
||
@GetMapping("/api/prompts")
|
||
public ApiResponse<List<Map<String, String>>> listPrompts() {
|
||
return ApiResponse.success(promptService.list());
|
||
}
|
||
|
||
/**
|
||
* 获取指定模板的原始内容。
|
||
*
|
||
* GET /api/prompts/{name}
|
||
*/
|
||
@GetMapping("/api/prompts/{name}")
|
||
public ApiResponse<Map<String, String>> getPrompt(@PathVariable String name) {
|
||
String content = promptService.getTemplate(name);
|
||
return ApiResponse.success(Map.of("name", name, "content", content));
|
||
}
|
||
|
||
/**
|
||
* 使用 Prompt 模板进行对话。
|
||
*
|
||
* POST /api/chat/with-template
|
||
* Body: {"message": "用户消息", "templateName": "java-expert", "params": {"key": "value"}, "conversationId": "xxx"}
|
||
*/
|
||
@PostMapping("/api/chat/with-template")
|
||
public ApiResponse<Map<String, String>> chatWithTemplate(@RequestBody Map<String, Object> body) {
|
||
String message = (String) body.get("message");
|
||
String templateName = (String) body.get("templateName");
|
||
String conversationId = (String) body.getOrDefault("conversationId", "default");
|
||
@SuppressWarnings("unchecked")
|
||
var params = (Map<String, String>) body.getOrDefault("params", Map.of());
|
||
|
||
String templateContent = promptService.apply(templateName, params);
|
||
String reply = chatService.chatWithTemplate(templateContent, message, conversationId);
|
||
return ApiResponse.success(Map.of("reply", reply, "conversationId", conversationId));
|
||
}
|
||
}
|