ltp.train.callback.tensorboard 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
import os
from . import Callback

default_main_tag = "tensorboard"
default_save_directory = './log'


[文档]class TensorboardCallback(Callback, alias="tensorboard"): def __init__(self, iteration, epoch, save_directory=default_save_directory, main_tag=default_main_tag, data_extractor=lambda state: {"loss": state.last_train_loss}): super(TensorboardCallback, self).__init__(iteration, epoch) from torch.utils.tensorboard import SummaryWriter self.main_tag = main_tag self.save_directory = save_directory self.data_extractor = data_extractor self.writer = SummaryWriter(self.save_directory) os.makedirs(save_directory, exist_ok=True) def call(self, executor): self.__save(executor.trainer.state) def __save(self, executor_state): self.writer.add_scalars( self.main_tag, self.data_extractor(executor_state), global_step=executor_state.global_step )