diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 3567e7339..aec077eb7 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -71,9 +71,7 @@ def pretrain_cli( with show_validation_error(config_path): config = util.load_config( - config_path, - overrides=config_overrides, - interpolate=True + config_path, overrides=config_overrides, interpolate=True ) if not config.get("pretraining"): # TODO: What's the solution here? How do we handle optional blocks? @@ -84,7 +82,7 @@ def pretrain_cli( config.to_disk(output_dir / "config.cfg") msg.good("Saved config file in the output directory") - + pretrain( config, output_dir, @@ -99,7 +97,7 @@ def pretrain( output_dir: Path, resume_path: Optional[Path] = None, epoch_resume: Optional[int] = None, - use_gpu: int=-1 + use_gpu: int = -1, ): if config["system"].get("seed") is not None: fix_random_seed(config["system"]["seed"]) @@ -107,7 +105,7 @@ def pretrain( use_pytorch_for_gpu_memory() nlp, config = util.load_model_from_config(config) P_cfg = config["pretraining"] - corpus = dot_to_object(config, config["pretraining"]["corpus"]) + corpus = dot_to_object(config, P_cfg["corpus"]) batcher = P_cfg["batcher"] model = create_pretraining_model(nlp, config["pretraining"]) optimizer = config["pretraining"]["optimizer"] @@ -148,9 +146,7 @@ def pretrain( progress = tracker.update(epoch, loss, docs) if progress: msg.row(progress, **row_settings) - if P_cfg["n_save_every"] and ( - batch_id % P_cfg["n_save_every"] == 0 - ): + if P_cfg["n_save_every"] and (batch_id % P_cfg["n_save_every"] == 0): _save_model(epoch, is_temp=True) _save_model(epoch) tracker.epoch_loss = 0.0 diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 15c745b69..50306b350 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -93,8 +93,8 @@ def train( raw_text, tag_map, morph_rules, weights_data = load_from_paths(config) T_cfg = config["training"] optimizer = T_cfg["optimizer"] - train_corpus = dot_to_object(config, config["training"]["train_corpus"]) - dev_corpus = dot_to_object(config, config["training"]["dev_corpus"]) + train_corpus = dot_to_object(config, T_cfg["train_corpus"]) + dev_corpus = dot_to_object(config, T_cfg["dev_corpus"]) batcher = T_cfg["batcher"] train_logger = T_cfg["logger"] # Components that shouldn't be updated during training diff --git a/spacy/schemas.py b/spacy/schemas.py index a530db3d0..06bc4beed 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -104,7 +104,7 @@ class TokenPatternOperator(str, Enum): StringValue = Union[TokenPatternString, StrictStr] NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat] UnderscoreValue = Union[ - TokenPatternString, TokenPatternNumber, str, int, float, list, bool, + TokenPatternString, TokenPatternNumber, str, int, float, list, bool ] diff --git a/spacy/tests/training/test_readers.py b/spacy/tests/training/test_readers.py index 52a4abecc..898746c2a 100644 --- a/spacy/tests/training/test_readers.py +++ b/spacy/tests/training/test_readers.py @@ -26,12 +26,15 @@ def test_readers(): [components.textcat] factory = "textcat" """ + @registry.readers.register("myreader.v1") def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]: annots = {"cats": {"POS": 1.0, "NEG": 0.0}} + def reader(nlp: Language): doc = nlp.make_doc(f"This is an example") return [Example.from_dict(doc, annots)] + return {"train": reader, "dev": reader, "extra": reader, "something": reader} config = Config().from_str(config_string)