上传文件至 /
This commit is contained in:
271
main.py
Normal file
271
main.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from state import ReportState, Section
|
||||
from configuration import SearchAPI
|
||||
from nodes import (
|
||||
generate_report_plan,
|
||||
think_and_generate_queries,
|
||||
search_web,
|
||||
write_section,
|
||||
compile_final_report,
|
||||
process_citations,
|
||||
save_final_report
|
||||
)
|
||||
|
||||
# 设置环境变量 - API密钥
|
||||
os.environ["DEEPSEEK_API_KEY"] = "XXX"
|
||||
os.environ["TAVILY_API_KEY"] = "XXX"
|
||||
|
||||
# 使用字典作为配置
|
||||
config_dict = {
|
||||
"configurable": {
|
||||
"number_of_queries": 2,
|
||||
"planner_provider": "deepseek",
|
||||
"planner_model": "deepseek-chat",
|
||||
"writer_provider": "deepseek",
|
||||
"writer_model": "deepseek-chat",
|
||||
"search_api": SearchAPI.TAVILY.value,
|
||||
"search_api_config": {"api_key": os.environ.get("TAVILY_API_KEY", "")},
|
||||
"max_tokens": 4096
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def get_user_topic():
|
||||
print("请输入您想要研究的报告主题 (直接回车使用默认主题: 人工智能在医疗领域的应用)")
|
||||
topic = input("> ").strip()
|
||||
if not topic:
|
||||
topic = "人工智能在医疗领域的应用"
|
||||
print(f"使用默认主题: {topic}")
|
||||
return topic
|
||||
|
||||
|
||||
async def create_report_outline(state, config_dict):
|
||||
report_plan_result = await generate_report_plan(state, config_dict)
|
||||
if "sections" in report_plan_result and report_plan_result["sections"]:
|
||||
sections = report_plan_result["sections"]
|
||||
print(f"共 {len(sections)} 个章节:")
|
||||
for i, section in enumerate(sections):
|
||||
print(f" {i+1}. {section.name}")
|
||||
print(f" 描述: {section.description}")
|
||||
print(f" 需要研究: {'是' if section.research else '否'}")
|
||||
print()
|
||||
state.update(report_plan_result)
|
||||
return state
|
||||
|
||||
|
||||
async def get_outline_feedback(state):
|
||||
sections = state.get("sections", [])
|
||||
print("\n请提供您的反馈意见:")
|
||||
print("1. 输入 'accept' 表示接受大纲并继续生成报告")
|
||||
print("2. 输入 'regenerate' 表示重新生成大纲")
|
||||
print("3. 输入具体的修改建议,例如 '增加关于AI伦理的章节' 或 '第2章应该更详细'")
|
||||
feedback = input("> ")
|
||||
return feedback
|
||||
|
||||
|
||||
async def regenerate_outline_if_needed(state, config_dict, feedback, max_attempts=3, current_attempt=0):
|
||||
if current_attempt >= max_attempts:
|
||||
print(f"已达到最大重试次数({max_attempts}次),将使用当前大纲继续。")
|
||||
return state, True
|
||||
state["feedback_on_report_plan"] = feedback
|
||||
if feedback.lower() == 'accept':
|
||||
print("大纲已接受,将继续生成报告内容。")
|
||||
return state, True
|
||||
if feedback.lower() == 'regenerate':
|
||||
print("正在重新生成大纲...")
|
||||
state = await create_report_outline(state, config_dict)
|
||||
new_feedback = await get_outline_feedback(state)
|
||||
return await regenerate_outline_if_needed(state, config_dict, new_feedback, max_attempts, current_attempt + 1)
|
||||
print(f"收到反馈: {feedback}")
|
||||
state = await create_report_outline(state, config_dict)
|
||||
new_feedback = await get_outline_feedback(state)
|
||||
return await regenerate_outline_if_needed(state, config_dict, new_feedback, max_attempts, current_attempt + 1)
|
||||
|
||||
|
||||
async def generate_report_sections(state, config_dict):
|
||||
completed_sections = []
|
||||
total_sections = len(state["sections"])
|
||||
while state["current_section_index"] < total_sections:
|
||||
current_index = state["current_section_index"]
|
||||
current_section = state["sections"][current_index]
|
||||
print(
|
||||
f"\n生成章节 {current_index + 1}/{total_sections}: {current_section.name}")
|
||||
print(f"章节描述: {current_section.description}")
|
||||
|
||||
# 为当前章节生成搜索查询
|
||||
print(f"为章节 '{current_section.name}' 思考中...")
|
||||
if asyncio.iscoroutinefunction(think_and_generate_queries):
|
||||
think_result = await think_and_generate_queries(state, config_dict)
|
||||
else:
|
||||
think_result = think_and_generate_queries(state, config_dict)
|
||||
if "thinking_process" in think_result and think_result["thinking_process"]:
|
||||
print("\n思考过程:")
|
||||
for thought in enumerate(think_result["thinking_process"]):
|
||||
print(f"{thought}")
|
||||
state.update(think_result)
|
||||
|
||||
# 执行search_web搜索
|
||||
print(f"为章节 '{current_section.name}' 搜索中...")
|
||||
search_result = await search_web(state, config_dict)
|
||||
state.update(search_result)
|
||||
print(f"搜索完成,获取到 {len(state.get('sources', {}))} 条资料")
|
||||
print(state["source_str"])
|
||||
|
||||
# 生成章节内容
|
||||
print("正在生成内容...")
|
||||
start_time = time.time()
|
||||
if asyncio.iscoroutinefunction(write_section):
|
||||
write_result = await write_section(state, config_dict)
|
||||
else:
|
||||
write_result = write_section(state, config_dict)
|
||||
end_time = time.time()
|
||||
print(f"章节生成耗时: {end_time - start_time:.2f} 秒")
|
||||
print(f"章节处理结果类型: {type(write_result)}")
|
||||
has_update = False
|
||||
has_goto = False
|
||||
goto_node = None
|
||||
|
||||
if isinstance(write_result, dict):
|
||||
print("结果是普通字典,直接更新状态")
|
||||
state.update(write_result)
|
||||
if 'completed_sections' in write_result:
|
||||
section_content = write_result['completed_sections'][0]
|
||||
completed_sections.append(section_content)
|
||||
print(
|
||||
f"章节内容完成,内容长度: {len(getattr(section_content, 'content', '')) if hasattr(section_content, 'content') else 0}")
|
||||
content = getattr(section_content, 'content', "")
|
||||
print(f"\n章节内容:\n{content}\n")
|
||||
state["current_section_index"] += 1
|
||||
continue
|
||||
|
||||
if hasattr(write_result, 'update'):
|
||||
has_update = True
|
||||
update_data = write_result.update
|
||||
for key, value in update_data.items():
|
||||
state[key] = value
|
||||
if 'completed_sections' in update_data:
|
||||
section_content = update_data['completed_sections'][0]
|
||||
completed_sections.append(section_content)
|
||||
print(
|
||||
f"章节内容完成,内容长度: {len(getattr(section_content, 'content', '')) if hasattr(section_content, 'content') else 0}")
|
||||
content = getattr(section_content, 'content', "")
|
||||
print(f"\n章节内容:\n{content}\n")
|
||||
|
||||
if hasattr(write_result, 'goto'):
|
||||
has_goto = True
|
||||
goto_node = write_result.goto
|
||||
if has_goto:
|
||||
if goto_node == "compile_final_report":
|
||||
print("所有章节完成,跳出循环")
|
||||
break
|
||||
elif not has_update:
|
||||
print("结果是简单值,直接作为当前章节内容")
|
||||
if current_section and hasattr(current_section, 'content'):
|
||||
current_section.content = str(write_result)
|
||||
completed_sections.append(current_section)
|
||||
content = current_section.content
|
||||
print(f"\n章节内容:\n{content}\n")
|
||||
state["current_section_index"] += 1
|
||||
state["completed_sections"] = completed_sections
|
||||
print(f"\n完成所有章节 ({len(completed_sections)}/{total_sections})")
|
||||
return state
|
||||
|
||||
|
||||
async def compile_report(state):
|
||||
compile_result = await compile_final_report(state)
|
||||
state.update(compile_result)
|
||||
print(f"报告编译完成,总字数: {len(state.get('final_report', ''))}")
|
||||
print("\n正在处理引用和引文...")
|
||||
if asyncio.iscoroutinefunction(process_citations):
|
||||
citations_result = await process_citations(state)
|
||||
else:
|
||||
citations_result = process_citations(state)
|
||||
state.update(citations_result)
|
||||
print("引用处理完成")
|
||||
return state
|
||||
|
||||
|
||||
async def save_report(state):
|
||||
topic_slug = state["topic"].lower().replace(" ", "_")[:30]
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
output_file = f"report_{topic_slug}_{timestamp}.md"
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write(state["final_report"])
|
||||
print(f"报告已保存到文件: {output_file}")
|
||||
metadata_file = f"metadata_{topic_slug}_{timestamp}.json"
|
||||
metadata = {
|
||||
"topic": state["topic"],
|
||||
"timestamp": timestamp,
|
||||
"sections_count": len(state["sections"]),
|
||||
"completed_sections_count": len(state["completed_sections"]),
|
||||
"word_count": len(state["final_report"].split()),
|
||||
"character_count": len(state["final_report"]),
|
||||
"report_file": output_file,
|
||||
"sections": [
|
||||
{
|
||||
"name": section.name,
|
||||
"description": section.description,
|
||||
"content_length": len(getattr(section, 'content', '')) if hasattr(section, 'content') else 0
|
||||
}
|
||||
for section in state["sections"]
|
||||
]
|
||||
}
|
||||
with open(metadata_file, "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
print(f"报告元数据已保存到文件: {metadata_file}")
|
||||
await save_final_report(state)
|
||||
preview_length = 500
|
||||
preview = state["final_report"][:preview_length] + "..." if len(
|
||||
state["final_report"]) > preview_length else state["final_report"]
|
||||
print("\n" + "=" * 80)
|
||||
print(f"报告预览 (前 {preview_length} 个字符):")
|
||||
print(preview)
|
||||
print("=" * 80)
|
||||
print(f"完整报告已保存到: {output_file}")
|
||||
return state
|
||||
|
||||
|
||||
async def main():
|
||||
topic = await get_user_topic()
|
||||
state = {
|
||||
"topic": topic,
|
||||
"feedback_on_report_plan": "",
|
||||
"sections": [],
|
||||
"thinking_process": [],
|
||||
"completed_sections": [],
|
||||
"current_section_index": 0,
|
||||
"final_report": "",
|
||||
"search_queries": {},
|
||||
"sources": {},
|
||||
"source_str": ""
|
||||
}
|
||||
state = await create_report_outline(state, config_dict)
|
||||
feedback = await get_outline_feedback(state)
|
||||
state, continue_generation = await regenerate_outline_if_needed(state, config_dict, feedback)
|
||||
if not continue_generation:
|
||||
print("根据用户指示,停止报告生成。")
|
||||
return state
|
||||
# 移除全局搜索步骤,改为在每个章节生成前单独搜索
|
||||
state = await generate_report_sections(state, config_dict)
|
||||
state = await compile_report(state)
|
||||
state = await save_report(state)
|
||||
final_report_state = ReportState(
|
||||
topic=state["topic"],
|
||||
feedback_on_report_plan=state["feedback_on_report_plan"],
|
||||
sections=state["sections"],
|
||||
thinking_process=state["thinking_process"],
|
||||
completed_sections=state["completed_sections"],
|
||||
current_section_index=state["current_section_index"],
|
||||
final_report=state["final_report"],
|
||||
search_queries=state["search_queries"],
|
||||
sources=state["sources"],
|
||||
source_str=state["source_str"]
|
||||
)
|
||||
return final_report_state
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user