AI大模型实战篇:LATS,可能是目前最强的AI Agent设计框架

9 人赞同了该文章

前言

在上篇文章《AI大模型实战篇:Reflexion,通过强化学习提升模型推理能力》中,风叔结合原理和具体源代码,详细介绍了Reflexion这种本质是强化学习的AI Agent设计模式。Reflexion已经算是一种非常高级的设计框架,在解决很多复杂问题时,也能有比较好的表现。

在这篇文章中,风叔将为大家介绍可能是目前最强大的AI Agent设计框架,集多种规划和反思技术的集大成者,LATS。文章内容会相对比较复杂难懂,值得收藏和反复研读。



LATS的概念

LATS,全称是Language Agent Tree Search,说的更直白一些,LATS = Tree search + ReAct + Plan&Execute + Reflexion。这么来看,LATS确实非常高级和复杂,下面我们根据上面的等式,先从宏观上拆解一下LATS。


1. Tree Search

Tree Search是一种树搜索算法,LATS 使用蒙特卡罗树搜索(MCTS)算法,通过平衡探索和利用,找到最优决策路径。

蒙特卡罗方法可能大家都比较熟悉了,是一种通过随机采样模拟来求解问题的方法。通过生成随机数,建立概率模型,以解决难以通过其他方法解决的数值问题。蒙特卡罗方法的一个典型应用是求定积分。假设我们要计算函数 f(x) 在[a, b]之间的积分,即阴影部分面积。



蒙特卡罗方法的解法如下:在[a, b]之间取一个随机数 x,用 f(x)⋅(b−a) 来估计阴影部分的面积。为了提高估计精度,可以取多个随机数 x,然后取这些估计值的平均值作为最终结果。当取的随机数 x 越多,结果将越准确,估计值将越接近真实值。




蒙特卡罗树搜索(MCTS)则是一种基于树结构的蒙特卡罗方法。它在整个 2^N(N 为决策次数,即树深度)空间中进行启发式搜索,通过反馈机制寻找最优路径。MCTS 的五个主要核心部分是:

  1. 树结构:每一个叶子节点到根节点的路径都对应一个解,解空间大小为 2^N。
  2. 蒙特卡罗方法:通过随机统计方法获取观测结果,驱动搜索过程。
  3. 损失评估函数:设计一个可量化的损失函数,提供反馈评估解的优劣。
  4. 反向传播线性优化:采用反向传播对路径上的所有节点进行优化。
  5. 启发式搜索策略:遵循损失最小化原则,在整个搜索空间上进行启发式搜索。


MCTS 的每个循环包括四个步骤:

  1. 选择(Selection):从根节点开始,按照最大化某种启发式价值选择子节点,直到到达叶子节点。使用上置信区间算法(UCB)选择子节点。
  2. 扩展(Expansion):如果叶子节点不是终止节点,扩展该节点,添加一个或多个子节点。
  3. 仿真(Simulation):从新扩展的节点开始,进行随机模拟,直到到达终止状态。
  4. 反向传播(Backpropagation):将模拟结果沿着路径反向传播,更新每个节点的统计信息。



2. ReAct

ReAct的概念和设计模式,风叔在此前的文章中《AI大模型实战篇:AI Agent设计模式 - ReAct》已做过详细介绍。

它的典型流程如下图所示,可以用一个有趣的循环来描述:思考(Thought)→ 行动(Action)→ 观察(Observation),简称TAO循环。

  • 思考(Thought):面对一个问题,我们需要进行深入的思考。这个思考过程是关于如何定义问题、确定解决问题所需的关键信息和推理步骤。
  • 行动(Action):确定了思考的方向后,接下来就是行动的时刻。根据我们的思考,采取相应的措施或执行特定的任务,以期望推动问题向解决的方向发展。
  • 观察(Observation):行动之后,我们必须仔细观察结果。这一步是检验我们的行动是否有效,是否接近了问题的答案。
  • 循环迭代




3. Plan & Execute

Plan & Execute的概念和设计模式,风叔同样在此前的文章中《AI大模型实战篇:AI Agent设计模式 - Plan & Execute》已做过详细介绍,因此不再赘述。

Plan-and-Execute这个方法的本质是先计划再执行,即先把用户的问题分解成一个个的子任务,然后再执行各个子任务,并根据执行情况调整计划。



4. Reflexion

Reflexion的概念和设计模式,风叔在上篇文章《AI大模型实战篇:Reflexion,通过强化学习提升模型推理能力》做了详细介绍。

Reflexion的本质是Basic Reflection加上强化学习,完整的Reflexion框架由三个部分组成:

  • 参与者(Actor):根据状态观测量生成文本和动作。参与者在环境中采取行动并接受观察结果,从而形成轨迹。前文所介绍的Reflexion Agent,其实指的就是这一块
  • 评估者(Evaluator):对参与者的输出进行评价。具体来说,它将生成的轨迹(也被称作短期记忆)作为输入并输出奖励分数。根据人物的不同,使用不同的奖励函数(决策任务使用LLM和基于规则的启发式奖励)。
  • 自我反思(Self-Reflection):这个角色由大语言模型承担,能够为未来的试验提供宝贵的反馈。自我反思模型利用奖励信号、当前轨迹和其持久记忆生成具体且相关的反馈,并存储在记忆组件中。智能体利用这些经验(存储在长期记忆中)来快速改进决策。


因此,融合了Tree Search、ReAct、Plan & Execute、Reflexion的能力于一身之后,LATS成为AI Agent设计模式中,集反思模式和规划模式的大成者。


LATS的工作流程

LATS的工作流程如下图所示,包括以下步骤:

  1. 选择 (Selection):即从根节点开始,使用上置信区树 (UCT) 算法选择具有最高 UCT 值的子节点进行扩展。
  2. 扩展 (Expansion):通过从预训练语言模型 (LM) 中采样 n 个动作扩展树,接收每个动作并返回反馈,然后增加 n 个新的子节点。
  3. 评估 (Evaluation):为每个新子节点分配一个标量值,以指导搜索算法前进,LATS 通过 LM 生成的评分和自一致性得分设计新的价值函数。
  4. 模拟 (Simulation):扩展当前选择的节点直到达到终端状态,优先选择最高价值的节点。
  5. 回溯 (Backpropagation):根据轨迹结果更新树的值,路径中的每个节点的值被更新以反映模拟结果。
  6. 反思 (Reflection):在遇到不成功的终端节点时,LM 生成自我反思,总结过程中的错误并提出改进方案。这些反思和失败轨迹在后续迭代中作为额外上下文整合,帮助提高模型的表现。



下图是在langchain中实现LATS的过程:

第一步,选择:根据下面步骤中的总奖励选择最佳的下一步行动,如果找到解决方案或达到最大搜索深度,做出响应;否则就继续搜索。

第二步,扩展和执行:生成N个潜在操作,并且并行执行。

第三步,反思和评估:观察行动的结果,并根据反思和外部反馈对决策评分。

第四步,反向传播:根据结果更新轨迹的分数。




LATS的实现过程

下面,风叔通过实际的源码,详细介绍LATS模式的实现方法,具体的源代码地址可以在文末获取。


第一步 构建树节点

LATS 基于蒙特卡罗树搜索。对于每个搜索步骤,它都会选择具有最高“置信上限”的节点,这是一个平衡开发(最高平均奖励)和探索(最低访问量)的指标。从该节点开始,它会生成 N(在本例中为 5)个新的候选操作,并将它们添加到树中。当它生成有效解决方案或达到最大次数(搜索树深度)时,会停止搜索。

在Node节点中,我们定义了几个关键的函数:

  • best_child:选择 UCT 最高的子项进行下一步搜索
  • best_child_score:返回具有最高价值的子项
  • height:检查已经推进的树的深度
  • upper_confidence_bound:返回 UCT 分数,平衡分支的探索与利用
  • backpropogate:利用反向传播,更新此节点及其父节点的分数
  • get_trajectory:获取代表此搜索分支的消息
  • get_best_solution:返回当前子树中的最佳解决方案
import math
from collections import deque
from typing import Optional
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage

class Node:
    def __init__(
        self,
        messages: list[BaseMessage],
        reflection: Reflection,
        parent: Optional[Node] = None,
    ):
        self.messages = messages
        self.parent = parent
        self.children = []
        self.value = 0
        self.visits = 0
        self.reflection = reflection
        self.depth = parent.depth + 1 if parent is not None else 1
        self._is_solved = reflection.found_solution if reflection else False
        if self._is_solved:
            self._mark_tree_as_solved()
        self.backpropagate(reflection.normalized_score)

    def __repr__(self) -> str:
        return (
            f"<Node value={self.value}, visits={self.visits},"
            f" solution={self.messages} reflection={self.reflection}/>"
        )

    @property
    def is_solved(self):
        """If any solutions exist, we can end the search."""
        return self._is_solved

    @property
    def is_terminal(self):
        return not self.children

    @property
    def best_child(self):
        """Select the child with the highest UCT to search next."""
        if not self.children:
            return None
        all_nodes = self._get_all_children()
        return max(all_nodes, key=lambda child: child.upper_confidence_bound())

    @property
    def best_child_score(self):
        """Return the child with the highest value."""
        if not self.children:
            return None
        return max(self.children, key=lambda child: int(child.is_solved) * child.value)

    @property
    def height(self) -> int:
        """Check for how far we've rolled out the tree."""
        if self.children:
            return 1 + max([child.height for child in self.children])
        return 1

    def upper_confidence_bound(self, exploration_weight=1.0):
        """Return the UCT score. This helps balance exploration vs. exploitation of a branch."""
        if self.parent is None:
            raise ValueError("Cannot obtain UCT from root node")
        if self.visits == 0:
            return self.value
        # Encourages exploitation of high-value trajectories
        average_reward = self.value / self.visits
        # Encourages exploration of less-visited trajectories
        exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
        return average_reward + exploration_weight * exploration_term

    def backpropagate(self, reward: float):
        """Update the score of this node and its parents."""
        node = self
        while node:
            node.visits += 1
            node.value = (node.value * (node.visits - 1) + reward) / node.visits
            node = node.parent

    def get_messages(self, include_reflections: bool = True):
        if include_reflections:
            return self.messages + [self.reflection.as_message()]
        return self.messages

    def get_trajectory(self, include_reflections: bool = True) -> list[BaseMessage]:
        """Get messages representing this search branch."""
        messages = []
        node = self
        while node:
            messages.extend(
                node.get_messages(include_reflections=include_reflections)[::-1]
            )
            node = node.parent
        # Reverse the final back-tracked trajectory to return in the correct order
        return messages[::-1]  # root solution, reflection, child 1, ...

    def _get_all_children(self):
        all_nodes = []
        nodes = deque()
        nodes.append(self)
        while nodes:
            node = nodes.popleft()
            all_nodes.extend(node.children)
            for n in node.children:
                nodes.append(n)
        return all_nodes

    def get_best_solution(self):
        """Return the best solution from within the current sub-tree."""
        all_nodes = [self] + self._get_all_children()
        best_node = max(
            all_nodes,
            # We filter out all non-terminal, non-solution trajectories
            key=lambda node: int(node.is_terminal and node.is_solved) * node.value,
        )
        return best_node

    def _mark_tree_as_solved(self):
        parent = self.parent
        while parent:
            parent._is_solved = True
            parent = parent.parent

第二步 构建Agent

Agent将主要处理三个事项:

  • 反思:根据工具执行响应的结果打分
  • 初始响应:创建根节点,并开始搜索
  • 扩展:从当前树中的最佳位置,生成5个候选的下一步

对于更多实际的应用,比如代码生成,可以将代码执行结果集成到反馈或奖励中,这种外部反馈对Agent效果的提升将非常有用。

对于Agent,首先构建工具Tools,我们只使用了一个搜索引擎工具。

from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation

search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)
tools = [tavily_tool]
tool_executor = ToolExecutor(tools=tools)

然后,构建反射系统,反射系统将根据决策和工具使用结果,对Agent的输出进行打分,我们将在其他两个节点中调用此方法。

class Reflection(BaseModel):
    reflections: str = Field(
        description="The critique and reflections on the sufficiency, superfluency,"
        " and general quality of the response"
    )
    score: int = Field(
        description="Score from 0-10 on the quality of the candidate response.",
        gte=0,
        lte=10,
    )
    found_solution: bool = Field(
        description="Whether the response has fully solved the question or task."
    )

    def as_message(self):
        return HumanMessage(
            content=f"Reasoning: {self.reflections}\nScore: {self.score}"
        )

    @property
    def normalized_score(self) -> float:
        return self.score / 10.0

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Reflect and grade the assistant response to the user question below.",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="candidate"),
    ]
)

reflection_llm_chain = (
    prompt
    | llm.bind_tools(tools=[Reflection], tool_choice="Reflection").with_config(
        run_name="Reflection"
    )
    | PydanticToolsParser(tools=[Reflection])
)

@as_runnable
def reflection_chain(inputs) -> Reflection:
    tool_choices = reflection_llm_chain.invoke(inputs)
    reflection = tool_choices[0]
    if not isinstance(inputs["candidate"][-1], AIMessage):
        reflection.found_solution = False
    return reflection


接下来,我们从根节点开始,根据用户输入进行响应

from langchain_core.prompt_values import ChatPromptValue
from langchain_core.runnables import RunnableConfig

prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are an AI assistant.",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="messages", optional=True),
    ]
)

initial_answer_chain = prompt_template | llm.bind_tools(tools=tools).with_config(
    run_name="GenerateInitialCandidate"
)

parser = JsonOutputToolsParser(return_id=True)
initial_response = initial_answer_chain.invoke(
    {"input": "Write a research report on lithium pollution."}
)


然后开始根节点,我们将候选节点生成和reflection打包到单个节点中。

import json

# Define the node we will add to the graph
def generate_initial_response(state: TreeState) -> dict:
    """Generate the initial candidate response."""
    res = initial_answer_chain.invoke({"input": state["input"]})
    parsed = parser.invoke(res)
    tool_responses = tool_executor.batch(
        [ToolInvocation(tool=r["type"], tool_input=r["args"]) for r in parsed]
    )
    output_messages = [res] + [
        ToolMessage(content=json.dumps(resp), tool_call_id=tool_call["id"])
        for resp, tool_call in zip(tool_responses, parsed)
    ]
    reflection = reflection_chain.invoke(
        {"input": state["input"], "candidate": output_messages}
    )
    root = Node(output_messages, reflection=reflection)
    return {
        **state,
        "root": root,
    }


第三步 生成候选节点

对于每个节点,生成5个待探索的候选节点。

def generate_candidates(messages: ChatPromptValue, config: RunnableConfig):
    n = config["configurable"].get("N", 5)
    bound_kwargs = llm.bind_tools(tools=tools).kwargs
    chat_result = llm.generate(
        [messages.to_messages()],
        n=n,
        callbacks=config["callbacks"],
        run_name="GenerateCandidates",
        **bound_kwargs,
    )
    return [gen.message for gen in chat_result.generations[0]]

expansion_chain = prompt_template | generate_candidates
res = expansion_chain.invoke({"input": "Write a research report on lithium pollution."})

将候选节点生成和refleciton步骤打包在下面的扩展节点中,所有操作都以批处理的方式进行,以加快执行速度。

from collections import defaultdict

def expand(state: TreeState, config: RunnableConfig) -> dict:
    """Starting from the "best" node in the tree, generate N candidates for the next step."""
    root = state["root"]
    best_candidate: Node = root.best_child if root.children else root
    messages = best_candidate.get_trajectory()
    # Generate N candidates from the single child candidate
    new_candidates = expansion_chain.invoke(
        {"input": state["input"], "messages": messages}, config
    )
    parsed = parser.batch(new_candidates)
    flattened = [
        (i, tool_call)
        for i, tool_calls in enumerate(parsed)
        for tool_call in tool_calls
    ]
    tool_responses = tool_executor.batch(
        [
            ToolInvocation(tool=tool_call["type"], tool_input=tool_call["args"])
            for _, tool_call in flattened
        ]
    )
    collected_responses = defaultdict(list)
    for (i, tool_call), resp in zip(flattened, tool_responses):
        collected_responses[i].append(
            ToolMessage(content=json.dumps(resp), tool_call_id=tool_call["id"])
        )
    output_messages = []
    for i, candidate in enumerate(new_candidates):
        output_messages.append([candidate] + collected_responses[i])

    # Reflect on each candidate
    # For tasks with external validation, you'd add that here.
    reflections = reflection_chain.batch(
        [{"input": state["input"], "candidate": msges} for msges in output_messages],
        config,
    )
    # Grow tree
    child_nodes = [
        Node(cand, parent=best_candidate, reflection=reflection)
        for cand, reflection in zip(output_messages, reflections)
    ]
    best_candidate.children.extend(child_nodes)
    # We have already extended the tree directly, so we just return the state
    return state


第四步 构建流程图

下面,我们构建流程图,将根节点和扩展节点加入进来

from typing import Literal

from langgraph.graph import END, StateGraph, START

def should_loop(state: TreeState) -> Literal["expand", "__end__"]:
    """Determine whether to continue the tree search."""
    root = state["root"]
    if root.is_solved:
        return END
    if root.height > 5:
        return END
    return "expand"


builder = StateGraph(TreeState)
builder.add_node("start", generate_initial_response)
builder.add_node("expand", expand)
builder.add_edge(START, "start")

builder.add_conditional_edges(
    "start",
    # Either expand/rollout or finish
    should_loop,
)
builder.add_conditional_edges(
    "expand",
    # Either continue to rollout or finish
    should_loop,
)

graph = builder.compile()


至此,整个LATS的核心逻辑就介绍完了。可以私信风叔,回复关键词【LATS源码】,获取LATS设计模式的完整源代码。


总结

与其他基于树的方法相比,LATS实现了自我反思的推理步骤,显著提升了性能。当采取行动后,LATS不仅利用环境反馈,还结合来自语言模型的反馈,以判断推理中是否存在错误并提出替代方案。这种自我反思的能力与其强大的搜索算法相结合,使得LATS更适合处理一些相对复杂的任务。

然而,由于算法本身的复杂性以及涉及的反思步骤,LATS通常比其他单智能体方法使用更多的计算资源,并且完成任务所需的时间更长。


后记

这篇文章之后,整个《AI大模型实战篇》系列就全部介绍完了,这个系列一共包括八篇文章,从最经典的ReAct模式开始,沿着规划路线介绍了REWOO、Plan&Execute和LLM Compiler,沿着反思路线介绍了Basic Reflection、Self Discover和Reflexion,并以最强大的设计模式LATS作为收尾。整个系列基本上包含了目前AI大模型和AI Agent的全部主流设计框架,后续如果有新的前沿设计模式和具体案例,风叔还会零星做一些介绍。

但是,所有的这些设计模式,都只是在告诉AI Agent应该如何规划和思考,且只能依赖于大模型既有的知识储备。而实际应用中,我们往往更希望AI Agent结合我们给定的知识和信息,在更专业的垂直领域内进行规划和思考。

比如我们希望Agent帮我们做论文分析、书籍总结,或者在企业级场景中,让AI Agent写营销计划、内部知识问答、智能客服等等非常多的场景,只靠上面几种Agent设计模式是远远不够的,我们必须给大模型外挂知识库,并且通过工作流进一步约束和规范Agent的思考方向和行为模式。

解决这个问题的最佳方式是利用Rag技术,接下来我们正式开启《Rag系统实战篇》系列。在后续的几篇文章中,风叔将同样结合应用场景和源代码,详细介绍Rag系统的实现方式和优化技巧。

对于还不太了解Rag的读者,可以先参考这两篇文章进行预习。

聊聊炙手可热的Rag:产生原因、基本原理与实施路径

Rag系统的发展历程,从朴素、高级到模块化


更多精彩文章:

AI大模型实战篇:AI Agent设计模式 - ReAct

AI大模型实战篇:AI Agent设计模式 - REWOO

AI大模型实战篇:AI Agent设计模式 - Plan & Execute

AI大模型实战篇:AI Agent设计模式 - LLM Compiler,Agent的并行处理器

AI大模型实战篇:Basic Reflection,AI Agent的左右互搏之术

AI大模型实战篇:Self Discover框架,万万想不到Agent还能这样推理

AI大模型实战篇:Reflexion,通过强化学习提升模型推理能力