从零开始认识langchain(十)定制中文聊天模型
西西嘛呦:从零开始认识langchain(一)初识langchain
西西嘛呦:从零开始认识langchain(二)组件-数据连接(data connection)
西西嘛呦:从零开始认识langchain(三)组件-模型IO(model I/O)
西西嘛呦:从零开始认识langchain(四)组件-链(chains)
西西嘛呦:从零开始认识langchain(五)组件-代理(agents)
西西嘛呦:从零开始认识langchain(六)组件-内存(memory)
西西嘛呦:从零开始认识langchain(七)组件-回调(callbacks)
西西嘛呦:从零开始认识langchain(八)在langchain中使用中文模型
西西嘛呦:从零开始认识langchain(九)基于openai的聊天模型背后原理
你也可以从github上获取相关代码:
Part1前言
之前我们已经讲过在langchain中怎么使用中文模型:
https://zhuanlan.zhihu.com/p/641631547
也讲过langchain中使用基于openai的聊天模型的原理:
https://zhuanlan.zhihu.com/p/641823532
还未了解的可以先去看看。本文将讲解的是如何在langchain中使用中文的聊天模型,它和中文模型不太一样。
Part2定制中文聊天模型
首先我们得看下cpm-bee-1b是怎么使用的:
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("openbmb/cpm-bee-1b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("openbmb/cpm-bee-1b", trust_remote_code=True).cuda()
result = model.generate({"input": "今天天气真", "<ans>": ""}, tokenizer)
print(result)
result = model.generate({"input": "今天天气真不错", "question": "今天天气怎么样?", "<ans>": ""}, tokenizer)
print(result)
输入是一个字典,而且有两种方式:
- 带有question,根据input的内容进行回答。
- 不带有question,根据input继续生成文本。
另外,还有一个键,生成的结果会存到它的值里面。上述结果:
[{'input': '今天天气真', '<ans>': '今天天气真好'}]
[{'input': '今天天气真不错', 'question': '今天天气怎么样?', '<ans>': '好'}]
要在langchain使用中文模型,我们要继承langchain.chat_models.base中的SimpleChatModel类,它位于from langchain.chat_models.base import SimpleChatModel
,然后重写_llm_type、_call、_identifying_params方法。
- _llm_type:用于标识模型名称
- _call:里面实现推理逻辑,既可以是原生的模型推理,也可以是接口。(这个是必须的)
- _identifying_params:用于帮助我们打印类的一些属性。
接下来看完整代码:
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
class ModelLoader:
def __init__(self, model_name_or_path):
self.model_name_or_path = model_name_or_path
self.model, self.tokenizer = self.load()
def load(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("openbmb/cpm-bee-1b", trust_remote_code=True).cuda()
return model, tokenizer
modelLoader = ModelLoader("openbmb/cpm-bee-1b")
from typing import Any, List, Mapping, Optional
from langchain.llms.base import LLM
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import HumanMessage
import asyncio
import inspect
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Mapping, Optional, Sequence
from pydantic import Field, root_validator
import langchain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.load.dump import dumpd, dumps
from langchain.schema import (
ChatGeneration,
ChatResult,
LLMResult,
PromptValue,
RunInfo,
)
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain.chat_models.base import BaseChatModel
class ChatCpmBee(SimpleChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Simpler interface."""
generations = []
for mes in messages:
prompt = json.loads(mes.content)
tokenizer = modelLoader.tokenizer
model = modelLoader.model
result = model.generate(prompt, tokenizer)
if len(result) == 1:
return result[0]["<ans>"]
return "对不起,我没有理解你的意思"
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)
@property
def _identifying_params(self) -> Mapping[str, Any]:
params_dict = {
"test": "这里是参数字典",
}
return params_dict
@property
def _llm_type(self) -> str:
return "chat-cpm-bee-1B"
prompt = {"input": "今天天气真不错", "question": "今天天气怎么样?", "<ans>": ""}
messages = [
HumanMessage(content=json.dumps(prompt, ensure_ascii=False))
]
chatCpmBee = ChatCpmBee()
print(chatCpmBee)
chatCpmBee(messages)
"""
cache=None verbose=False callbacks=None callback_manager=None tags=None metadata=None
AIMessage(content='好', additional_kwargs={}, example=False)
"""
需要注意,这里我们要传入的是一个message的列表,什么是message呢?message用于标识传过来的角色是什么,有AIMessage, BaseMessage, HumanMessage等,你也可以自定义角色。
- AIMessage:模型返回的结果
- HumanMessage:我们传入的文本消息
- SystemMessage:前提消息,比如:你是一个有用的助手,接下来是你的任务描述。
这里我们模拟传入一个消息,在_call
中最后会返回一个字符串。基于openai的聊天模型比这个简单的示例要复杂,具体可以看之前所讲解的。
最后看到,返回给我们的是一个AIMessage,输出的content为好。正好是我们的输出。
与普通的中文模型相比。有以下不同:
- 普通的中文模型继承LLM,而聊天模型继承SimpleChatModel。
- 聊天模型传入的是一个消息列表,而不是prompt(提示),输出的也是消息,而不是直接的字符串。
相同之处是都要实现_call
方法。
到这里,你已经基本了解在langchain中怎么构建中文聊天模型了。