上传文件至 /
This commit is contained in:
516
app.py
Normal file
516
app.py
Normal file
@@ -0,0 +1,516 @@
|
||||
import streamlit as st
|
||||
import asyncio
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from state import ReportState, Section
|
||||
from configuration import SearchAPI
|
||||
from nodes import (
|
||||
generate_report_plan,
|
||||
think_and_generate_queries,
|
||||
search_web,
|
||||
write_section,
|
||||
compile_final_report,
|
||||
process_citations,
|
||||
save_final_report
|
||||
)
|
||||
|
||||
# 从 config.json 文件读取配置
|
||||
import json
|
||||
|
||||
with open("config.json", "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
# 设置环境变量 - API密钥
|
||||
os.environ["DEEPSEEK_API_KEY"] = config_data["environment_variables"]["DEEPSEEK_API_KEY"]
|
||||
os.environ["TAVILY_API_KEY"] = config_data["environment_variables"]["TAVILY_API_KEY"]
|
||||
|
||||
# 使用字典作为配置
|
||||
config_dict = config_data["config_dict"]
|
||||
# 替换配置中的环境变量占位符
|
||||
if "search_api_config" in config_dict["configurable"] and "api_key" in config_dict["configurable"]["search_api_config"]:
|
||||
config_dict["configurable"]["search_api_config"]["api_key"] = os.environ.get("TAVILY_API_KEY", "")
|
||||
|
||||
# --- Streamlit App ---
|
||||
|
||||
st.set_page_config(page_title="DeepResearch Stream", layout="wide", page_icon="🔬")
|
||||
|
||||
# 简化CSS样式,移除可能导致空白的margin
|
||||
st.markdown("""
|
||||
<style>
|
||||
/* 移除默认的顶部margin */
|
||||
.block-container {
|
||||
padding-top: 1rem;
|
||||
}
|
||||
|
||||
/* 主题颜色 */
|
||||
:root {
|
||||
--primary-color: #1f77b4;
|
||||
--secondary-color: #ff7f0e;
|
||||
--success-color: #2ca02c;
|
||||
--warning-color: #d62728;
|
||||
--background-color: #f5f5f5;
|
||||
--card-background: #ffffff;
|
||||
--text-color: #333333;
|
||||
}
|
||||
|
||||
/* 侧边栏样式 */
|
||||
[data-testid="stSidebar"] {
|
||||
background-color: #2c3e50;
|
||||
color: white;
|
||||
}
|
||||
|
||||
/* 卡片样式 */
|
||||
.report-card {
|
||||
border: 1px solid #e0e0e0;
|
||||
border-radius: 10px;
|
||||
padding: 20px;
|
||||
margin-bottom: 20px;
|
||||
background-color: var(--card-background);
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.section-card {
|
||||
border-left: 4px solid var(--primary-color);
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
.process-card {
|
||||
background-color: #e8f4fd;
|
||||
border-radius: 8px;
|
||||
padding: 15px;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
/* 按钮样式 */
|
||||
.stButton>button {
|
||||
border-radius: 8px;
|
||||
border: 1px solid transparent;
|
||||
padding: 8px 16px;
|
||||
font-size: 16px;
|
||||
font-weight: 500;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.stButton>button:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
/* 进度条样式 */
|
||||
.stProgress > div > div {
|
||||
background-color: var(--success-color);
|
||||
}
|
||||
|
||||
/* 标题样式 */
|
||||
h1, h2, h3 {
|
||||
color: var(--primary-color);
|
||||
margin-top: 0.5rem;
|
||||
margin-bottom: 0.4rem;
|
||||
}
|
||||
|
||||
/* 状态信息样式 */
|
||||
.info-message {
|
||||
padding: 10px;
|
||||
border-radius: 5px;
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
.info-message.info {
|
||||
background-color: #d1ecf1;
|
||||
border: 1px solid #bee5eb;
|
||||
color: #0c5460;
|
||||
}
|
||||
|
||||
.info-message.success {
|
||||
background-color: #d4edda;
|
||||
border: 1px solid #c3e6cb;
|
||||
color: #155724;
|
||||
}
|
||||
|
||||
.info-message.warning {
|
||||
background-color: #fff3cd;
|
||||
border: 1px solid #ffeaa7;
|
||||
color: #856404;
|
||||
}
|
||||
|
||||
/* 侧边栏内容样式 */
|
||||
.sidebar-content {
|
||||
color: white;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.sidebar-content h3 {
|
||||
color: #ecf0f1;
|
||||
border-bottom: 1px solid #ecf0f1;
|
||||
padding-bottom: 5px;
|
||||
}
|
||||
|
||||
.sidebar-content ul {
|
||||
padding-left: 20px;
|
||||
}
|
||||
|
||||
.sidebar-content li {
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
/* 历史记录样式 */
|
||||
.history-item {
|
||||
background-color: #f8f9fa;
|
||||
border-radius: 5px;
|
||||
padding: 10px;
|
||||
margin-bottom: 10px;
|
||||
border-left: 3px solid #1f77b4;
|
||||
}
|
||||
|
||||
.history-title {
|
||||
font-weight: bold;
|
||||
color: #1f77b4;
|
||||
}
|
||||
|
||||
.history-time {
|
||||
font-size: 0.8em;
|
||||
color: #6c757d;
|
||||
}
|
||||
</style>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
# --- 历史记录功能 ---
|
||||
def load_history():
|
||||
"""加载历史记录"""
|
||||
reports_dir = os.path.join(os.getcwd(), "reports")
|
||||
os.makedirs(reports_dir, exist_ok=True)
|
||||
history_file = os.path.join(reports_dir, "history.json")
|
||||
if os.path.exists(history_file):
|
||||
with open(history_file, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return []
|
||||
|
||||
def save_history(history):
|
||||
"""保存历史记录"""
|
||||
reports_dir = os.path.join(os.getcwd(), "reports")
|
||||
os.makedirs(reports_dir, exist_ok=True)
|
||||
history_file = os.path.join(reports_dir, "history.json")
|
||||
with open(history_file, "w", encoding="utf-8") as f:
|
||||
json.dump(history, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def add_to_history(topic, file_path):
|
||||
"""添加到历史记录"""
|
||||
history = load_history()
|
||||
history.insert(0, {
|
||||
"topic": topic,
|
||||
"file_path": file_path,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
})
|
||||
# 只保留最近50条记录
|
||||
history = history[:50]
|
||||
save_history(history)
|
||||
|
||||
# 侧边栏内容
|
||||
with st.sidebar:
|
||||
st.markdown('<div class="sidebar-content">', unsafe_allow_html=True)
|
||||
st.title("🔬 AutomaSynth 探策矩阵")
|
||||
st.markdown("### (星航电子工作室)")
|
||||
|
||||
st.markdown("#### 📋 使用说明")
|
||||
st.markdown("""
|
||||
1. 在主界面输入您想要研究的报告主题
|
||||
2. 点击"生成报告大纲"按钮生成初始大纲
|
||||
3. 编辑和确认大纲内容
|
||||
4. 点击"确认修改并生成完整报告"开始生成报告
|
||||
5. 等待报告生成完成,查看最终结果
|
||||
""")
|
||||
|
||||
st.markdown("#### ⚙️ 功能特点")
|
||||
st.markdown("""
|
||||
- AI驱动的大纲生成
|
||||
- 智能网络搜索
|
||||
- 流式内容生成
|
||||
- 实时进度跟踪
|
||||
- 自动引用处理
|
||||
""")
|
||||
|
||||
# 历史记录部分
|
||||
st.markdown("#### 📚 历史记录")
|
||||
history = load_history()
|
||||
if history:
|
||||
for item in history[:10]: # 显示最近10条
|
||||
with st.expander(f"📝 {item['topic']}", expanded=False):
|
||||
st.markdown(f"<div class='history-item'>", unsafe_allow_html=True)
|
||||
st.markdown(f"<div class='history-title'>{item['topic']}</div>", unsafe_allow_html=True)
|
||||
st.markdown(f"<div class='history-time'>{item['timestamp']}</div>", unsafe_allow_html=True)
|
||||
if os.path.exists(item['file_path']):
|
||||
with open(item['file_path'], "r", encoding="utf-8") as f:
|
||||
st.download_button(
|
||||
label="📥 下载报告",
|
||||
data=f.read(),
|
||||
file_name=os.path.basename(item['file_path']),
|
||||
mime="text/markdown",
|
||||
key=f"download_{item['file_path']}"
|
||||
)
|
||||
st.markdown("</div>", unsafe_allow_html=True)
|
||||
else:
|
||||
st.info("暂无历史记录")
|
||||
|
||||
st.markdown("#### 📞 联系我们")
|
||||
st.markdown("如有问题,请联系技术支持。")
|
||||
st.markdown('</div>', unsafe_allow_html=True)
|
||||
|
||||
# 初始化 session state
|
||||
if 'report_state' not in st.session_state:
|
||||
st.session_state.report_state = None
|
||||
if 'generating' not in st.session_state:
|
||||
st.session_state.generating = False
|
||||
if 'outline_generated' not in st.session_state:
|
||||
st.session_state.outline_generated = False
|
||||
if 'report_generated' not in st.session_state:
|
||||
st.session_state.report_generated = False
|
||||
if 'feedback_given' not in st.session_state:
|
||||
st.session_state.feedback_given = False
|
||||
|
||||
# --- Helper Functions for Streamlit ---
|
||||
|
||||
async def run_generate_report_plan(state, config):
|
||||
return await generate_report_plan(state, config)
|
||||
|
||||
async def run_generate_full_report(state, config):
|
||||
# 逻辑从 main.py 的 generate_report_sections, compile_report, save_report 迁移过来
|
||||
completed_sections = []
|
||||
total_sections = len(state["sections"])
|
||||
|
||||
# 创建一个容器来显示进度和状态
|
||||
progress_container = st.container()
|
||||
|
||||
with progress_container:
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
# 为每个章节的输出创建一个容器
|
||||
section_placeholders = [st.empty() for _ in range(total_sections)]
|
||||
|
||||
while state["current_section_index"] < total_sections:
|
||||
current_index = state["current_section_index"]
|
||||
current_section = state["sections"][current_index]
|
||||
|
||||
# 更新状态信息
|
||||
status_text.markdown(f"<div class='info-message info'>正在生成章节 {current_index + 1}/{total_sections}: {current_section.name}</div>", unsafe_allow_html=True)
|
||||
|
||||
# 获取当前章节的UI容器
|
||||
section_container = section_placeholders[current_index]
|
||||
|
||||
with section_container.container():
|
||||
st.markdown(f"<div class='report-card section-card'><h3>章节 {current_index + 1}: {current_section.name}</h3></div>", unsafe_allow_html=True)
|
||||
|
||||
# 1. 展示思考过程
|
||||
with st.expander("🧠 思考过程", expanded=False):
|
||||
think_result = think_and_generate_queries(state, config)
|
||||
state.update(think_result)
|
||||
st.write(think_result.get("thinking_process", "没有记录思考过程。"))
|
||||
|
||||
# 2. 展示搜索查询
|
||||
with st.expander("🔍 搜索查询", expanded=False):
|
||||
search_queries = think_result.get("search_queries", {}).get(current_section.name, [])
|
||||
if search_queries:
|
||||
for query in search_queries:
|
||||
st.code(query.search_query, language="text")
|
||||
else:
|
||||
st.write("没有为本章节生成搜索查询。")
|
||||
|
||||
search_result = await search_web(state, config)
|
||||
state.update(search_result)
|
||||
|
||||
# 3. 流式展示生成的章节内容
|
||||
content_placeholder = st.empty()
|
||||
full_content_for_section = ""
|
||||
|
||||
async def stream_callback_for_section(chunk):
|
||||
nonlocal full_content_for_section
|
||||
full_content_for_section += chunk
|
||||
content_placeholder.markdown(f"<div class='process-card'>{full_content_for_section} ▌</div>", unsafe_allow_html=True)
|
||||
|
||||
write_result = await write_section(state, config, stream_callback=stream_callback_for_section)
|
||||
|
||||
# 写入完成后,移除光标
|
||||
content_placeholder.markdown(f"<div class='process-card'>{full_content_for_section}</div>", unsafe_allow_html=True)
|
||||
|
||||
if isinstance(write_result, dict):
|
||||
state.update(write_result)
|
||||
if 'completed_sections' in write_result:
|
||||
completed_sections.extend(write_result['completed_sections'])
|
||||
state["current_section_index"] += 1
|
||||
elif hasattr(write_result, 'update'):
|
||||
update_data = write_result.update
|
||||
state.update(update_data)
|
||||
if 'completed_sections' in update_data:
|
||||
completed_sections.extend(update_data['completed_sections'])
|
||||
|
||||
if hasattr(write_result, 'goto') and write_result.goto == "compile_final_report":
|
||||
break
|
||||
|
||||
progress = (current_index + 1) / total_sections
|
||||
progress_bar.progress(progress)
|
||||
|
||||
state["completed_sections"] = completed_sections
|
||||
status_text.markdown("<div class='info-message info'>所有章节完成,正在编译最终报告...</div>", unsafe_allow_html=True)
|
||||
|
||||
compile_result = await compile_final_report(state)
|
||||
state.update(compile_result)
|
||||
|
||||
citations_result = process_citations(state)
|
||||
state.update(citations_result)
|
||||
|
||||
status_text.markdown("<div class='info-message success'>✅ 报告生成完毕!</div>", unsafe_allow_html=True)
|
||||
|
||||
# 保存文件 - 使用相对路径
|
||||
topic_slug = state["topic"].lower().replace(" ", "_")[:30]
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
reports_dir = os.path.join(os.getcwd(), "reports")
|
||||
os.makedirs(reports_dir, exist_ok=True)
|
||||
output_file = os.path.join(reports_dir, f"report_{topic_slug}_{timestamp}.md")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write(state["final_report"])
|
||||
|
||||
st.session_state.report_file_path = output_file
|
||||
|
||||
# 添加到历史记录
|
||||
add_to_history(state["topic"], output_file)
|
||||
|
||||
return state
|
||||
|
||||
# --- UI Components ---
|
||||
|
||||
# 直接在主区域添加内容,不使用额外的容器
|
||||
# 主要内容区域
|
||||
# 主页面标题
|
||||
st.markdown("""
|
||||
<div style='
|
||||
text-align: center;
|
||||
margin: 20px 0;
|
||||
font-size: 2.2em;
|
||||
color: #1f77b4;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 12px;
|
||||
font-weight: bold;
|
||||
'>
|
||||
<span>🔬</span>
|
||||
<span>AutomaSynth 探策矩阵</span>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
#st.markdown("<h3 style='text-align: center; color: #555;'>(星航电子工作室)</h3>", unsafe_allow_html=True)
|
||||
|
||||
|
||||
# 1. Topic Input
|
||||
st.markdown("<div class='report-card'>", unsafe_allow_html=True)
|
||||
st.markdown("#### 📝 报告主题")
|
||||
topic = st.text_input("请输入您想要研究的报告主题", value="例如:人工智能在医疗领域的应用",
|
||||
placeholder="例如:人工智能在医疗领域的应用")
|
||||
|
||||
col1, col2 = st.columns([1, 4])
|
||||
with col1:
|
||||
generate_outline_btn = st.button("生成报告大纲", disabled=st.session_state.generating,
|
||||
type="primary", use_container_width=True)
|
||||
st.markdown("</div>", unsafe_allow_html=True)
|
||||
|
||||
if generate_outline_btn:
|
||||
st.session_state.generating = True
|
||||
st.session_state.outline_generated = False
|
||||
st.session_state.report_generated = False
|
||||
st.session_state.feedback_given = False
|
||||
|
||||
initial_state = {
|
||||
"topic": topic,
|
||||
"feedback_on_report_plan": "",
|
||||
"sections": [],
|
||||
"thinking_process": [],
|
||||
"completed_sections": [],
|
||||
"current_section_index": 0,
|
||||
"final_report": "",
|
||||
"search_queries": {},
|
||||
"sources": {},
|
||||
"source_str": ""
|
||||
}
|
||||
|
||||
with st.spinner("正在生成报告大纲..."):
|
||||
try:
|
||||
plan_result = asyncio.run(run_generate_report_plan(initial_state, config_dict))
|
||||
initial_state.update(plan_result)
|
||||
st.session_state.report_state = initial_state
|
||||
st.session_state.outline_generated = True
|
||||
except Exception as e:
|
||||
st.error(f"生成报告大纲时出错:{str(e)}")
|
||||
finally:
|
||||
st.session_state.generating = False
|
||||
st.rerun(scope="app")
|
||||
|
||||
# 2. Outline Display and Feedback
|
||||
if st.session_state.outline_generated and not st.session_state.report_generated:
|
||||
st.markdown("<div class='report-card'>", unsafe_allow_html=True)
|
||||
st.markdown("#### 📋 报告大纲 (可在此直接编辑)")
|
||||
|
||||
with st.form(key='outline_form'):
|
||||
# Add a field for the main report title
|
||||
st.session_state.report_state['topic'] = st.text_input(
|
||||
"报告总标题",
|
||||
value=st.session_state.report_state.get('topic', '人工智能在医疗领域的应用')
|
||||
)
|
||||
|
||||
sections = st.session_state.report_state.get("sections", [])
|
||||
|
||||
# Create UI for each section to be editable
|
||||
for i, section in enumerate(sections):
|
||||
st.markdown(f"<div class='section-card'>", unsafe_allow_html=True)
|
||||
st.markdown(f"##### 章节 {i+1}")
|
||||
# Use unique keys for each widget to avoid Streamlit's DuplicateWidgetID error
|
||||
section.name = st.text_input("章节名称", value=section.name, key=f"name_{i}")
|
||||
section.description = st.text_area("章节描述", value=section.description, key=f"desc_{i}", height=100)
|
||||
st.markdown("</div>", unsafe_allow_html=True)
|
||||
|
||||
col1, col2 = st.columns([1, 4])
|
||||
with col1:
|
||||
submitted = st.form_submit_button("确认修改并生成完整报告",
|
||||
disabled=st.session_state.generating,
|
||||
type="primary", use_container_width=True)
|
||||
st.markdown("</div>", unsafe_allow_html=True)
|
||||
|
||||
if submitted:
|
||||
# User has confirmed the (potentially edited) outline, proceed to generate full report
|
||||
st.session_state.generating = True
|
||||
st.session_state.feedback_given = True # Prevent resubmission
|
||||
|
||||
with st.spinner("正在生成完整报告..."):
|
||||
try:
|
||||
final_state = asyncio.run(run_generate_full_report(st.session_state.report_state, config_dict))
|
||||
st.session_state.report_state = final_state
|
||||
st.session_state.report_generated = True
|
||||
except Exception as e:
|
||||
st.error(f"生成报告时出错:{str(e)}")
|
||||
finally:
|
||||
st.session_state.generating = False
|
||||
st.rerun(scope="app")
|
||||
|
||||
# 3. Final Report Display
|
||||
if st.session_state.report_generated:
|
||||
st.markdown("<div class='report-card'>", unsafe_allow_html=True)
|
||||
st.markdown("#### 📄 最终报告")
|
||||
final_report_content = st.session_state.report_state.get('final_report', '报告生成失败。')
|
||||
st.markdown(final_report_content)
|
||||
st.markdown("</div>", unsafe_allow_html=True)
|
||||
|
||||
report_path = st.session_state.get("report_file_path", "")
|
||||
if report_path and os.path.exists(report_path):
|
||||
with open(report_path, "r", encoding="utf-8") as f:
|
||||
st.download_button(
|
||||
label="📥 下载报告 (Markdown)",
|
||||
data=f.read(),
|
||||
file_name=os.path.basename(report_path),
|
||||
mime="text/markdown",
|
||||
type="primary"
|
||||
)
|
||||
|
||||
# 底部按钮
|
||||
st.markdown("---")
|
||||
if st.button("🔄 开始新的报告", type="secondary"):
|
||||
st.session_state.clear()
|
||||
st.rerun(scope="app")
|
||||
300
nodes.py
Normal file
300
nodes.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from typing import Literal, List
|
||||
import aiofiles
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from datetime import datetime
|
||||
from langgraph.types import interrupt, Command
|
||||
import os
|
||||
import json
|
||||
from state import (
|
||||
Sections,
|
||||
ReportState,
|
||||
Queries,
|
||||
SearchQuery,
|
||||
Section
|
||||
)
|
||||
|
||||
from prompts import (
|
||||
report_planner_query_writer_instructions,
|
||||
report_planner_instructions,
|
||||
query_writer_instructions,
|
||||
section_writer_instructions,
|
||||
section_writer_system_prompt
|
||||
)
|
||||
|
||||
from configuration import Configuration
|
||||
from utils import (
|
||||
compile_completed_sections,
|
||||
get_config_value,
|
||||
get_search_params,
|
||||
select_and_execute_search,
|
||||
format_sources
|
||||
)
|
||||
|
||||
async def generate_report_plan(state: ReportState, config: RunnableConfig):
|
||||
"""用于生成报告大纲,同时进行网络搜索帮助自己更好地规划大纲内容。
|
||||
会自动中断等待人类反馈,若不通过则根据反馈重新生成大纲。直到通过并跳转到章节生成部分
|
||||
"""
|
||||
|
||||
topic = state["topic"]
|
||||
feedback = state.get("feedback_on_report_plan", None)
|
||||
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
number_of_queries = configurable.number_of_queries
|
||||
search_api = get_config_value(configurable.search_api)
|
||||
search_api_config = configurable.search_api_config or {} # Get the config dict, default to empty
|
||||
params_to_pass = get_search_params(search_api, search_api_config) # Filter parameters
|
||||
writer_provider = get_config_value(configurable.writer_provider)
|
||||
writer_model_name = get_config_value(configurable.writer_model)
|
||||
|
||||
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
|
||||
structured_llm = writer_model.with_structured_output(Queries)
|
||||
|
||||
# 根据主题先生成查询Prompt,让模型生成一系列查询等待API搜索,从而获取信息
|
||||
query_prompt = report_planner_query_writer_instructions.format(
|
||||
topic=topic, number_of_queries=number_of_queries)
|
||||
|
||||
results = structured_llm.invoke(query_prompt)
|
||||
|
||||
query_list: list[str] = [query.search_query for query in results.queries]
|
||||
|
||||
# TODO:这里和search_web重复了,需要做一个封装节省代码
|
||||
unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass)
|
||||
source_str = format_sources(unique_sources)
|
||||
|
||||
# 将搜索到的内容作为上下文提供给了模型
|
||||
sections_prompt = report_planner_instructions.format(
|
||||
topic=topic, context=source_str, feedback=feedback)
|
||||
|
||||
planner_provider = get_config_value(configurable.planner_provider)
|
||||
planner_model = get_config_value(configurable.planner_model)
|
||||
planner_llm = init_chat_model(model=planner_model,
|
||||
model_provider=planner_provider)
|
||||
|
||||
structured_llm = planner_llm.with_structured_output(Sections)
|
||||
report_sections = structured_llm.invoke(sections_prompt)
|
||||
|
||||
# 获取写好的sections
|
||||
sections: list[Section] = report_sections.sections
|
||||
return {"sections": sections, "current_section_index": 0}
|
||||
|
||||
|
||||
|
||||
def human_feedback(state: ReportState, config: RunnableConfig) -> Command[Literal["generate_report_plan", "think_and_generate_queries"]]:
|
||||
"""
|
||||
获取人类反馈来修改大纲,若通过则进入到ReAct部分。
|
||||
"""
|
||||
|
||||
sections = state['sections']
|
||||
sections_str = "\n\n".join(
|
||||
f"{number}. {section.name}\n"
|
||||
f"{section.description}\n"
|
||||
f"**是否需要搜索研究: {'是' if section.research else '否'}**\n"
|
||||
for number, section in enumerate(sections, 1)
|
||||
)
|
||||
|
||||
interrupt_message = f"""请为这份大纲提供修改意见:
|
||||
{sections_str}
|
||||
这份大纲符合您的需求吗?,输入“true”以通过大纲,或者提供修改意见来修改大纲。"""
|
||||
|
||||
feedback = interrupt(interrupt_message)
|
||||
|
||||
if isinstance(feedback, bool) and feedback is True:
|
||||
return Command(goto="think_and_generate_queries")
|
||||
elif isinstance(feedback, str):
|
||||
return Command(goto="generate_report_plan",
|
||||
update={"feedback_on_report_plan": feedback})
|
||||
else:
|
||||
raise TypeError(f"Interrupt value of type {type(feedback)} is not supported.")
|
||||
|
||||
|
||||
|
||||
|
||||
def think_and_generate_queries(state: ReportState, config: RunnableConfig):
|
||||
"""Think部分
|
||||
思考目前内容,生成查询语句,同时记录思考过程
|
||||
"""
|
||||
|
||||
current_section_index = state["current_section_index"]
|
||||
section = state["sections"][current_section_index]
|
||||
section_name = section.name
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
number_of_queries = configurable.number_of_queries
|
||||
|
||||
writer_provider = get_config_value(configurable.writer_provider)
|
||||
writer_model_name = get_config_value(configurable.writer_model)
|
||||
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
|
||||
structured_llm = writer_model.with_structured_output(Queries)
|
||||
|
||||
system_instructions = query_writer_instructions.format(section=section_name,
|
||||
section_description=section.description,
|
||||
number_of_queries=number_of_queries)
|
||||
|
||||
queries: Queries = structured_llm.invoke(system_instructions)
|
||||
search_queries: List[SearchQuery] = queries.queries
|
||||
return {"search_queries": {section_name: search_queries},
|
||||
"thinking_process": [queries.thought],
|
||||
}
|
||||
|
||||
|
||||
async def search_web(state: ReportState, config: RunnableConfig):
|
||||
"""Acting部分,进行搜索并返回搜索内容
|
||||
"""
|
||||
current_section_index = state["current_section_index"]
|
||||
section = state["sections"][current_section_index]
|
||||
section_name = section.name
|
||||
search_queries = state["search_queries"][section_name]
|
||||
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
search_api = get_config_value(configurable.search_api)
|
||||
search_api_config = configurable.search_api_config or {}
|
||||
params_to_pass = get_search_params(search_api, search_api_config)
|
||||
query_list = [query.search_query for query in search_queries]
|
||||
print(query_list)
|
||||
|
||||
unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass)
|
||||
|
||||
current_sources = {
|
||||
url: {
|
||||
"title": source['title'],
|
||||
"content": source["content"]
|
||||
} for url, source in unique_sources.items()
|
||||
}
|
||||
|
||||
all_sources = state.get("sources", {}).copy()
|
||||
all_sources.update(current_sources)
|
||||
|
||||
source_str = format_sources(unique_sources)
|
||||
|
||||
return {"source_str": source_str, "sources": all_sources}
|
||||
|
||||
|
||||
|
||||
|
||||
async def write_section(state: ReportState, config: RunnableConfig, stream_callback=None) -> Command[Literal["think_and_generate_queries", "compile_final_report"]]:
|
||||
"""完成一个章节的内容
|
||||
"""
|
||||
|
||||
current_section_index = state["current_section_index"]
|
||||
section = state["sections"][current_section_index]
|
||||
section_name = section.name
|
||||
source_str = state["source_str"]
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
section_content = compile_completed_sections(state)
|
||||
section_writer_inputs = section_writer_instructions.format(
|
||||
section_name=section_name,
|
||||
section_description=section.description,
|
||||
context=source_str,
|
||||
section_content=section_content
|
||||
)
|
||||
|
||||
# Generate section
|
||||
writer_provider = get_config_value(configurable.writer_provider)
|
||||
writer_model_name = get_config_value(configurable.writer_model)
|
||||
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
|
||||
prompt = [SystemMessage(content=section_writer_system_prompt),
|
||||
HumanMessage(content=section_writer_inputs)]
|
||||
|
||||
writer_result = writer_model.astream(prompt)
|
||||
content_parts = []
|
||||
|
||||
async for chunk in writer_result:
|
||||
content_parts.append(chunk.content)
|
||||
if stream_callback:
|
||||
await stream_callback(chunk.content)
|
||||
else:
|
||||
print(chunk.content, end='', flush=True)
|
||||
|
||||
section.content = ''.join(content_parts)
|
||||
|
||||
|
||||
|
||||
if current_section_index == len(state["sections"]) - 1:
|
||||
return Command(
|
||||
update={"completed_sections": [section]},
|
||||
goto="compile_final_report"
|
||||
)
|
||||
else:
|
||||
return Command(
|
||||
update={"completed_sections": [section], "current_section_index": current_section_index + 1},
|
||||
goto="think_and_generate_queries"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
async def compile_final_report(state: ReportState):
|
||||
topic = state.get("topic", "未命名报告")
|
||||
sections = state["completed_sections"]
|
||||
final_report = f"# {topic}\n\n" + "\n\n".join([section.content for section in sections])
|
||||
return {"final_report": final_report}
|
||||
|
||||
|
||||
async def save_final_report(state: ReportState):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
# 定义报告保存的基础目录
|
||||
base_reports_dir = r"C:\Users\24019\Desktop\Deepresearch-stream\reports"
|
||||
output_dir = os.path.join(base_reports_dir, f"report_{timestamp}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
async def write_file(filename: str, content: str):
|
||||
async with aiofiles.open(os.path.join(output_dir, filename), 'w', encoding='utf-8') as f:
|
||||
await f.write(content)
|
||||
|
||||
await write_file("report.md", state["final_report"])
|
||||
await write_file("sources.json", json.dumps(state["sources"], indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def process_citations(state: ReportState):
|
||||
"""处理引用
|
||||
替换正文中的 [url:<link>] → [1](只保留首次引用)...
|
||||
多次引用同一链接,后续位置直接删除,不重复编号
|
||||
若某个 URL 不在 sources 中,则忽略
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
import re
|
||||
|
||||
url_to_index = OrderedDict()
|
||||
pattern = r"\[url:(https?://[^\]]+)\]"
|
||||
matches = list(re.finditer(pattern, state["final_report"]))
|
||||
|
||||
url_title_map = state["sources"]
|
||||
index = 1
|
||||
|
||||
try:
|
||||
for match in matches:
|
||||
url = match.group(1)
|
||||
if url not in url_title_map:
|
||||
continue
|
||||
if url not in url_to_index:
|
||||
url_to_index[url] = index
|
||||
index += 1
|
||||
|
||||
# 标记哪些URL已经在正文中被替换过,后面都删掉
|
||||
replaced_urls = set()
|
||||
|
||||
def replacer(match):
|
||||
url = match.group(1)
|
||||
if url not in url_title_map:
|
||||
return "" # 无效的URL,删掉
|
||||
if url in url_to_index and url not in replaced_urls:
|
||||
replaced_urls.add(url)
|
||||
return f"[{url_to_index[url]}]" # 首次替换
|
||||
else:
|
||||
return ""
|
||||
|
||||
processed_text = re.sub(pattern, replacer, state["final_report"])
|
||||
|
||||
citation_lines = []
|
||||
for url, idx in url_to_index.items():
|
||||
title = url_title_map[url]["title"]
|
||||
citation_lines.append(f"[{idx}] [{title}]({url})")
|
||||
|
||||
citation_list = "\n".join(citation_lines)
|
||||
final_report = processed_text + "\n\n## 参考链接:\n" + citation_list
|
||||
except Exception as e:
|
||||
final_report = state["final_report"]
|
||||
|
||||
return {"final_report": final_report}
|
||||
64
state.py
Normal file
64
state.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Annotated, List, TypedDict, Literal, Dict
|
||||
from pydantic import BaseModel, Field
|
||||
import operator
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
name: str = Field(
|
||||
description="章节的标题,应简洁明了地概括本章节的主题或主旨。",
|
||||
)
|
||||
description: str = Field(
|
||||
description="简要说明这一章节将围绕哪些内容进行展开,字数在50-100字。描述应突出章节的写作重点,明确要讨论的问题、角度或结构。",
|
||||
)
|
||||
research: bool = Field(
|
||||
description="判断该章节的写作是否需要进行网络搜索。若该部分需要引用外部资料、数据、案例或是知识储备不足以独立完成,则应标记为需要研究。"
|
||||
)
|
||||
content: str = Field(
|
||||
description="章节的主要写作内容。暂时留空,后续将用于填写具体文本内容"
|
||||
)
|
||||
|
||||
|
||||
class Sections(BaseModel):
|
||||
sections: List[Section] = Field(
|
||||
description="包含报告的各个章节",
|
||||
)
|
||||
|
||||
|
||||
class SearchQuery(BaseModel):
|
||||
search_query: str = Field(None, description="网络搜索的查询")
|
||||
|
||||
|
||||
class Queries(BaseModel):
|
||||
thought: str = Field(None, description="思考过程")
|
||||
queries: List[SearchQuery] = Field(
|
||||
description="包含网络搜索的查询列表",
|
||||
)
|
||||
|
||||
|
||||
'''作为整张图的状态输入输出,输入是报告主题,输出是最终报告'''
|
||||
|
||||
|
||||
class ReportStateInput(TypedDict):
|
||||
topic: str # 报告主题
|
||||
|
||||
|
||||
class ReportStateOutput(TypedDict):
|
||||
final_report: str # 最终报告
|
||||
|
||||
|
||||
'''管理报告的状态,有报告的各个章节'''
|
||||
|
||||
|
||||
class ReportState(TypedDict):
|
||||
topic: str
|
||||
feedback_on_report_plan: str
|
||||
sections: list[Section]
|
||||
# 维护一个推理链便于后面模型思考怎么做
|
||||
thinking_process: Annotated[list[str], operator.add]
|
||||
completed_sections: Annotated[list, operator.add]
|
||||
current_section_index: int # 目前正在研究的章节序号
|
||||
final_report: str # Final report
|
||||
search_queries: Annotated[Dict[str, List[SearchQuery]], operator.or_]
|
||||
# 键是来源,值是title,url等为键的字典
|
||||
sources: Annotated[Dict[str, Dict[str, str]], operator.or_]
|
||||
source_str: str
|
||||
Reference in New Issue
Block a user