Make the server actually work

This commit is contained in:
Giovanni Campagna 2019-03-01 12:32:54 -08:00
parent a0ff18f1fe
commit c66dde4ca0
1 changed files with 9 additions and 4 deletions

View File

@ -62,6 +62,10 @@ class Server():
setattr(processed, f'{name}_limited', limited_entry) setattr(processed, f'{name}_limited', limited_entry)
setattr(processed, f'{name}_elmo', [[s.strip() for s in l] for l in raw]) setattr(processed, f'{name}_elmo', [[s.strip() for s in l] for l in raw])
processed.oov_to_limited_idx = self._oov_to_limited_idx
processed.limited_idx_to_full_idx = self._limited_idx_to_full_idx
return processed
async def handle_client(self, client_reader, client_writer): async def handle_client(self, client_reader, client_writer):
try: try:
request = json.loads(await client_reader.readline()) request = json.loads(await client_reader.readline())
@ -81,14 +85,14 @@ class Server():
ex = Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields, tokenize=tokenize) ex = Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields, tokenize=tokenize)
batch = self.numericalize_example(ex) batch = self.numericalize_example(ex)
_, prediction_batch = model(batch) _, prediction_batch = model(batch, iteration=0)
if task == 'almond': if task == 'almond':
predictions = field.reverse(prediction_batch, detokenize=lambda x: ' '.join(x)) predictions = field.reverse(prediction_batch, detokenize=lambda x: ' '.join(x))
else: else:
predictions = field.reverse(prediction_batch) predictions = field.reverse(prediction_batch)
await client_writer.writeln(json.dumps(dict(id=request['id'], answer=predictions[0]))) client_writer.write((json.dumps(dict(id=request['id'], answer=predictions[0])) + '\n').encode('utf-8'))
except IOError: except IOError:
logger.info('Connection to client_reader closed') logger.info('Connection to client_reader closed')
@ -142,7 +146,8 @@ def get_args():
'transformer_layers', 'rnn_layers', 'transformer_hidden', 'transformer_layers', 'rnn_layers', 'transformer_hidden',
'dimension', 'load', 'max_val_context_length', 'val_batch_size', 'dimension', 'load', 'max_val_context_length', 'val_batch_size',
'transformer_heads', 'max_output_length', 'max_generative_vocab', 'transformer_heads', 'max_output_length', 'max_generative_vocab',
'lower', 'cove', 'intermediate_cove', 'elmo', 'glove_and_char', 'use_maxmargin_loss'] 'lower', 'cove', 'intermediate_cove', 'elmo', 'glove_and_char', 'use_maxmargin_loss',
'reverse_task_bool']
for r in retrieve: for r in retrieve:
if r in config: if r in config:
setattr(args, r, config[r]) setattr(args, r, config[r])
@ -204,7 +209,7 @@ if __name__ == '__main__':
model.load_state_dict(model_dict) model.load_state_dict(model_dict)
server = Server(args, field, model) server = Server(args, field, model)
server.prepare(args) server.prepare_data()
model.set_embeddings(field.vocab.vectors) model.set_embeddings(field.vocab.vectors)
server.run() server.run()