300 lines
11 KiB
Python
300 lines
11 KiB
Python
from typing import Literal, List
|
||
import aiofiles
|
||
from langchain.chat_models import init_chat_model
|
||
from langchain_core.messages import HumanMessage, SystemMessage
|
||
from langchain_core.runnables import RunnableConfig
|
||
from datetime import datetime
|
||
from langgraph.types import interrupt, Command
|
||
import os
|
||
import json
|
||
from state import (
|
||
Sections,
|
||
ReportState,
|
||
Queries,
|
||
SearchQuery,
|
||
Section
|
||
)
|
||
|
||
from prompts import (
|
||
report_planner_query_writer_instructions,
|
||
report_planner_instructions,
|
||
query_writer_instructions,
|
||
section_writer_instructions,
|
||
section_writer_system_prompt
|
||
)
|
||
|
||
from configuration import Configuration
|
||
from utils import (
|
||
compile_completed_sections,
|
||
get_config_value,
|
||
get_search_params,
|
||
select_and_execute_search,
|
||
format_sources
|
||
)
|
||
|
||
async def generate_report_plan(state: ReportState, config: RunnableConfig):
|
||
"""用于生成报告大纲,同时进行网络搜索帮助自己更好地规划大纲内容。
|
||
会自动中断等待人类反馈,若不通过则根据反馈重新生成大纲。直到通过并跳转到章节生成部分
|
||
"""
|
||
|
||
topic = state["topic"]
|
||
feedback = state.get("feedback_on_report_plan", None)
|
||
|
||
configurable = Configuration.from_runnable_config(config)
|
||
number_of_queries = configurable.number_of_queries
|
||
search_api = get_config_value(configurable.search_api)
|
||
search_api_config = configurable.search_api_config or {} # Get the config dict, default to empty
|
||
params_to_pass = get_search_params(search_api, search_api_config) # Filter parameters
|
||
writer_provider = get_config_value(configurable.writer_provider)
|
||
writer_model_name = get_config_value(configurable.writer_model)
|
||
|
||
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
|
||
structured_llm = writer_model.with_structured_output(Queries)
|
||
|
||
# 根据主题先生成查询Prompt,让模型生成一系列查询等待API搜索,从而获取信息
|
||
query_prompt = report_planner_query_writer_instructions.format(
|
||
topic=topic, number_of_queries=number_of_queries)
|
||
|
||
results = structured_llm.invoke(query_prompt)
|
||
|
||
query_list: list[str] = [query.search_query for query in results.queries]
|
||
|
||
# TODO:这里和search_web重复了,需要做一个封装节省代码
|
||
unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass)
|
||
source_str = format_sources(unique_sources)
|
||
|
||
# 将搜索到的内容作为上下文提供给了模型
|
||
sections_prompt = report_planner_instructions.format(
|
||
topic=topic, context=source_str, feedback=feedback)
|
||
|
||
planner_provider = get_config_value(configurable.planner_provider)
|
||
planner_model = get_config_value(configurable.planner_model)
|
||
planner_llm = init_chat_model(model=planner_model,
|
||
model_provider=planner_provider)
|
||
|
||
structured_llm = planner_llm.with_structured_output(Sections)
|
||
report_sections = structured_llm.invoke(sections_prompt)
|
||
|
||
# 获取写好的sections
|
||
sections: list[Section] = report_sections.sections
|
||
return {"sections": sections, "current_section_index": 0}
|
||
|
||
|
||
|
||
def human_feedback(state: ReportState, config: RunnableConfig) -> Command[Literal["generate_report_plan", "think_and_generate_queries"]]:
|
||
"""
|
||
获取人类反馈来修改大纲,若通过则进入到ReAct部分。
|
||
"""
|
||
|
||
sections = state['sections']
|
||
sections_str = "\n\n".join(
|
||
f"{number}. {section.name}\n"
|
||
f"{section.description}\n"
|
||
f"**是否需要搜索研究: {'是' if section.research else '否'}**\n"
|
||
for number, section in enumerate(sections, 1)
|
||
)
|
||
|
||
interrupt_message = f"""请为这份大纲提供修改意见:
|
||
{sections_str}
|
||
这份大纲符合您的需求吗?,输入“true”以通过大纲,或者提供修改意见来修改大纲。"""
|
||
|
||
feedback = interrupt(interrupt_message)
|
||
|
||
if isinstance(feedback, bool) and feedback is True:
|
||
return Command(goto="think_and_generate_queries")
|
||
elif isinstance(feedback, str):
|
||
return Command(goto="generate_report_plan",
|
||
update={"feedback_on_report_plan": feedback})
|
||
else:
|
||
raise TypeError(f"Interrupt value of type {type(feedback)} is not supported.")
|
||
|
||
|
||
|
||
|
||
def think_and_generate_queries(state: ReportState, config: RunnableConfig):
|
||
"""Think部分
|
||
思考目前内容,生成查询语句,同时记录思考过程
|
||
"""
|
||
|
||
current_section_index = state["current_section_index"]
|
||
section = state["sections"][current_section_index]
|
||
section_name = section.name
|
||
configurable = Configuration.from_runnable_config(config)
|
||
number_of_queries = configurable.number_of_queries
|
||
|
||
writer_provider = get_config_value(configurable.writer_provider)
|
||
writer_model_name = get_config_value(configurable.writer_model)
|
||
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
|
||
structured_llm = writer_model.with_structured_output(Queries)
|
||
|
||
system_instructions = query_writer_instructions.format(section=section_name,
|
||
section_description=section.description,
|
||
number_of_queries=number_of_queries)
|
||
|
||
queries: Queries = structured_llm.invoke(system_instructions)
|
||
search_queries: List[SearchQuery] = queries.queries
|
||
return {"search_queries": {section_name: search_queries},
|
||
"thinking_process": [queries.thought],
|
||
}
|
||
|
||
|
||
async def search_web(state: ReportState, config: RunnableConfig):
|
||
"""Acting部分,进行搜索并返回搜索内容
|
||
"""
|
||
current_section_index = state["current_section_index"]
|
||
section = state["sections"][current_section_index]
|
||
section_name = section.name
|
||
search_queries = state["search_queries"][section_name]
|
||
|
||
configurable = Configuration.from_runnable_config(config)
|
||
search_api = get_config_value(configurable.search_api)
|
||
search_api_config = configurable.search_api_config or {}
|
||
params_to_pass = get_search_params(search_api, search_api_config)
|
||
query_list = [query.search_query for query in search_queries]
|
||
print(query_list)
|
||
|
||
unique_sources = await select_and_execute_search(search_api, query_list, params_to_pass)
|
||
|
||
current_sources = {
|
||
url: {
|
||
"title": source['title'],
|
||
"content": source["content"]
|
||
} for url, source in unique_sources.items()
|
||
}
|
||
|
||
all_sources = state.get("sources", {}).copy()
|
||
all_sources.update(current_sources)
|
||
|
||
source_str = format_sources(unique_sources)
|
||
|
||
return {"source_str": source_str, "sources": all_sources}
|
||
|
||
|
||
|
||
|
||
async def write_section(state: ReportState, config: RunnableConfig, stream_callback=None) -> Command[Literal["think_and_generate_queries", "compile_final_report"]]:
|
||
"""完成一个章节的内容
|
||
"""
|
||
|
||
current_section_index = state["current_section_index"]
|
||
section = state["sections"][current_section_index]
|
||
section_name = section.name
|
||
source_str = state["source_str"]
|
||
configurable = Configuration.from_runnable_config(config)
|
||
section_content = compile_completed_sections(state)
|
||
section_writer_inputs = section_writer_instructions.format(
|
||
section_name=section_name,
|
||
section_description=section.description,
|
||
context=source_str,
|
||
section_content=section_content
|
||
)
|
||
|
||
# Generate section
|
||
writer_provider = get_config_value(configurable.writer_provider)
|
||
writer_model_name = get_config_value(configurable.writer_model)
|
||
writer_model = init_chat_model(model=writer_model_name, model_provider=writer_provider)
|
||
prompt = [SystemMessage(content=section_writer_system_prompt),
|
||
HumanMessage(content=section_writer_inputs)]
|
||
|
||
writer_result = writer_model.astream(prompt)
|
||
content_parts = []
|
||
|
||
async for chunk in writer_result:
|
||
content_parts.append(chunk.content)
|
||
if stream_callback:
|
||
await stream_callback(chunk.content)
|
||
else:
|
||
print(chunk.content, end='', flush=True)
|
||
|
||
section.content = ''.join(content_parts)
|
||
|
||
|
||
|
||
if current_section_index == len(state["sections"]) - 1:
|
||
return Command(
|
||
update={"completed_sections": [section]},
|
||
goto="compile_final_report"
|
||
)
|
||
else:
|
||
return Command(
|
||
update={"completed_sections": [section], "current_section_index": current_section_index + 1},
|
||
goto="think_and_generate_queries"
|
||
)
|
||
|
||
|
||
|
||
|
||
async def compile_final_report(state: ReportState):
|
||
topic = state.get("topic", "未命名报告")
|
||
sections = state["completed_sections"]
|
||
final_report = f"# {topic}\n\n" + "\n\n".join([section.content for section in sections])
|
||
return {"final_report": final_report}
|
||
|
||
|
||
async def save_final_report(state: ReportState):
|
||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||
|
||
# 定义报告保存的基础目录
|
||
base_reports_dir = r"C:\Users\24019\Desktop\Deepresearch-stream\reports"
|
||
output_dir = os.path.join(base_reports_dir, f"report_{timestamp}")
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
async def write_file(filename: str, content: str):
|
||
async with aiofiles.open(os.path.join(output_dir, filename), 'w', encoding='utf-8') as f:
|
||
await f.write(content)
|
||
|
||
await write_file("report.md", state["final_report"])
|
||
await write_file("sources.json", json.dumps(state["sources"], indent=2, ensure_ascii=False))
|
||
|
||
|
||
def process_citations(state: ReportState):
|
||
"""处理引用
|
||
替换正文中的 [url:<link>] → [1](只保留首次引用)...
|
||
多次引用同一链接,后续位置直接删除,不重复编号
|
||
若某个 URL 不在 sources 中,则忽略
|
||
"""
|
||
from collections import OrderedDict
|
||
import re
|
||
|
||
url_to_index = OrderedDict()
|
||
pattern = r"\[url:(https?://[^\]]+)\]"
|
||
matches = list(re.finditer(pattern, state["final_report"]))
|
||
|
||
url_title_map = state["sources"]
|
||
index = 1
|
||
|
||
try:
|
||
for match in matches:
|
||
url = match.group(1)
|
||
if url not in url_title_map:
|
||
continue
|
||
if url not in url_to_index:
|
||
url_to_index[url] = index
|
||
index += 1
|
||
|
||
# 标记哪些URL已经在正文中被替换过,后面都删掉
|
||
replaced_urls = set()
|
||
|
||
def replacer(match):
|
||
url = match.group(1)
|
||
if url not in url_title_map:
|
||
return "" # 无效的URL,删掉
|
||
if url in url_to_index and url not in replaced_urls:
|
||
replaced_urls.add(url)
|
||
return f"[{url_to_index[url]}]" # 首次替换
|
||
else:
|
||
return ""
|
||
|
||
processed_text = re.sub(pattern, replacer, state["final_report"])
|
||
|
||
citation_lines = []
|
||
for url, idx in url_to_index.items():
|
||
title = url_title_map[url]["title"]
|
||
citation_lines.append(f"[{idx}] [{title}]({url})")
|
||
|
||
citation_list = "\n".join(citation_lines)
|
||
final_report = processed_text + "\n\n## 参考链接:\n" + citation_list
|
||
except Exception as e:
|
||
final_report = state["final_report"]
|
||
|
||
return {"final_report": final_report} |