Files
Deepresearch/main.py
2025-08-25 10:07:26 +08:00

272 lines
11 KiB
Python

import os
import sys
import asyncio
import json
import time
from state import ReportState, Section
from configuration import SearchAPI
from nodes import (
generate_report_plan,
think_and_generate_queries,
search_web,
write_section,
compile_final_report,
process_citations,
save_final_report
)
# 设置环境变量 - API密钥
os.environ["DEEPSEEK_API_KEY"] = "XXX"
os.environ["TAVILY_API_KEY"] = "XXX"
# 使用字典作为配置
config_dict = {
"configurable": {
"number_of_queries": 2,
"planner_provider": "deepseek",
"planner_model": "deepseek-chat",
"writer_provider": "deepseek",
"writer_model": "deepseek-chat",
"search_api": SearchAPI.TAVILY.value,
"search_api_config": {"api_key": os.environ.get("TAVILY_API_KEY", "")},
"max_tokens": 4096
}
}
async def get_user_topic():
print("请输入您想要研究的报告主题 (直接回车使用默认主题: 人工智能在医疗领域的应用)")
topic = input("> ").strip()
if not topic:
topic = "人工智能在医疗领域的应用"
print(f"使用默认主题: {topic}")
return topic
async def create_report_outline(state, config_dict):
report_plan_result = await generate_report_plan(state, config_dict)
if "sections" in report_plan_result and report_plan_result["sections"]:
sections = report_plan_result["sections"]
print(f"{len(sections)} 个章节:")
for i, section in enumerate(sections):
print(f" {i+1}. {section.name}")
print(f" 描述: {section.description}")
print(f" 需要研究: {'' if section.research else ''}")
print()
state.update(report_plan_result)
return state
async def get_outline_feedback(state):
sections = state.get("sections", [])
print("\n请提供您的反馈意见:")
print("1. 输入 'accept' 表示接受大纲并继续生成报告")
print("2. 输入 'regenerate' 表示重新生成大纲")
print("3. 输入具体的修改建议,例如 '增加关于AI伦理的章节''第2章应该更详细'")
feedback = input("> ")
return feedback
async def regenerate_outline_if_needed(state, config_dict, feedback, max_attempts=3, current_attempt=0):
if current_attempt >= max_attempts:
print(f"已达到最大重试次数({max_attempts}次),将使用当前大纲继续。")
return state, True
state["feedback_on_report_plan"] = feedback
if feedback.lower() == 'accept':
print("大纲已接受,将继续生成报告内容。")
return state, True
if feedback.lower() == 'regenerate':
print("正在重新生成大纲...")
state = await create_report_outline(state, config_dict)
new_feedback = await get_outline_feedback(state)
return await regenerate_outline_if_needed(state, config_dict, new_feedback, max_attempts, current_attempt + 1)
print(f"收到反馈: {feedback}")
state = await create_report_outline(state, config_dict)
new_feedback = await get_outline_feedback(state)
return await regenerate_outline_if_needed(state, config_dict, new_feedback, max_attempts, current_attempt + 1)
async def generate_report_sections(state, config_dict):
completed_sections = []
total_sections = len(state["sections"])
while state["current_section_index"] < total_sections:
current_index = state["current_section_index"]
current_section = state["sections"][current_index]
print(
f"\n生成章节 {current_index + 1}/{total_sections}: {current_section.name}")
print(f"章节描述: {current_section.description}")
# 为当前章节生成搜索查询
print(f"为章节 '{current_section.name}' 思考中...")
if asyncio.iscoroutinefunction(think_and_generate_queries):
think_result = await think_and_generate_queries(state, config_dict)
else:
think_result = think_and_generate_queries(state, config_dict)
if "thinking_process" in think_result and think_result["thinking_process"]:
print("\n思考过程:")
for thought in enumerate(think_result["thinking_process"]):
print(f"{thought}")
state.update(think_result)
# 执行search_web搜索
print(f"为章节 '{current_section.name}' 搜索中...")
search_result = await search_web(state, config_dict)
state.update(search_result)
print(f"搜索完成,获取到 {len(state.get('sources', {}))} 条资料")
print(state["source_str"])
# 生成章节内容
print("正在生成内容...")
start_time = time.time()
if asyncio.iscoroutinefunction(write_section):
write_result = await write_section(state, config_dict)
else:
write_result = write_section(state, config_dict)
end_time = time.time()
print(f"章节生成耗时: {end_time - start_time:.2f}")
print(f"章节处理结果类型: {type(write_result)}")
has_update = False
has_goto = False
goto_node = None
if isinstance(write_result, dict):
print("结果是普通字典,直接更新状态")
state.update(write_result)
if 'completed_sections' in write_result:
section_content = write_result['completed_sections'][0]
completed_sections.append(section_content)
print(
f"章节内容完成,内容长度: {len(getattr(section_content, 'content', '')) if hasattr(section_content, 'content') else 0}")
content = getattr(section_content, 'content', "")
print(f"\n章节内容:\n{content}\n")
state["current_section_index"] += 1
continue
if hasattr(write_result, 'update'):
has_update = True
update_data = write_result.update
for key, value in update_data.items():
state[key] = value
if 'completed_sections' in update_data:
section_content = update_data['completed_sections'][0]
completed_sections.append(section_content)
print(
f"章节内容完成,内容长度: {len(getattr(section_content, 'content', '')) if hasattr(section_content, 'content') else 0}")
content = getattr(section_content, 'content', "")
print(f"\n章节内容:\n{content}\n")
if hasattr(write_result, 'goto'):
has_goto = True
goto_node = write_result.goto
if has_goto:
if goto_node == "compile_final_report":
print("所有章节完成,跳出循环")
break
elif not has_update:
print("结果是简单值,直接作为当前章节内容")
if current_section and hasattr(current_section, 'content'):
current_section.content = str(write_result)
completed_sections.append(current_section)
content = current_section.content
print(f"\n章节内容:\n{content}\n")
state["current_section_index"] += 1
state["completed_sections"] = completed_sections
print(f"\n完成所有章节 ({len(completed_sections)}/{total_sections})")
return state
async def compile_report(state):
compile_result = await compile_final_report(state)
state.update(compile_result)
print(f"报告编译完成,总字数: {len(state.get('final_report', ''))}")
print("\n正在处理引用和引文...")
if asyncio.iscoroutinefunction(process_citations):
citations_result = await process_citations(state)
else:
citations_result = process_citations(state)
state.update(citations_result)
print("引用处理完成")
return state
async def save_report(state):
topic_slug = state["topic"].lower().replace(" ", "_")[:30]
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_file = f"report_{topic_slug}_{timestamp}.md"
with open(output_file, "w", encoding="utf-8") as f:
f.write(state["final_report"])
print(f"报告已保存到文件: {output_file}")
metadata_file = f"metadata_{topic_slug}_{timestamp}.json"
metadata = {
"topic": state["topic"],
"timestamp": timestamp,
"sections_count": len(state["sections"]),
"completed_sections_count": len(state["completed_sections"]),
"word_count": len(state["final_report"].split()),
"character_count": len(state["final_report"]),
"report_file": output_file,
"sections": [
{
"name": section.name,
"description": section.description,
"content_length": len(getattr(section, 'content', '')) if hasattr(section, 'content') else 0
}
for section in state["sections"]
]
}
with open(metadata_file, "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
print(f"报告元数据已保存到文件: {metadata_file}")
await save_final_report(state)
preview_length = 500
preview = state["final_report"][:preview_length] + "..." if len(
state["final_report"]) > preview_length else state["final_report"]
print("\n" + "=" * 80)
print(f"报告预览 (前 {preview_length} 个字符):")
print(preview)
print("=" * 80)
print(f"完整报告已保存到: {output_file}")
return state
async def main():
topic = await get_user_topic()
state = {
"topic": topic,
"feedback_on_report_plan": "",
"sections": [],
"thinking_process": [],
"completed_sections": [],
"current_section_index": 0,
"final_report": "",
"search_queries": {},
"sources": {},
"source_str": ""
}
state = await create_report_outline(state, config_dict)
feedback = await get_outline_feedback(state)
state, continue_generation = await regenerate_outline_if_needed(state, config_dict, feedback)
if not continue_generation:
print("根据用户指示,停止报告生成。")
return state
# 移除全局搜索步骤,改为在每个章节生成前单独搜索
state = await generate_report_sections(state, config_dict)
state = await compile_report(state)
state = await save_report(state)
final_report_state = ReportState(
topic=state["topic"],
feedback_on_report_plan=state["feedback_on_report_plan"],
sections=state["sections"],
thinking_process=state["thinking_process"],
completed_sections=state["completed_sections"],
current_section_index=state["current_section_index"],
final_report=state["final_report"],
search_queries=state["search_queries"],
sources=state["sources"],
source_str=state["source_str"]
)
return final_report_state
if __name__ == "__main__":
asyncio.run(main())