上传文件至 /

This commit is contained in:
2025-08-20 20:12:32 +08:00
parent 159c4efd99
commit 30af25082d
3 changed files with 880 additions and 0 deletions

300
nodes.py Normal file
View File

@@ -0,0 +1,300 @@
from typing import Literal, List
import aiofiles
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from datetime import datetime
from langgraph.types import interrupt, Command
import os
import json
from state import (
Sections,
ReportState,
Queries,
SearchQuery,
Section
)
from prompts import (
report_planner_query_writer_instructions,
report_planner_instructions,
query_writer_instructions,
section_writer_instructions,
section_writer_system_prompt
)
from configuration import Configuration
from utils import (
compile_completed_sections,
get_config_value,
get_search_params,
select_and_execute_search,
format_sources
)
async def generate_report_plan(state: ReportState, config: RunnableConfig):
"""用于生成报告大纲,同时进行网络搜索帮助自己更好地规划大纲内容。
会自动中断等待人类反馈,若不通过则根据反馈重新生成大纲。直到通过并跳转到章节生成部分
"""
topic = state["topic"]
feedback = state.get("feedback_on_report_plan", None)
configurable = Configuration.from_runnable_config(config)
number_of_queries = configurable.number_of_queries
search_api = get_config_value(configurable.search_api)
search_api_config = configurable.search_api_config or {} # Get the config dict, default to empty
params_to_pass = get_search_params(search_api, search_api_config) # Filter parameters
writer_provider = get_config_value(configurable.writer_provider)
writer_model_name = get_config_value(configurable.writer_model)
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
structured_llm = writer_model.with_structured_output(Queries)
# 根据主题先生成查询Prompt让模型生成一系列查询等待API搜索从而获取信息
query_prompt = report_planner_query_writer_instructions.format(
topic=topic, number_of_queries=number_of_queries)
results = structured_llm.invoke(query_prompt)
query_list: list[str] = [query.search_query for query in results.queries]
# TODO:这里和search_web重复了需要做一个封装节省代码
unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass)
source_str = format_sources(unique_sources)
# 将搜索到的内容作为上下文提供给了模型
sections_prompt = report_planner_instructions.format(
topic=topic, context=source_str, feedback=feedback)
planner_provider = get_config_value(configurable.planner_provider)
planner_model = get_config_value(configurable.planner_model)
planner_llm = init_chat_model(model=planner_model,
model_provider=planner_provider)
structured_llm = planner_llm.with_structured_output(Sections)
report_sections = structured_llm.invoke(sections_prompt)
# 获取写好的sections
sections: list[Section] = report_sections.sections
return {"sections": sections, "current_section_index": 0}
def human_feedback(state: ReportState, config: RunnableConfig) -> Command[Literal["generate_report_plan", "think_and_generate_queries"]]:
"""
获取人类反馈来修改大纲若通过则进入到ReAct部分。
"""
sections = state['sections']
sections_str = "\n\n".join(
f"{number}. {section.name}\n"
f"{section.description}\n"
f"**是否需要搜索研究: {'' if section.research else ''}**\n"
for number, section in enumerate(sections, 1)
)
interrupt_message = f"""请为这份大纲提供修改意见:
{sections_str}
这份大纲符合您的需求吗输入“true”以通过大纲或者提供修改意见来修改大纲。"""
feedback = interrupt(interrupt_message)
if isinstance(feedback, bool) and feedback is True:
return Command(goto="think_and_generate_queries")
elif isinstance(feedback, str):
return Command(goto="generate_report_plan",
update={"feedback_on_report_plan": feedback})
else:
raise TypeError(f"Interrupt value of type {type(feedback)} is not supported.")
def think_and_generate_queries(state: ReportState, config: RunnableConfig):
"""Think部分
思考目前内容,生成查询语句,同时记录思考过程
"""
current_section_index = state["current_section_index"]
section = state["sections"][current_section_index]
section_name = section.name
configurable = Configuration.from_runnable_config(config)
number_of_queries = configurable.number_of_queries
writer_provider = get_config_value(configurable.writer_provider)
writer_model_name = get_config_value(configurable.writer_model)
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
structured_llm = writer_model.with_structured_output(Queries)
system_instructions = query_writer_instructions.format(section=section_name,
section_description=section.description,
number_of_queries=number_of_queries)
queries: Queries = structured_llm.invoke(system_instructions)
search_queries: List[SearchQuery] = queries.queries
return {"search_queries": {section_name: search_queries},
"thinking_process": [queries.thought],
}
async def search_web(state: ReportState, config: RunnableConfig):
"""Acting部分进行搜索并返回搜索内容
"""
current_section_index = state["current_section_index"]
section = state["sections"][current_section_index]
section_name = section.name
search_queries = state["search_queries"][section_name]
configurable = Configuration.from_runnable_config(config)
search_api = get_config_value(configurable.search_api)
search_api_config = configurable.search_api_config or {}
params_to_pass = get_search_params(search_api, search_api_config)
query_list = [query.search_query for query in search_queries]
print(query_list)
unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass)
current_sources = {
url: {
"title": source['title'],
"content": source["content"]
} for url, source in unique_sources.items()
}
all_sources = state.get("sources", {}).copy()
all_sources.update(current_sources)
source_str = format_sources(unique_sources)
return {"source_str": source_str, "sources": all_sources}
async def write_section(state: ReportState, config: RunnableConfig, stream_callback=None) -> Command[Literal["think_and_generate_queries", "compile_final_report"]]:
"""完成一个章节的内容
"""
current_section_index = state["current_section_index"]
section = state["sections"][current_section_index]
section_name = section.name
source_str = state["source_str"]
configurable = Configuration.from_runnable_config(config)
section_content = compile_completed_sections(state)
section_writer_inputs = section_writer_instructions.format(
section_name=section_name,
section_description=section.description,
context=source_str,
section_content=section_content
)
# Generate section
writer_provider = get_config_value(configurable.writer_provider)
writer_model_name = get_config_value(configurable.writer_model)
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
prompt = [SystemMessage(content=section_writer_system_prompt),
HumanMessage(content=section_writer_inputs)]
writer_result = writer_model.astream(prompt)
content_parts = []
async for chunk in writer_result:
content_parts.append(chunk.content)
if stream_callback:
await stream_callback(chunk.content)
else:
print(chunk.content, end='', flush=True)
section.content = ''.join(content_parts)
if current_section_index == len(state["sections"]) - 1:
return Command(
update={"completed_sections": [section]},
goto="compile_final_report"
)
else:
return Command(
update={"completed_sections": [section], "current_section_index": current_section_index + 1},
goto="think_and_generate_queries"
)
async def compile_final_report(state: ReportState):
topic = state.get("topic", "未命名报告")
sections = state["completed_sections"]
final_report = f"# {topic}\n\n" + "\n\n".join([section.content for section in sections])
return {"final_report": final_report}
async def save_final_report(state: ReportState):
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# 定义报告保存的基础目录
base_reports_dir = r"C:\Users\24019\Desktop\Deepresearch-stream\reports"
output_dir = os.path.join(base_reports_dir, f"report_{timestamp}")
os.makedirs(output_dir, exist_ok=True)
async def write_file(filename: str, content: str):
async with aiofiles.open(os.path.join(output_dir, filename), 'w', encoding='utf-8') as f:
await f.write(content)
await write_file("report.md", state["final_report"])
await write_file("sources.json", json.dumps(state["sources"], indent=2, ensure_ascii=False))
def process_citations(state: ReportState):
"""处理引用
替换正文中的 [url:<link>] → [1](只保留首次引用)...
多次引用同一链接,后续位置直接删除,不重复编号
若某个 URL 不在 sources 中,则忽略
"""
from collections import OrderedDict
import re
url_to_index = OrderedDict()
pattern = r"\[url:(https?://[^\]]+)\]"
matches = list(re.finditer(pattern, state["final_report"]))
url_title_map = state["sources"]
index = 1
try:
for match in matches:
url = match.group(1)
if url not in url_title_map:
continue
if url not in url_to_index:
url_to_index[url] = index
index += 1
# 标记哪些URL已经在正文中被替换过后面都删掉
replaced_urls = set()
def replacer(match):
url = match.group(1)
if url not in url_title_map:
return "" # 无效的URL删掉
if url in url_to_index and url not in replaced_urls:
replaced_urls.add(url)
return f"[{url_to_index[url]}]" # 首次替换
else:
return ""
processed_text = re.sub(pattern, replacer, state["final_report"])
citation_lines = []
for url, idx in url_to_index.items():
title = url_title_map[url]["title"]
citation_lines.append(f"[{idx}] [{title}]({url})")
citation_list = "\n".join(citation_lines)
final_report = processed_text + "\n\n## 参考链接:\n" + citation_list
except Exception as e:
final_report = state["final_report"]
return {"final_report": final_report}