import os import sys import asyncio import json import time from state import ReportState, Section from configuration import SearchAPI from nodes import ( generate_report_plan, think_and_generate_queries, search_web, write_section, compile_final_report, process_citations, save_final_report ) # 设置环境变量 - API密钥 os.environ["DEEPSEEK_API_KEY"] = "XXX" os.environ["TAVILY_API_KEY"] = "XXX" # 使用字典作为配置 config_dict = { "configurable": { "number_of_queries": 2, "planner_provider": "deepseek", "planner_model": "deepseek-chat", "writer_provider": "deepseek", "writer_model": "deepseek-chat", "search_api": SearchAPI.TAVILY.value, "search_api_config": {"api_key": os.environ.get("TAVILY_API_KEY", "")}, "max_tokens": 4096 } } async def get_user_topic(): print("请输入您想要研究的报告主题 (直接回车使用默认主题: 人工智能在医疗领域的应用)") topic = input("> ").strip() if not topic: topic = "人工智能在医疗领域的应用" print(f"使用默认主题: {topic}") return topic async def create_report_outline(state, config_dict): report_plan_result = await generate_report_plan(state, config_dict) if "sections" in report_plan_result and report_plan_result["sections"]: sections = report_plan_result["sections"] print(f"共 {len(sections)} 个章节:") for i, section in enumerate(sections): print(f" {i+1}. {section.name}") print(f" 描述: {section.description}") print(f" 需要研究: {'是' if section.research else '否'}") print() state.update(report_plan_result) return state async def get_outline_feedback(state): sections = state.get("sections", []) print("\n请提供您的反馈意见:") print("1. 输入 'accept' 表示接受大纲并继续生成报告") print("2. 输入 'regenerate' 表示重新生成大纲") print("3. 输入具体的修改建议,例如 '增加关于AI伦理的章节' 或 '第2章应该更详细'") feedback = input("> ") return feedback async def regenerate_outline_if_needed(state, config_dict, feedback, max_attempts=3, current_attempt=0): if current_attempt >= max_attempts: print(f"已达到最大重试次数({max_attempts}次),将使用当前大纲继续。") return state, True state["feedback_on_report_plan"] = feedback if feedback.lower() == 'accept': print("大纲已接受,将继续生成报告内容。") return state, True if feedback.lower() == 'regenerate': print("正在重新生成大纲...") state = await create_report_outline(state, config_dict) new_feedback = await get_outline_feedback(state) return await regenerate_outline_if_needed(state, config_dict, new_feedback, max_attempts, current_attempt + 1) print(f"收到反馈: {feedback}") state = await create_report_outline(state, config_dict) new_feedback = await get_outline_feedback(state) return await regenerate_outline_if_needed(state, config_dict, new_feedback, max_attempts, current_attempt + 1) async def generate_report_sections(state, config_dict): completed_sections = [] total_sections = len(state["sections"]) while state["current_section_index"] < total_sections: current_index = state["current_section_index"] current_section = state["sections"][current_index] print( f"\n生成章节 {current_index + 1}/{total_sections}: {current_section.name}") print(f"章节描述: {current_section.description}") # 为当前章节生成搜索查询 print(f"为章节 '{current_section.name}' 思考中...") if asyncio.iscoroutinefunction(think_and_generate_queries): think_result = await think_and_generate_queries(state, config_dict) else: think_result = think_and_generate_queries(state, config_dict) if "thinking_process" in think_result and think_result["thinking_process"]: print("\n思考过程:") for thought in enumerate(think_result["thinking_process"]): print(f"{thought}") state.update(think_result) # 执行search_web搜索 print(f"为章节 '{current_section.name}' 搜索中...") search_result = await search_web(state, config_dict) state.update(search_result) print(f"搜索完成,获取到 {len(state.get('sources', {}))} 条资料") print(state["source_str"]) # 生成章节内容 print("正在生成内容...") start_time = time.time() if asyncio.iscoroutinefunction(write_section): write_result = await write_section(state, config_dict) else: write_result = write_section(state, config_dict) end_time = time.time() print(f"章节生成耗时: {end_time - start_time:.2f} 秒") print(f"章节处理结果类型: {type(write_result)}") has_update = False has_goto = False goto_node = None if isinstance(write_result, dict): print("结果是普通字典,直接更新状态") state.update(write_result) if 'completed_sections' in write_result: section_content = write_result['completed_sections'][0] completed_sections.append(section_content) print( f"章节内容完成,内容长度: {len(getattr(section_content, 'content', '')) if hasattr(section_content, 'content') else 0}") content = getattr(section_content, 'content', "") print(f"\n章节内容:\n{content}\n") state["current_section_index"] += 1 continue if hasattr(write_result, 'update'): has_update = True update_data = write_result.update for key, value in update_data.items(): state[key] = value if 'completed_sections' in update_data: section_content = update_data['completed_sections'][0] completed_sections.append(section_content) print( f"章节内容完成,内容长度: {len(getattr(section_content, 'content', '')) if hasattr(section_content, 'content') else 0}") content = getattr(section_content, 'content', "") print(f"\n章节内容:\n{content}\n") if hasattr(write_result, 'goto'): has_goto = True goto_node = write_result.goto if has_goto: if goto_node == "compile_final_report": print("所有章节完成,跳出循环") break elif not has_update: print("结果是简单值,直接作为当前章节内容") if current_section and hasattr(current_section, 'content'): current_section.content = str(write_result) completed_sections.append(current_section) content = current_section.content print(f"\n章节内容:\n{content}\n") state["current_section_index"] += 1 state["completed_sections"] = completed_sections print(f"\n完成所有章节 ({len(completed_sections)}/{total_sections})") return state async def compile_report(state): compile_result = await compile_final_report(state) state.update(compile_result) print(f"报告编译完成,总字数: {len(state.get('final_report', ''))}") print("\n正在处理引用和引文...") if asyncio.iscoroutinefunction(process_citations): citations_result = await process_citations(state) else: citations_result = process_citations(state) state.update(citations_result) print("引用处理完成") return state async def save_report(state): topic_slug = state["topic"].lower().replace(" ", "_")[:30] timestamp = time.strftime("%Y%m%d_%H%M%S") output_file = f"report_{topic_slug}_{timestamp}.md" with open(output_file, "w", encoding="utf-8") as f: f.write(state["final_report"]) print(f"报告已保存到文件: {output_file}") metadata_file = f"metadata_{topic_slug}_{timestamp}.json" metadata = { "topic": state["topic"], "timestamp": timestamp, "sections_count": len(state["sections"]), "completed_sections_count": len(state["completed_sections"]), "word_count": len(state["final_report"].split()), "character_count": len(state["final_report"]), "report_file": output_file, "sections": [ { "name": section.name, "description": section.description, "content_length": len(getattr(section, 'content', '')) if hasattr(section, 'content') else 0 } for section in state["sections"] ] } with open(metadata_file, "w", encoding="utf-8") as f: json.dump(metadata, f, ensure_ascii=False, indent=2) print(f"报告元数据已保存到文件: {metadata_file}") await save_final_report(state) preview_length = 500 preview = state["final_report"][:preview_length] + "..." if len( state["final_report"]) > preview_length else state["final_report"] print("\n" + "=" * 80) print(f"报告预览 (前 {preview_length} 个字符):") print(preview) print("=" * 80) print(f"完整报告已保存到: {output_file}") return state async def main(): topic = await get_user_topic() state = { "topic": topic, "feedback_on_report_plan": "", "sections": [], "thinking_process": [], "completed_sections": [], "current_section_index": 0, "final_report": "", "search_queries": {}, "sources": {}, "source_str": "" } state = await create_report_outline(state, config_dict) feedback = await get_outline_feedback(state) state, continue_generation = await regenerate_outline_if_needed(state, config_dict, feedback) if not continue_generation: print("根据用户指示,停止报告生成。") return state # 移除全局搜索步骤,改为在每个章节生成前单独搜索 state = await generate_report_sections(state, config_dict) state = await compile_report(state) state = await save_report(state) final_report_state = ReportState( topic=state["topic"], feedback_on_report_plan=state["feedback_on_report_plan"], sections=state["sections"], thinking_process=state["thinking_process"], completed_sections=state["completed_sections"], current_section_index=state["current_section_index"], final_report=state["final_report"], search_queries=state["search_queries"], sources=state["sources"], source_str=state["source_str"] ) return final_report_state if __name__ == "__main__": asyncio.run(main())