mirror of https://github.com/explosion/spaCy.git
Test and fix #1919: Error resuming training
This commit is contained in:
parent
6b1126c312
commit
f74a802d09
|
@ -827,6 +827,7 @@ cdef class Parser:
|
|||
for action, labels in actions.items():
|
||||
for label in labels:
|
||||
self.moves.add_action(action, label)
|
||||
cfg.setdefault('token_vector_width', 128)
|
||||
if self.model is True:
|
||||
cfg['pretrained_dims'] = self.vocab.vectors_length
|
||||
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
|
||||
|
@ -836,11 +837,12 @@ cdef class Parser:
|
|||
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
|
||||
self.init_multitask_objectives(gold_tuples, pipeline, sgd=sgd, **cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
self.cfg.update(cfg)
|
||||
elif sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
self.model[1].begin_training(
|
||||
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
|
||||
else:
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
self.model[1].begin_training(
|
||||
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
|
||||
self.cfg.update(cfg)
|
||||
return sgd
|
||||
|
||||
def add_multitask_objective(self, target):
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
'''Test that nlp.begin_training() doesn't require missing cfg properties.'''
|
||||
from __future__ import unicode_literals
|
||||
import pytest
|
||||
from ... import load as load_spacy
|
||||
|
||||
@pytest.mark.models('en')
|
||||
def test_issue1919():
|
||||
nlp = load_spacy('en')
|
||||
opt = nlp.begin_training()
|
||||
|
Loading…
Reference in New Issue