From 30af25082df647183878780703e69e4c9c1cf062 Mon Sep 17 00:00:00 2001 From: Element <2401926342@qq.com> Date: Wed, 20 Aug 2025 20:12:32 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.py | 516 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ nodes.py | 300 ++++++++++++++++++++++++++++++++ state.py | 64 +++++++ 3 files changed, 880 insertions(+) create mode 100644 app.py create mode 100644 nodes.py create mode 100644 state.py diff --git a/app.py b/app.py new file mode 100644 index 0000000..a0bb13b --- /dev/null +++ b/app.py @@ -0,0 +1,516 @@ +import streamlit as st +import asyncio +import os +import json +import time +from state import ReportState, Section +from configuration import SearchAPI +from nodes import ( + generate_report_plan, + think_and_generate_queries, + search_web, + write_section, + compile_final_report, + process_citations, + save_final_report +) + +# 从 config.json 文件读取配置 +import json + +with open("config.json", "r", encoding="utf-8") as f: + config_data = json.load(f) + +# 设置环境变量 - API密钥 +os.environ["DEEPSEEK_API_KEY"] = config_data["environment_variables"]["DEEPSEEK_API_KEY"] +os.environ["TAVILY_API_KEY"] = config_data["environment_variables"]["TAVILY_API_KEY"] + +# 使用字典作为配置 +config_dict = config_data["config_dict"] +# 替换配置中的环境变量占位符 +if "search_api_config" in config_dict["configurable"] and "api_key" in config_dict["configurable"]["search_api_config"]: + config_dict["configurable"]["search_api_config"]["api_key"] = os.environ.get("TAVILY_API_KEY", "") + +# --- Streamlit App --- + +st.set_page_config(page_title="DeepResearch Stream", layout="wide", page_icon="🔬") + +# 简化CSS样式,移除可能导致空白的margin +st.markdown(""" + +""", unsafe_allow_html=True) + +# --- 历史记录功能 --- +def load_history(): + """加载历史记录""" + reports_dir = os.path.join(os.getcwd(), "reports") + os.makedirs(reports_dir, exist_ok=True) + history_file = os.path.join(reports_dir, "history.json") + if os.path.exists(history_file): + with open(history_file, "r", encoding="utf-8") as f: + return json.load(f) + return [] + +def save_history(history): + """保存历史记录""" + reports_dir = os.path.join(os.getcwd(), "reports") + os.makedirs(reports_dir, exist_ok=True) + history_file = os.path.join(reports_dir, "history.json") + with open(history_file, "w", encoding="utf-8") as f: + json.dump(history, f, ensure_ascii=False, indent=2) + +def add_to_history(topic, file_path): + """添加到历史记录""" + history = load_history() + history.insert(0, { + "topic": topic, + "file_path": file_path, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S") + }) + # 只保留最近50条记录 + history = history[:50] + save_history(history) + +# 侧边栏内容 +with st.sidebar: + st.markdown('', unsafe_allow_html=True) + +# 初始化 session state +if 'report_state' not in st.session_state: + st.session_state.report_state = None +if 'generating' not in st.session_state: + st.session_state.generating = False +if 'outline_generated' not in st.session_state: + st.session_state.outline_generated = False +if 'report_generated' not in st.session_state: + st.session_state.report_generated = False +if 'feedback_given' not in st.session_state: + st.session_state.feedback_given = False + +# --- Helper Functions for Streamlit --- + +async def run_generate_report_plan(state, config): + return await generate_report_plan(state, config) + +async def run_generate_full_report(state, config): + # 逻辑从 main.py 的 generate_report_sections, compile_report, save_report 迁移过来 + completed_sections = [] + total_sections = len(state["sections"]) + + # 创建一个容器来显示进度和状态 + progress_container = st.container() + + with progress_container: + progress_bar = st.progress(0) + status_text = st.empty() + + # 为每个章节的输出创建一个容器 + section_placeholders = [st.empty() for _ in range(total_sections)] + + while state["current_section_index"] < total_sections: + current_index = state["current_section_index"] + current_section = state["sections"][current_index] + + # 更新状态信息 + status_text.markdown(f"
正在生成章节 {current_index + 1}/{total_sections}: {current_section.name}
", unsafe_allow_html=True) + + # 获取当前章节的UI容器 + section_container = section_placeholders[current_index] + + with section_container.container(): + st.markdown(f"

章节 {current_index + 1}: {current_section.name}

", unsafe_allow_html=True) + + # 1. 展示思考过程 + with st.expander("🧠 思考过程", expanded=False): + think_result = think_and_generate_queries(state, config) + state.update(think_result) + st.write(think_result.get("thinking_process", "没有记录思考过程。")) + + # 2. 展示搜索查询 + with st.expander("🔍 搜索查询", expanded=False): + search_queries = think_result.get("search_queries", {}).get(current_section.name, []) + if search_queries: + for query in search_queries: + st.code(query.search_query, language="text") + else: + st.write("没有为本章节生成搜索查询。") + + search_result = await search_web(state, config) + state.update(search_result) + + # 3. 流式展示生成的章节内容 + content_placeholder = st.empty() + full_content_for_section = "" + + async def stream_callback_for_section(chunk): + nonlocal full_content_for_section + full_content_for_section += chunk + content_placeholder.markdown(f"
{full_content_for_section} ▌
", unsafe_allow_html=True) + + write_result = await write_section(state, config, stream_callback=stream_callback_for_section) + + # 写入完成后,移除光标 + content_placeholder.markdown(f"
{full_content_for_section}
", unsafe_allow_html=True) + + if isinstance(write_result, dict): + state.update(write_result) + if 'completed_sections' in write_result: + completed_sections.extend(write_result['completed_sections']) + state["current_section_index"] += 1 + elif hasattr(write_result, 'update'): + update_data = write_result.update + state.update(update_data) + if 'completed_sections' in update_data: + completed_sections.extend(update_data['completed_sections']) + + if hasattr(write_result, 'goto') and write_result.goto == "compile_final_report": + break + + progress = (current_index + 1) / total_sections + progress_bar.progress(progress) + + state["completed_sections"] = completed_sections + status_text.markdown("
所有章节完成,正在编译最终报告...
", unsafe_allow_html=True) + + compile_result = await compile_final_report(state) + state.update(compile_result) + + citations_result = process_citations(state) + state.update(citations_result) + + status_text.markdown("
✅ 报告生成完毕!
", unsafe_allow_html=True) + + # 保存文件 - 使用相对路径 + topic_slug = state["topic"].lower().replace(" ", "_")[:30] + timestamp = time.strftime("%Y%m%d_%H%M%S") + reports_dir = os.path.join(os.getcwd(), "reports") + os.makedirs(reports_dir, exist_ok=True) + output_file = os.path.join(reports_dir, f"report_{topic_slug}_{timestamp}.md") + with open(output_file, "w", encoding="utf-8") as f: + f.write(state["final_report"]) + + st.session_state.report_file_path = output_file + + # 添加到历史记录 + add_to_history(state["topic"], output_file) + + return state + +# --- UI Components --- + +# 直接在主区域添加内容,不使用额外的容器 +# 主要内容区域 +# 主页面标题 +st.markdown(""" +
+ 🔬 + AutomaSynth 探策矩阵 +
+""", unsafe_allow_html=True) +#st.markdown("

(星航电子工作室)

", unsafe_allow_html=True) + + +# 1. Topic Input +st.markdown("
", unsafe_allow_html=True) +st.markdown("#### 📝 报告主题") +topic = st.text_input("请输入您想要研究的报告主题", value="例如:人工智能在医疗领域的应用", + placeholder="例如:人工智能在医疗领域的应用") + +col1, col2 = st.columns([1, 4]) +with col1: + generate_outline_btn = st.button("生成报告大纲", disabled=st.session_state.generating, + type="primary", use_container_width=True) +st.markdown("
", unsafe_allow_html=True) + +if generate_outline_btn: + st.session_state.generating = True + st.session_state.outline_generated = False + st.session_state.report_generated = False + st.session_state.feedback_given = False + + initial_state = { + "topic": topic, + "feedback_on_report_plan": "", + "sections": [], + "thinking_process": [], + "completed_sections": [], + "current_section_index": 0, + "final_report": "", + "search_queries": {}, + "sources": {}, + "source_str": "" + } + + with st.spinner("正在生成报告大纲..."): + try: + plan_result = asyncio.run(run_generate_report_plan(initial_state, config_dict)) + initial_state.update(plan_result) + st.session_state.report_state = initial_state + st.session_state.outline_generated = True + except Exception as e: + st.error(f"生成报告大纲时出错:{str(e)}") + finally: + st.session_state.generating = False + st.rerun(scope="app") + +# 2. Outline Display and Feedback +if st.session_state.outline_generated and not st.session_state.report_generated: + st.markdown("
", unsafe_allow_html=True) + st.markdown("#### 📋 报告大纲 (可在此直接编辑)") + + with st.form(key='outline_form'): + # Add a field for the main report title + st.session_state.report_state['topic'] = st.text_input( + "报告总标题", + value=st.session_state.report_state.get('topic', '人工智能在医疗领域的应用') + ) + + sections = st.session_state.report_state.get("sections", []) + + # Create UI for each section to be editable + for i, section in enumerate(sections): + st.markdown(f"
", unsafe_allow_html=True) + st.markdown(f"##### 章节 {i+1}") + # Use unique keys for each widget to avoid Streamlit's DuplicateWidgetID error + section.name = st.text_input("章节名称", value=section.name, key=f"name_{i}") + section.description = st.text_area("章节描述", value=section.description, key=f"desc_{i}", height=100) + st.markdown("
", unsafe_allow_html=True) + + col1, col2 = st.columns([1, 4]) + with col1: + submitted = st.form_submit_button("确认修改并生成完整报告", + disabled=st.session_state.generating, + type="primary", use_container_width=True) + st.markdown("
", unsafe_allow_html=True) + + if submitted: + # User has confirmed the (potentially edited) outline, proceed to generate full report + st.session_state.generating = True + st.session_state.feedback_given = True # Prevent resubmission + + with st.spinner("正在生成完整报告..."): + try: + final_state = asyncio.run(run_generate_full_report(st.session_state.report_state, config_dict)) + st.session_state.report_state = final_state + st.session_state.report_generated = True + except Exception as e: + st.error(f"生成报告时出错:{str(e)}") + finally: + st.session_state.generating = False + st.rerun(scope="app") + +# 3. Final Report Display +if st.session_state.report_generated: + st.markdown("
", unsafe_allow_html=True) + st.markdown("#### 📄 最终报告") + final_report_content = st.session_state.report_state.get('final_report', '报告生成失败。') + st.markdown(final_report_content) + st.markdown("
", unsafe_allow_html=True) + + report_path = st.session_state.get("report_file_path", "") + if report_path and os.path.exists(report_path): + with open(report_path, "r", encoding="utf-8") as f: + st.download_button( + label="📥 下载报告 (Markdown)", + data=f.read(), + file_name=os.path.basename(report_path), + mime="text/markdown", + type="primary" + ) + +# 底部按钮 +st.markdown("---") +if st.button("🔄 开始新的报告", type="secondary"): + st.session_state.clear() + st.rerun(scope="app") \ No newline at end of file diff --git a/nodes.py b/nodes.py new file mode 100644 index 0000000..fac7a2a --- /dev/null +++ b/nodes.py @@ -0,0 +1,300 @@ +from typing import Literal, List +import aiofiles +from langchain.chat_models import init_chat_model +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.runnables import RunnableConfig +from datetime import datetime +from langgraph.types import interrupt, Command +import os +import json +from state import ( + Sections, + ReportState, + Queries, + SearchQuery, + Section +) + +from prompts import ( + report_planner_query_writer_instructions, + report_planner_instructions, + query_writer_instructions, + section_writer_instructions, + section_writer_system_prompt +) + +from configuration import Configuration +from utils import ( + compile_completed_sections, + get_config_value, + get_search_params, + select_and_execute_search, + format_sources +) + +async def generate_report_plan(state: ReportState, config: RunnableConfig): + """用于生成报告大纲,同时进行网络搜索帮助自己更好地规划大纲内容。 + 会自动中断等待人类反馈,若不通过则根据反馈重新生成大纲。直到通过并跳转到章节生成部分 + """ + + topic = state["topic"] + feedback = state.get("feedback_on_report_plan", None) + + configurable = Configuration.from_runnable_config(config) + number_of_queries = configurable.number_of_queries + search_api = get_config_value(configurable.search_api) + search_api_config = configurable.search_api_config or {} # Get the config dict, default to empty + params_to_pass = get_search_params(search_api, search_api_config) # Filter parameters + writer_provider = get_config_value(configurable.writer_provider) + writer_model_name = get_config_value(configurable.writer_model) + + writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider) + structured_llm = writer_model.with_structured_output(Queries) + + # 根据主题先生成查询Prompt,让模型生成一系列查询等待API搜索,从而获取信息 + query_prompt = report_planner_query_writer_instructions.format( + topic=topic, number_of_queries=number_of_queries) + + results = structured_llm.invoke(query_prompt) + + query_list: list[str] = [query.search_query for query in results.queries] + + # TODO:这里和search_web重复了,需要做一个封装节省代码 + unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass) + source_str = format_sources(unique_sources) + + # 将搜索到的内容作为上下文提供给了模型 + sections_prompt = report_planner_instructions.format( + topic=topic, context=source_str, feedback=feedback) + + planner_provider = get_config_value(configurable.planner_provider) + planner_model = get_config_value(configurable.planner_model) + planner_llm = init_chat_model(model=planner_model, + model_provider=planner_provider) + + structured_llm = planner_llm.with_structured_output(Sections) + report_sections = structured_llm.invoke(sections_prompt) + + # 获取写好的sections + sections: list[Section] = report_sections.sections + return {"sections": sections, "current_section_index": 0} + + + +def human_feedback(state: ReportState, config: RunnableConfig) -> Command[Literal["generate_report_plan", "think_and_generate_queries"]]: + """ + 获取人类反馈来修改大纲,若通过则进入到ReAct部分。 + """ + + sections = state['sections'] + sections_str = "\n\n".join( + f"{number}. {section.name}\n" + f"{section.description}\n" + f"**是否需要搜索研究: {'是' if section.research else '否'}**\n" + for number, section in enumerate(sections, 1) + ) + + interrupt_message = f"""请为这份大纲提供修改意见: + {sections_str} + 这份大纲符合您的需求吗?,输入“true”以通过大纲,或者提供修改意见来修改大纲。""" + + feedback = interrupt(interrupt_message) + + if isinstance(feedback, bool) and feedback is True: + return Command(goto="think_and_generate_queries") + elif isinstance(feedback, str): + return Command(goto="generate_report_plan", + update={"feedback_on_report_plan": feedback}) + else: + raise TypeError(f"Interrupt value of type {type(feedback)} is not supported.") + + + + +def think_and_generate_queries(state: ReportState, config: RunnableConfig): + """Think部分 + 思考目前内容,生成查询语句,同时记录思考过程 + """ + + current_section_index = state["current_section_index"] + section = state["sections"][current_section_index] + section_name = section.name + configurable = Configuration.from_runnable_config(config) + number_of_queries = configurable.number_of_queries + + writer_provider = get_config_value(configurable.writer_provider) + writer_model_name = get_config_value(configurable.writer_model) + writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider) + structured_llm = writer_model.with_structured_output(Queries) + + system_instructions = query_writer_instructions.format(section=section_name, + section_description=section.description, + number_of_queries=number_of_queries) + + queries: Queries = structured_llm.invoke(system_instructions) + search_queries: List[SearchQuery] = queries.queries + return {"search_queries": {section_name: search_queries}, + "thinking_process": [queries.thought], + } + + +async def search_web(state: ReportState, config: RunnableConfig): + """Acting部分,进行搜索并返回搜索内容 + """ + current_section_index = state["current_section_index"] + section = state["sections"][current_section_index] + section_name = section.name + search_queries = state["search_queries"][section_name] + + configurable = Configuration.from_runnable_config(config) + search_api = get_config_value(configurable.search_api) + search_api_config = configurable.search_api_config or {} + params_to_pass = get_search_params(search_api, search_api_config) + query_list = [query.search_query for query in search_queries] + print(query_list) + + unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass) + + current_sources = { + url: { + "title": source['title'], + "content": source["content"] + } for url, source in unique_sources.items() + } + + all_sources = state.get("sources", {}).copy() + all_sources.update(current_sources) + + source_str = format_sources(unique_sources) + + return {"source_str": source_str, "sources": all_sources} + + + + +async def write_section(state: ReportState, config: RunnableConfig, stream_callback=None) -> Command[Literal["think_and_generate_queries", "compile_final_report"]]: + """完成一个章节的内容 + """ + + current_section_index = state["current_section_index"] + section = state["sections"][current_section_index] + section_name = section.name + source_str = state["source_str"] + configurable = Configuration.from_runnable_config(config) + section_content = compile_completed_sections(state) + section_writer_inputs = section_writer_instructions.format( + section_name=section_name, + section_description=section.description, + context=source_str, + section_content=section_content + ) + + # Generate section + writer_provider = get_config_value(configurable.writer_provider) + writer_model_name = get_config_value(configurable.writer_model) + writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider) + prompt = [SystemMessage(content=section_writer_system_prompt), + HumanMessage(content=section_writer_inputs)] + + writer_result = writer_model.astream(prompt) + content_parts = [] + + async for chunk in writer_result: + content_parts.append(chunk.content) + if stream_callback: + await stream_callback(chunk.content) + else: + print(chunk.content, end='', flush=True) + + section.content = ''.join(content_parts) + + + + if current_section_index == len(state["sections"]) - 1: + return Command( + update={"completed_sections": [section]}, + goto="compile_final_report" + ) + else: + return Command( + update={"completed_sections": [section], "current_section_index": current_section_index + 1}, + goto="think_and_generate_queries" + ) + + + + +async def compile_final_report(state: ReportState): + topic = state.get("topic", "未命名报告") + sections = state["completed_sections"] + final_report = f"# {topic}\n\n" + "\n\n".join([section.content for section in sections]) + return {"final_report": final_report} + + +async def save_final_report(state: ReportState): + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + + # 定义报告保存的基础目录 + base_reports_dir = r"C:\Users\24019\Desktop\Deepresearch-stream\reports" + output_dir = os.path.join(base_reports_dir, f"report_{timestamp}") + os.makedirs(output_dir, exist_ok=True) + + async def write_file(filename: str, content: str): + async with aiofiles.open(os.path.join(output_dir, filename), 'w', encoding='utf-8') as f: + await f.write(content) + + await write_file("report.md", state["final_report"]) + await write_file("sources.json", json.dumps(state["sources"], indent=2, ensure_ascii=False)) + + +def process_citations(state: ReportState): + """处理引用 + 替换正文中的 [url:] → [1](只保留首次引用)... + 多次引用同一链接,后续位置直接删除,不重复编号 + 若某个 URL 不在 sources 中,则忽略 + """ + from collections import OrderedDict + import re + + url_to_index = OrderedDict() + pattern = r"\[url:(https?://[^\]]+)\]" + matches = list(re.finditer(pattern, state["final_report"])) + + url_title_map = state["sources"] + index = 1 + + try: + for match in matches: + url = match.group(1) + if url not in url_title_map: + continue + if url not in url_to_index: + url_to_index[url] = index + index += 1 + + # 标记哪些URL已经在正文中被替换过,后面都删掉 + replaced_urls = set() + + def replacer(match): + url = match.group(1) + if url not in url_title_map: + return "" # 无效的URL,删掉 + if url in url_to_index and url not in replaced_urls: + replaced_urls.add(url) + return f"[{url_to_index[url]}]" # 首次替换 + else: + return "" + + processed_text = re.sub(pattern, replacer, state["final_report"]) + + citation_lines = [] + for url, idx in url_to_index.items(): + title = url_title_map[url]["title"] + citation_lines.append(f"[{idx}] [{title}]({url})") + + citation_list = "\n".join(citation_lines) + final_report = processed_text + "\n\n## 参考链接:\n" + citation_list + except Exception as e: + final_report = state["final_report"] + + return {"final_report": final_report} \ No newline at end of file diff --git a/state.py b/state.py new file mode 100644 index 0000000..d9d7df8 --- /dev/null +++ b/state.py @@ -0,0 +1,64 @@ +from typing import Annotated, List, TypedDict, Literal, Dict +from pydantic import BaseModel, Field +import operator + + +class Section(BaseModel): + name: str = Field( + description="章节的标题,应简洁明了地概括本章节的主题或主旨。", + ) + description: str = Field( + description="简要说明这一章节将围绕哪些内容进行展开,字数在50-100字。描述应突出章节的写作重点,明确要讨论的问题、角度或结构。", + ) + research: bool = Field( + description="判断该章节的写作是否需要进行网络搜索。若该部分需要引用外部资料、数据、案例或是知识储备不足以独立完成,则应标记为需要研究。" + ) + content: str = Field( + description="章节的主要写作内容。暂时留空,后续将用于填写具体文本内容" + ) + + +class Sections(BaseModel): + sections: List[Section] = Field( + description="包含报告的各个章节", + ) + + +class SearchQuery(BaseModel): + search_query: str = Field(None, description="网络搜索的查询") + + +class Queries(BaseModel): + thought: str = Field(None, description="思考过程") + queries: List[SearchQuery] = Field( + description="包含网络搜索的查询列表", + ) + + +'''作为整张图的状态输入输出,输入是报告主题,输出是最终报告''' + + +class ReportStateInput(TypedDict): + topic: str # 报告主题 + + +class ReportStateOutput(TypedDict): + final_report: str # 最终报告 + + +'''管理报告的状态,有报告的各个章节''' + + +class ReportState(TypedDict): + topic: str + feedback_on_report_plan: str + sections: list[Section] + # 维护一个推理链便于后面模型思考怎么做 + thinking_process: Annotated[list[str], operator.add] + completed_sections: Annotated[list, operator.add] + current_section_index: int # 目前正在研究的章节序号 + final_report: str # Final report + search_queries: Annotated[Dict[str, List[SearchQuery]], operator.or_] + # 键是来源,值是title,url等为键的字典 + sources: Annotated[Dict[str, Dict[str, str]], operator.or_] + source_str: str