模型¶
模型¶
-
class
Model
[源代码]¶ 基类:
torch.nn.modules.module.Module
模型基础类,从本类继承的模型自动注册
-
classmethod
by_name
(name)¶ 通过注册的名字取得实际的类型
- 参数
name – 注册的名字
- 返回
Type[T] 使用 name 注册的子类
- 返回类型
class
- 引发
RegistrationError – 如果 name 未被注册
-
classmethod
from_params
(__config, *args, extra=None, **kwargs)¶ 使用 config 生成对象
- 参数
cls – 类型
__config – 配置项,通常为字典,形如 {‘class’:’ClassName’, ‘init’:{ ‘arg’: arg } }
*args – 直接传入的arg
extra – 根据需要传入的数据
**kwargs – 其他参数
- 返回
根据参数生成的对象
-
classmethod
hook
(hook)¶ 函数装饰器,给某个类注册装饰器
-
classmethod
is_registered
(name)¶ 如果 name 在类中已经注册,则返回 True
-
classmethod
iter_registered
()¶ 迭代已经注册的名字和对象
-
classmethod
list_available
()¶ 列出所有的注册子类
-
classmethod
register
(name, override=False, hooks=None)¶ 装饰器 Class decorator for registering a subclass.
- 参数
name – 注册名
override (bool) – 当name已经注册时,是否进行覆盖
hooks (List[HookType]) – 在注册时会被执行的Hook函数
- 引发
RegistrationError – 如果 override 为 false 并且 name 已经被注册
-
classmethod
weak_register
(name, subclass, override=False, hooks=None)¶ 用于手动对子类进行注册
- 参数
name (str) – 子类的引用名
subclass – 子类类型
override (bool) – 当name已经注册时,是否进行覆盖
hooks – 在注册时会被执行的Hook函数
- 引发
RegistrationError – 如果 override 为 false 并且 name 已经被注册
-
classmethod
序列标注模型¶
-
class
SequenceTaggingModel
[源代码]¶ -
基本序列标注模型,封装了各种解码器,将来可能要进行解耦,如果使用crf解码器,需要换用支持 CRF 的 Trainer
使用 init_decoder 函数初始化解码器
- 参数
decoder – [None=Linear,lan=lan decoder] 默认是简单线性层分类,目前支持 lan decoder
hidden_size – decoder = lan 时必填, lan hidden size
num_heads – decoder = lan 时使用, lan 多头注意力模型 heads,默认为 5
num_layers – decoder = lan 时使用, lan decoder 层数,默认为 3
lan – decoder = lan 时使用,lan decoder 其他参数
arc_hidden_size – decoder = biaffine 时必填
rel_hidden_size – decoder = biaffine 时必填
rel_num – decoder = biaffine 时默认为label_num,rel 数目
bias – decoder = Linear 时,传入Linear层
-
class
SimpleTaggingModel
(label_num, pretrained=None, config=None, dropout=0.1, word_base=False, use_cls=False, freeze=False, **kwargs)[源代码]¶ 基类:
ltp.models.seq_tag_model.SequenceTaggingModel
基本序列标注模型
- 参数
pretrained – 预训练模型路径或名称,参照 huggingface/transformers
config – 预训练模型路径或名称,参照 huggingface/transformers
label_num – 分类标签数目
dropout – pretrained Embedding dropout,默认0.1
word_base – 是否是以词为基础的模型,如果是,输入时需要传入 word index
decoder – [None=Linear,lan=lan decoder] 默认是简单线性层分类,目前支持 lan decoder
hidden_size – decoder = lan 时必填, lan hidden size
num_heads – decoder = lan 时使用, lan 多头注意力模型 heads,默认为 5
num_layers – decoder = lan 时使用, lan decoder 层数,,默认为 3
lan – decoder = lan 时使用,lan decoder 其他参数
arc_hidden_size – decoder = graph 时必填
rel_hidden_size – decoder = graph 时必填
rel_num – decoder = graph 时默认为label_num,rel 数目
bias – decoder = Linear 时,传入Linear层
多任务模型¶
-
class
MultiTaskModel
[源代码]¶ -
基本多任务序列标注模型
-
create_decoder
(input_size, label_num, dropout=0.1, **kwargs)[源代码]¶ 封装了各种解码器 :param decoder: [None=Linear,lan=lan decoder,crf] 默认是简单线性层分类,目前支持 lan decoder
- 参数
hidden_size – decoder = lan 时必填, lan hidden size
num_heads – decoder = lan 时使用, lan 多头注意力模型 heads,默认为 5
num_layers – decoder = lan 时使用, lan decoder 层数,默认为 3
lan – decoder = lan 时使用,lan decoder 其他参数
arc_hidden_size – decoder = biaffine 时必填
rel_hidden_size – decoder = biaffine 时必填
rel_num – decoder = biaffine 时必填,rel 数目
bias – decoder = Linear 时,传入Linear层
-