pipe_name instead of section in debug_model

This commit is contained in:
svlandeg 2020-07-30 20:06:28 +02:00
parent 3449c45fd9
commit 0b23594953
1 changed files with 5 additions and 5 deletions

View File

@ -16,7 +16,7 @@ def debug_model_cli(
# fmt: off
ctx: typer.Context, # This is only used to read additional arguments
config_path: Path = Arg(..., help="Path to config file", exists=True),
section: str = Arg(..., help="Section that defines the model to be analysed"),
pipe_name: str = Arg(..., help="Name of the pipe of which the model should be analysed"),
layers: str = Opt("", "--layers", "-l", help="Comma-separated names of layer IDs to print"),
dimensions: bool = Opt(False, "--dimensions", "-DIM", help="Show dimensions"),
parameters: bool = Opt(False, "--parameters", "-PAR", help="Show parameters"),
@ -53,20 +53,20 @@ def debug_model_cli(
cfg = Config().from_disk(config_path)
with show_validation_error():
try:
_, config = util.load_model_from_config(cfg, overrides=config_overrides)
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
except ValueError as e:
msg.fail(str(e), exits=1)
seed = config["pretraining"]["seed"]
seed = config.get("training", {}).get("seed", None)
if seed is not None:
msg.info(f"Fixing random seed: {seed}")
fix_random_seed(seed)
component = dot_to_object(config, section)
component = nlp.get_pipe(pipe_name)
if hasattr(component, "model"):
model = component.model
else:
msg.fail(
f"The section '{section}' does not specify an object that holds a Model.",
f"The component '{pipe_name}' does not specify an object that holds a Model.",
exits=1,
)
debug_model(model, print_settings=print_settings)