mirror of https://github.com/explosion/spaCy.git
small fixes to pretrain config, init_tok2vec TODO
This commit is contained in:
parent
ddf8244df9
commit
4ed6278663
|
@ -45,12 +45,16 @@ eps = 1e-8
|
|||
learn_rate = 0.001
|
||||
|
||||
[pretraining]
|
||||
max_epochs = 100
|
||||
max_epochs = 1000
|
||||
start_epoch = 0
|
||||
min_length = 5
|
||||
max_length = 500
|
||||
dropout = 0.2
|
||||
n_save_every = null
|
||||
batch_size = 3000
|
||||
seed = ${training:seed}
|
||||
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory}
|
||||
init_tok2vec = null
|
||||
|
||||
[pretraining.model]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
|
|
|
@ -16,14 +16,15 @@ from ..tokens import Doc
|
|||
from ..attrs import ID, HEAD
|
||||
from .. import util
|
||||
from ..gold import Example
|
||||
from .deprecated_pretrain import _load_pretrained_tok2vec # TODO
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
# fmt: off
|
||||
texts_loc=("Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", "positional", None, str),
|
||||
vectors_model=("Name or path to spaCy model with vectors to learn from", "positional", None, str),
|
||||
config_path=("Path to config file", "positional", None, Path),
|
||||
output_dir=("Directory to write models to on each epoch", "positional", None, Path),
|
||||
config_path=("Path to config file", "positional", None, Path),
|
||||
use_gpu=("Use GPU", "option", "g", int),
|
||||
# fmt: on
|
||||
)
|
||||
|
@ -60,8 +61,8 @@ def pretrain(
|
|||
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
config = util.load_config(config_path, create_objects=False)
|
||||
util.fix_random_seed(config["training"]["seed"])
|
||||
if config["training"]["use_pytorch_for_gpu_memory"]:
|
||||
util.fix_random_seed(config["pretraining"]["seed"])
|
||||
if config["pretraining"]["use_pytorch_for_gpu_memory"]:
|
||||
use_pytorch_for_gpu_memory()
|
||||
|
||||
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
||||
|
@ -100,8 +101,33 @@ def pretrain(
|
|||
tok2vec = pretrain_config["model"]
|
||||
model = create_pretraining_model(nlp, tok2vec)
|
||||
optimizer = pretrain_config["optimizer"]
|
||||
init_tok2vec = pretrain_config["init_tok2vec"]
|
||||
epoch_start = pretrain_config["epoch_start"]
|
||||
|
||||
# Load in pretrained weights - TODO test
|
||||
if init_tok2vec is not None:
|
||||
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
|
||||
msg.text(f"Loaded pretrained tok2vec for: {components}")
|
||||
# Parse the epoch number from the given weight file
|
||||
model_name = re.search(r"model\d+\.bin", str(init_tok2vec))
|
||||
if model_name:
|
||||
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
|
||||
epoch_start = int(model_name.group(0)[5:][:-4]) + 1
|
||||
else:
|
||||
if not epoch_start:
|
||||
msg.fail(
|
||||
"You have to use the epoch_start setting when using a renamed weight file for init_tok2vec",
|
||||
exits=True,
|
||||
)
|
||||
elif epoch_start < 0:
|
||||
msg.fail(
|
||||
f"The setting epoch_start has to be greater or equal to 0. {epoch_start} is invalid",
|
||||
exits=True,
|
||||
)
|
||||
else:
|
||||
# Without 'init-tok2vec' the 'epoch_start' setting is ignored
|
||||
epoch_start = 0
|
||||
|
||||
epoch_start = 0 # TODO
|
||||
tracker = ProgressTracker(frequency=10000)
|
||||
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_start}")
|
||||
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
|
||||
|
|
Loading…
Reference in New Issue