Source code for diaparser.utils.vocab

# -*- coding: utf-8 -*-

from collections import defaultdict
from import Iterable

[docs]class Vocab(object): r""" Defines a vocabulary object that will be used to numericalize a field. Args: counter (~collections.Counter): :class:`~collections.Counter` object holding the frequencies of each value found in the data. min_freq (int): The minimum frequency needed to include a token in the vocabulary. Default: 1. specials (list[str]): The list of special tokens (e.g., pad, unk, bos and eos) that will be prepended to the vocabulary. Default: []. unk_index (int): The index of unk token. Default: 0. Attributes: itos: A list of token strings indexed by their numerical identifiers. stoi: A :class:`~collections.defaultdict` object mapping token strings to numerical identifiers. """ def __init__(self, counter, min_freq=1, specials=[], unk_index=0): self.itos = list(specials) self.stoi = defaultdict(lambda: unk_index) self.stoi.update({token: i for i, token in enumerate(self.itos)}) self.extend([token for token, freq in counter.items() if freq >= min_freq]) self.unk_index = unk_index self.n_init = len(self) def __len__(self): return len(self.itos) def __getitem__(self, key): if isinstance(key, str): return self.stoi[key] elif not isinstance(key, Iterable): return self.itos[key] elif isinstance(key[0], str): return [self.stoi[i] for i in key] else: return [self.itos[i] for i in key] def __contains__(self, token): return token in self.stoi def __getstate__(self): # avoid picking defaultdict attrs = dict(self.__dict__) # cast to regular dict attrs['stoi'] = dict(self.stoi) return attrs def __setstate__(self, state): stoi = defaultdict(lambda: self.unk_index) stoi.update(state['stoi']) state['stoi'] = stoi self.__dict__.update(state) def extend(self, tokens): self.itos.extend(sorted(set(tokens).difference(self.stoi))) self.stoi.update({token: i for i, token in enumerate(self.itos)})
class FieldVocab(dict): """Surrogate for missing vocab in certain Transformers tokenizers.""" def __init__(self, unk_token_id, items): super(FieldVocab, self).__init__(items) self.unk_token_id = unk_token_id def __getitem__(self, tok): return super(FieldVocab, self).get(tok, self.unk_token_id)