diff --git a/main.py b/main.py new file mode 100644 index 0000000..02b51a9 --- /dev/null +++ b/main.py @@ -0,0 +1,271 @@ +import os +import sys +import asyncio +import json +import time +from state import ReportState, Section +from configuration import SearchAPI +from nodes import ( + generate_report_plan, + think_and_generate_queries, + search_web, + write_section, + compile_final_report, + process_citations, + save_final_report +) + +# 设置环境变量 - API密钥 +os.environ["DEEPSEEK_API_KEY"] = "XXX" +os.environ["TAVILY_API_KEY"] = "XXX" + +# 使用字典作为配置 +config_dict = { + "configurable": { + "number_of_queries": 2, + "planner_provider": "deepseek", + "planner_model": "deepseek-chat", + "writer_provider": "deepseek", + "writer_model": "deepseek-chat", + "search_api": SearchAPI.TAVILY.value, + "search_api_config": {"api_key": os.environ.get("TAVILY_API_KEY", "")}, + "max_tokens": 4096 + } +} + + +async def get_user_topic(): + print("请输入您想要研究的报告主题 (直接回车使用默认主题: 人工智能在医疗领域的应用)") + topic = input("> ").strip() + if not topic: + topic = "人工智能在医疗领域的应用" + print(f"使用默认主题: {topic}") + return topic + + +async def create_report_outline(state, config_dict): + report_plan_result = await generate_report_plan(state, config_dict) + if "sections" in report_plan_result and report_plan_result["sections"]: + sections = report_plan_result["sections"] + print(f"共 {len(sections)} 个章节:") + for i, section in enumerate(sections): + print(f" {i+1}. {section.name}") + print(f" 描述: {section.description}") + print(f" 需要研究: {'是' if section.research else '否'}") + print() + state.update(report_plan_result) + return state + + +async def get_outline_feedback(state): + sections = state.get("sections", []) + print("\n请提供您的反馈意见:") + print("1. 输入 'accept' 表示接受大纲并继续生成报告") + print("2. 输入 'regenerate' 表示重新生成大纲") + print("3. 输入具体的修改建议,例如 '增加关于AI伦理的章节' 或 '第2章应该更详细'") + feedback = input("> ") + return feedback + + +async def regenerate_outline_if_needed(state, config_dict, feedback, max_attempts=3, current_attempt=0): + if current_attempt >= max_attempts: + print(f"已达到最大重试次数({max_attempts}次),将使用当前大纲继续。") + return state, True + state["feedback_on_report_plan"] = feedback + if feedback.lower() == 'accept': + print("大纲已接受,将继续生成报告内容。") + return state, True + if feedback.lower() == 'regenerate': + print("正在重新生成大纲...") + state = await create_report_outline(state, config_dict) + new_feedback = await get_outline_feedback(state) + return await regenerate_outline_if_needed(state, config_dict, new_feedback, max_attempts, current_attempt + 1) + print(f"收到反馈: {feedback}") + state = await create_report_outline(state, config_dict) + new_feedback = await get_outline_feedback(state) + return await regenerate_outline_if_needed(state, config_dict, new_feedback, max_attempts, current_attempt + 1) + + +async def generate_report_sections(state, config_dict): + completed_sections = [] + total_sections = len(state["sections"]) + while state["current_section_index"] < total_sections: + current_index = state["current_section_index"] + current_section = state["sections"][current_index] + print( + f"\n生成章节 {current_index + 1}/{total_sections}: {current_section.name}") + print(f"章节描述: {current_section.description}") + + # 为当前章节生成搜索查询 + print(f"为章节 '{current_section.name}' 思考中...") + if asyncio.iscoroutinefunction(think_and_generate_queries): + think_result = await think_and_generate_queries(state, config_dict) + else: + think_result = think_and_generate_queries(state, config_dict) + if "thinking_process" in think_result and think_result["thinking_process"]: + print("\n思考过程:") + for thought in enumerate(think_result["thinking_process"]): + print(f"{thought}") + state.update(think_result) + + # 执行search_web搜索 + print(f"为章节 '{current_section.name}' 搜索中...") + search_result = await search_web(state, config_dict) + state.update(search_result) + print(f"搜索完成,获取到 {len(state.get('sources', {}))} 条资料") + print(state["source_str"]) + + # 生成章节内容 + print("正在生成内容...") + start_time = time.time() + if asyncio.iscoroutinefunction(write_section): + write_result = await write_section(state, config_dict) + else: + write_result = write_section(state, config_dict) + end_time = time.time() + print(f"章节生成耗时: {end_time - start_time:.2f} 秒") + print(f"章节处理结果类型: {type(write_result)}") + has_update = False + has_goto = False + goto_node = None + + if isinstance(write_result, dict): + print("结果是普通字典,直接更新状态") + state.update(write_result) + if 'completed_sections' in write_result: + section_content = write_result['completed_sections'][0] + completed_sections.append(section_content) + print( + f"章节内容完成,内容长度: {len(getattr(section_content, 'content', '')) if hasattr(section_content, 'content') else 0}") + content = getattr(section_content, 'content', "") + print(f"\n章节内容:\n{content}\n") + state["current_section_index"] += 1 + continue + + if hasattr(write_result, 'update'): + has_update = True + update_data = write_result.update + for key, value in update_data.items(): + state[key] = value + if 'completed_sections' in update_data: + section_content = update_data['completed_sections'][0] + completed_sections.append(section_content) + print( + f"章节内容完成,内容长度: {len(getattr(section_content, 'content', '')) if hasattr(section_content, 'content') else 0}") + content = getattr(section_content, 'content', "") + print(f"\n章节内容:\n{content}\n") + + if hasattr(write_result, 'goto'): + has_goto = True + goto_node = write_result.goto + if has_goto: + if goto_node == "compile_final_report": + print("所有章节完成,跳出循环") + break + elif not has_update: + print("结果是简单值,直接作为当前章节内容") + if current_section and hasattr(current_section, 'content'): + current_section.content = str(write_result) + completed_sections.append(current_section) + content = current_section.content + print(f"\n章节内容:\n{content}\n") + state["current_section_index"] += 1 + state["completed_sections"] = completed_sections + print(f"\n完成所有章节 ({len(completed_sections)}/{total_sections})") + return state + + +async def compile_report(state): + compile_result = await compile_final_report(state) + state.update(compile_result) + print(f"报告编译完成,总字数: {len(state.get('final_report', ''))}") + print("\n正在处理引用和引文...") + if asyncio.iscoroutinefunction(process_citations): + citations_result = await process_citations(state) + else: + citations_result = process_citations(state) + state.update(citations_result) + print("引用处理完成") + return state + + +async def save_report(state): + topic_slug = state["topic"].lower().replace(" ", "_")[:30] + timestamp = time.strftime("%Y%m%d_%H%M%S") + output_file = f"report_{topic_slug}_{timestamp}.md" + with open(output_file, "w", encoding="utf-8") as f: + f.write(state["final_report"]) + print(f"报告已保存到文件: {output_file}") + metadata_file = f"metadata_{topic_slug}_{timestamp}.json" + metadata = { + "topic": state["topic"], + "timestamp": timestamp, + "sections_count": len(state["sections"]), + "completed_sections_count": len(state["completed_sections"]), + "word_count": len(state["final_report"].split()), + "character_count": len(state["final_report"]), + "report_file": output_file, + "sections": [ + { + "name": section.name, + "description": section.description, + "content_length": len(getattr(section, 'content', '')) if hasattr(section, 'content') else 0 + } + for section in state["sections"] + ] + } + with open(metadata_file, "w", encoding="utf-8") as f: + json.dump(metadata, f, ensure_ascii=False, indent=2) + print(f"报告元数据已保存到文件: {metadata_file}") + await save_final_report(state) + preview_length = 500 + preview = state["final_report"][:preview_length] + "..." if len( + state["final_report"]) > preview_length else state["final_report"] + print("\n" + "=" * 80) + print(f"报告预览 (前 {preview_length} 个字符):") + print(preview) + print("=" * 80) + print(f"完整报告已保存到: {output_file}") + return state + + +async def main(): + topic = await get_user_topic() + state = { + "topic": topic, + "feedback_on_report_plan": "", + "sections": [], + "thinking_process": [], + "completed_sections": [], + "current_section_index": 0, + "final_report": "", + "search_queries": {}, + "sources": {}, + "source_str": "" + } + state = await create_report_outline(state, config_dict) + feedback = await get_outline_feedback(state) + state, continue_generation = await regenerate_outline_if_needed(state, config_dict, feedback) + if not continue_generation: + print("根据用户指示,停止报告生成。") + return state + # 移除全局搜索步骤,改为在每个章节生成前单独搜索 + state = await generate_report_sections(state, config_dict) + state = await compile_report(state) + state = await save_report(state) + final_report_state = ReportState( + topic=state["topic"], + feedback_on_report_plan=state["feedback_on_report_plan"], + sections=state["sections"], + thinking_process=state["thinking_process"], + completed_sections=state["completed_sections"], + current_section_index=state["current_section_index"], + final_report=state["final_report"], + search_queries=state["search_queries"], + sources=state["sources"], + source_str=state["source_str"] + ) + return final_report_state + +if __name__ == "__main__": + asyncio.run(main())