模型

模型

class Model[源代码]

基类:torch.nn.modules.module.Module

模型基础类,从本类继承的模型自动注册

create_decoder(input_size, label_num, dropout=0.1, **kwargs)[源代码]

封装了各种解码器

forward(*args, **kwargs)[源代码]

前向传播函数

classmethod by_name(name)

通过注册的名字取得实际的类型

参数

name – 注册的名字

返回

Type[T] 使用 name 注册的子类

返回类型

class

引发

RegistrationError – 如果 name 未被注册

classmethod from_params(__config, *args, extra=None, **kwargs)

使用 config 生成对象

参数
  • cls – 类型

  • __config – 配置项,通常为字典,形如 {‘class’:’ClassName’, ‘init’:{ ‘arg’: arg } }

  • *args – 直接传入的arg

  • extra – 根据需要传入的数据

  • **kwargs – 其他参数

返回

根据参数生成的对象

classmethod hook(hook)

函数装饰器,给某个类注册装饰器

classmethod is_registered(name)

如果 name 在类中已经注册,则返回 True

classmethod iter_registered()

迭代已经注册的名字和对象

classmethod list_available()

列出所有的注册子类

classmethod register(name, override=False, hooks=None)

装饰器 Class decorator for registering a subclass.

参数
  • name – 注册名

  • override (bool) – 当name已经注册时,是否进行覆盖

  • hooks (List[HookType]) – 在注册时会被执行的Hook函数

引发

RegistrationError – 如果 override 为 false 并且 name 已经被注册

classmethod weak_register(name, subclass, override=False, hooks=None)

用于手动对子类进行注册

参数
  • name (str) – 子类的引用名

  • subclass – 子类类型

  • override (bool) – 当name已经注册时,是否进行覆盖

  • hooks – 在注册时会被执行的Hook函数

引发

RegistrationError – 如果 override 为 false 并且 name 已经被注册

序列标注模型

class SequenceTaggingModel[源代码]

基类:ltp.models.model.Model

基本序列标注模型,封装了各种解码器,将来可能要进行解耦,如果使用crf解码器,需要换用支持 CRF 的 Trainer

使用 init_decoder 函数初始化解码器

参数
  • decoder – [None=Linear,lan=lan decoder] 默认是简单线性层分类,目前支持 lan decoder

  • hidden_size – decoder = lan 时必填, lan hidden size

  • num_heads – decoder = lan 时使用, lan 多头注意力模型 heads,默认为 5

  • num_layers – decoder = lan 时使用, lan decoder 层数,默认为 3

  • lan – decoder = lan 时使用,lan decoder 其他参数

  • arc_hidden_size – decoder = biaffine 时必填

  • rel_hidden_size – decoder = biaffine 时必填

  • rel_num – decoder = biaffine 时默认为label_num,rel 数目

  • bias – decoder = Linear 时,传入Linear层

init_decoder(input_size, label_num, dropout=0.1, **kwargs)[源代码]

基本序列标注模型,封装了各种解码器

forward(text, gold=None)[源代码]

前向传播函数

class SimpleTaggingModel(label_num, pretrained=None, config=None, dropout=0.1, word_base=False, use_cls=False, freeze=False, **kwargs)[源代码]

基类:ltp.models.seq_tag_model.SequenceTaggingModel

基本序列标注模型

参数
  • pretrained – 预训练模型路径或名称,参照 huggingface/transformers

  • config – 预训练模型路径或名称,参照 huggingface/transformers

  • label_num – 分类标签数目

  • dropout – pretrained Embedding dropout,默认0.1

  • word_base – 是否是以词为基础的模型,如果是,输入时需要传入 word index

  • decoder – [None=Linear,lan=lan decoder] 默认是简单线性层分类,目前支持 lan decoder

  • hidden_size – decoder = lan 时必填, lan hidden size

  • num_heads – decoder = lan 时使用, lan 多头注意力模型 heads,默认为 5

  • num_layers – decoder = lan 时使用, lan decoder 层数,,默认为 3

  • lan – decoder = lan 时使用,lan decoder 其他参数

  • arc_hidden_size – decoder = graph 时必填

  • rel_hidden_size – decoder = graph 时必填

  • rel_num – decoder = graph 时默认为label_num,rel 数目

  • bias – decoder = Linear 时,传入Linear层

多任务模型

class MultiTaskModel[源代码]

基类:ltp.models.model.Model

基本多任务序列标注模型

create_decoder(input_size, label_num, dropout=0.1, **kwargs)[源代码]

封装了各种解码器 :param decoder: [None=Linear,lan=lan decoder,crf] 默认是简单线性层分类,目前支持 lan decoder

参数
  • hidden_size – decoder = lan 时必填, lan hidden size

  • num_heads – decoder = lan 时使用, lan 多头注意力模型 heads,默认为 5

  • num_layers – decoder = lan 时使用, lan decoder 层数,默认为 3

  • lan – decoder = lan 时使用,lan decoder 其他参数

  • arc_hidden_size – decoder = biaffine 时必填

  • rel_hidden_size – decoder = biaffine 时必填

  • rel_num – decoder = biaffine 时必填,rel 数目

  • bias – decoder = Linear 时,传入Linear层

class SimpleMultiTaskModel(pretrained=None, config=None, dropout=0.1, freeze=False, **kwargs)[源代码]

基类:ltp.models.multi_task_model.MultiTaskModel

forward(text, *args, **kwargs)[源代码]

前向传播函数