Files
Deepresearch/app.py
2025-08-20 20:12:32 +08:00

516 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import streamlit as st
import asyncio
import os
import json
import time
from state import ReportState, Section
from configuration import SearchAPI
from nodes import (
generate_report_plan,
think_and_generate_queries,
search_web,
write_section,
compile_final_report,
process_citations,
save_final_report
)
# 从 config.json 文件读取配置
import json
with open("config.json", "r", encoding="utf-8") as f:
config_data = json.load(f)
# 设置环境变量 - API密钥
os.environ["DEEPSEEK_API_KEY"] = config_data["environment_variables"]["DEEPSEEK_API_KEY"]
os.environ["TAVILY_API_KEY"] = config_data["environment_variables"]["TAVILY_API_KEY"]
# 使用字典作为配置
config_dict = config_data["config_dict"]
# 替换配置中的环境变量占位符
if "search_api_config" in config_dict["configurable"] and "api_key" in config_dict["configurable"]["search_api_config"]:
config_dict["configurable"]["search_api_config"]["api_key"] = os.environ.get("TAVILY_API_KEY", "")
# --- Streamlit App ---
st.set_page_config(page_title="DeepResearch Stream", layout="wide", page_icon="🔬")
# 简化CSS样式移除可能导致空白的margin
st.markdown("""
<style>
/* 移除默认的顶部margin */
.block-container {
padding-top: 1rem;
}
/* 主题颜色 */
:root {
--primary-color: #1f77b4;
--secondary-color: #ff7f0e;
--success-color: #2ca02c;
--warning-color: #d62728;
--background-color: #f5f5f5;
--card-background: #ffffff;
--text-color: #333333;
}
/* 侧边栏样式 */
[data-testid="stSidebar"] {
background-color: #2c3e50;
color: white;
}
/* 卡片样式 */
.report-card {
border: 1px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
margin-bottom: 20px;
background-color: var(--card-background);
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.section-card {
border-left: 4px solid var(--primary-color);
margin-bottom: 15px;
}
.process-card {
background-color: #e8f4fd;
border-radius: 8px;
padding: 15px;
margin-bottom: 15px;
}
/* 按钮样式 */
.stButton>button {
border-radius: 8px;
border: 1px solid transparent;
padding: 8px 16px;
font-size: 16px;
font-weight: 500;
transition: all 0.3s ease;
}
.stButton>button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
/* 进度条样式 */
.stProgress > div > div {
background-color: var(--success-color);
}
/* 标题样式 */
h1, h2, h3 {
color: var(--primary-color);
margin-top: 0.5rem;
margin-bottom: 0.4rem;
}
/* 状态信息样式 */
.info-message {
padding: 10px;
border-radius: 5px;
margin: 10px 0;
}
.info-message.info {
background-color: #d1ecf1;
border: 1px solid #bee5eb;
color: #0c5460;
}
.info-message.success {
background-color: #d4edda;
border: 1px solid #c3e6cb;
color: #155724;
}
.info-message.warning {
background-color: #fff3cd;
border: 1px solid #ffeaa7;
color: #856404;
}
/* 侧边栏内容样式 */
.sidebar-content {
color: white;
padding: 10px;
}
.sidebar-content h3 {
color: #ecf0f1;
border-bottom: 1px solid #ecf0f1;
padding-bottom: 5px;
}
.sidebar-content ul {
padding-left: 20px;
}
.sidebar-content li {
margin-bottom: 10px;
}
/* 历史记录样式 */
.history-item {
background-color: #f8f9fa;
border-radius: 5px;
padding: 10px;
margin-bottom: 10px;
border-left: 3px solid #1f77b4;
}
.history-title {
font-weight: bold;
color: #1f77b4;
}
.history-time {
font-size: 0.8em;
color: #6c757d;
}
</style>
""", unsafe_allow_html=True)
# --- 历史记录功能 ---
def load_history():
"""加载历史记录"""
reports_dir = os.path.join(os.getcwd(), "reports")
os.makedirs(reports_dir, exist_ok=True)
history_file = os.path.join(reports_dir, "history.json")
if os.path.exists(history_file):
with open(history_file, "r", encoding="utf-8") as f:
return json.load(f)
return []
def save_history(history):
"""保存历史记录"""
reports_dir = os.path.join(os.getcwd(), "reports")
os.makedirs(reports_dir, exist_ok=True)
history_file = os.path.join(reports_dir, "history.json")
with open(history_file, "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
def add_to_history(topic, file_path):
"""添加到历史记录"""
history = load_history()
history.insert(0, {
"topic": topic,
"file_path": file_path,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
})
# 只保留最近50条记录
history = history[:50]
save_history(history)
# 侧边栏内容
with st.sidebar:
st.markdown('<div class="sidebar-content">', unsafe_allow_html=True)
st.title("🔬 AutomaSynth 探策矩阵")
st.markdown("### (星航电子工作室)")
st.markdown("#### 📋 使用说明")
st.markdown("""
1. 在主界面输入您想要研究的报告主题
2. 点击"生成报告大纲"按钮生成初始大纲
3. 编辑和确认大纲内容
4. 点击"确认修改并生成完整报告"开始生成报告
5. 等待报告生成完成,查看最终结果
""")
st.markdown("#### ⚙️ 功能特点")
st.markdown("""
- AI驱动的大纲生成
- 智能网络搜索
- 流式内容生成
- 实时进度跟踪
- 自动引用处理
""")
# 历史记录部分
st.markdown("#### 📚 历史记录")
history = load_history()
if history:
for item in history[:10]: # 显示最近10条
with st.expander(f"📝 {item['topic']}", expanded=False):
st.markdown(f"<div class='history-item'>", unsafe_allow_html=True)
st.markdown(f"<div class='history-title'>{item['topic']}</div>", unsafe_allow_html=True)
st.markdown(f"<div class='history-time'>{item['timestamp']}</div>", unsafe_allow_html=True)
if os.path.exists(item['file_path']):
with open(item['file_path'], "r", encoding="utf-8") as f:
st.download_button(
label="📥 下载报告",
data=f.read(),
file_name=os.path.basename(item['file_path']),
mime="text/markdown",
key=f"download_{item['file_path']}"
)
st.markdown("</div>", unsafe_allow_html=True)
else:
st.info("暂无历史记录")
st.markdown("#### 📞 联系我们")
st.markdown("如有问题,请联系技术支持。")
st.markdown('</div>', unsafe_allow_html=True)
# 初始化 session state
if 'report_state' not in st.session_state:
st.session_state.report_state = None
if 'generating' not in st.session_state:
st.session_state.generating = False
if 'outline_generated' not in st.session_state:
st.session_state.outline_generated = False
if 'report_generated' not in st.session_state:
st.session_state.report_generated = False
if 'feedback_given' not in st.session_state:
st.session_state.feedback_given = False
# --- Helper Functions for Streamlit ---
async def run_generate_report_plan(state, config):
return await generate_report_plan(state, config)
async def run_generate_full_report(state, config):
# 逻辑从 main.py 的 generate_report_sections, compile_report, save_report 迁移过来
completed_sections = []
total_sections = len(state["sections"])
# 创建一个容器来显示进度和状态
progress_container = st.container()
with progress_container:
progress_bar = st.progress(0)
status_text = st.empty()
# 为每个章节的输出创建一个容器
section_placeholders = [st.empty() for _ in range(total_sections)]
while state["current_section_index"] < total_sections:
current_index = state["current_section_index"]
current_section = state["sections"][current_index]
# 更新状态信息
status_text.markdown(f"<div class='info-message info'>正在生成章节 {current_index + 1}/{total_sections}: {current_section.name}</div>", unsafe_allow_html=True)
# 获取当前章节的UI容器
section_container = section_placeholders[current_index]
with section_container.container():
st.markdown(f"<div class='report-card section-card'><h3>章节 {current_index + 1}: {current_section.name}</h3></div>", unsafe_allow_html=True)
# 1. 展示思考过程
with st.expander("🧠 思考过程", expanded=False):
think_result = think_and_generate_queries(state, config)
state.update(think_result)
st.write(think_result.get("thinking_process", "没有记录思考过程。"))
# 2. 展示搜索查询
with st.expander("🔍 搜索查询", expanded=False):
search_queries = think_result.get("search_queries", {}).get(current_section.name, [])
if search_queries:
for query in search_queries:
st.code(query.search_query, language="text")
else:
st.write("没有为本章节生成搜索查询。")
search_result = await search_web(state, config)
state.update(search_result)
# 3. 流式展示生成的章节内容
content_placeholder = st.empty()
full_content_for_section = ""
async def stream_callback_for_section(chunk):
nonlocal full_content_for_section
full_content_for_section += chunk
content_placeholder.markdown(f"<div class='process-card'>{full_content_for_section} ▌</div>", unsafe_allow_html=True)
write_result = await write_section(state, config, stream_callback=stream_callback_for_section)
# 写入完成后,移除光标
content_placeholder.markdown(f"<div class='process-card'>{full_content_for_section}</div>", unsafe_allow_html=True)
if isinstance(write_result, dict):
state.update(write_result)
if 'completed_sections' in write_result:
completed_sections.extend(write_result['completed_sections'])
state["current_section_index"] += 1
elif hasattr(write_result, 'update'):
update_data = write_result.update
state.update(update_data)
if 'completed_sections' in update_data:
completed_sections.extend(update_data['completed_sections'])
if hasattr(write_result, 'goto') and write_result.goto == "compile_final_report":
break
progress = (current_index + 1) / total_sections
progress_bar.progress(progress)
state["completed_sections"] = completed_sections
status_text.markdown("<div class='info-message info'>所有章节完成,正在编译最终报告...</div>", unsafe_allow_html=True)
compile_result = await compile_final_report(state)
state.update(compile_result)
citations_result = process_citations(state)
state.update(citations_result)
status_text.markdown("<div class='info-message success'>✅ 报告生成完毕!</div>", unsafe_allow_html=True)
# 保存文件 - 使用相对路径
topic_slug = state["topic"].lower().replace(" ", "_")[:30]
timestamp = time.strftime("%Y%m%d_%H%M%S")
reports_dir = os.path.join(os.getcwd(), "reports")
os.makedirs(reports_dir, exist_ok=True)
output_file = os.path.join(reports_dir, f"report_{topic_slug}_{timestamp}.md")
with open(output_file, "w", encoding="utf-8") as f:
f.write(state["final_report"])
st.session_state.report_file_path = output_file
# 添加到历史记录
add_to_history(state["topic"], output_file)
return state
# --- UI Components ---
# 直接在主区域添加内容,不使用额外的容器
# 主要内容区域
# 主页面标题
st.markdown("""
<div style='
text-align: center;
margin: 20px 0;
font-size: 2.2em;
color: #1f77b4;
display: flex;
align-items: center;
justify-content: center;
gap: 12px;
font-weight: bold;
'>
<span>🔬</span>
<span>AutomaSynth 探策矩阵</span>
</div>
""", unsafe_allow_html=True)
#st.markdown("<h3 style='text-align: center; color: #555;'>(星航电子工作室)</h3>", unsafe_allow_html=True)
# 1. Topic Input
st.markdown("<div class='report-card'>", unsafe_allow_html=True)
st.markdown("#### 📝 报告主题")
topic = st.text_input("请输入您想要研究的报告主题", value="例如:人工智能在医疗领域的应用",
placeholder="例如:人工智能在医疗领域的应用")
col1, col2 = st.columns([1, 4])
with col1:
generate_outline_btn = st.button("生成报告大纲", disabled=st.session_state.generating,
type="primary", use_container_width=True)
st.markdown("</div>", unsafe_allow_html=True)
if generate_outline_btn:
st.session_state.generating = True
st.session_state.outline_generated = False
st.session_state.report_generated = False
st.session_state.feedback_given = False
initial_state = {
"topic": topic,
"feedback_on_report_plan": "",
"sections": [],
"thinking_process": [],
"completed_sections": [],
"current_section_index": 0,
"final_report": "",
"search_queries": {},
"sources": {},
"source_str": ""
}
with st.spinner("正在生成报告大纲..."):
try:
plan_result = asyncio.run(run_generate_report_plan(initial_state, config_dict))
initial_state.update(plan_result)
st.session_state.report_state = initial_state
st.session_state.outline_generated = True
except Exception as e:
st.error(f"生成报告大纲时出错:{str(e)}")
finally:
st.session_state.generating = False
st.rerun(scope="app")
# 2. Outline Display and Feedback
if st.session_state.outline_generated and not st.session_state.report_generated:
st.markdown("<div class='report-card'>", unsafe_allow_html=True)
st.markdown("#### 📋 报告大纲 (可在此直接编辑)")
with st.form(key='outline_form'):
# Add a field for the main report title
st.session_state.report_state['topic'] = st.text_input(
"报告总标题",
value=st.session_state.report_state.get('topic', '人工智能在医疗领域的应用')
)
sections = st.session_state.report_state.get("sections", [])
# Create UI for each section to be editable
for i, section in enumerate(sections):
st.markdown(f"<div class='section-card'>", unsafe_allow_html=True)
st.markdown(f"##### 章节 {i+1}")
# Use unique keys for each widget to avoid Streamlit's DuplicateWidgetID error
section.name = st.text_input("章节名称", value=section.name, key=f"name_{i}")
section.description = st.text_area("章节描述", value=section.description, key=f"desc_{i}", height=100)
st.markdown("</div>", unsafe_allow_html=True)
col1, col2 = st.columns([1, 4])
with col1:
submitted = st.form_submit_button("确认修改并生成完整报告",
disabled=st.session_state.generating,
type="primary", use_container_width=True)
st.markdown("</div>", unsafe_allow_html=True)
if submitted:
# User has confirmed the (potentially edited) outline, proceed to generate full report
st.session_state.generating = True
st.session_state.feedback_given = True # Prevent resubmission
with st.spinner("正在生成完整报告..."):
try:
final_state = asyncio.run(run_generate_full_report(st.session_state.report_state, config_dict))
st.session_state.report_state = final_state
st.session_state.report_generated = True
except Exception as e:
st.error(f"生成报告时出错:{str(e)}")
finally:
st.session_state.generating = False
st.rerun(scope="app")
# 3. Final Report Display
if st.session_state.report_generated:
st.markdown("<div class='report-card'>", unsafe_allow_html=True)
st.markdown("#### 📄 最终报告")
final_report_content = st.session_state.report_state.get('final_report', '报告生成失败。')
st.markdown(final_report_content)
st.markdown("</div>", unsafe_allow_html=True)
report_path = st.session_state.get("report_file_path", "")
if report_path and os.path.exists(report_path):
with open(report_path, "r", encoding="utf-8") as f:
st.download_button(
label="📥 下载报告 (Markdown)",
data=f.read(),
file_name=os.path.basename(report_path),
mime="text/markdown",
type="primary"
)
# 底部按钮
st.markdown("---")
if st.button("🔄 开始新的报告", type="secondary"):
st.session_state.clear()
st.rerun(scope="app")