mirror of https://github.com/explosion/spaCy.git
Fix defaults for ud-train
This commit is contained in:
parent
59cf533879
commit
3eb9f3e2b8
|
@ -300,17 +300,25 @@ def initialize_pipeline(nlp, docs, golds, config, device):
|
|||
########################
|
||||
|
||||
class Config(object):
|
||||
def __init__(self, vectors=None, max_doc_length=10, multitask_tag=True,
|
||||
multitask_sent=True, multitask_dep=True, multitask_vectors=False,
|
||||
nr_epoch=30, min_batch_size=1, max_batch_size=16, batch_by_words=False,
|
||||
dropout=0.2, conv_depth=4, subword_features=True):
|
||||
def __init__(self, vectors=None, max_doc_length=10, multitask_tag=False,
|
||||
multitask_sent=False, multitask_dep=False, multitask_vectors=None,
|
||||
nr_epoch=30, min_batch_size=100, max_batch_size=1000,
|
||||
batch_by_words=True, dropout=0.2, conv_depth=4, subword_features=True,
|
||||
vectors_dir=None):
|
||||
if vectors_dir is not None:
|
||||
if vectors is None:
|
||||
vectors = True
|
||||
if multitask_vectors is None:
|
||||
multitask_vectors = True
|
||||
for key, value in locals().items():
|
||||
setattr(self, key, value)
|
||||
|
||||
@classmethod
|
||||
def load(cls, loc):
|
||||
def load(cls, loc, vectors_dir=None):
|
||||
with Path(loc).open('r', encoding='utf8') as file_:
|
||||
cfg = json.load(file_)
|
||||
if vectors_dir is not None:
|
||||
cfg['vectors_dir'] = vectors_dir
|
||||
return cls(**cfg)
|
||||
|
||||
|
||||
|
@ -353,16 +361,16 @@ class TreebankPaths(object):
|
|||
vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/",
|
||||
"option", "v", Path),
|
||||
)
|
||||
def main(ud_dir, parses_dir, config=None, corpus, limit=0, use_gpu=-1, vectors_dir=None,
|
||||
def main(ud_dir, parses_dir, corpus, config=None, limit=0, use_gpu=-1, vectors_dir=None,
|
||||
use_oracle_segments=False):
|
||||
spacy.util.fix_random_seed()
|
||||
lang.zh.Chinese.Defaults.use_jieba = False
|
||||
lang.ja.Japanese.Defaults.use_janome = False
|
||||
|
||||
if config is not None:
|
||||
config = Config.load(config)
|
||||
config = Config.load(config, vectors_dir=vectors_dir)
|
||||
else:
|
||||
config = Config()
|
||||
config = Config(vectors_dir=vectors_dir)
|
||||
paths = TreebankPaths(ud_dir, corpus)
|
||||
if not (parses_dir / corpus).exists():
|
||||
(parses_dir / corpus).mkdir()
|
||||
|
|
Loading…
Reference in New Issue