Updated model building after suggestion from Matthew

Signed-off-by: Avadh Patel <avadh4all@gmail.com>
This commit is contained in:
Avadh Patel 2018-01-18 06:51:57 -06:00
parent 2146faffee
commit 75903949da
1 changed files with 3 additions and 5 deletions

View File

@ -259,10 +259,6 @@ cdef class Parser:
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0)) zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
) )
# TODO: This is an unfortunate hack atm!
# Used to set input dimensions in network.
if not cfg.get('from_disk', False):
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
cfg = { cfg = {
'nr_class': nr_class, 'nr_class': nr_class,
'hidden_depth': depth, 'hidden_depth': depth,
@ -812,6 +808,8 @@ cdef class Parser:
self.model, cfg = self.Model(self.moves.n_moves, **cfg) self.model, cfg = self.Model(self.moves.n_moves, **cfg)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
self.model[1].begin_training(
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
self.init_multitask_objectives(gold_tuples, pipeline, sgd=sgd, **cfg) self.init_multitask_objectives(gold_tuples, pipeline, sgd=sgd, **cfg)
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
self.cfg.update(cfg) self.cfg.update(cfg)
@ -865,7 +863,7 @@ cdef class Parser:
path = util.ensure_path(path) path = util.ensure_path(path)
if self.model is True: if self.model is True:
self.cfg['pretrained_dims'] = self.vocab.vectors_length self.cfg['pretrained_dims'] = self.vocab.vectors_length
self.model, cfg = self.Model(from_disk=True, **self.cfg) self.model, cfg = self.Model(**self.cfg)
else: else:
cfg = {} cfg = {}
with (path / 'tok2vec_model').open('rb') as file_: with (path / 'tok2vec_model').open('rb') as file_: