diff --git a/server.py b/server.py index 63a183a4..97bcc6a0 100644 --- a/server.py +++ b/server.py @@ -62,6 +62,10 @@ class Server(): setattr(processed, f'{name}_limited', limited_entry) setattr(processed, f'{name}_elmo', [[s.strip() for s in l] for l in raw]) + processed.oov_to_limited_idx = self._oov_to_limited_idx + processed.limited_idx_to_full_idx = self._limited_idx_to_full_idx + return processed + async def handle_client(self, client_reader, client_writer): try: request = json.loads(await client_reader.readline()) @@ -81,14 +85,14 @@ class Server(): ex = Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields, tokenize=tokenize) batch = self.numericalize_example(ex) - _, prediction_batch = model(batch) + _, prediction_batch = model(batch, iteration=0) if task == 'almond': predictions = field.reverse(prediction_batch, detokenize=lambda x: ' '.join(x)) else: predictions = field.reverse(prediction_batch) - await client_writer.writeln(json.dumps(dict(id=request['id'], answer=predictions[0]))) + client_writer.write((json.dumps(dict(id=request['id'], answer=predictions[0])) + '\n').encode('utf-8')) except IOError: logger.info('Connection to client_reader closed') @@ -142,7 +146,8 @@ def get_args(): 'transformer_layers', 'rnn_layers', 'transformer_hidden', 'dimension', 'load', 'max_val_context_length', 'val_batch_size', 'transformer_heads', 'max_output_length', 'max_generative_vocab', - 'lower', 'cove', 'intermediate_cove', 'elmo', 'glove_and_char', 'use_maxmargin_loss'] + 'lower', 'cove', 'intermediate_cove', 'elmo', 'glove_and_char', 'use_maxmargin_loss', + 'reverse_task_bool'] for r in retrieve: if r in config: setattr(args, r, config[r]) @@ -204,7 +209,7 @@ if __name__ == '__main__': model.load_state_dict(model_dict) server = Server(args, field, model) - server.prepare(args) + server.prepare_data() model.set_embeddings(field.vocab.vectors) server.run()