ltp.data.dataset.dataset 源代码

#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>

import torch.utils.data
import random
from contextlib import contextmanager
from copy import deepcopy
import os

from ltp import Registrable


[文档]class Dataset(torch.utils.data.dataset.Dataset, metaclass=Registrable): """ 数据集抽象类,在配置文件中的配置项通常为:: [Dataset] class = "CTB" [Dataset.init] path="data/postag" train="train.txt" validation="dev.txt" test="test.txt" 定义一个由 Examples 和 Fields 组合成的 Dataset Attributes: sort_key (callable): 用于在生成Batch的是时候进行排序的函数 examples (List[Example]): 数据集的 Examples fields (dict[str, Field]): Fields,同一个 Field Object 会共享它们的 Vocab """ sort_key = None def __init__(self, examples, fields, filter_pred=None): """使用examples和fields创建一个Dataset Args: examples: Examples的列表 fields: List(tuple(str, Field)) filter_pred: 只有当该函数返回值为True时,该Example才会被使用 """ if filter_pred is not None: make_list = isinstance(examples, list) examples = filter(filter_pred, examples) if make_list: examples = list(examples) self.examples = examples self.fields = dict(fields) # Unpack field tuples for n, f in list(self.fields.items()): if isinstance(n, tuple): self.fields.update(zip(n, f)) del self.fields[n]
[文档] @classmethod def splits(cls, path=None, root='.data', train=None, validation=None, test=None, **kwargs): """ 一次创建多个数据集 Args: path: 数据集路径前缀 root: 根目录 train: 训练集文件名 validation: 验证集文件名 test: 测试集文件名 **kwargs: 其他传给Dataset的参数 Returns: Tuple[Dataset] train, validation, test """ if path is None: path = cls.download(root) train_data = None if train is None else cls( os.path.join(path, train), **kwargs) val_data = None if validation is None else cls( os.path.join(path, validation), **kwargs) test_data = None if test is None else cls( os.path.join(path, test), **kwargs) return tuple(d for d in (train_data, val_data, test_data) if d is not None)
[文档] def split(self, split_ratio=0.7, stratified=False, strata_field='label', random_state=None): """通过切分当前数据集的Examples以获得多个数据集 Args: split_ratio (float/list(float)): 切分比例(默认为0.7) stratified: 是否进行分层采样 strata_field: 分层采样的field random_state: 用于Shuffle的随机种子`random.getstate()`的返回值 Returns: Tuple[Dataset] train, validation, test Dataset """ train_ratio, test_ratio, val_ratio = check_split_ratio(split_ratio) # For the permutations rnd = RandomShuffler(random_state) if not stratified: train_data, test_data, val_data = rationed_split(self.examples, train_ratio, test_ratio, val_ratio, rnd) else: if strata_field not in self.fields: raise ValueError("Invalid field name for strata_field {}" .format(strata_field)) strata = stratify(self.examples, strata_field) train_data, test_data, val_data = [], [], [] for group in strata: # Stratify each group and add together the indices. group_train, group_test, group_val = rationed_split(group, train_ratio, test_ratio, val_ratio, rnd) train_data += group_train test_data += group_test val_data += group_val splits = tuple(Dataset(d, self.fields) for d in (train_data, val_data, test_data) if d) # In case the parent sort key isn't none if self.sort_key: for subset in splits: subset.sort_key = self.sort_key return splits
def __getitem__(self, i): return self.examples[i] def __len__(self): try: return len(self.examples) except TypeError: return 2 ** 32 def __iter__(self): for x in self.examples: yield x def __getattr__(self, attr): if attr in self.fields: for x in self.examples: yield getattr(x, attr)
[文档] def filter_examples(self, field_names): """Remove unknown words from dataset examples with respect to given field. Arguments: field_names (list(str)): Within example only the parts with field names in field_names will have their unknown words deleted. """ for i, example in enumerate(self.examples): for field_name in field_names: vocab = set(self.fields[field_name].vocab.stoi) text = getattr(example, field_name) example_part = [word for word in text if word in vocab] setattr(example, field_name, example_part) self.examples[i] = example
def check_split_ratio(split_ratio): """Check that the split ratio argument is not malformed""" valid_ratio = 0. if isinstance(split_ratio, float): # Only the train set relative ratio is provided # Assert in bounds, validation size is zero assert 0. < split_ratio < 1., ( "Split ratio {} not between 0 and 1".format(split_ratio)) test_ratio = 1. - split_ratio return (split_ratio, test_ratio, valid_ratio) elif isinstance(split_ratio, list): # A list of relative ratios is provided length = len(split_ratio) assert length == 2 or length == 3, ( "Length of split ratio list should be 2 or 3, got {}".format(split_ratio)) # Normalize if necessary ratio_sum = sum(split_ratio) if not ratio_sum == 1.: split_ratio = [float(ratio) / ratio_sum for ratio in split_ratio] if length == 2: return tuple(split_ratio + [valid_ratio]) return tuple(split_ratio) else: raise ValueError('Split ratio must be float or a list, got {}' .format(type(split_ratio))) def stratify(examples, strata_field): # The field has to be hashable otherwise this doesn't work # There's two iterations over the whole dataset here, which can be # reduced to just one if a dedicated method for stratified splitting is used unique_strata = set(getattr(example, strata_field) for example in examples) strata_maps = {s: [] for s in unique_strata} for example in examples: strata_maps[getattr(example, strata_field)].append(example) return list(strata_maps.values()) def rationed_split(examples, train_ratio, test_ratio, val_ratio, rnd): """Create a random permutation of examples, then split them by ratios Arguments: examples: a list of data train_ratio, test_ratio, val_ratio: split fractions. rnd: a random shuffler Examples: >>> import ltp >>> examples = [] >>> train_ratio, test_ratio, val_ratio = 0.7, 0.2, 0.1 >>> rnd = ltp.data.dataset.RandomShuffler(None) >>> train_examples, test_examples, valid_examples = \ ltp.data.dataset.rationed_split(examples, train_ratio, test_ratio, val_ratio, rnd) """ N = len(examples) randperm = rnd(range(N)) train_len = int(round(train_ratio * N)) # Due to possible rounding problems if not val_ratio: test_len = N - train_len else: test_len = int(round(test_ratio * N)) indices = (randperm[:train_len], # Train randperm[train_len:train_len + test_len], # Test randperm[train_len + test_len:]) # Validation # There's a possibly empty list for the validation set data = tuple([examples[i] for i in index] for index in indices) return data class RandomShuffler(object): """Use random functions while keeping track of the random state to make it reproducible and deterministic.""" def __init__(self, random_state=None): self._random_state = random_state if self._random_state is None: self._random_state = random.getstate() @contextmanager def use_internal_state(self): """Use a specific RNG state.""" old_state = random.getstate() random.setstate(self._random_state) yield self._random_state = random.getstate() random.setstate(old_state) @property def random_state(self): return deepcopy(self._random_state) @random_state.setter def random_state(self, s): self._random_state = s def __call__(self, data): """Shuffle and return a new list.""" with self.use_internal_state(): return random.sample(data, len(data))