上传文件至 /
This commit is contained in:
300
nodes.py
Normal file
300
nodes.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from typing import Literal, List
|
||||
import aiofiles
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from datetime import datetime
|
||||
from langgraph.types import interrupt, Command
|
||||
import os
|
||||
import json
|
||||
from state import (
|
||||
Sections,
|
||||
ReportState,
|
||||
Queries,
|
||||
SearchQuery,
|
||||
Section
|
||||
)
|
||||
|
||||
from prompts import (
|
||||
report_planner_query_writer_instructions,
|
||||
report_planner_instructions,
|
||||
query_writer_instructions,
|
||||
section_writer_instructions,
|
||||
section_writer_system_prompt
|
||||
)
|
||||
|
||||
from configuration import Configuration
|
||||
from utils import (
|
||||
compile_completed_sections,
|
||||
get_config_value,
|
||||
get_search_params,
|
||||
select_and_execute_search,
|
||||
format_sources
|
||||
)
|
||||
|
||||
async def generate_report_plan(state: ReportState, config: RunnableConfig):
|
||||
"""用于生成报告大纲,同时进行网络搜索帮助自己更好地规划大纲内容。
|
||||
会自动中断等待人类反馈,若不通过则根据反馈重新生成大纲。直到通过并跳转到章节生成部分
|
||||
"""
|
||||
|
||||
topic = state["topic"]
|
||||
feedback = state.get("feedback_on_report_plan", None)
|
||||
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
number_of_queries = configurable.number_of_queries
|
||||
search_api = get_config_value(configurable.search_api)
|
||||
search_api_config = configurable.search_api_config or {} # Get the config dict, default to empty
|
||||
params_to_pass = get_search_params(search_api, search_api_config) # Filter parameters
|
||||
writer_provider = get_config_value(configurable.writer_provider)
|
||||
writer_model_name = get_config_value(configurable.writer_model)
|
||||
|
||||
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
|
||||
structured_llm = writer_model.with_structured_output(Queries)
|
||||
|
||||
# 根据主题先生成查询Prompt,让模型生成一系列查询等待API搜索,从而获取信息
|
||||
query_prompt = report_planner_query_writer_instructions.format(
|
||||
topic=topic, number_of_queries=number_of_queries)
|
||||
|
||||
results = structured_llm.invoke(query_prompt)
|
||||
|
||||
query_list: list[str] = [query.search_query for query in results.queries]
|
||||
|
||||
# TODO:这里和search_web重复了,需要做一个封装节省代码
|
||||
unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass)
|
||||
source_str = format_sources(unique_sources)
|
||||
|
||||
# 将搜索到的内容作为上下文提供给了模型
|
||||
sections_prompt = report_planner_instructions.format(
|
||||
topic=topic, context=source_str, feedback=feedback)
|
||||
|
||||
planner_provider = get_config_value(configurable.planner_provider)
|
||||
planner_model = get_config_value(configurable.planner_model)
|
||||
planner_llm = init_chat_model(model=planner_model,
|
||||
model_provider=planner_provider)
|
||||
|
||||
structured_llm = planner_llm.with_structured_output(Sections)
|
||||
report_sections = structured_llm.invoke(sections_prompt)
|
||||
|
||||
# 获取写好的sections
|
||||
sections: list[Section] = report_sections.sections
|
||||
return {"sections": sections, "current_section_index": 0}
|
||||
|
||||
|
||||
|
||||
def human_feedback(state: ReportState, config: RunnableConfig) -> Command[Literal["generate_report_plan", "think_and_generate_queries"]]:
|
||||
"""
|
||||
获取人类反馈来修改大纲,若通过则进入到ReAct部分。
|
||||
"""
|
||||
|
||||
sections = state['sections']
|
||||
sections_str = "\n\n".join(
|
||||
f"{number}. {section.name}\n"
|
||||
f"{section.description}\n"
|
||||
f"**是否需要搜索研究: {'是' if section.research else '否'}**\n"
|
||||
for number, section in enumerate(sections, 1)
|
||||
)
|
||||
|
||||
interrupt_message = f"""请为这份大纲提供修改意见:
|
||||
{sections_str}
|
||||
这份大纲符合您的需求吗?,输入“true”以通过大纲,或者提供修改意见来修改大纲。"""
|
||||
|
||||
feedback = interrupt(interrupt_message)
|
||||
|
||||
if isinstance(feedback, bool) and feedback is True:
|
||||
return Command(goto="think_and_generate_queries")
|
||||
elif isinstance(feedback, str):
|
||||
return Command(goto="generate_report_plan",
|
||||
update={"feedback_on_report_plan": feedback})
|
||||
else:
|
||||
raise TypeError(f"Interrupt value of type {type(feedback)} is not supported.")
|
||||
|
||||
|
||||
|
||||
|
||||
def think_and_generate_queries(state: ReportState, config: RunnableConfig):
|
||||
"""Think部分
|
||||
思考目前内容,生成查询语句,同时记录思考过程
|
||||
"""
|
||||
|
||||
current_section_index = state["current_section_index"]
|
||||
section = state["sections"][current_section_index]
|
||||
section_name = section.name
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
number_of_queries = configurable.number_of_queries
|
||||
|
||||
writer_provider = get_config_value(configurable.writer_provider)
|
||||
writer_model_name = get_config_value(configurable.writer_model)
|
||||
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
|
||||
structured_llm = writer_model.with_structured_output(Queries)
|
||||
|
||||
system_instructions = query_writer_instructions.format(section=section_name,
|
||||
section_description=section.description,
|
||||
number_of_queries=number_of_queries)
|
||||
|
||||
queries: Queries = structured_llm.invoke(system_instructions)
|
||||
search_queries: List[SearchQuery] = queries.queries
|
||||
return {"search_queries": {section_name: search_queries},
|
||||
"thinking_process": [queries.thought],
|
||||
}
|
||||
|
||||
|
||||
async def search_web(state: ReportState, config: RunnableConfig):
|
||||
"""Acting部分,进行搜索并返回搜索内容
|
||||
"""
|
||||
current_section_index = state["current_section_index"]
|
||||
section = state["sections"][current_section_index]
|
||||
section_name = section.name
|
||||
search_queries = state["search_queries"][section_name]
|
||||
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
search_api = get_config_value(configurable.search_api)
|
||||
search_api_config = configurable.search_api_config or {}
|
||||
params_to_pass = get_search_params(search_api, search_api_config)
|
||||
query_list = [query.search_query for query in search_queries]
|
||||
print(query_list)
|
||||
|
||||
unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass)
|
||||
|
||||
current_sources = {
|
||||
url: {
|
||||
"title": source['title'],
|
||||
"content": source["content"]
|
||||
} for url, source in unique_sources.items()
|
||||
}
|
||||
|
||||
all_sources = state.get("sources", {}).copy()
|
||||
all_sources.update(current_sources)
|
||||
|
||||
source_str = format_sources(unique_sources)
|
||||
|
||||
return {"source_str": source_str, "sources": all_sources}
|
||||
|
||||
|
||||
|
||||
|
||||
async def write_section(state: ReportState, config: RunnableConfig, stream_callback=None) -> Command[Literal["think_and_generate_queries", "compile_final_report"]]:
|
||||
"""完成一个章节的内容
|
||||
"""
|
||||
|
||||
current_section_index = state["current_section_index"]
|
||||
section = state["sections"][current_section_index]
|
||||
section_name = section.name
|
||||
source_str = state["source_str"]
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
section_content = compile_completed_sections(state)
|
||||
section_writer_inputs = section_writer_instructions.format(
|
||||
section_name=section_name,
|
||||
section_description=section.description,
|
||||
context=source_str,
|
||||
section_content=section_content
|
||||
)
|
||||
|
||||
# Generate section
|
||||
writer_provider = get_config_value(configurable.writer_provider)
|
||||
writer_model_name = get_config_value(configurable.writer_model)
|
||||
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
|
||||
prompt = [SystemMessage(content=section_writer_system_prompt),
|
||||
HumanMessage(content=section_writer_inputs)]
|
||||
|
||||
writer_result = writer_model.astream(prompt)
|
||||
content_parts = []
|
||||
|
||||
async for chunk in writer_result:
|
||||
content_parts.append(chunk.content)
|
||||
if stream_callback:
|
||||
await stream_callback(chunk.content)
|
||||
else:
|
||||
print(chunk.content, end='', flush=True)
|
||||
|
||||
section.content = ''.join(content_parts)
|
||||
|
||||
|
||||
|
||||
if current_section_index == len(state["sections"]) - 1:
|
||||
return Command(
|
||||
update={"completed_sections": [section]},
|
||||
goto="compile_final_report"
|
||||
)
|
||||
else:
|
||||
return Command(
|
||||
update={"completed_sections": [section], "current_section_index": current_section_index + 1},
|
||||
goto="think_and_generate_queries"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
async def compile_final_report(state: ReportState):
|
||||
topic = state.get("topic", "未命名报告")
|
||||
sections = state["completed_sections"]
|
||||
final_report = f"# {topic}\n\n" + "\n\n".join([section.content for section in sections])
|
||||
return {"final_report": final_report}
|
||||
|
||||
|
||||
async def save_final_report(state: ReportState):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
# 定义报告保存的基础目录
|
||||
base_reports_dir = r"C:\Users\24019\Desktop\Deepresearch-stream\reports"
|
||||
output_dir = os.path.join(base_reports_dir, f"report_{timestamp}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
async def write_file(filename: str, content: str):
|
||||
async with aiofiles.open(os.path.join(output_dir, filename), 'w', encoding='utf-8') as f:
|
||||
await f.write(content)
|
||||
|
||||
await write_file("report.md", state["final_report"])
|
||||
await write_file("sources.json", json.dumps(state["sources"], indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def process_citations(state: ReportState):
|
||||
"""处理引用
|
||||
替换正文中的 [url:<link>] → [1](只保留首次引用)...
|
||||
多次引用同一链接,后续位置直接删除,不重复编号
|
||||
若某个 URL 不在 sources 中,则忽略
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
import re
|
||||
|
||||
url_to_index = OrderedDict()
|
||||
pattern = r"\[url:(https?://[^\]]+)\]"
|
||||
matches = list(re.finditer(pattern, state["final_report"]))
|
||||
|
||||
url_title_map = state["sources"]
|
||||
index = 1
|
||||
|
||||
try:
|
||||
for match in matches:
|
||||
url = match.group(1)
|
||||
if url not in url_title_map:
|
||||
continue
|
||||
if url not in url_to_index:
|
||||
url_to_index[url] = index
|
||||
index += 1
|
||||
|
||||
# 标记哪些URL已经在正文中被替换过,后面都删掉
|
||||
replaced_urls = set()
|
||||
|
||||
def replacer(match):
|
||||
url = match.group(1)
|
||||
if url not in url_title_map:
|
||||
return "" # 无效的URL,删掉
|
||||
if url in url_to_index and url not in replaced_urls:
|
||||
replaced_urls.add(url)
|
||||
return f"[{url_to_index[url]}]" # 首次替换
|
||||
else:
|
||||
return ""
|
||||
|
||||
processed_text = re.sub(pattern, replacer, state["final_report"])
|
||||
|
||||
citation_lines = []
|
||||
for url, idx in url_to_index.items():
|
||||
title = url_title_map[url]["title"]
|
||||
citation_lines.append(f"[{idx}] [{title}]({url})")
|
||||
|
||||
citation_list = "\n".join(citation_lines)
|
||||
final_report = processed_text + "\n\n## 参考链接:\n" + citation_list
|
||||
except Exception as e:
|
||||
final_report = state["final_report"]
|
||||
|
||||
return {"final_report": final_report}
|
||||
Reference in New Issue
Block a user