上传文件至 /

This commit is contained in:
2025-08-20 20:12:32 +08:00
parent 159c4efd99
commit 30af25082d
3 changed files with 880 additions and 0 deletions

516
app.py Normal file
View File

@@ -0,0 +1,516 @@
import streamlit as st
import asyncio
import os
import json
import time
from state import ReportState, Section
from configuration import SearchAPI
from nodes import (
generate_report_plan,
think_and_generate_queries,
search_web,
write_section,
compile_final_report,
process_citations,
save_final_report
)
# 从 config.json 文件读取配置
import json
with open("config.json", "r", encoding="utf-8") as f:
config_data = json.load(f)
# 设置环境变量 - API密钥
os.environ["DEEPSEEK_API_KEY"] = config_data["environment_variables"]["DEEPSEEK_API_KEY"]
os.environ["TAVILY_API_KEY"] = config_data["environment_variables"]["TAVILY_API_KEY"]
# 使用字典作为配置
config_dict = config_data["config_dict"]
# 替换配置中的环境变量占位符
if "search_api_config" in config_dict["configurable"] and "api_key" in config_dict["configurable"]["search_api_config"]:
config_dict["configurable"]["search_api_config"]["api_key"] = os.environ.get("TAVILY_API_KEY", "")
# --- Streamlit App ---
st.set_page_config(page_title="DeepResearch Stream", layout="wide", page_icon="🔬")
# 简化CSS样式移除可能导致空白的margin
st.markdown("""
<style>
/* 移除默认的顶部margin */
.block-container {
padding-top: 1rem;
}
/* 主题颜色 */
:root {
--primary-color: #1f77b4;
--secondary-color: #ff7f0e;
--success-color: #2ca02c;
--warning-color: #d62728;
--background-color: #f5f5f5;
--card-background: #ffffff;
--text-color: #333333;
}
/* 侧边栏样式 */
[data-testid="stSidebar"] {
background-color: #2c3e50;
color: white;
}
/* 卡片样式 */
.report-card {
border: 1px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
margin-bottom: 20px;
background-color: var(--card-background);
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.section-card {
border-left: 4px solid var(--primary-color);
margin-bottom: 15px;
}
.process-card {
background-color: #e8f4fd;
border-radius: 8px;
padding: 15px;
margin-bottom: 15px;
}
/* 按钮样式 */
.stButton>button {
border-radius: 8px;
border: 1px solid transparent;
padding: 8px 16px;
font-size: 16px;
font-weight: 500;
transition: all 0.3s ease;
}
.stButton>button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
/* 进度条样式 */
.stProgress > div > div {
background-color: var(--success-color);
}
/* 标题样式 */
h1, h2, h3 {
color: var(--primary-color);
margin-top: 0.5rem;
margin-bottom: 0.4rem;
}
/* 状态信息样式 */
.info-message {
padding: 10px;
border-radius: 5px;
margin: 10px 0;
}
.info-message.info {
background-color: #d1ecf1;
border: 1px solid #bee5eb;
color: #0c5460;
}
.info-message.success {
background-color: #d4edda;
border: 1px solid #c3e6cb;
color: #155724;
}
.info-message.warning {
background-color: #fff3cd;
border: 1px solid #ffeaa7;
color: #856404;
}
/* 侧边栏内容样式 */
.sidebar-content {
color: white;
padding: 10px;
}
.sidebar-content h3 {
color: #ecf0f1;
border-bottom: 1px solid #ecf0f1;
padding-bottom: 5px;
}
.sidebar-content ul {
padding-left: 20px;
}
.sidebar-content li {
margin-bottom: 10px;
}
/* 历史记录样式 */
.history-item {
background-color: #f8f9fa;
border-radius: 5px;
padding: 10px;
margin-bottom: 10px;
border-left: 3px solid #1f77b4;
}
.history-title {
font-weight: bold;
color: #1f77b4;
}
.history-time {
font-size: 0.8em;
color: #6c757d;
}
</style>
""", unsafe_allow_html=True)
# --- 历史记录功能 ---
def load_history():
"""加载历史记录"""
reports_dir = os.path.join(os.getcwd(), "reports")
os.makedirs(reports_dir, exist_ok=True)
history_file = os.path.join(reports_dir, "history.json")
if os.path.exists(history_file):
with open(history_file, "r", encoding="utf-8") as f:
return json.load(f)
return []
def save_history(history):
"""保存历史记录"""
reports_dir = os.path.join(os.getcwd(), "reports")
os.makedirs(reports_dir, exist_ok=True)
history_file = os.path.join(reports_dir, "history.json")
with open(history_file, "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
def add_to_history(topic, file_path):
"""添加到历史记录"""
history = load_history()
history.insert(0, {
"topic": topic,
"file_path": file_path,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
})
# 只保留最近50条记录
history = history[:50]
save_history(history)
# 侧边栏内容
with st.sidebar:
st.markdown('<div class="sidebar-content">', unsafe_allow_html=True)
st.title("🔬 AutomaSynth 探策矩阵")
st.markdown("### (星航电子工作室)")
st.markdown("#### 📋 使用说明")
st.markdown("""
1. 在主界面输入您想要研究的报告主题
2. 点击"生成报告大纲"按钮生成初始大纲
3. 编辑和确认大纲内容
4. 点击"确认修改并生成完整报告"开始生成报告
5. 等待报告生成完成,查看最终结果
""")
st.markdown("#### ⚙️ 功能特点")
st.markdown("""
- AI驱动的大纲生成
- 智能网络搜索
- 流式内容生成
- 实时进度跟踪
- 自动引用处理
""")
# 历史记录部分
st.markdown("#### 📚 历史记录")
history = load_history()
if history:
for item in history[:10]: # 显示最近10条
with st.expander(f"📝 {item['topic']}", expanded=False):
st.markdown(f"<div class='history-item'>", unsafe_allow_html=True)
st.markdown(f"<div class='history-title'>{item['topic']}</div>", unsafe_allow_html=True)
st.markdown(f"<div class='history-time'>{item['timestamp']}</div>", unsafe_allow_html=True)
if os.path.exists(item['file_path']):
with open(item['file_path'], "r", encoding="utf-8") as f:
st.download_button(
label="📥 下载报告",
data=f.read(),
file_name=os.path.basename(item['file_path']),
mime="text/markdown",
key=f"download_{item['file_path']}"
)
st.markdown("</div>", unsafe_allow_html=True)
else:
st.info("暂无历史记录")
st.markdown("#### 📞 联系我们")
st.markdown("如有问题,请联系技术支持。")
st.markdown('</div>', unsafe_allow_html=True)
# 初始化 session state
if 'report_state' not in st.session_state:
st.session_state.report_state = None
if 'generating' not in st.session_state:
st.session_state.generating = False
if 'outline_generated' not in st.session_state:
st.session_state.outline_generated = False
if 'report_generated' not in st.session_state:
st.session_state.report_generated = False
if 'feedback_given' not in st.session_state:
st.session_state.feedback_given = False
# --- Helper Functions for Streamlit ---
async def run_generate_report_plan(state, config):
return await generate_report_plan(state, config)
async def run_generate_full_report(state, config):
# 逻辑从 main.py 的 generate_report_sections, compile_report, save_report 迁移过来
completed_sections = []
total_sections = len(state["sections"])
# 创建一个容器来显示进度和状态
progress_container = st.container()
with progress_container:
progress_bar = st.progress(0)
status_text = st.empty()
# 为每个章节的输出创建一个容器
section_placeholders = [st.empty() for _ in range(total_sections)]
while state["current_section_index"] < total_sections:
current_index = state["current_section_index"]
current_section = state["sections"][current_index]
# 更新状态信息
status_text.markdown(f"<div class='info-message info'>正在生成章节 {current_index + 1}/{total_sections}: {current_section.name}</div>", unsafe_allow_html=True)
# 获取当前章节的UI容器
section_container = section_placeholders[current_index]
with section_container.container():
st.markdown(f"<div class='report-card section-card'><h3>章节 {current_index + 1}: {current_section.name}</h3></div>", unsafe_allow_html=True)
# 1. 展示思考过程
with st.expander("🧠 思考过程", expanded=False):
think_result = think_and_generate_queries(state, config)
state.update(think_result)
st.write(think_result.get("thinking_process", "没有记录思考过程。"))
# 2. 展示搜索查询
with st.expander("🔍 搜索查询", expanded=False):
search_queries = think_result.get("search_queries", {}).get(current_section.name, [])
if search_queries:
for query in search_queries:
st.code(query.search_query, language="text")
else:
st.write("没有为本章节生成搜索查询。")
search_result = await search_web(state, config)
state.update(search_result)
# 3. 流式展示生成的章节内容
content_placeholder = st.empty()
full_content_for_section = ""
async def stream_callback_for_section(chunk):
nonlocal full_content_for_section
full_content_for_section += chunk
content_placeholder.markdown(f"<div class='process-card'>{full_content_for_section} ▌</div>", unsafe_allow_html=True)
write_result = await write_section(state, config, stream_callback=stream_callback_for_section)
# 写入完成后,移除光标
content_placeholder.markdown(f"<div class='process-card'>{full_content_for_section}</div>", unsafe_allow_html=True)
if isinstance(write_result, dict):
state.update(write_result)
if 'completed_sections' in write_result:
completed_sections.extend(write_result['completed_sections'])
state["current_section_index"] += 1
elif hasattr(write_result, 'update'):
update_data = write_result.update
state.update(update_data)
if 'completed_sections' in update_data:
completed_sections.extend(update_data['completed_sections'])
if hasattr(write_result, 'goto') and write_result.goto == "compile_final_report":
break
progress = (current_index + 1) / total_sections
progress_bar.progress(progress)
state["completed_sections"] = completed_sections
status_text.markdown("<div class='info-message info'>所有章节完成,正在编译最终报告...</div>", unsafe_allow_html=True)
compile_result = await compile_final_report(state)
state.update(compile_result)
citations_result = process_citations(state)
state.update(citations_result)
status_text.markdown("<div class='info-message success'>✅ 报告生成完毕!</div>", unsafe_allow_html=True)
# 保存文件 - 使用相对路径
topic_slug = state["topic"].lower().replace(" ", "_")[:30]
timestamp = time.strftime("%Y%m%d_%H%M%S")
reports_dir = os.path.join(os.getcwd(), "reports")
os.makedirs(reports_dir, exist_ok=True)
output_file = os.path.join(reports_dir, f"report_{topic_slug}_{timestamp}.md")
with open(output_file, "w", encoding="utf-8") as f:
f.write(state["final_report"])
st.session_state.report_file_path = output_file
# 添加到历史记录
add_to_history(state["topic"], output_file)
return state
# --- UI Components ---
# 直接在主区域添加内容,不使用额外的容器
# 主要内容区域
# 主页面标题
st.markdown("""
<div style='
text-align: center;
margin: 20px 0;
font-size: 2.2em;
color: #1f77b4;
display: flex;
align-items: center;
justify-content: center;
gap: 12px;
font-weight: bold;
'>
<span>🔬</span>
<span>AutomaSynth 探策矩阵</span>
</div>
""", unsafe_allow_html=True)
#st.markdown("<h3 style='text-align: center; color: #555;'>(星航电子工作室)</h3>", unsafe_allow_html=True)
# 1. Topic Input
st.markdown("<div class='report-card'>", unsafe_allow_html=True)
st.markdown("#### 📝 报告主题")
topic = st.text_input("请输入您想要研究的报告主题", value="例如:人工智能在医疗领域的应用",
placeholder="例如:人工智能在医疗领域的应用")
col1, col2 = st.columns([1, 4])
with col1:
generate_outline_btn = st.button("生成报告大纲", disabled=st.session_state.generating,
type="primary", use_container_width=True)
st.markdown("</div>", unsafe_allow_html=True)
if generate_outline_btn:
st.session_state.generating = True
st.session_state.outline_generated = False
st.session_state.report_generated = False
st.session_state.feedback_given = False
initial_state = {
"topic": topic,
"feedback_on_report_plan": "",
"sections": [],
"thinking_process": [],
"completed_sections": [],
"current_section_index": 0,
"final_report": "",
"search_queries": {},
"sources": {},
"source_str": ""
}
with st.spinner("正在生成报告大纲..."):
try:
plan_result = asyncio.run(run_generate_report_plan(initial_state, config_dict))
initial_state.update(plan_result)
st.session_state.report_state = initial_state
st.session_state.outline_generated = True
except Exception as e:
st.error(f"生成报告大纲时出错:{str(e)}")
finally:
st.session_state.generating = False
st.rerun(scope="app")
# 2. Outline Display and Feedback
if st.session_state.outline_generated and not st.session_state.report_generated:
st.markdown("<div class='report-card'>", unsafe_allow_html=True)
st.markdown("#### 📋 报告大纲 (可在此直接编辑)")
with st.form(key='outline_form'):
# Add a field for the main report title
st.session_state.report_state['topic'] = st.text_input(
"报告总标题",
value=st.session_state.report_state.get('topic', '人工智能在医疗领域的应用')
)
sections = st.session_state.report_state.get("sections", [])
# Create UI for each section to be editable
for i, section in enumerate(sections):
st.markdown(f"<div class='section-card'>", unsafe_allow_html=True)
st.markdown(f"##### 章节 {i+1}")
# Use unique keys for each widget to avoid Streamlit's DuplicateWidgetID error
section.name = st.text_input("章节名称", value=section.name, key=f"name_{i}")
section.description = st.text_area("章节描述", value=section.description, key=f"desc_{i}", height=100)
st.markdown("</div>", unsafe_allow_html=True)
col1, col2 = st.columns([1, 4])
with col1:
submitted = st.form_submit_button("确认修改并生成完整报告",
disabled=st.session_state.generating,
type="primary", use_container_width=True)
st.markdown("</div>", unsafe_allow_html=True)
if submitted:
# User has confirmed the (potentially edited) outline, proceed to generate full report
st.session_state.generating = True
st.session_state.feedback_given = True # Prevent resubmission
with st.spinner("正在生成完整报告..."):
try:
final_state = asyncio.run(run_generate_full_report(st.session_state.report_state, config_dict))
st.session_state.report_state = final_state
st.session_state.report_generated = True
except Exception as e:
st.error(f"生成报告时出错:{str(e)}")
finally:
st.session_state.generating = False
st.rerun(scope="app")
# 3. Final Report Display
if st.session_state.report_generated:
st.markdown("<div class='report-card'>", unsafe_allow_html=True)
st.markdown("#### 📄 最终报告")
final_report_content = st.session_state.report_state.get('final_report', '报告生成失败。')
st.markdown(final_report_content)
st.markdown("</div>", unsafe_allow_html=True)
report_path = st.session_state.get("report_file_path", "")
if report_path and os.path.exists(report_path):
with open(report_path, "r", encoding="utf-8") as f:
st.download_button(
label="📥 下载报告 (Markdown)",
data=f.read(),
file_name=os.path.basename(report_path),
mime="text/markdown",
type="primary"
)
# 底部按钮
st.markdown("---")
if st.button("🔄 开始新的报告", type="secondary"):
st.session_state.clear()
st.rerun(scope="app")