从零开始认识langchain(八)在langchain中使用中文模型

6 人赞同了该文章

西西嘛呦:从零开始认识langchain(一)初识langchain

西西嘛呦:从零开始认识langchain(二)组件-数据连接(data connection)

西西嘛呦:从零开始认识langchain(三)组件-模型IO(model I/O)

西西嘛呦:从零开始认识langchain(四)组件-链(chains)

西西嘛呦:从零开始认识langchain(五)组件-代理(agents)

西西嘛呦:从零开始认识langchain(六)组件-内存(memory)

西西嘛呦:从零开始认识langchain(七)组件-回调(callbacks)

你也可以从github上获取相关代码:


Part1前言

目前langchain都是基于openai的模型进行的,本文将讲解下怎么定制化使用中文的模型。为了方便起见,这里使用的模型为cpm-bee-1b。

Part2定制中文模型

首先我们得看下cpm-bee-1b是怎么使用的:

from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("openbmb/cpm-bee-1b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("openbmb/cpm-bee-1b", trust_remote_code=True).cuda()
result = model.generate({"input": "今天天气真", "<ans>": ""}, tokenizer)
print(result)
result = model.generate({"input": "今天天气真不错", "question": "今天天气怎么样?", "<ans>": ""}, tokenizer)
print(result)

输入是一个字典,而且有两种方式:

  • 带有question,根据input的内容进行回答。
  • 不带有question,根据input继续生成文本。

另外,还有一个键,生成的结果会存到它的值里面。上述结果:

[{'input': '今天天气真', '<ans>': '今天天气真好'}]
[{'input': '今天天气真不错', 'question': '今天天气怎么样?', '<ans>': '好'}]  

要在langchain使用中文模型,我们要继承langchain中的LLM类,它位于from langchain.llms.base import LLM,然后重写_llm_type、_call、_identifying_params方法。

  • _llm_type:用于标识模型名称
  • _call:里面实现推理逻辑,既可以是原生的模型推理,也可以是接口。
  • _identifying_params:用于帮助我们打印类的一些属性。

接下来看完整代码:

# 使用langchain加载中文模型
# 继承LLM,并重写_llm_type、_call、_identifying_params方法
import json
from transformers import AutoModelForCausalLM, AutoTokenizer

class ModelLoader:
  def __init__(self, model_name_or_path):
    self.model_name_or_path = model_name_or_path
    self.model, self.tokenizer = self.load()
  
  def load(self):
    tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained("openbmb/cpm-bee-1b", trust_remote_code=True).cuda()
    return model, tokenizer

modelLoader = ModelLoader("openbmb/cpm-bee-1b")

from typing import Any, List, Mapping, Optional
from langchain.llms.base import LLM
class CpmBee(LLM):
  @property
  def _llm_type(self) -> str:
    return "cpm-bee-1B"

  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
    print(prompt)
    prompt = json.loads(prompt)
    tokenizer = modelLoader.tokenizer
    model = modelLoader.model
    result = model.generate(prompt, tokenizer)
    if len(result) == 1:
      return result[0]["<ans>"]
    return "对不起,我没有理解你的意思"
  
  @property
  def _identifying_params(self) -> Mapping[str, Any]:
    params_dict = {
      "test": "这里是参数字典",
    }
    return params_dict

prompt = {"input": "今天天气真不错", "question": "今天天气怎么样?", "<ans>": ""}
cpmBee = CpmBee()

print(cpmBee)

print(cpmBee(json.dumps(prompt, ensure_ascii=False)))
  
"""
CpmBee
Params: {'test': '这里是参数字典'}
{"input": "今天天气真不错", "question": "今天天气怎么样?", "<ans>": ""}
"""

我们需要注意的几点:

  • prompt必须为一个字符串,而cpm-bee-1b的输入有点特殊,需要是一个字典,可能内部有对其进行转换,这里不作探讨。因此,我们在_call里面将其转换为字典。
  • 输出也要是一个字符串,因此,我们从cpm-bee-1b的结果中提取结果。

Part3总结

以上虽然只是一个简单的例子,但是也足够我们完成各种传统NLP的任务了。当然,更加复杂的一些任务我们还是需要借助langchain的其它一些特性的。