import streamlit as st import asyncio import os import json import time from state import ReportState, Section from configuration import SearchAPI from nodes import ( generate_report_plan, think_and_generate_queries, search_web, write_section, compile_final_report, process_citations, save_final_report ) # 从 config.json 文件读取配置 import json with open("config.json", "r", encoding="utf-8") as f: config_data = json.load(f) # 设置环境变量 - API密钥 os.environ["DEEPSEEK_API_KEY"] = config_data["environment_variables"]["DEEPSEEK_API_KEY"] os.environ["TAVILY_API_KEY"] = config_data["environment_variables"]["TAVILY_API_KEY"] # 使用字典作为配置 config_dict = config_data["config_dict"] # 替换配置中的环境变量占位符 if "search_api_config" in config_dict["configurable"] and "api_key" in config_dict["configurable"]["search_api_config"]: config_dict["configurable"]["search_api_config"]["api_key"] = os.environ.get("TAVILY_API_KEY", "") # --- Streamlit App --- st.set_page_config(page_title="DeepResearch Stream", layout="wide", page_icon="🔬") # 简化CSS样式,移除可能导致空白的margin st.markdown(""" """, unsafe_allow_html=True) # --- 历史记录功能 --- def load_history(): """加载历史记录""" reports_dir = os.path.join(os.getcwd(), "reports") os.makedirs(reports_dir, exist_ok=True) history_file = os.path.join(reports_dir, "history.json") if os.path.exists(history_file): with open(history_file, "r", encoding="utf-8") as f: return json.load(f) return [] def save_history(history): """保存历史记录""" reports_dir = os.path.join(os.getcwd(), "reports") os.makedirs(reports_dir, exist_ok=True) history_file = os.path.join(reports_dir, "history.json") with open(history_file, "w", encoding="utf-8") as f: json.dump(history, f, ensure_ascii=False, indent=2) def add_to_history(topic, file_path): """添加到历史记录""" history = load_history() history.insert(0, { "topic": topic, "file_path": file_path, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S") }) # 只保留最近50条记录 history = history[:50] save_history(history) # 侧边栏内容 with st.sidebar: st.markdown('', unsafe_allow_html=True) # 初始化 session state if 'report_state' not in st.session_state: st.session_state.report_state = None if 'generating' not in st.session_state: st.session_state.generating = False if 'outline_generated' not in st.session_state: st.session_state.outline_generated = False if 'report_generated' not in st.session_state: st.session_state.report_generated = False if 'feedback_given' not in st.session_state: st.session_state.feedback_given = False # --- Helper Functions for Streamlit --- async def run_generate_report_plan(state, config): return await generate_report_plan(state, config) async def run_generate_full_report(state, config): # 逻辑从 main.py 的 generate_report_sections, compile_report, save_report 迁移过来 completed_sections = [] total_sections = len(state["sections"]) # 创建一个容器来显示进度和状态 progress_container = st.container() with progress_container: progress_bar = st.progress(0) status_text = st.empty() # 为每个章节的输出创建一个容器 section_placeholders = [st.empty() for _ in range(total_sections)] while state["current_section_index"] < total_sections: current_index = state["current_section_index"] current_section = state["sections"][current_index] # 更新状态信息 status_text.markdown(f"
正在生成章节 {current_index + 1}/{total_sections}: {current_section.name}
", unsafe_allow_html=True) # 获取当前章节的UI容器 section_container = section_placeholders[current_index] with section_container.container(): st.markdown(f"

章节 {current_index + 1}: {current_section.name}

", unsafe_allow_html=True) # 1. 展示思考过程 with st.expander("🧠 思考过程", expanded=False): think_result = think_and_generate_queries(state, config) state.update(think_result) st.write(think_result.get("thinking_process", "没有记录思考过程。")) # 2. 展示搜索查询 with st.expander("🔍 搜索查询", expanded=False): search_queries = think_result.get("search_queries", {}).get(current_section.name, []) if search_queries: for query in search_queries: st.code(query.search_query, language="text") else: st.write("没有为本章节生成搜索查询。") search_result = await search_web(state, config) state.update(search_result) # 3. 流式展示生成的章节内容 content_placeholder = st.empty() full_content_for_section = "" async def stream_callback_for_section(chunk): nonlocal full_content_for_section full_content_for_section += chunk content_placeholder.markdown(f"
{full_content_for_section} ▌
", unsafe_allow_html=True) write_result = await write_section(state, config, stream_callback=stream_callback_for_section) # 写入完成后,移除光标 content_placeholder.markdown(f"
{full_content_for_section}
", unsafe_allow_html=True) if isinstance(write_result, dict): state.update(write_result) if 'completed_sections' in write_result: completed_sections.extend(write_result['completed_sections']) state["current_section_index"] += 1 elif hasattr(write_result, 'update'): update_data = write_result.update state.update(update_data) if 'completed_sections' in update_data: completed_sections.extend(update_data['completed_sections']) if hasattr(write_result, 'goto') and write_result.goto == "compile_final_report": break progress = (current_index + 1) / total_sections progress_bar.progress(progress) state["completed_sections"] = completed_sections status_text.markdown("
所有章节完成,正在编译最终报告...
", unsafe_allow_html=True) compile_result = await compile_final_report(state) state.update(compile_result) citations_result = process_citations(state) state.update(citations_result) status_text.markdown("
✅ 报告生成完毕!
", unsafe_allow_html=True) # 保存文件 - 使用相对路径 topic_slug = state["topic"].lower().replace(" ", "_")[:30] timestamp = time.strftime("%Y%m%d_%H%M%S") reports_dir = os.path.join(os.getcwd(), "reports") os.makedirs(reports_dir, exist_ok=True) output_file = os.path.join(reports_dir, f"report_{topic_slug}_{timestamp}.md") with open(output_file, "w", encoding="utf-8") as f: f.write(state["final_report"]) st.session_state.report_file_path = output_file # 添加到历史记录 add_to_history(state["topic"], output_file) return state # --- UI Components --- # 直接在主区域添加内容,不使用额外的容器 # 主要内容区域 # 主页面标题 st.markdown("""
🔬 AutomaSynth 探策矩阵
""", unsafe_allow_html=True) #st.markdown("

(星航电子工作室)

", unsafe_allow_html=True) # 1. Topic Input st.markdown("
", unsafe_allow_html=True) st.markdown("#### 📝 报告主题") topic = st.text_input("请输入您想要研究的报告主题", value="例如:人工智能在医疗领域的应用", placeholder="例如:人工智能在医疗领域的应用") col1, col2 = st.columns([1, 4]) with col1: generate_outline_btn = st.button("生成报告大纲", disabled=st.session_state.generating, type="primary", use_container_width=True) st.markdown("
", unsafe_allow_html=True) if generate_outline_btn: st.session_state.generating = True st.session_state.outline_generated = False st.session_state.report_generated = False st.session_state.feedback_given = False initial_state = { "topic": topic, "feedback_on_report_plan": "", "sections": [], "thinking_process": [], "completed_sections": [], "current_section_index": 0, "final_report": "", "search_queries": {}, "sources": {}, "source_str": "" } with st.spinner("正在生成报告大纲..."): try: plan_result = asyncio.run(run_generate_report_plan(initial_state, config_dict)) initial_state.update(plan_result) st.session_state.report_state = initial_state st.session_state.outline_generated = True except Exception as e: st.error(f"生成报告大纲时出错:{str(e)}") finally: st.session_state.generating = False st.rerun(scope="app") # 2. Outline Display and Feedback if st.session_state.outline_generated and not st.session_state.report_generated: st.markdown("
", unsafe_allow_html=True) st.markdown("#### 📋 报告大纲 (可在此直接编辑)") with st.form(key='outline_form'): # Add a field for the main report title st.session_state.report_state['topic'] = st.text_input( "报告总标题", value=st.session_state.report_state.get('topic', '人工智能在医疗领域的应用') ) sections = st.session_state.report_state.get("sections", []) # Create UI for each section to be editable for i, section in enumerate(sections): st.markdown(f"
", unsafe_allow_html=True) st.markdown(f"##### 章节 {i+1}") # Use unique keys for each widget to avoid Streamlit's DuplicateWidgetID error section.name = st.text_input("章节名称", value=section.name, key=f"name_{i}") section.description = st.text_area("章节描述", value=section.description, key=f"desc_{i}", height=100) st.markdown("
", unsafe_allow_html=True) col1, col2 = st.columns([1, 4]) with col1: submitted = st.form_submit_button("确认修改并生成完整报告", disabled=st.session_state.generating, type="primary", use_container_width=True) st.markdown("
", unsafe_allow_html=True) if submitted: # User has confirmed the (potentially edited) outline, proceed to generate full report st.session_state.generating = True st.session_state.feedback_given = True # Prevent resubmission with st.spinner("正在生成完整报告..."): try: final_state = asyncio.run(run_generate_full_report(st.session_state.report_state, config_dict)) st.session_state.report_state = final_state st.session_state.report_generated = True except Exception as e: st.error(f"生成报告时出错:{str(e)}") finally: st.session_state.generating = False st.rerun(scope="app") # 3. Final Report Display if st.session_state.report_generated: st.markdown("
", unsafe_allow_html=True) st.markdown("#### 📄 最终报告") final_report_content = st.session_state.report_state.get('final_report', '报告生成失败。') st.markdown(final_report_content) st.markdown("
", unsafe_allow_html=True) report_path = st.session_state.get("report_file_path", "") if report_path and os.path.exists(report_path): with open(report_path, "r", encoding="utf-8") as f: st.download_button( label="📥 下载报告 (Markdown)", data=f.read(), file_name=os.path.basename(report_path), mime="text/markdown", type="primary" ) # 底部按钮 st.markdown("---") if st.button("🔄 开始新的报告", type="secondary"): st.session_state.clear() st.rerun(scope="app")