Merge pull request #88 from stanford-oval/wip/no_grad
add torch.no_grad to kfserver
This commit is contained in:
commit
ca293b8814
|
@ -64,69 +64,70 @@ class Server:
|
|||
return NumericalizedExamples.collate_batches(all_features, self.numericalizer, device=self.device)
|
||||
|
||||
def handle_request(self, line):
|
||||
if isinstance(line, dict):
|
||||
request = line
|
||||
else:
|
||||
request = json.loads(line)
|
||||
with torch.no_grad():
|
||||
if isinstance(line, dict):
|
||||
request = line
|
||||
else:
|
||||
request = json.loads(line)
|
||||
|
||||
task_name = request['task'] if 'task' in request else 'generic'
|
||||
if task_name in self._cached_tasks:
|
||||
task = self._cached_tasks[task_name]
|
||||
else:
|
||||
task = list(get_tasks([task_name], self.args).values())[0]
|
||||
self._cached_tasks[task_name] = task
|
||||
task_name = request['task'] if 'task' in request else 'generic'
|
||||
if task_name in self._cached_tasks:
|
||||
task = self._cached_tasks[task_name]
|
||||
else:
|
||||
task = list(get_tasks([task_name], self.args).values())[0]
|
||||
self._cached_tasks[task_name] = task
|
||||
|
||||
if 'instances' in request:
|
||||
examples = []
|
||||
# request['instances'] is an array of {context, question, answer, example_id}
|
||||
for instance in request['instances']:
|
||||
example_id, context, question, answer = instance.get('example_id', ''), instance['context'], instance['question'], instance.get('answer', '')
|
||||
if 'instances' in request:
|
||||
examples = []
|
||||
# request['instances'] is an array of {context, question, answer, example_id}
|
||||
for instance in request['instances']:
|
||||
example_id, context, question, answer = instance.get('example_id', ''), instance['context'], instance['question'], instance.get('answer', '')
|
||||
if not context:
|
||||
context = task.default_context
|
||||
if not question:
|
||||
question = task.default_question
|
||||
|
||||
ex = Example.from_raw(str(example_id), context, question, answer, preprocess=task.preprocess_field, lower=self.args.lower)
|
||||
examples.append(ex)
|
||||
|
||||
self.model.add_new_vocab_from_data([task])
|
||||
batch = self.numericalize_examples(examples)
|
||||
# it is a single batch, so wrap it in []
|
||||
if self.args.calibrator_path is not None:
|
||||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||
output_predictions_only=True,
|
||||
confidence_estimator=self.confidence_estimator)
|
||||
|
||||
response = json.dumps({ 'id': request['id'], 'instances': [{ 'answer': p[0], 'score': float(s)} for (p, s) in zip(output.predictions, output.confidence_scores)]})
|
||||
else:
|
||||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||
output_predictions_only=True)
|
||||
|
||||
response = json.dumps({ 'id': request['id'], 'instances': [{ 'answer': p[0]} for p in output.predictions]})
|
||||
return response + '\n'
|
||||
else:
|
||||
context = request['context']
|
||||
if not context:
|
||||
context = task.default_context
|
||||
question = request['question']
|
||||
if not question:
|
||||
question = task.default_question
|
||||
answer = ''
|
||||
|
||||
ex = Example.from_raw(str(example_id), context, question, answer, preprocess=task.preprocess_field, lower=self.args.lower)
|
||||
examples.append(ex)
|
||||
ex = Example.from_raw(str(request['id']), context, question, answer, preprocess=task.preprocess_field, lower=self.args.lower)
|
||||
|
||||
self.model.add_new_vocab_from_data([task])
|
||||
batch = self.numericalize_examples(examples)
|
||||
# it is a single batch, so wrap it in []
|
||||
if self.args.calibrator_path is not None:
|
||||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||
output_predictions_only=True,
|
||||
confidence_estimator=self.confidence_estimator)
|
||||
|
||||
response = json.dumps({ 'id': request['id'], 'instances': [{ 'answer': p[0], 'score': float(s)} for (p, s) in zip(output.predictions, output.confidence_scores)]})
|
||||
else:
|
||||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||
output_predictions_only=True)
|
||||
|
||||
response = json.dumps({ 'id': request['id'], 'instances': [{ 'answer': p[0]} for p in output.predictions]})
|
||||
return response + '\n'
|
||||
else:
|
||||
context = request['context']
|
||||
if not context:
|
||||
context = task.default_context
|
||||
question = request['question']
|
||||
if not question:
|
||||
question = task.default_question
|
||||
answer = ''
|
||||
|
||||
ex = Example.from_raw(str(request['id']), context, question, answer, preprocess=task.preprocess_field, lower=self.args.lower)
|
||||
|
||||
self.model.add_new_vocab_from_data([task])
|
||||
batch = self.numericalize_examples([ex])
|
||||
if self.args.calibrator_path is not None:
|
||||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||
output_predictions_only=True,
|
||||
confidence_estimator=self.confidence_estimator)
|
||||
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0], score=float(output.confidence_scores[0])))
|
||||
else:
|
||||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||
output_predictions_only=True)
|
||||
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0]))
|
||||
return response + '\n'
|
||||
self.model.add_new_vocab_from_data([task])
|
||||
batch = self.numericalize_examples([ex])
|
||||
if self.args.calibrator_path is not None:
|
||||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||
output_predictions_only=True,
|
||||
confidence_estimator=self.confidence_estimator)
|
||||
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0], score=float(output.confidence_scores[0])))
|
||||
else:
|
||||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||
output_predictions_only=True)
|
||||
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0]))
|
||||
return response + '\n'
|
||||
|
||||
async def handle_client(self, client_reader, client_writer):
|
||||
try:
|
||||
|
@ -169,11 +170,10 @@ class Server:
|
|||
self.model.to(self.device)
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
if self.args.stdin:
|
||||
self._run_stdin()
|
||||
else:
|
||||
self._run_tcp()
|
||||
if self.args.stdin:
|
||||
self._run_stdin()
|
||||
else:
|
||||
self._run_tcp()
|
||||
|
||||
|
||||
def parse_argv(parser):
|
||||
|
|
Loading…
Reference in New Issue