mirror of https://github.com/explosion/spaCy.git
cleanup and formatting
This commit is contained in:
parent
0c35885751
commit
427dbecdd6
|
@ -71,9 +71,7 @@ def pretrain_cli(
|
||||||
|
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
config = util.load_config(
|
config = util.load_config(
|
||||||
config_path,
|
config_path, overrides=config_overrides, interpolate=True
|
||||||
overrides=config_overrides,
|
|
||||||
interpolate=True
|
|
||||||
)
|
)
|
||||||
if not config.get("pretraining"):
|
if not config.get("pretraining"):
|
||||||
# TODO: What's the solution here? How do we handle optional blocks?
|
# TODO: What's the solution here? How do we handle optional blocks?
|
||||||
|
@ -84,7 +82,7 @@ def pretrain_cli(
|
||||||
|
|
||||||
config.to_disk(output_dir / "config.cfg")
|
config.to_disk(output_dir / "config.cfg")
|
||||||
msg.good("Saved config file in the output directory")
|
msg.good("Saved config file in the output directory")
|
||||||
|
|
||||||
pretrain(
|
pretrain(
|
||||||
config,
|
config,
|
||||||
output_dir,
|
output_dir,
|
||||||
|
@ -99,7 +97,7 @@ def pretrain(
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
resume_path: Optional[Path] = None,
|
resume_path: Optional[Path] = None,
|
||||||
epoch_resume: Optional[int] = None,
|
epoch_resume: Optional[int] = None,
|
||||||
use_gpu: int=-1
|
use_gpu: int = -1,
|
||||||
):
|
):
|
||||||
if config["system"].get("seed") is not None:
|
if config["system"].get("seed") is not None:
|
||||||
fix_random_seed(config["system"]["seed"])
|
fix_random_seed(config["system"]["seed"])
|
||||||
|
@ -107,7 +105,7 @@ def pretrain(
|
||||||
use_pytorch_for_gpu_memory()
|
use_pytorch_for_gpu_memory()
|
||||||
nlp, config = util.load_model_from_config(config)
|
nlp, config = util.load_model_from_config(config)
|
||||||
P_cfg = config["pretraining"]
|
P_cfg = config["pretraining"]
|
||||||
corpus = dot_to_object(config, config["pretraining"]["corpus"])
|
corpus = dot_to_object(config, P_cfg["corpus"])
|
||||||
batcher = P_cfg["batcher"]
|
batcher = P_cfg["batcher"]
|
||||||
model = create_pretraining_model(nlp, config["pretraining"])
|
model = create_pretraining_model(nlp, config["pretraining"])
|
||||||
optimizer = config["pretraining"]["optimizer"]
|
optimizer = config["pretraining"]["optimizer"]
|
||||||
|
@ -148,9 +146,7 @@ def pretrain(
|
||||||
progress = tracker.update(epoch, loss, docs)
|
progress = tracker.update(epoch, loss, docs)
|
||||||
if progress:
|
if progress:
|
||||||
msg.row(progress, **row_settings)
|
msg.row(progress, **row_settings)
|
||||||
if P_cfg["n_save_every"] and (
|
if P_cfg["n_save_every"] and (batch_id % P_cfg["n_save_every"] == 0):
|
||||||
batch_id % P_cfg["n_save_every"] == 0
|
|
||||||
):
|
|
||||||
_save_model(epoch, is_temp=True)
|
_save_model(epoch, is_temp=True)
|
||||||
_save_model(epoch)
|
_save_model(epoch)
|
||||||
tracker.epoch_loss = 0.0
|
tracker.epoch_loss = 0.0
|
||||||
|
|
|
@ -93,8 +93,8 @@ def train(
|
||||||
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
||||||
T_cfg = config["training"]
|
T_cfg = config["training"]
|
||||||
optimizer = T_cfg["optimizer"]
|
optimizer = T_cfg["optimizer"]
|
||||||
train_corpus = dot_to_object(config, config["training"]["train_corpus"])
|
train_corpus = dot_to_object(config, T_cfg["train_corpus"])
|
||||||
dev_corpus = dot_to_object(config, config["training"]["dev_corpus"])
|
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
|
||||||
batcher = T_cfg["batcher"]
|
batcher = T_cfg["batcher"]
|
||||||
train_logger = T_cfg["logger"]
|
train_logger = T_cfg["logger"]
|
||||||
# Components that shouldn't be updated during training
|
# Components that shouldn't be updated during training
|
||||||
|
|
|
@ -104,7 +104,7 @@ class TokenPatternOperator(str, Enum):
|
||||||
StringValue = Union[TokenPatternString, StrictStr]
|
StringValue = Union[TokenPatternString, StrictStr]
|
||||||
NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
|
NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
|
||||||
UnderscoreValue = Union[
|
UnderscoreValue = Union[
|
||||||
TokenPatternString, TokenPatternNumber, str, int, float, list, bool,
|
TokenPatternString, TokenPatternNumber, str, int, float, list, bool
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,12 +26,15 @@ def test_readers():
|
||||||
[components.textcat]
|
[components.textcat]
|
||||||
factory = "textcat"
|
factory = "textcat"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@registry.readers.register("myreader.v1")
|
@registry.readers.register("myreader.v1")
|
||||||
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
|
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
|
||||||
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
|
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
|
||||||
|
|
||||||
def reader(nlp: Language):
|
def reader(nlp: Language):
|
||||||
doc = nlp.make_doc(f"This is an example")
|
doc = nlp.make_doc(f"This is an example")
|
||||||
return [Example.from_dict(doc, annots)]
|
return [Example.from_dict(doc, annots)]
|
||||||
|
|
||||||
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
|
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
|
||||||
|
|
||||||
config = Config().from_str(config_string)
|
config = Config().from_str(config_string)
|
||||||
|
|
Loading…
Reference in New Issue