Add option to `predict.py` to directly run HF hub models
This commit is contained in:
parent
554ec07610
commit
7d1d52c30d
|
@ -35,7 +35,7 @@ import torch
|
|||
|
||||
from . import models
|
||||
from .calibrate import ConfidenceEstimator
|
||||
from .util import load_config_json
|
||||
from .util import load_config_file_to_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -51,7 +51,7 @@ def parse_argv(parser):
|
|||
|
||||
def main(args):
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
load_config_json(args)
|
||||
load_config_file_to_args(args)
|
||||
|
||||
# load everything - this will ensure that we initialize the numericalizer correctly
|
||||
Model = getattr(models, args.model)
|
||||
|
|
|
@ -130,6 +130,8 @@ class TransformerSeq2Seq(GenieModelForGeneration):
|
|||
# remove BOS from the answer to BART-Large because BART-Large was not trained to predict BOS
|
||||
# (unlike BART-Base or mBART)
|
||||
#
|
||||
# NOTE: this change for some reason does not change the outputs of fine-tuned bart-large models
|
||||
# like `stanford-oval/paraphaser-bart-large`
|
||||
# NOTE: various people at Huggingface and elsewhere have tried to conclusively ascertain
|
||||
# whether BOS should be there or not, and the answer seems to be that BOS should not be there
|
||||
# at all, either in input or in the output
|
||||
|
|
|
@ -56,7 +56,7 @@ from .util import (
|
|||
combine_folders_on_disk,
|
||||
get_devices,
|
||||
get_part_path,
|
||||
load_config_json,
|
||||
load_config_file_to_args,
|
||||
log_model_size,
|
||||
make_data_loader,
|
||||
set_seed,
|
||||
|
@ -67,7 +67,21 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def parse_argv(parser):
|
||||
parser.add_argument('--path', type=str, required=True, help='Folder to load the model from')
|
||||
parser.add_argument('--is_hf_model', action='store_true',
|
||||
help='Whether the model should be directly loaded from HuggingFace model hub. If True, `--path` is the full model name.')
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
choices=[
|
||||
'TransformerLSTM',
|
||||
'TransformerSeq2Seq',
|
||||
'TransformerForTokenClassification',
|
||||
'TransformerForSequenceClassification',
|
||||
],
|
||||
default=None,
|
||||
help='which model to import',
|
||||
)
|
||||
parser.add_argument('--path', '--model_name_or_path', type=str, required=True, help='Folder to load the model from')
|
||||
parser.add_argument(
|
||||
'--evaluate',
|
||||
type=str,
|
||||
|
@ -127,7 +141,7 @@ def parse_argv(parser):
|
|||
parser.add_argument(
|
||||
'--val_batch_size',
|
||||
nargs='+',
|
||||
default=None,
|
||||
default=[4000],
|
||||
type=int,
|
||||
help='Batch size for validation corresponding to tasks in val tasks',
|
||||
)
|
||||
|
@ -163,9 +177,10 @@ def parse_argv(parser):
|
|||
default=[0],
|
||||
help='ngrams of this size cannot be repeated in the output. 0 disables it.',
|
||||
)
|
||||
parser.add_argument('--max_output_length', type=int, help='maximum output length for generation')
|
||||
parser.add_argument('--max_output_length', default=150, type=int, help='maximum output length for generation')
|
||||
parser.add_argument(
|
||||
'--min_output_length',
|
||||
default=3,
|
||||
type=int,
|
||||
help='maximum output length for generation; '
|
||||
'default is 3 for most multilingual models: BOS, language code, and one token. otherwise it is 2',
|
||||
|
@ -315,6 +330,10 @@ def check_args(args):
|
|||
setattr(args, 'pred_src_languages', [args.eval_src_languages])
|
||||
if not args.pred_tgt_languages:
|
||||
setattr(args, 'pred_tgt_languages', [args.eval_tgt_languages])
|
||||
|
||||
if args.is_hf_model and (not args.pred_src_languages or not args.model):
|
||||
# because in for HF models we are not getting these values from genienlp's training script
|
||||
raise ValueError('You need to specify --pred_languages and --model when directly loading a HuggingFace model.')
|
||||
|
||||
if len(args.task_names) != len(args.pred_src_languages):
|
||||
raise ValueError('You have to define prediction languages for each task in the same order you provided the tasks.')
|
||||
|
@ -337,6 +356,7 @@ def check_args(args):
|
|||
raise ValueError('Please remove --main_metric_only from your arguments so the requested extra metrics can be shown.')
|
||||
|
||||
|
||||
|
||||
def prepare_data(args):
|
||||
# TODO handle multiple languages
|
||||
src_lang = args.pred_src_languages[0]
|
||||
|
@ -456,18 +476,27 @@ def get_metrics_to_compute(args, task):
|
|||
|
||||
|
||||
def run(args, device):
|
||||
|
||||
# TODO handle multiple languages
|
||||
Model = getattr(models, args.model)
|
||||
model, _ = Model.load(
|
||||
args.path,
|
||||
model_checkpoint_file=args.checkpoint_name,
|
||||
args=args,
|
||||
device=device,
|
||||
tasks=args.tasks,
|
||||
src_lang=args.pred_src_languages[0],
|
||||
tgt_lang=args.pred_tgt_languages[0],
|
||||
)
|
||||
print(args.model)
|
||||
model_class = getattr(models, args.model)
|
||||
if args.is_hf_model:
|
||||
logger.info(f'Loading model {args.path} from HuggingFace model hub')
|
||||
model = model_class(args=args,
|
||||
vocab_sets=None,
|
||||
tasks=args.tasks,
|
||||
src_lang=args.pred_src_languages[0],
|
||||
tgt_lang=args.pred_tgt_languages[0]
|
||||
)
|
||||
else:
|
||||
# TODO handle multiple languages
|
||||
model, _ = model_class.load(
|
||||
args.path,
|
||||
model_checkpoint_file=args.checkpoint_name,
|
||||
args=args,
|
||||
device=device,
|
||||
tasks=args.tasks,
|
||||
src_lang=args.pred_src_languages[0],
|
||||
tgt_lang=args.pred_tgt_languages[0],
|
||||
)
|
||||
|
||||
val_sets = prepare_data(args)
|
||||
model.add_new_vocab_from_data(args.tasks)
|
||||
|
@ -589,7 +618,7 @@ def update_metrics(args):
|
|||
|
||||
|
||||
def main(args):
|
||||
load_config_json(args)
|
||||
load_config_file_to_args(args)
|
||||
check_and_update_generation_args(args)
|
||||
check_args(args)
|
||||
set_default_values(args)
|
||||
|
|
|
@ -36,7 +36,7 @@ import torch
|
|||
from . import models
|
||||
from .arguments import check_and_update_generation_args
|
||||
from .tasks.registry import get_tasks
|
||||
from .util import get_devices, load_config_json, set_seed
|
||||
from .util import get_devices, load_config_file_to_args, set_seed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -105,7 +105,7 @@ def init(args):
|
|||
devices = get_devices()
|
||||
device = devices[0] # server only runs on a single device
|
||||
|
||||
load_config_json(args)
|
||||
load_config_file_to_args(args)
|
||||
check_and_update_generation_args(args)
|
||||
|
||||
if not args.src_locale:
|
||||
|
|
|
@ -45,7 +45,7 @@ from .calibrate import ConfidenceEstimator
|
|||
from .data_utils.example import Example, NumericalizedExamples
|
||||
from .ned.ned_utils import init_ned_model
|
||||
from .tasks.registry import get_tasks
|
||||
from .util import adjust_language_code, get_devices, load_config_json, log_model_size, set_seed
|
||||
from .util import adjust_language_code, get_devices, load_config_file_to_args, log_model_size, set_seed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -327,7 +327,7 @@ class Server(object):
|
|||
|
||||
|
||||
def init(args):
|
||||
load_config_json(args)
|
||||
load_config_file_to_args(args)
|
||||
check_and_update_generation_args(args)
|
||||
if not args.src_locale:
|
||||
args.src_locale = args.eval_src_languages
|
||||
|
|
382
genienlp/util.py
382
genienlp/util.py
|
@ -585,201 +585,207 @@ def have_multilingual(task_names):
|
|||
return any(['multilingual' in name for name in task_names])
|
||||
|
||||
|
||||
def load_config_json(args):
|
||||
args.almond_type_embeddings = False
|
||||
with open(os.path.join(args.path, 'config.json')) as config_file:
|
||||
config = json.load(config_file)
|
||||
def load_config_file_to_args(args) -> bool:
|
||||
if args.is_hf_model:
|
||||
# no config file found, treat `args.path` as a model name on HuggingFace model hub
|
||||
args.pretrained_model = args.path
|
||||
args.override_question = "" # because HF models are trained without a separate question
|
||||
config = vars(args).copy()
|
||||
else:
|
||||
with open(os.path.join(args.path, 'config.json')) as config_file:
|
||||
config = json.load(config_file)
|
||||
|
||||
retrieve = [
|
||||
'model',
|
||||
'pretrained_model',
|
||||
'rnn_dimension',
|
||||
'rnn_layers',
|
||||
'rnn_zero_state',
|
||||
'max_generative_vocab',
|
||||
'lower',
|
||||
'trainable_decoder_embeddings',
|
||||
'override_context',
|
||||
'override_question',
|
||||
'almond_lang_as_question',
|
||||
'almond_has_multiple_programs',
|
||||
'almond_detokenize_sentence',
|
||||
'preprocess_special_tokens',
|
||||
'dropper_ratio',
|
||||
'dropper_min_count',
|
||||
'label_smoothing',
|
||||
'use_encoder_loss',
|
||||
'num_workers',
|
||||
'no_fast_tokenizer',
|
||||
'force_fast_tokenizer',
|
||||
'add_entities_to_text',
|
||||
'entity_attributes',
|
||||
'max_qids_per_entity',
|
||||
'max_types_per_qid',
|
||||
retrieve = [
|
||||
'model',
|
||||
'pretrained_model',
|
||||
'rnn_dimension',
|
||||
'rnn_layers',
|
||||
'rnn_zero_state',
|
||||
'max_generative_vocab',
|
||||
'lower',
|
||||
'trainable_decoder_embeddings',
|
||||
'override_context',
|
||||
'override_question',
|
||||
'almond_lang_as_question',
|
||||
'almond_has_multiple_programs',
|
||||
'almond_detokenize_sentence',
|
||||
'preprocess_special_tokens',
|
||||
'dropper_ratio',
|
||||
'dropper_min_count',
|
||||
'label_smoothing',
|
||||
'use_encoder_loss',
|
||||
'num_workers',
|
||||
'no_fast_tokenizer',
|
||||
'force_fast_tokenizer',
|
||||
'add_entities_to_text',
|
||||
'entity_attributes',
|
||||
'max_qids_per_entity',
|
||||
'max_types_per_qid',
|
||||
'do_ned',
|
||||
'database_type',
|
||||
'min_entity_len',
|
||||
'max_entity_len',
|
||||
'entity_type_agg_method',
|
||||
'entity_word_embeds_dropout',
|
||||
'num_db_types',
|
||||
'db_unk_id',
|
||||
'ned_retrieve_method',
|
||||
'ned_domains',
|
||||
'almond_type_mapping_path',
|
||||
'max_features_size',
|
||||
'bootleg_output_dir',
|
||||
'bootleg_model',
|
||||
'bootleg_prob_threshold',
|
||||
'ned_normalize_types',
|
||||
'att_pooling',
|
||||
'no_separator',
|
||||
'num_labels',
|
||||
'crossner_domains',
|
||||
'override_valid_metrics',
|
||||
'eval_src_languages',
|
||||
'eval_tgt_languages',
|
||||
'log_n_longest',
|
||||
]
|
||||
|
||||
# train and predict scripts have these arguments in common. We use the values from train only if they are not provided in predict.
|
||||
# NOTE: do not set default values for these arguments in predict cause the defaults will always override training arguments
|
||||
overwrite = [
|
||||
'model',
|
||||
'val_batch_size',
|
||||
'num_beams',
|
||||
'num_beam_groups',
|
||||
'diversity_penalty',
|
||||
'num_outputs',
|
||||
'no_repeat_ngram_size',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'repetition_penalty',
|
||||
'temperature',
|
||||
'align_span_symbol',
|
||||
'max_output_length',
|
||||
'min_output_length',
|
||||
'reduce_metrics',
|
||||
'database_dir',
|
||||
'e2e_dialogue_valid_subtasks',
|
||||
'e2e_dialogue_valid_submetrics',
|
||||
'e2e_dialogue_valid_subweights',
|
||||
]
|
||||
for o in overwrite:
|
||||
if o not in args or getattr(args, o) is None:
|
||||
retrieve.append(o)
|
||||
|
||||
# these are true/ false arguments
|
||||
overwrite_actions = [
|
||||
'do_alignment',
|
||||
'align_preserve_input_quotation',
|
||||
'align_remove_output_quotation',
|
||||
'e2e_dialogue_evaluation',
|
||||
'filter_long_inputs',
|
||||
]
|
||||
for o in overwrite_actions:
|
||||
# if argument is True in predict overwrite train; if False retrieve from train
|
||||
if not getattr(args, o, False):
|
||||
retrieve.append(o)
|
||||
|
||||
for r in retrieve:
|
||||
if r in config:
|
||||
setattr(args, r, config[r])
|
||||
# These are for backward compatibility with models that were trained before we added these arguments
|
||||
elif r in (
|
||||
'do_ned',
|
||||
'database_type',
|
||||
'min_entity_len',
|
||||
'max_entity_len',
|
||||
'entity_type_agg_method',
|
||||
'entity_word_embeds_dropout',
|
||||
'num_db_types',
|
||||
'db_unk_id',
|
||||
'ned_retrieve_method',
|
||||
'ned_domains',
|
||||
'almond_type_mapping_path',
|
||||
'max_features_size',
|
||||
'bootleg_output_dir',
|
||||
'bootleg_model',
|
||||
'bootleg_prob_threshold',
|
||||
'ned_normalize_types',
|
||||
'att_pooling',
|
||||
'no_separator',
|
||||
'num_labels',
|
||||
'crossner_domains',
|
||||
'override_valid_metrics',
|
||||
'eval_src_languages',
|
||||
'eval_tgt_languages',
|
||||
'log_n_longest',
|
||||
]
|
||||
|
||||
# train and predict scripts have these arguments in common. We use the values from train only if they are not provided in predict.
|
||||
# NOTE: do not set default values for these arguments in predict cause the defaults will always override training arguments
|
||||
overwrite = [
|
||||
'val_batch_size',
|
||||
'num_beams',
|
||||
'num_beam_groups',
|
||||
'diversity_penalty',
|
||||
'num_outputs',
|
||||
'no_repeat_ngram_size',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'repetition_penalty',
|
||||
'temperature',
|
||||
'align_span_symbol',
|
||||
'max_output_length',
|
||||
'min_output_length',
|
||||
'reduce_metrics',
|
||||
'database_dir',
|
||||
'e2e_dialogue_valid_subtasks',
|
||||
'e2e_dialogue_valid_submetrics',
|
||||
'e2e_dialogue_valid_subweights',
|
||||
]
|
||||
for o in overwrite:
|
||||
if o not in args or getattr(args, o) is None:
|
||||
retrieve.append(o)
|
||||
|
||||
# these are true/ false arguments
|
||||
overwrite_actions = [
|
||||
'do_alignment',
|
||||
'align_preserve_input_quotation',
|
||||
'align_remove_output_quotation',
|
||||
'e2e_dialogue_evaluation',
|
||||
'filter_long_inputs',
|
||||
]
|
||||
for o in overwrite_actions:
|
||||
# if argument is True in predict overwrite train; if False retrieve from train
|
||||
if not getattr(args, o, False):
|
||||
retrieve.append(o)
|
||||
|
||||
for r in retrieve:
|
||||
if r in config:
|
||||
setattr(args, r, config[r])
|
||||
# These are for backward compatibility with models that were trained before we added these arguments
|
||||
elif r in (
|
||||
'do_ned',
|
||||
'do_alignment',
|
||||
'align_preserve_input_quotation',
|
||||
'align_remove_output_quotation',
|
||||
'use_encoder_loss',
|
||||
'almond_has_multiple_programs',
|
||||
'almond_lang_as_question',
|
||||
'preprocess_special_tokens',
|
||||
'no_fast_tokenizer',
|
||||
'force_fast_tokenizer',
|
||||
):
|
||||
setattr(args, r, False)
|
||||
elif r in ('ned_normalize_types'):
|
||||
setattr(args, r, 'off')
|
||||
elif r in ('num_db_types', 'db_unk_id', 'num_workers'):
|
||||
setattr(args, r, 0)
|
||||
elif r in ('entity_word_embeds_dropout'):
|
||||
setattr(args, r, 0.0)
|
||||
elif r in ('num_beams', 'num_outputs', 'top_p', 'repetition_penalty'):
|
||||
setattr(args, r, [1])
|
||||
elif r in ('no_repeat_ngram_size', 'top_k', 'temperature'):
|
||||
setattr(args, r, [0])
|
||||
elif r in ['override_valid_metrics']:
|
||||
setattr(args, r, [])
|
||||
elif r == 'align_span_symbol':
|
||||
setattr(args, r, '"')
|
||||
elif r == 'log_n_longest':
|
||||
setattr(args, r, 3)
|
||||
elif r == 'database_type':
|
||||
setattr(args, r, 'json')
|
||||
elif r == 'att_pooling':
|
||||
setattr(args, r, 'max')
|
||||
elif r == 'min_entity_len':
|
||||
setattr(args, r, 2)
|
||||
elif r == 'max_entity_len':
|
||||
setattr(args, r, 4)
|
||||
elif r == 'ned_retrieve_method':
|
||||
setattr(args, r, 'naive')
|
||||
elif r == 'locale':
|
||||
setattr(args, r, 'en')
|
||||
elif r == 'num_beam_groups':
|
||||
setattr(args, r, [1])
|
||||
elif r == 'diversity_penalty':
|
||||
setattr(args, r, [0.0])
|
||||
elif r == 'dropper_ratio':
|
||||
setattr(args, r, 0.0)
|
||||
elif r == 'dropper_min_count':
|
||||
setattr(args, r, 10000)
|
||||
elif r == 'label_smoothing':
|
||||
setattr(args, r, 0.0)
|
||||
elif r == 'min_output_length':
|
||||
setattr(args, r, 3)
|
||||
elif r == 'no_separator':
|
||||
setattr(args, r, True) # old models don't use a separator
|
||||
else:
|
||||
# use default value
|
||||
setattr(args, r, None)
|
||||
|
||||
if args.e2e_dialogue_valid_subtasks is None:
|
||||
setattr(args, 'e2e_dialogue_valid_subtasks', ['dst', 'api', 'da', 'rg'])
|
||||
if args.e2e_dialogue_valid_submetrics is None:
|
||||
setattr(args, 'e2e_dialogue_valid_submetrics', ['dst_em', 'em', 'da_em', 'casedbleu'])
|
||||
if args.e2e_dialogue_valid_subweights is None:
|
||||
setattr(args, 'e2e_dialogue_valid_subweights', [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
# backward compatibility for models trained with genienlp before NED Refactoring (2)
|
||||
if args.max_features_size is None:
|
||||
if hasattr(args, 'ned_features_size'):
|
||||
setattr(args, 'max_features_size', args.ned_features_size)
|
||||
else:
|
||||
setattr(args, 'max_features_size', 0)
|
||||
if args.ned_domains is None:
|
||||
if hasattr(args, 'almond_domains'):
|
||||
setattr(args, 'ned_domains', args.almond_domains)
|
||||
else:
|
||||
setattr(args, 'ned_domains', [])
|
||||
if args.add_entities_to_text is None:
|
||||
if hasattr(args, 'add_types_to_text'):
|
||||
setattr(args, 'add_entities_to_text', args.add_types_to_text)
|
||||
else:
|
||||
setattr(args, 'add_entities_to_text', 'off')
|
||||
if args.entity_attributes is None:
|
||||
if hasattr(args, 'ned_features'):
|
||||
setattr(args, 'entity_attributes', args.ned_features)
|
||||
else:
|
||||
setattr(args, 'entity_attributes', [])
|
||||
if args.ned_normalize_types is None:
|
||||
if hasattr(args, 'bootleg_post_process_types') and args.bootleg_post_process_types:
|
||||
setattr(args, 'ned_normalize_types', 'soft')
|
||||
'use_encoder_loss',
|
||||
'almond_has_multiple_programs',
|
||||
'almond_lang_as_question',
|
||||
'preprocess_special_tokens',
|
||||
'no_fast_tokenizer',
|
||||
'force_fast_tokenizer',
|
||||
):
|
||||
setattr(args, r, False)
|
||||
elif r in ('ned_normalize_types'):
|
||||
setattr(args, r, 'off')
|
||||
elif r in ('num_db_types', 'db_unk_id', 'num_workers'):
|
||||
setattr(args, r, 0)
|
||||
elif r in ('entity_word_embeds_dropout'):
|
||||
setattr(args, r, 0.0)
|
||||
elif r in ('num_beams', 'num_outputs', 'top_p', 'repetition_penalty'):
|
||||
setattr(args, r, [1])
|
||||
elif r in ('no_repeat_ngram_size', 'top_k', 'temperature'):
|
||||
setattr(args, r, [0])
|
||||
elif r in ['override_valid_metrics']:
|
||||
setattr(args, r, [])
|
||||
elif r == 'align_span_symbol':
|
||||
setattr(args, r, '"')
|
||||
elif r == 'log_n_longest':
|
||||
setattr(args, r, 3)
|
||||
elif r == 'database_type':
|
||||
setattr(args, r, 'json')
|
||||
elif r == 'att_pooling':
|
||||
setattr(args, r, 'max')
|
||||
elif r == 'min_entity_len':
|
||||
setattr(args, r, 2)
|
||||
elif r == 'max_entity_len':
|
||||
setattr(args, r, 4)
|
||||
elif r == 'ned_retrieve_method':
|
||||
setattr(args, r, 'naive')
|
||||
elif r == 'locale':
|
||||
setattr(args, r, 'en')
|
||||
elif r == 'num_beam_groups':
|
||||
setattr(args, r, [1])
|
||||
elif r == 'diversity_penalty':
|
||||
setattr(args, r, [0.0])
|
||||
elif r == 'dropper_ratio':
|
||||
setattr(args, r, 0.0)
|
||||
elif r == 'dropper_min_count':
|
||||
setattr(args, r, 10000)
|
||||
elif r == 'label_smoothing':
|
||||
setattr(args, r, 0.0)
|
||||
elif r == 'min_output_length':
|
||||
setattr(args, r, 3)
|
||||
elif r == 'no_separator':
|
||||
setattr(args, r, True) # old models don't use a separator
|
||||
else:
|
||||
setattr(args, 'ned_normalize_types', 'off')
|
||||
# use default value
|
||||
setattr(args, r, None)
|
||||
|
||||
args.dropout_ratio = 0.0
|
||||
args.verbose = False
|
||||
if args.e2e_dialogue_valid_subtasks is None:
|
||||
setattr(args, 'e2e_dialogue_valid_subtasks', ['dst', 'api', 'da', 'rg'])
|
||||
if args.e2e_dialogue_valid_submetrics is None:
|
||||
setattr(args, 'e2e_dialogue_valid_submetrics', ['dst_em', 'em', 'da_em', 'casedbleu'])
|
||||
if args.e2e_dialogue_valid_subweights is None:
|
||||
setattr(args, 'e2e_dialogue_valid_subweights', [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
# backward compatibility for models trained with genienlp before NED Refactoring (2)
|
||||
if args.max_features_size is None:
|
||||
if hasattr(args, 'ned_features_size'):
|
||||
setattr(args, 'max_features_size', args.ned_features_size)
|
||||
else:
|
||||
setattr(args, 'max_features_size', 0)
|
||||
if args.ned_domains is None:
|
||||
if hasattr(args, 'almond_domains'):
|
||||
setattr(args, 'ned_domains', args.almond_domains)
|
||||
else:
|
||||
setattr(args, 'ned_domains', [])
|
||||
if args.add_entities_to_text is None:
|
||||
if hasattr(args, 'add_types_to_text'):
|
||||
setattr(args, 'add_entities_to_text', args.add_types_to_text)
|
||||
else:
|
||||
setattr(args, 'add_entities_to_text', 'off')
|
||||
if args.entity_attributes is None:
|
||||
if hasattr(args, 'ned_features'):
|
||||
setattr(args, 'entity_attributes', args.ned_features)
|
||||
else:
|
||||
setattr(args, 'entity_attributes', [])
|
||||
if args.ned_normalize_types is None:
|
||||
if hasattr(args, 'bootleg_post_process_types') and args.bootleg_post_process_types:
|
||||
setattr(args, 'ned_normalize_types', 'soft')
|
||||
else:
|
||||
setattr(args, 'ned_normalize_types', 'off')
|
||||
|
||||
args.dropout_ratio = 0.0
|
||||
args.verbose = False
|
||||
|
||||
args.best_checkpoint = os.path.join(args.path, args.checkpoint_name)
|
||||
|
||||
|
|
Loading…
Reference in New Issue