从零开始认识langchain(八)在langchain中使用中文模型
6 人赞同了该文章
西西嘛呦:从零开始认识langchain(一)初识langchain
西西嘛呦:从零开始认识langchain(二)组件-数据连接(data connection)
西西嘛呦:从零开始认识langchain(三)组件-模型IO(model I/O)
西西嘛呦:从零开始认识langchain(四)组件-链(chains)
西西嘛呦:从零开始认识langchain(五)组件-代理(agents)
西西嘛呦:从零开始认识langchain(六)组件-内存(memory)
西西嘛呦:从零开始认识langchain(七)组件-回调(callbacks)
你也可以从github上获取相关代码:
Part1前言
目前langchain都是基于openai的模型进行的,本文将讲解下怎么定制化使用中文的模型。为了方便起见,这里使用的模型为cpm-bee-1b。
Part2定制中文模型
首先我们得看下cpm-bee-1b是怎么使用的:
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("openbmb/cpm-bee-1b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("openbmb/cpm-bee-1b", trust_remote_code=True).cuda()
result = model.generate({"input": "今天天气真", "<ans>": ""}, tokenizer)
print(result)
result = model.generate({"input": "今天天气真不错", "question": "今天天气怎么样?", "<ans>": ""}, tokenizer)
print(result)
输入是一个字典,而且有两种方式:
- 带有question,根据input的内容进行回答。
- 不带有question,根据input继续生成文本。
另外,还有一个键,生成的结果会存到它的值里面。上述结果:
[{'input': '今天天气真', '<ans>': '今天天气真好'}]
[{'input': '今天天气真不错', 'question': '今天天气怎么样?', '<ans>': '好'}]
要在langchain使用中文模型,我们要继承langchain中的LLM类,它位于from langchain.llms.base import LLM
,然后重写_llm_type、_call、_identifying_params方法。
- _llm_type:用于标识模型名称
- _call:里面实现推理逻辑,既可以是原生的模型推理,也可以是接口。
- _identifying_params:用于帮助我们打印类的一些属性。
接下来看完整代码:
# 使用langchain加载中文模型
# 继承LLM,并重写_llm_type、_call、_identifying_params方法
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
class ModelLoader:
def __init__(self, model_name_or_path):
self.model_name_or_path = model_name_or_path
self.model, self.tokenizer = self.load()
def load(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("openbmb/cpm-bee-1b", trust_remote_code=True).cuda()
return model, tokenizer
modelLoader = ModelLoader("openbmb/cpm-bee-1b")
from typing import Any, List, Mapping, Optional
from langchain.llms.base import LLM
class CpmBee(LLM):
@property
def _llm_type(self) -> str:
return "cpm-bee-1B"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
print(prompt)
prompt = json.loads(prompt)
tokenizer = modelLoader.tokenizer
model = modelLoader.model
result = model.generate(prompt, tokenizer)
if len(result) == 1:
return result[0]["<ans>"]
return "对不起,我没有理解你的意思"
@property
def _identifying_params(self) -> Mapping[str, Any]:
params_dict = {
"test": "这里是参数字典",
}
return params_dict
prompt = {"input": "今天天气真不错", "question": "今天天气怎么样?", "<ans>": ""}
cpmBee = CpmBee()
print(cpmBee)
print(cpmBee(json.dumps(prompt, ensure_ascii=False)))
"""
CpmBee
Params: {'test': '这里是参数字典'}
{"input": "今天天气真不错", "question": "今天天气怎么样?", "<ans>": ""}
好
"""
我们需要注意的几点:
- prompt必须为一个字符串,而cpm-bee-1b的输入有点特殊,需要是一个字典,可能内部有对其进行转换,这里不作探讨。因此,我们在_call里面将其转换为字典。
- 输出也要是一个字符串,因此,我们从cpm-bee-1b的结果中提取结果。
Part3总结
以上虽然只是一个简单的例子,但是也足够我们完成各种传统NLP的任务了。当然,更加复杂的一些任务我们还是需要借助langchain的其它一些特性的。