Files
Deepresearch/nodes.py
2025-08-20 20:12:32 +08:00

300 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Literal, List
import aiofiles
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from datetime import datetime
from langgraph.types import interrupt, Command
import os
import json
from state import (
Sections,
ReportState,
Queries,
SearchQuery,
Section
)
from prompts import (
report_planner_query_writer_instructions,
report_planner_instructions,
query_writer_instructions,
section_writer_instructions,
section_writer_system_prompt
)
from configuration import Configuration
from utils import (
compile_completed_sections,
get_config_value,
get_search_params,
select_and_execute_search,
format_sources
)
async def generate_report_plan(state: ReportState, config: RunnableConfig):
"""用于生成报告大纲,同时进行网络搜索帮助自己更好地规划大纲内容。
会自动中断等待人类反馈,若不通过则根据反馈重新生成大纲。直到通过并跳转到章节生成部分
"""
topic = state["topic"]
feedback = state.get("feedback_on_report_plan", None)
configurable = Configuration.from_runnable_config(config)
number_of_queries = configurable.number_of_queries
search_api = get_config_value(configurable.search_api)
search_api_config = configurable.search_api_config or {} # Get the config dict, default to empty
params_to_pass = get_search_params(search_api, search_api_config) # Filter parameters
writer_provider = get_config_value(configurable.writer_provider)
writer_model_name = get_config_value(configurable.writer_model)
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
structured_llm = writer_model.with_structured_output(Queries)
# 根据主题先生成查询Prompt让模型生成一系列查询等待API搜索从而获取信息
query_prompt = report_planner_query_writer_instructions.format(
topic=topic, number_of_queries=number_of_queries)
results = structured_llm.invoke(query_prompt)
query_list: list[str] = [query.search_query for query in results.queries]
# TODO:这里和search_web重复了需要做一个封装节省代码
unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass)
source_str = format_sources(unique_sources)
# 将搜索到的内容作为上下文提供给了模型
sections_prompt = report_planner_instructions.format(
topic=topic, context=source_str, feedback=feedback)
planner_provider = get_config_value(configurable.planner_provider)
planner_model = get_config_value(configurable.planner_model)
planner_llm = init_chat_model(model=planner_model,
model_provider=planner_provider)
structured_llm = planner_llm.with_structured_output(Sections)
report_sections = structured_llm.invoke(sections_prompt)
# 获取写好的sections
sections: list[Section] = report_sections.sections
return {"sections": sections, "current_section_index": 0}
def human_feedback(state: ReportState, config: RunnableConfig) -> Command[Literal["generate_report_plan", "think_and_generate_queries"]]:
"""
获取人类反馈来修改大纲若通过则进入到ReAct部分。
"""
sections = state['sections']
sections_str = "\n\n".join(
f"{number}. {section.name}\n"
f"{section.description}\n"
f"**是否需要搜索研究: {'' if section.research else ''}**\n"
for number, section in enumerate(sections, 1)
)
interrupt_message = f"""请为这份大纲提供修改意见:
{sections_str}
这份大纲符合您的需求吗输入“true”以通过大纲或者提供修改意见来修改大纲。"""
feedback = interrupt(interrupt_message)
if isinstance(feedback, bool) and feedback is True:
return Command(goto="think_and_generate_queries")
elif isinstance(feedback, str):
return Command(goto="generate_report_plan",
update={"feedback_on_report_plan": feedback})
else:
raise TypeError(f"Interrupt value of type {type(feedback)} is not supported.")
def think_and_generate_queries(state: ReportState, config: RunnableConfig):
"""Think部分
思考目前内容,生成查询语句,同时记录思考过程
"""
current_section_index = state["current_section_index"]
section = state["sections"][current_section_index]
section_name = section.name
configurable = Configuration.from_runnable_config(config)
number_of_queries = configurable.number_of_queries
writer_provider = get_config_value(configurable.writer_provider)
writer_model_name = get_config_value(configurable.writer_model)
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
structured_llm = writer_model.with_structured_output(Queries)
system_instructions = query_writer_instructions.format(section=section_name,
section_description=section.description,
number_of_queries=number_of_queries)
queries: Queries = structured_llm.invoke(system_instructions)
search_queries: List[SearchQuery] = queries.queries
return {"search_queries": {section_name: search_queries},
"thinking_process": [queries.thought],
}
async def search_web(state: ReportState, config: RunnableConfig):
"""Acting部分进行搜索并返回搜索内容
"""
current_section_index = state["current_section_index"]
section = state["sections"][current_section_index]
section_name = section.name
search_queries = state["search_queries"][section_name]
configurable = Configuration.from_runnable_config(config)
search_api = get_config_value(configurable.search_api)
search_api_config = configurable.search_api_config or {}
params_to_pass = get_search_params(search_api, search_api_config)
query_list = [query.search_query for query in search_queries]
print(query_list)
unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass)
current_sources = {
url: {
"title": source['title'],
"content": source["content"]
} for url, source in unique_sources.items()
}
all_sources = state.get("sources", {}).copy()
all_sources.update(current_sources)
source_str = format_sources(unique_sources)
return {"source_str": source_str, "sources": all_sources}
async def write_section(state: ReportState, config: RunnableConfig, stream_callback=None) -> Command[Literal["think_and_generate_queries", "compile_final_report"]]:
"""完成一个章节的内容
"""
current_section_index = state["current_section_index"]
section = state["sections"][current_section_index]
section_name = section.name
source_str = state["source_str"]
configurable = Configuration.from_runnable_config(config)
section_content = compile_completed_sections(state)
section_writer_inputs = section_writer_instructions.format(
section_name=section_name,
section_description=section.description,
context=source_str,
section_content=section_content
)
# Generate section
writer_provider = get_config_value(configurable.writer_provider)
writer_model_name = get_config_value(configurable.writer_model)
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
prompt = [SystemMessage(content=section_writer_system_prompt),
HumanMessage(content=section_writer_inputs)]
writer_result = writer_model.astream(prompt)
content_parts = []
async for chunk in writer_result:
content_parts.append(chunk.content)
if stream_callback:
await stream_callback(chunk.content)
else:
print(chunk.content, end='', flush=True)
section.content = ''.join(content_parts)
if current_section_index == len(state["sections"]) - 1:
return Command(
update={"completed_sections": [section]},
goto="compile_final_report"
)
else:
return Command(
update={"completed_sections": [section], "current_section_index": current_section_index + 1},
goto="think_and_generate_queries"
)
async def compile_final_report(state: ReportState):
topic = state.get("topic", "未命名报告")
sections = state["completed_sections"]
final_report = f"# {topic}\n\n" + "\n\n".join([section.content for section in sections])
return {"final_report": final_report}
async def save_final_report(state: ReportState):
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# 定义报告保存的基础目录
base_reports_dir = r"C:\Users\24019\Desktop\Deepresearch-stream\reports"
output_dir = os.path.join(base_reports_dir, f"report_{timestamp}")
os.makedirs(output_dir, exist_ok=True)
async def write_file(filename: str, content: str):
async with aiofiles.open(os.path.join(output_dir, filename), 'w', encoding='utf-8') as f:
await f.write(content)
await write_file("report.md", state["final_report"])
await write_file("sources.json", json.dumps(state["sources"], indent=2, ensure_ascii=False))
def process_citations(state: ReportState):
"""处理引用
替换正文中的 [url:<link>] → [1](只保留首次引用)...
多次引用同一链接,后续位置直接删除,不重复编号
若某个 URL 不在 sources 中,则忽略
"""
from collections import OrderedDict
import re
url_to_index = OrderedDict()
pattern = r"\[url:(https?://[^\]]+)\]"
matches = list(re.finditer(pattern, state["final_report"]))
url_title_map = state["sources"]
index = 1
try:
for match in matches:
url = match.group(1)
if url not in url_title_map:
continue
if url not in url_to_index:
url_to_index[url] = index
index += 1
# 标记哪些URL已经在正文中被替换过后面都删掉
replaced_urls = set()
def replacer(match):
url = match.group(1)
if url not in url_title_map:
return "" # 无效的URL删掉
if url in url_to_index and url not in replaced_urls:
replaced_urls.add(url)
return f"[{url_to_index[url]}]" # 首次替换
else:
return ""
processed_text = re.sub(pattern, replacer, state["final_report"])
citation_lines = []
for url, idx in url_to_index.items():
title = url_title_map[url]["title"]
citation_lines.append(f"[{idx}] [{title}]({url})")
citation_list = "\n".join(citation_lines)
final_report = processed_text + "\n\n## 参考链接:\n" + citation_list
except Exception as e:
final_report = state["final_report"]
return {"final_report": final_report}