diff --git a/spacy/cli/train_from_config.py b/spacy/cli/train_from_config.py index eeb21c10c..c0e3bd169 100644 --- a/spacy/cli/train_from_config.py +++ b/spacy/cli/train_from_config.py @@ -7,7 +7,7 @@ from pathlib import Path from wasabi import msg import thinc import thinc.schedules -from thinc.api import Model +from thinc.api import Model, use_pytorch_for_gpu_memory import random from ..gold import GoldCorpus @@ -171,6 +171,8 @@ def train_from_config( msg.info(f"Loading config from: {config_path}") config = util.load_config(config_path, create_objects=False) util.fix_random_seed(config["training"]["seed"]) + if config["training"]["use_pytorch_for_gpu_memory"]: + use_pytorch_for_gpu_memory() nlp_config = config["nlp"] config = util.load_config(config_path, create_objects=True) msg.info("Creating nlp from config")