diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index 31db48d64..01e8fbea0 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -82,7 +82,8 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000): output_dir = Path(output_dir) if not output_dir.exists(): output_dir.mkdir() - nlp.to_disk(output_dir) + with nlp.use_params(optimizer.averages): + nlp.to_disk(output_dir) print("Saved model to", output_dir) # test the saved model