上传文件至 /
This commit is contained in:
271
main.py
Normal file
271
main.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from state import ReportState, Section
|
||||||
|
from configuration import SearchAPI
|
||||||
|
from nodes import (
|
||||||
|
generate_report_plan,
|
||||||
|
think_and_generate_queries,
|
||||||
|
search_web,
|
||||||
|
write_section,
|
||||||
|
compile_final_report,
|
||||||
|
process_citations,
|
||||||
|
save_final_report
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置环境变量 - API密钥
|
||||||
|
os.environ["DEEPSEEK_API_KEY"] = "XXX"
|
||||||
|
os.environ["TAVILY_API_KEY"] = "XXX"
|
||||||
|
|
||||||
|
# 使用字典作为配置
|
||||||
|
config_dict = {
|
||||||
|
"configurable": {
|
||||||
|
"number_of_queries": 2,
|
||||||
|
"planner_provider": "deepseek",
|
||||||
|
"planner_model": "deepseek-chat",
|
||||||
|
"writer_provider": "deepseek",
|
||||||
|
"writer_model": "deepseek-chat",
|
||||||
|
"search_api": SearchAPI.TAVILY.value,
|
||||||
|
"search_api_config": {"api_key": os.environ.get("TAVILY_API_KEY", "")},
|
||||||
|
"max_tokens": 4096
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_topic():
|
||||||
|
print("请输入您想要研究的报告主题 (直接回车使用默认主题: 人工智能在医疗领域的应用)")
|
||||||
|
topic = input("> ").strip()
|
||||||
|
if not topic:
|
||||||
|
topic = "人工智能在医疗领域的应用"
|
||||||
|
print(f"使用默认主题: {topic}")
|
||||||
|
return topic
|
||||||
|
|
||||||
|
|
||||||
|
async def create_report_outline(state, config_dict):
|
||||||
|
report_plan_result = await generate_report_plan(state, config_dict)
|
||||||
|
if "sections" in report_plan_result and report_plan_result["sections"]:
|
||||||
|
sections = report_plan_result["sections"]
|
||||||
|
print(f"共 {len(sections)} 个章节:")
|
||||||
|
for i, section in enumerate(sections):
|
||||||
|
print(f" {i+1}. {section.name}")
|
||||||
|
print(f" 描述: {section.description}")
|
||||||
|
print(f" 需要研究: {'是' if section.research else '否'}")
|
||||||
|
print()
|
||||||
|
state.update(report_plan_result)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
async def get_outline_feedback(state):
|
||||||
|
sections = state.get("sections", [])
|
||||||
|
print("\n请提供您的反馈意见:")
|
||||||
|
print("1. 输入 'accept' 表示接受大纲并继续生成报告")
|
||||||
|
print("2. 输入 'regenerate' 表示重新生成大纲")
|
||||||
|
print("3. 输入具体的修改建议,例如 '增加关于AI伦理的章节' 或 '第2章应该更详细'")
|
||||||
|
feedback = input("> ")
|
||||||
|
return feedback
|
||||||
|
|
||||||
|
|
||||||
|
async def regenerate_outline_if_needed(state, config_dict, feedback, max_attempts=3, current_attempt=0):
|
||||||
|
if current_attempt >= max_attempts:
|
||||||
|
print(f"已达到最大重试次数({max_attempts}次),将使用当前大纲继续。")
|
||||||
|
return state, True
|
||||||
|
state["feedback_on_report_plan"] = feedback
|
||||||
|
if feedback.lower() == 'accept':
|
||||||
|
print("大纲已接受,将继续生成报告内容。")
|
||||||
|
return state, True
|
||||||
|
if feedback.lower() == 'regenerate':
|
||||||
|
print("正在重新生成大纲...")
|
||||||
|
state = await create_report_outline(state, config_dict)
|
||||||
|
new_feedback = await get_outline_feedback(state)
|
||||||
|
return await regenerate_outline_if_needed(state, config_dict, new_feedback, max_attempts, current_attempt + 1)
|
||||||
|
print(f"收到反馈: {feedback}")
|
||||||
|
state = await create_report_outline(state, config_dict)
|
||||||
|
new_feedback = await get_outline_feedback(state)
|
||||||
|
return await regenerate_outline_if_needed(state, config_dict, new_feedback, max_attempts, current_attempt + 1)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_report_sections(state, config_dict):
|
||||||
|
completed_sections = []
|
||||||
|
total_sections = len(state["sections"])
|
||||||
|
while state["current_section_index"] < total_sections:
|
||||||
|
current_index = state["current_section_index"]
|
||||||
|
current_section = state["sections"][current_index]
|
||||||
|
print(
|
||||||
|
f"\n生成章节 {current_index + 1}/{total_sections}: {current_section.name}")
|
||||||
|
print(f"章节描述: {current_section.description}")
|
||||||
|
|
||||||
|
# 为当前章节生成搜索查询
|
||||||
|
print(f"为章节 '{current_section.name}' 思考中...")
|
||||||
|
if asyncio.iscoroutinefunction(think_and_generate_queries):
|
||||||
|
think_result = await think_and_generate_queries(state, config_dict)
|
||||||
|
else:
|
||||||
|
think_result = think_and_generate_queries(state, config_dict)
|
||||||
|
if "thinking_process" in think_result and think_result["thinking_process"]:
|
||||||
|
print("\n思考过程:")
|
||||||
|
for thought in enumerate(think_result["thinking_process"]):
|
||||||
|
print(f"{thought}")
|
||||||
|
state.update(think_result)
|
||||||
|
|
||||||
|
# 执行search_web搜索
|
||||||
|
print(f"为章节 '{current_section.name}' 搜索中...")
|
||||||
|
search_result = await search_web(state, config_dict)
|
||||||
|
state.update(search_result)
|
||||||
|
print(f"搜索完成,获取到 {len(state.get('sources', {}))} 条资料")
|
||||||
|
print(state["source_str"])
|
||||||
|
|
||||||
|
# 生成章节内容
|
||||||
|
print("正在生成内容...")
|
||||||
|
start_time = time.time()
|
||||||
|
if asyncio.iscoroutinefunction(write_section):
|
||||||
|
write_result = await write_section(state, config_dict)
|
||||||
|
else:
|
||||||
|
write_result = write_section(state, config_dict)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"章节生成耗时: {end_time - start_time:.2f} 秒")
|
||||||
|
print(f"章节处理结果类型: {type(write_result)}")
|
||||||
|
has_update = False
|
||||||
|
has_goto = False
|
||||||
|
goto_node = None
|
||||||
|
|
||||||
|
if isinstance(write_result, dict):
|
||||||
|
print("结果是普通字典,直接更新状态")
|
||||||
|
state.update(write_result)
|
||||||
|
if 'completed_sections' in write_result:
|
||||||
|
section_content = write_result['completed_sections'][0]
|
||||||
|
completed_sections.append(section_content)
|
||||||
|
print(
|
||||||
|
f"章节内容完成,内容长度: {len(getattr(section_content, 'content', '')) if hasattr(section_content, 'content') else 0}")
|
||||||
|
content = getattr(section_content, 'content', "")
|
||||||
|
print(f"\n章节内容:\n{content}\n")
|
||||||
|
state["current_section_index"] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if hasattr(write_result, 'update'):
|
||||||
|
has_update = True
|
||||||
|
update_data = write_result.update
|
||||||
|
for key, value in update_data.items():
|
||||||
|
state[key] = value
|
||||||
|
if 'completed_sections' in update_data:
|
||||||
|
section_content = update_data['completed_sections'][0]
|
||||||
|
completed_sections.append(section_content)
|
||||||
|
print(
|
||||||
|
f"章节内容完成,内容长度: {len(getattr(section_content, 'content', '')) if hasattr(section_content, 'content') else 0}")
|
||||||
|
content = getattr(section_content, 'content', "")
|
||||||
|
print(f"\n章节内容:\n{content}\n")
|
||||||
|
|
||||||
|
if hasattr(write_result, 'goto'):
|
||||||
|
has_goto = True
|
||||||
|
goto_node = write_result.goto
|
||||||
|
if has_goto:
|
||||||
|
if goto_node == "compile_final_report":
|
||||||
|
print("所有章节完成,跳出循环")
|
||||||
|
break
|
||||||
|
elif not has_update:
|
||||||
|
print("结果是简单值,直接作为当前章节内容")
|
||||||
|
if current_section and hasattr(current_section, 'content'):
|
||||||
|
current_section.content = str(write_result)
|
||||||
|
completed_sections.append(current_section)
|
||||||
|
content = current_section.content
|
||||||
|
print(f"\n章节内容:\n{content}\n")
|
||||||
|
state["current_section_index"] += 1
|
||||||
|
state["completed_sections"] = completed_sections
|
||||||
|
print(f"\n完成所有章节 ({len(completed_sections)}/{total_sections})")
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
async def compile_report(state):
|
||||||
|
compile_result = await compile_final_report(state)
|
||||||
|
state.update(compile_result)
|
||||||
|
print(f"报告编译完成,总字数: {len(state.get('final_report', ''))}")
|
||||||
|
print("\n正在处理引用和引文...")
|
||||||
|
if asyncio.iscoroutinefunction(process_citations):
|
||||||
|
citations_result = await process_citations(state)
|
||||||
|
else:
|
||||||
|
citations_result = process_citations(state)
|
||||||
|
state.update(citations_result)
|
||||||
|
print("引用处理完成")
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
async def save_report(state):
|
||||||
|
topic_slug = state["topic"].lower().replace(" ", "_")[:30]
|
||||||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
output_file = f"report_{topic_slug}_{timestamp}.md"
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(state["final_report"])
|
||||||
|
print(f"报告已保存到文件: {output_file}")
|
||||||
|
metadata_file = f"metadata_{topic_slug}_{timestamp}.json"
|
||||||
|
metadata = {
|
||||||
|
"topic": state["topic"],
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"sections_count": len(state["sections"]),
|
||||||
|
"completed_sections_count": len(state["completed_sections"]),
|
||||||
|
"word_count": len(state["final_report"].split()),
|
||||||
|
"character_count": len(state["final_report"]),
|
||||||
|
"report_file": output_file,
|
||||||
|
"sections": [
|
||||||
|
{
|
||||||
|
"name": section.name,
|
||||||
|
"description": section.description,
|
||||||
|
"content_length": len(getattr(section, 'content', '')) if hasattr(section, 'content') else 0
|
||||||
|
}
|
||||||
|
for section in state["sections"]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
with open(metadata_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"报告元数据已保存到文件: {metadata_file}")
|
||||||
|
await save_final_report(state)
|
||||||
|
preview_length = 500
|
||||||
|
preview = state["final_report"][:preview_length] + "..." if len(
|
||||||
|
state["final_report"]) > preview_length else state["final_report"]
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print(f"报告预览 (前 {preview_length} 个字符):")
|
||||||
|
print(preview)
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"完整报告已保存到: {output_file}")
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
topic = await get_user_topic()
|
||||||
|
state = {
|
||||||
|
"topic": topic,
|
||||||
|
"feedback_on_report_plan": "",
|
||||||
|
"sections": [],
|
||||||
|
"thinking_process": [],
|
||||||
|
"completed_sections": [],
|
||||||
|
"current_section_index": 0,
|
||||||
|
"final_report": "",
|
||||||
|
"search_queries": {},
|
||||||
|
"sources": {},
|
||||||
|
"source_str": ""
|
||||||
|
}
|
||||||
|
state = await create_report_outline(state, config_dict)
|
||||||
|
feedback = await get_outline_feedback(state)
|
||||||
|
state, continue_generation = await regenerate_outline_if_needed(state, config_dict, feedback)
|
||||||
|
if not continue_generation:
|
||||||
|
print("根据用户指示,停止报告生成。")
|
||||||
|
return state
|
||||||
|
# 移除全局搜索步骤,改为在每个章节生成前单独搜索
|
||||||
|
state = await generate_report_sections(state, config_dict)
|
||||||
|
state = await compile_report(state)
|
||||||
|
state = await save_report(state)
|
||||||
|
final_report_state = ReportState(
|
||||||
|
topic=state["topic"],
|
||||||
|
feedback_on_report_plan=state["feedback_on_report_plan"],
|
||||||
|
sections=state["sections"],
|
||||||
|
thinking_process=state["thinking_process"],
|
||||||
|
completed_sections=state["completed_sections"],
|
||||||
|
current_section_index=state["current_section_index"],
|
||||||
|
final_report=state["final_report"],
|
||||||
|
search_queries=state["search_queries"],
|
||||||
|
sources=state["sources"],
|
||||||
|
source_str=state["source_str"]
|
||||||
|
)
|
||||||
|
return final_report_state
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Reference in New Issue
Block a user