Merge pull request #88 from stanford-oval/wip/no_grad

add torch.no_grad to kfserver
This commit is contained in:
jgd5 2021-02-04 13:34:18 -08:00 committed by GitHub
commit ca293b8814
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 60 additions and 60 deletions

View File

@ -64,69 +64,70 @@ class Server:
return NumericalizedExamples.collate_batches(all_features, self.numericalizer, device=self.device)
def handle_request(self, line):
if isinstance(line, dict):
request = line
else:
request = json.loads(line)
with torch.no_grad():
if isinstance(line, dict):
request = line
else:
request = json.loads(line)
task_name = request['task'] if 'task' in request else 'generic'
if task_name in self._cached_tasks:
task = self._cached_tasks[task_name]
else:
task = list(get_tasks([task_name], self.args).values())[0]
self._cached_tasks[task_name] = task
task_name = request['task'] if 'task' in request else 'generic'
if task_name in self._cached_tasks:
task = self._cached_tasks[task_name]
else:
task = list(get_tasks([task_name], self.args).values())[0]
self._cached_tasks[task_name] = task
if 'instances' in request:
examples = []
# request['instances'] is an array of {context, question, answer, example_id}
for instance in request['instances']:
example_id, context, question, answer = instance.get('example_id', ''), instance['context'], instance['question'], instance.get('answer', '')
if 'instances' in request:
examples = []
# request['instances'] is an array of {context, question, answer, example_id}
for instance in request['instances']:
example_id, context, question, answer = instance.get('example_id', ''), instance['context'], instance['question'], instance.get('answer', '')
if not context:
context = task.default_context
if not question:
question = task.default_question
ex = Example.from_raw(str(example_id), context, question, answer, preprocess=task.preprocess_field, lower=self.args.lower)
examples.append(ex)
self.model.add_new_vocab_from_data([task])
batch = self.numericalize_examples(examples)
# it is a single batch, so wrap it in []
if self.args.calibrator_path is not None:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
output_predictions_only=True,
confidence_estimator=self.confidence_estimator)
response = json.dumps({ 'id': request['id'], 'instances': [{ 'answer': p[0], 'score': float(s)} for (p, s) in zip(output.predictions, output.confidence_scores)]})
else:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
output_predictions_only=True)
response = json.dumps({ 'id': request['id'], 'instances': [{ 'answer': p[0]} for p in output.predictions]})
return response + '\n'
else:
context = request['context']
if not context:
context = task.default_context
question = request['question']
if not question:
question = task.default_question
answer = ''
ex = Example.from_raw(str(example_id), context, question, answer, preprocess=task.preprocess_field, lower=self.args.lower)
examples.append(ex)
ex = Example.from_raw(str(request['id']), context, question, answer, preprocess=task.preprocess_field, lower=self.args.lower)
self.model.add_new_vocab_from_data([task])
batch = self.numericalize_examples(examples)
# it is a single batch, so wrap it in []
if self.args.calibrator_path is not None:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
output_predictions_only=True,
confidence_estimator=self.confidence_estimator)
response = json.dumps({ 'id': request['id'], 'instances': [{ 'answer': p[0], 'score': float(s)} for (p, s) in zip(output.predictions, output.confidence_scores)]})
else:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
output_predictions_only=True)
response = json.dumps({ 'id': request['id'], 'instances': [{ 'answer': p[0]} for p in output.predictions]})
return response + '\n'
else:
context = request['context']
if not context:
context = task.default_context
question = request['question']
if not question:
question = task.default_question
answer = ''
ex = Example.from_raw(str(request['id']), context, question, answer, preprocess=task.preprocess_field, lower=self.args.lower)
self.model.add_new_vocab_from_data([task])
batch = self.numericalize_examples([ex])
if self.args.calibrator_path is not None:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
output_predictions_only=True,
confidence_estimator=self.confidence_estimator)
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0], score=float(output.confidence_scores[0])))
else:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
output_predictions_only=True)
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0]))
return response + '\n'
self.model.add_new_vocab_from_data([task])
batch = self.numericalize_examples([ex])
if self.args.calibrator_path is not None:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
output_predictions_only=True,
confidence_estimator=self.confidence_estimator)
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0], score=float(output.confidence_scores[0])))
else:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
output_predictions_only=True)
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0]))
return response + '\n'
async def handle_client(self, client_reader, client_writer):
try:
@ -169,11 +170,10 @@ class Server:
self.model.to(self.device)
self.model.eval()
with torch.no_grad():
if self.args.stdin:
self._run_stdin()
else:
self._run_tcp()
if self.args.stdin:
self._run_stdin()
else:
self._run_tcp()
def parse_argv(parser):