updated lib name
This commit is contained in:
parent
0e82428eb9
commit
2117485550
|
@ -0,0 +1,167 @@
|
|||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from test_tube import HyperOptArgumentParser
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from sklearn.metrics import confusion_matrix, f1_score
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class BiLSTMPack(nn.Module):
|
||||
"""
|
||||
Sample model to show how to define a template
|
||||
"""
|
||||
def __init__(self, hparams):
|
||||
# init superclass
|
||||
super(BiLSTMPack, self).__init__(hparams)
|
||||
|
||||
self.hidden = None
|
||||
|
||||
# trigger tag building
|
||||
self.ner_tagset = {'O': 0, 'I-Bio': 1}
|
||||
self.nb_tags = len(self.ner_tagset)
|
||||
|
||||
# build model
|
||||
print('building model...')
|
||||
if hparams.model_load_weights_path is None:
|
||||
self.__build_model()
|
||||
print('model built')
|
||||
else:
|
||||
self = BiLSTMPack.load(hparams.model_load_weights_path, hparams.on_gpu, hparams)
|
||||
print('model loaded from: {}'.format(hparams.model_load_weights_path))
|
||||
|
||||
def __build_model(self):
|
||||
"""
|
||||
Layout model
|
||||
:return:
|
||||
"""
|
||||
# design the number of final units
|
||||
self.output_dim = self.hparams.nb_lstm_units
|
||||
|
||||
# when it's bidirectional our weights double
|
||||
if self.hparams.bidirectional:
|
||||
self.output_dim *= 2
|
||||
|
||||
# total number of words
|
||||
total_words = len(self.tng_dataloader.dataset.words_token_to_idx)
|
||||
|
||||
# word embeddings
|
||||
self.word_embedding = nn.Embedding(
|
||||
num_embeddings=total_words + 1,
|
||||
embedding_dim=self.hparams.embedding_dim,
|
||||
padding_idx=0
|
||||
)
|
||||
|
||||
# design the LSTM
|
||||
self.lstm = nn.LSTM(
|
||||
self.hparams.embedding_dim,
|
||||
self.hparams.nb_lstm_units,
|
||||
num_layers=self.hparams.nb_lstm_layers,
|
||||
bidirectional=self.hparams.bidirectional,
|
||||
dropout=self.hparams.drop_prob,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
# map to tag space
|
||||
self.fc_out = nn.Linear(self.output_dim, self.out_dim)
|
||||
self.hidden_to_tag = nn.Linear(self.output_dim, self.nb_tags)
|
||||
|
||||
|
||||
def init_hidden(self, batch_size):
|
||||
|
||||
# the weights are of the form (nb_layers * 2 if bidirectional, batch_size, nb_lstm_units)
|
||||
mult = 2 if self.hparams.bidirectional else 1
|
||||
hidden_a = torch.randn(self.hparams.nb_layers * mult, batch_size, self.nb_rnn_units)
|
||||
hidden_b = torch.randn(self.hparams.nb_layers * mult, batch_size, self.nb_rnn_units)
|
||||
|
||||
if self.hparams.on_gpu:
|
||||
hidden_a = hidden_a.cuda()
|
||||
hidden_b = hidden_b.cuda()
|
||||
|
||||
hidden_a = Variable(hidden_a)
|
||||
hidden_b = Variable(hidden_b)
|
||||
|
||||
return (hidden_a, hidden_b)
|
||||
|
||||
def forward(self, model_in):
|
||||
# layout data (expand it, etc...)
|
||||
# x = sequences
|
||||
x, seq_lengths = model_in
|
||||
batch_size, seq_len = x.size()
|
||||
|
||||
# reset RNN hidden state
|
||||
self.hidden = self.init_hidden(batch_size)
|
||||
|
||||
# embed
|
||||
x = self.word_embedding(x)
|
||||
|
||||
# run through rnn using packed sequences
|
||||
x = torch.nn.utils.rnn.pack_padded_sequence(x, seq_lengths, batch_first=True)
|
||||
x, self.hidden = self.lstm(x, self.hidden)
|
||||
x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
||||
|
||||
# if asked for only last state, use the h_n which is the same as out(t=n)
|
||||
if not self.return_sequence:
|
||||
# pull out hidden states
|
||||
# h_n = (nb_directions * nb_layers, batch_size, emb_size)
|
||||
nb_directions = 2 if self.bidirectional else 1
|
||||
(h_n, _) = self.hidden
|
||||
|
||||
# reshape to make indexing easier
|
||||
# forward = 0, backward = 1 (of nb_directions)
|
||||
h_n = h_n.view(self.nb_layers, nb_directions, batch_size, self.nb_rnn_units)
|
||||
|
||||
# pull out last forward
|
||||
forward_h_n = h_n[-1, 0, :, :]
|
||||
x = forward_h_n
|
||||
|
||||
# if bidirectional, also pull out the last hidden of backward network
|
||||
if self.bidirectional:
|
||||
backward_h_n = h_n[-1, 1, :, :]
|
||||
x = torch.cat([forward_h_n, backward_h_n], dim=1)
|
||||
|
||||
# project to tag space
|
||||
x = x.contiguous()
|
||||
x = x.view(-1, self.output_dim)
|
||||
x = self.hidden_to_tag(x)
|
||||
|
||||
return x
|
||||
|
||||
def loss(self, model_out):
|
||||
# cross entropy loss
|
||||
logits, y = model_out
|
||||
y, y_lens = y
|
||||
|
||||
# flatten y and logits
|
||||
y = y.view(-1)
|
||||
logits = logits.view(-1, self.nb_tags)
|
||||
|
||||
# calculate a mask to remove padding tokens
|
||||
mask = (y >= 0).float()
|
||||
|
||||
# count how many tokens we have
|
||||
num_tokens = int(torch.sum(mask).data[0])
|
||||
|
||||
# pick the correct values and mask out
|
||||
logits = logits[range(logits.shape[0]), y] * mask
|
||||
|
||||
# compute the ce loss
|
||||
ce_loss = -torch.sum(logits)/num_tokens
|
||||
|
||||
return ce_loss
|
||||
|
||||
def pull_out_last_embedding(self, x, seq_lengths, batch_size, on_gpu):
|
||||
# grab only the last activations from the non-padded ouput
|
||||
x_last = torch.zeros([batch_size, 1, x.size(-1)])
|
||||
for i, seq_len in enumerate(seq_lengths):
|
||||
x_last[i, :, :] = x[i, seq_len-1, :]
|
||||
|
||||
# put on gpu when requested
|
||||
if on_gpu:
|
||||
x_last = x_last.cuda()
|
||||
|
||||
# turn into torch var
|
||||
x_last = Variable(x_last)
|
||||
|
||||
return x_last
|
|
@ -0,0 +1,203 @@
|
|||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from pytorch_lightning.root_module.root_module import RootModule
|
||||
from test_tube import HyperOptArgumentParser
|
||||
from torchvision.datasets import MNIST
|
||||
import torchvision.transforms as transforms
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ExampleModel1(RootModule):
|
||||
"""
|
||||
Sample model to show how to define a template
|
||||
"""
|
||||
|
||||
def __init__(self, hparams):
|
||||
# init superclass
|
||||
super(ExampleModel1, self).__init__(hparams)
|
||||
|
||||
self.batch_size = hparams.batch_size
|
||||
|
||||
# build model
|
||||
self.__build_model()
|
||||
|
||||
# ---------------------
|
||||
# MODEL SETUP
|
||||
# ---------------------
|
||||
def __build_model(self):
|
||||
"""
|
||||
Layout model
|
||||
:return:
|
||||
"""
|
||||
self.c_d1 = nn.Linear(in_features=self.hparams.in_features, out_features=self.hparams.hidden_dim)
|
||||
self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
|
||||
self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)
|
||||
|
||||
self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim, out_features=self.hparams.out_features)
|
||||
|
||||
# ---------------------
|
||||
# TRAINING
|
||||
# ---------------------
|
||||
def forward(self, x):
|
||||
x = self.c_d1(x)
|
||||
x = F.tanh(x)
|
||||
x = self.c_d1_bn(x)
|
||||
x = self.c_d1_drop(x)
|
||||
|
||||
x = self.c_d2(x)
|
||||
logits = F.log_softmax(x, dim=1)
|
||||
|
||||
return logits
|
||||
|
||||
def loss(self, labels, logits):
|
||||
nll = F.nll_loss(logits, labels)
|
||||
return nll
|
||||
|
||||
def training_step(self, data_batch):
|
||||
"""
|
||||
Called inside the training loop
|
||||
:param data_batch:
|
||||
:return:
|
||||
"""
|
||||
# forward pass
|
||||
x, y = data_batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
# calculate loss
|
||||
loss_val = self.loss(y, y_hat)
|
||||
|
||||
tqdm_dic = {'jefe': 1}
|
||||
return loss_val, tqdm_dic
|
||||
|
||||
def validation_step(self, data_batch):
|
||||
"""
|
||||
Called inside the validation loop
|
||||
:param data_batch:
|
||||
:return:
|
||||
"""
|
||||
x, y = data_batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss_val = self.loss(y, y_hat)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
|
||||
output = {'y_hat': y_hat, 'val_loss': loss_val.item(), 'val_acc': val_acc}
|
||||
return output
|
||||
|
||||
def validation_end(self, outputs):
|
||||
"""
|
||||
Called at the end of validation to aggregate outputs
|
||||
:param outputs: list of individual outputs of each validation step
|
||||
:return:
|
||||
"""
|
||||
val_loss_mean = 0
|
||||
accs = []
|
||||
for output in outputs:
|
||||
val_loss_mean += output['val_loss']
|
||||
accs.append(output['val_acc'])
|
||||
|
||||
val_loss_mean /= len(outputs)
|
||||
tqdm_dic = {'val_loss': val_loss_mean, 'val_acc': np.mean(accs)}
|
||||
return tqdm_dic
|
||||
|
||||
def update_tng_log_metrics(self, logs):
|
||||
return logs
|
||||
|
||||
# ---------------------
|
||||
# MODEL SAVING
|
||||
# ---------------------
|
||||
def get_save_dict(self):
|
||||
checkpoint = {
|
||||
'state_dict': self.state_dict(),
|
||||
}
|
||||
|
||||
return checkpoint
|
||||
|
||||
def load_model_specific(self, checkpoint):
|
||||
self.load_state_dict(checkpoint['state_dict'])
|
||||
pass
|
||||
|
||||
# ---------------------
|
||||
# TRAINING SETUP
|
||||
# ---------------------
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
return whatever optimizers we want here
|
||||
:return: list of optimizers
|
||||
"""
|
||||
optimizer = self.choose_optimizer(self.hparams.optimizer_name, self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
|
||||
self.optimizers = [optimizer]
|
||||
return self.optimizers
|
||||
|
||||
def __dataloader(self, train):
|
||||
# init data generators
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
|
||||
|
||||
dataset = MNIST(root=self.hparams.data_root, train=train, transform=transform, download=True)
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
return loader
|
||||
|
||||
@property
|
||||
def tng_dataloader(self):
|
||||
if self._tng_dataloader is None:
|
||||
try:
|
||||
self._tng_dataloader = self.__dataloader(train=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
return self._tng_dataloader
|
||||
|
||||
@property
|
||||
def val_dataloader(self):
|
||||
if self._val_dataloader is None:
|
||||
try:
|
||||
self._val_dataloader = self.__dataloader(train=False)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
return self._val_dataloader
|
||||
|
||||
@property
|
||||
def test_dataloader(self):
|
||||
if self._test_dataloader is None:
|
||||
try:
|
||||
self._test_dataloader = self.__dataloader(train=False)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
return self._test_dataloader
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser):
|
||||
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
|
||||
|
||||
# param overwrites
|
||||
# parser.set_defaults(gradient_clip=5.0)
|
||||
|
||||
# network params
|
||||
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
|
||||
parser.add_argument('--in_features', default=28*28)
|
||||
parser.add_argument('--hidden_dim', default=500)
|
||||
parser.add_argument('--out_features', default=10)
|
||||
|
||||
# data
|
||||
parser.add_argument('--data_root', default='/Users/williamfalcon/Developer/personal/research_lib/research_proj/datasets/mnist', type=str)
|
||||
|
||||
# training params (opt)
|
||||
parser.opt_list('--learning_rate', default=0.001, type=float, options=[0.0001, 0.0005, 0.001, 0.005],
|
||||
tunable=False)
|
||||
parser.opt_list('--batch_size', default=256, type=int, options=[32, 64, 128, 256], tunable=False)
|
||||
parser.opt_list('--optimizer_name', default='adam', type=str, options=['adam'], tunable=False)
|
||||
return parser
|
|
@ -0,0 +1,409 @@
|
|||
import torch
|
||||
import tqdm
|
||||
import numpy as np
|
||||
from research_lib.root_module.memory import get_gpu_memory_map
|
||||
import traceback
|
||||
from research_lib.root_module.model_saving import TrainerIO
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
|
||||
|
||||
class Trainer(TrainerIO):
|
||||
|
||||
def __init__(self,
|
||||
experiment,
|
||||
cluster,
|
||||
checkpoint_callback, early_stop_callback,
|
||||
process_position=0,
|
||||
current_gpu_name=0,
|
||||
on_gpu=False,
|
||||
enable_tqdm=True,
|
||||
overfit_pct=None,
|
||||
track_grad_norm=-1,
|
||||
check_val_every_n_epoch=1,
|
||||
fast_dev_run=False,
|
||||
accumulate_grad_batches=False,
|
||||
enable_early_stop=True, max_nb_epochs=5, min_nb_epochs=1,
|
||||
train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=0.95,
|
||||
log_save_interval=1, add_log_row_interval=1,
|
||||
lr_scheduler_milestones=None,
|
||||
nb_sanity_val_steps=5):
|
||||
|
||||
# Transfer params
|
||||
self.check_val_every_n_epoch = check_val_every_n_epoch
|
||||
self.enable_early_stop = enable_early_stop
|
||||
self.track_grad_norm = track_grad_norm
|
||||
self.fast_dev_run = fast_dev_run
|
||||
self.on_gpu = on_gpu
|
||||
self.enable_tqdm = enable_tqdm
|
||||
self.experiment = experiment
|
||||
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
|
||||
self.cluster = cluster
|
||||
self.process_position = process_position
|
||||
self.current_gpu_name = current_gpu_name
|
||||
self.checkpoint_callback = checkpoint_callback
|
||||
self.checkpoint_callback.save_function = self.save_checkpoint
|
||||
self.early_stop = early_stop_callback
|
||||
self.model = None
|
||||
self.max_nb_epochs = max_nb_epochs
|
||||
self.accumulate_grad_batches = accumulate_grad_batches
|
||||
self.early_stop_callback = early_stop_callback
|
||||
self.min_nb_epochs = min_nb_epochs
|
||||
self.nb_sanity_val_steps = nb_sanity_val_steps
|
||||
self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')]
|
||||
self.lr_schedulers = []
|
||||
|
||||
# training state
|
||||
self.optimizers = None
|
||||
self.prog_bar = None
|
||||
self.global_step = 0
|
||||
self.current_epoch = 0
|
||||
self.total_batches = 0
|
||||
|
||||
# logging
|
||||
self.log_save_interval = log_save_interval
|
||||
self.val_check_interval = val_check_interval
|
||||
self.add_log_row_interval = add_log_row_interval
|
||||
|
||||
# dataloaders
|
||||
self.tng_dataloader = None
|
||||
self.test_dataloader = None
|
||||
self.val_dataloader = None
|
||||
|
||||
# how much of the data to use
|
||||
self.__determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct)
|
||||
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))
|
||||
|
||||
def __determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct):
|
||||
"""
|
||||
Use less data for debugging purposes
|
||||
"""
|
||||
self.train_percent_check = train_percent_check
|
||||
self.val_percent_check = val_percent_check
|
||||
self.test_percent_check = test_percent_check
|
||||
if overfit_pct > 0:
|
||||
self.train_percent_check = overfit_pct
|
||||
self.val_percent_check = overfit_pct
|
||||
self.test_percent_check = overfit_pct
|
||||
|
||||
def __is_function_implemented(self, f_name):
|
||||
f_op = getattr(self, f_name, None)
|
||||
return callable(f_op)
|
||||
|
||||
@property
|
||||
def __tng_tqdm_dic(self):
|
||||
tqdm_dic = {
|
||||
'tng_loss': '{0:.3f}'.format(self.avg_loss),
|
||||
'gpu': '{}'.format(self.current_gpu_name),
|
||||
'v_nb': '{}'.format(self.experiment.version),
|
||||
'epoch': '{}'.format(self.current_epoch),
|
||||
'batch_nb':'{}'.format(self.batch_nb),
|
||||
}
|
||||
tqdm_dic.update(self.tqdm_metrics)
|
||||
return tqdm_dic
|
||||
|
||||
def __layout_bookeeping(self):
|
||||
# training bookeeping
|
||||
self.total_batch_nb = 0
|
||||
self.running_loss = []
|
||||
self.avg_loss = 0
|
||||
self.batch_nb = 0
|
||||
self.tqdm_metrics = {}
|
||||
|
||||
# determine number of training batches
|
||||
nb_tng_batches = self.model.nb_batches(self.tng_dataloader)
|
||||
self.nb_tng_batches = int(nb_tng_batches * self.train_percent_check)
|
||||
|
||||
# determine number of validation batches
|
||||
nb_val_batches = self.model.nb_batches(self.val_dataloader)
|
||||
nb_val_batches = int(nb_val_batches * self.val_percent_check)
|
||||
nb_val_batches = max(1, nb_val_batches)
|
||||
self.nb_val_batches = nb_val_batches
|
||||
|
||||
# determine number of test batches
|
||||
nb_test_batches = self.model.nb_batches(self.test_dataloader)
|
||||
self.nb_test_batches = int(nb_test_batches * self.test_percent_check)
|
||||
|
||||
# determine when to check validation
|
||||
self.val_check_batch = int(nb_tng_batches * self.val_check_interval)
|
||||
|
||||
def __add_tqdm_metrics(self, metrics):
|
||||
for k, v in metrics.items():
|
||||
self.tqdm_metrics[k] = v
|
||||
|
||||
def validate(self, model, dataloader, max_batches):
|
||||
"""
|
||||
Run validation code
|
||||
:param model: PT model
|
||||
:param dataloader: PT dataloader
|
||||
:param max_batches: Scalar
|
||||
:return:
|
||||
"""
|
||||
print('validating...')
|
||||
|
||||
# enable eval mode
|
||||
model.zero_grad()
|
||||
model.eval()
|
||||
|
||||
# disable gradients to save memory
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# bookkeeping
|
||||
outputs = []
|
||||
|
||||
# run training
|
||||
for i, data_batch in enumerate(dataloader):
|
||||
|
||||
if data_batch is None:
|
||||
continue
|
||||
|
||||
# stop short when on fast dev run
|
||||
if max_batches is not None and i >= max_batches:
|
||||
break
|
||||
|
||||
# -----------------
|
||||
# RUN VALIDATION STEP
|
||||
# -----------------
|
||||
output = model.validation_step(data_batch)
|
||||
outputs.append(output)
|
||||
|
||||
# batch done
|
||||
if self.enable_tqdm and self.prog_bar is not None:
|
||||
self.prog_bar.update(1)
|
||||
|
||||
# give model a chance to do something with the outputs
|
||||
val_results = model.validation_end(outputs)
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
|
||||
# enable gradients to save memory
|
||||
torch.set_grad_enabled(True)
|
||||
return val_results
|
||||
|
||||
def __get_dataloaders(self, model):
|
||||
"""
|
||||
Dataloaders are provided by the model
|
||||
:param model:
|
||||
:return:
|
||||
"""
|
||||
self.tng_dataloader = model.tng_dataloader
|
||||
self.test_dataloader = model.test_dataloader
|
||||
self.val_dataloader = model.val_dataloader
|
||||
|
||||
# -----------------------------
|
||||
# MODEL TRAINING
|
||||
# -----------------------------
|
||||
def fit(self, model):
|
||||
self.model = model
|
||||
|
||||
# transfer data loaders from model
|
||||
self.__get_dataloaders(model)
|
||||
|
||||
# init training constants
|
||||
self.__layout_bookeeping()
|
||||
|
||||
# CHOOSE OPTIMIZER
|
||||
# filter out the weights that were done on gpu so we can load on good old cpus
|
||||
self.optimizers = model.configure_optimizers()
|
||||
|
||||
# add lr schedulers
|
||||
if self.lr_scheduler_milestones is not None:
|
||||
for optimizer in self.optimizers:
|
||||
scheduler = MultiStepLR(optimizer, self.lr_scheduler_milestones)
|
||||
self.lr_schedulers.append(scheduler)
|
||||
|
||||
# print model summary
|
||||
model.summarize()
|
||||
|
||||
# put on gpu if needed
|
||||
if self.on_gpu:
|
||||
model = model.cuda()
|
||||
|
||||
# run tiny validation to make sure program won't crash during val
|
||||
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
|
||||
|
||||
# save exp to get started
|
||||
self.experiment.save()
|
||||
|
||||
# enable cluster checkpointing
|
||||
self.enable_auto_hpc_walltime_manager()
|
||||
|
||||
# ---------------------------
|
||||
# CORE TRAINING LOOP
|
||||
# ---------------------------
|
||||
self.__train()
|
||||
|
||||
def __train(self):
|
||||
# run all epochs
|
||||
for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
|
||||
# update the lr scheduler
|
||||
for lr_scheduler in self.lr_schedulers:
|
||||
lr_scheduler.step()
|
||||
|
||||
self.model.current_epoch = epoch_nb
|
||||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_epoch_start'):
|
||||
self.model.on_epoch_start()
|
||||
|
||||
self.current_epoch = epoch_nb
|
||||
self.total_batches = self.nb_tng_batches + self.nb_val_batches
|
||||
self.batch_loss_value = 0 # accumulated grads
|
||||
|
||||
# init progbar when requested
|
||||
if self.enable_tqdm:
|
||||
self.prog_bar = tqdm.tqdm(range(self.total_batches), position=self.process_position)
|
||||
|
||||
for batch_nb, data_batch in enumerate(self.tng_dataloader):
|
||||
self.batch_nb = batch_nb
|
||||
self.global_step += 1
|
||||
self.model.global_step = self.global_step
|
||||
|
||||
# stop when the flag is changed or we've gone past the amount requested in the batches
|
||||
self.total_batch_nb += 1
|
||||
met_batch_limit = batch_nb > self.nb_tng_batches
|
||||
if met_batch_limit:
|
||||
break
|
||||
|
||||
# ---------------
|
||||
# RUN TRAIN STEP
|
||||
# ---------------
|
||||
self.__run_tng_batch(data_batch)
|
||||
|
||||
# ---------------
|
||||
# RUN VAL STEP
|
||||
# ---------------
|
||||
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
|
||||
if self.fast_dev_run or is_val_check_batch:
|
||||
self.__run_validation()
|
||||
|
||||
# when batch should be saved
|
||||
if (batch_nb + 1) % self.log_save_interval == 0:
|
||||
self.experiment.save()
|
||||
|
||||
# when metrics should be logged
|
||||
if batch_nb % self.add_log_row_interval == 0:
|
||||
# count items in memory
|
||||
# nb_params, nb_tensors = count_mem_items()
|
||||
|
||||
metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic)
|
||||
|
||||
# add gpu memory
|
||||
if self.on_gpu:
|
||||
mem_map = get_gpu_memory_map()
|
||||
metrics.update(mem_map)
|
||||
|
||||
# add norms
|
||||
if self.track_grad_norm > 0:
|
||||
grad_norm_dic = self.model.grad_norm(self.track_grad_norm)
|
||||
metrics.update(grad_norm_dic)
|
||||
|
||||
# log metrics
|
||||
self.experiment.log(metrics)
|
||||
self.experiment.save()
|
||||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_batch_end'):
|
||||
self.model.on_batch_end()
|
||||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_epoch_end'):
|
||||
self.model.on_epoch_end()
|
||||
|
||||
# early stopping
|
||||
if self.enable_early_stop:
|
||||
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb, logs=self.__tng_tqdm_dic)
|
||||
met_min_epochs = epoch_nb > self.min_nb_epochs
|
||||
|
||||
# stop training
|
||||
stop = should_stop and met_min_epochs
|
||||
if stop:
|
||||
return
|
||||
|
||||
def __run_tng_batch(self, data_batch):
|
||||
if data_batch is None:
|
||||
return
|
||||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_batch_start'):
|
||||
self.model.on_batch_start()
|
||||
|
||||
if self.enable_tqdm:
|
||||
self.prog_bar.update(1)
|
||||
|
||||
# forward pass
|
||||
# return a scalar value and a dic with tqdm metrics
|
||||
loss, model_specific_tqdm_metrics_dic = self.model.training_step(data_batch)
|
||||
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
|
||||
|
||||
# backward pass
|
||||
loss.backward()
|
||||
self.batch_loss_value += loss.item()
|
||||
|
||||
# gradient update with accumulated gradients
|
||||
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
|
||||
|
||||
# update gradients across all optimizers
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.step()
|
||||
|
||||
# clear gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# queuing loss across batches blows it up proportionally... divide out the number accumulated
|
||||
self.batch_loss_value = self.batch_loss_value / self.accumulate_grad_batches
|
||||
|
||||
# track loss
|
||||
self.running_loss.append(self.batch_loss_value)
|
||||
self.batch_loss_value = 0
|
||||
self.avg_loss = np.mean(self.running_loss[-100:])
|
||||
|
||||
# update progbar
|
||||
if self.enable_tqdm:
|
||||
# add model specific metrics
|
||||
tqdm_metrics = self.__tng_tqdm_dic
|
||||
self.prog_bar.set_postfix(**tqdm_metrics)
|
||||
|
||||
# activate batch end hook
|
||||
if self.__is_function_implemented('on_batch_end'):
|
||||
self.model.on_batch_end()
|
||||
|
||||
def __run_validation(self):
|
||||
# decide if can check epochs
|
||||
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
|
||||
if self.fast_dev_run:
|
||||
print('skipping to check performance bc of --fast_dev_run')
|
||||
elif not can_check_epoch:
|
||||
return
|
||||
|
||||
try:
|
||||
# hook
|
||||
if self.__is_function_implemented('on_pre_performance_check'):
|
||||
self.model.on_pre_performance_check()
|
||||
|
||||
# use full val set on end of epoch
|
||||
# use a small portion otherwise
|
||||
max_batches = None if not self.fast_dev_run else 1
|
||||
model_specific_tqdm_metrics_dic = self.validate(
|
||||
self.model,
|
||||
self.val_dataloader,
|
||||
max_batches
|
||||
)
|
||||
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
|
||||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_post_performance_check'):
|
||||
self.model.on_post_performance_check()
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(traceback.print_exc())
|
||||
|
||||
if self.enable_tqdm:
|
||||
# add model specific metrics
|
||||
tqdm_metrics = self.__tng_tqdm_dic
|
||||
self.prog_bar.set_postfix(**tqdm_metrics)
|
||||
|
||||
# model checkpointing
|
||||
print('save callback...')
|
||||
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, logs=self.__tng_tqdm_dic)
|
|
@ -0,0 +1,40 @@
|
|||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
"""
|
||||
Module to describe gradients
|
||||
"""
|
||||
|
||||
|
||||
class GradInformation(nn.Module):
|
||||
|
||||
def grad_norm(self, norm_type):
|
||||
results = {}
|
||||
total_norm = 0
|
||||
for i, p in enumerate(self.parameters()):
|
||||
if p.requires_grad:
|
||||
try:
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm ** norm_type
|
||||
norm = param_norm ** (1 / norm_type)
|
||||
|
||||
results['grad_{}_norm_{}'.format(norm_type, i)] = round(norm.data.cpu().numpy().flatten()[0], 3)
|
||||
except Exception as e:
|
||||
# this param had no grad
|
||||
pass
|
||||
|
||||
total_norm = total_norm ** (1. / norm_type)
|
||||
results['grad_{}_norm_total'.format(norm_type)] = round(total_norm.data.cpu().numpy().flatten()[0], 3)
|
||||
return results
|
||||
|
||||
|
||||
def describe_grads(self):
|
||||
for p in self.parameters():
|
||||
g = p.grad.data.numpy().flatten()
|
||||
print(np.max(g), np.min(g), np.mean(g))
|
||||
|
||||
|
||||
def describe_params(self):
|
||||
for p in self.parameters():
|
||||
g = p.data.numpy().flatten()
|
||||
print(np.max(g), np.min(g), np.mean(g))
|
|
@ -0,0 +1,20 @@
|
|||
import torch
|
||||
|
||||
class ModelHooks(torch.nn.Module):
|
||||
def on_batch_start(self):
|
||||
pass
|
||||
|
||||
def on_batch_end(self):
|
||||
pass
|
||||
|
||||
def on_epoch_start(self):
|
||||
pass
|
||||
|
||||
def on_epoch_end(self):
|
||||
pass
|
||||
|
||||
def on_pre_performance_check(self):
|
||||
pass
|
||||
|
||||
def on_post_performance_check(self):
|
||||
pass
|
|
@ -0,0 +1,180 @@
|
|||
import torch
|
||||
import gc
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
'''
|
||||
Generates a summary of a model's layers and dimensionality
|
||||
'''
|
||||
|
||||
|
||||
class ModelSummary(object):
|
||||
|
||||
def __init__(self, model):
|
||||
'''
|
||||
Generates summaries of model layers and dimensions.
|
||||
'''
|
||||
self.model = model
|
||||
self.in_sizes = []
|
||||
self.out_sizes = []
|
||||
|
||||
self.summarize()
|
||||
|
||||
def __str__(self):
|
||||
return self.summary.__str__()
|
||||
|
||||
def __repr__(self):
|
||||
return self.summary.__str__()
|
||||
|
||||
def get_variable_sizes(self):
|
||||
'''Run sample input through each layer to get output sizes'''
|
||||
mods = list(self.model.modules())
|
||||
in_sizes = []
|
||||
out_sizes = []
|
||||
input_ = self.example_input_array
|
||||
for i in range(1, len(mods)):
|
||||
m = mods[i]
|
||||
if type(input_) is list or type(input_) is tuple:
|
||||
out = m(*input_)
|
||||
else:
|
||||
out = m(input_)
|
||||
|
||||
if type(input_) is tuple or type(input_) is list:
|
||||
in_size = []
|
||||
for x in input_:
|
||||
if type(x) is list:
|
||||
in_size.append(len(x))
|
||||
else:
|
||||
in_size.append(x.size())
|
||||
else:
|
||||
in_size = np.array(input_.size())
|
||||
|
||||
in_sizes.append(in_size)
|
||||
|
||||
if type(out) is tuple or type(out) is list:
|
||||
out_size = np.asarray([x.size() for x in out])
|
||||
else:
|
||||
out_size = np.array(out.size())
|
||||
|
||||
out_sizes.append(out_size)
|
||||
input_ = out
|
||||
|
||||
self.in_sizes = in_sizes
|
||||
self.out_sizes = out_sizes
|
||||
return
|
||||
|
||||
def get_layer_names(self):
|
||||
'''Collect Layer Names'''
|
||||
mods = list(self.model.named_modules())
|
||||
names = []
|
||||
layers = []
|
||||
for m in mods[1:]:
|
||||
names += [m[0]]
|
||||
layers += [str(m[1].__class__)]
|
||||
|
||||
layer_types = [x.split('.')[-1][:-2] for x in layers]
|
||||
|
||||
self.layer_names = names
|
||||
self.layer_types = layer_types
|
||||
return
|
||||
|
||||
def get_parameter_sizes(self):
|
||||
'''Get sizes of all parameters in `model`'''
|
||||
mods = list(self.model.modules())
|
||||
sizes = []
|
||||
|
||||
for i in range(1,len(mods)):
|
||||
m = mods[i]
|
||||
p = list(m.parameters())
|
||||
modsz = []
|
||||
for j in range(len(p)):
|
||||
modsz.append(np.array(p[j].size()))
|
||||
sizes.append(modsz)
|
||||
|
||||
self.param_sizes = sizes
|
||||
return
|
||||
|
||||
def get_parameter_nums(self):
|
||||
'''Get number of parameters in each layer'''
|
||||
param_nums = []
|
||||
for mod in self.param_sizes:
|
||||
all_params = 0
|
||||
for p in mod:
|
||||
all_params += np.prod(p)
|
||||
param_nums.append(all_params)
|
||||
self.param_nums = param_nums
|
||||
return
|
||||
|
||||
def make_summary(self):
|
||||
'''
|
||||
Makes a summary listing with:
|
||||
|
||||
Layer Name, Layer Type, Input Size, Output Size, Number of Parameters
|
||||
'''
|
||||
|
||||
df = pd.DataFrame( np.zeros( (len(self.layer_names), 3) ) )
|
||||
df.columns = ['Name', 'Type', 'Params']
|
||||
|
||||
df['Name'] = self.layer_names
|
||||
df['Type'] = self.layer_types
|
||||
df['Params'] = self.param_nums
|
||||
|
||||
self.summary = df
|
||||
return
|
||||
|
||||
def summarize(self):
|
||||
self.get_layer_names()
|
||||
self.get_parameter_sizes()
|
||||
self.get_parameter_nums()
|
||||
self.make_summary()
|
||||
|
||||
|
||||
def print_mem_stack():
|
||||
for obj in gc.get_objects():
|
||||
try:
|
||||
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
|
||||
print(type(obj), obj.size())
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def count_mem_items():
|
||||
nb_params = 0
|
||||
nb_tensors = 0
|
||||
for obj in gc.get_objects():
|
||||
try:
|
||||
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
|
||||
obj_type = str(type(obj))
|
||||
if 'parameter' in obj_type:
|
||||
nb_params += 1
|
||||
else:
|
||||
nb_tensors += 1
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return nb_params, nb_tensors
|
||||
|
||||
|
||||
def get_gpu_memory_map():
|
||||
"""Get the current gpu usage.
|
||||
|
||||
Returns
|
||||
-------
|
||||
usage: dict
|
||||
Keys are device ids as integers.
|
||||
Values are memory usage as integers in MB.
|
||||
"""
|
||||
result = subprocess.check_output(
|
||||
[
|
||||
'nvidia-smi', '--query-gpu=memory.used',
|
||||
'--format=csv,nounits,noheader'
|
||||
], encoding='utf-8')
|
||||
# Convert lines into a dictionary
|
||||
gpu_memory = [int(x) for x in result.strip().split('\n')]
|
||||
gpu_memory_map = {}
|
||||
for k, v in zip(range(len(gpu_memory)), gpu_memory):
|
||||
k = f'gpu_{k}'
|
||||
gpu_memory_map[k] = v
|
||||
return gpu_memory_map
|
|
@ -0,0 +1,168 @@
|
|||
import torch
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
class ModelIO(object):
|
||||
|
||||
def load_model_specific(self, checkpoint):
|
||||
"""
|
||||
Do something with the checkpoint
|
||||
:param checkpoint:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_save_dict(self):
|
||||
"""
|
||||
Return specific things for the model
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TrainerIO(object):
|
||||
|
||||
# --------------------
|
||||
# MODEL SAVE CHECKPOINT
|
||||
# --------------------
|
||||
def save_checkpoint(self, filepath):
|
||||
checkpoint = self.dump_checkpoint()
|
||||
|
||||
# do the actual save
|
||||
torch.save(checkpoint, filepath)
|
||||
|
||||
def dump_checkpoint(self):
|
||||
checkpoint = {
|
||||
'epoch': self.current_epoch,
|
||||
'checkpoint_callback_best': self.checkpoint_callback.best,
|
||||
'early_stop_callback_wait': self.early_stop_callback.wait,
|
||||
'early_stop_callback_patience': self.early_stop_callback.patience,
|
||||
'global_step': self.global_step
|
||||
}
|
||||
|
||||
optimizer_states = []
|
||||
for i, optimizer in enumerate(self.optimizers):
|
||||
optimizer_states.append(optimizer.state_dict())
|
||||
|
||||
checkpoint['optimizer_states'] = optimizer_states
|
||||
|
||||
# request what to save from the model
|
||||
checkpoint_dict = self.model.get_save_dict()
|
||||
|
||||
# merge trainer and model saving items
|
||||
checkpoint.update(checkpoint_dict)
|
||||
return checkpoint
|
||||
|
||||
# --------------------
|
||||
# HPC IO
|
||||
# --------------------
|
||||
def enable_auto_hpc_walltime_manager(self):
|
||||
if self.cluster is None:
|
||||
return
|
||||
|
||||
# allow test tube to handle model check pointing automatically
|
||||
self.cluster.set_checkpoint_save_function(
|
||||
self.hpc_save,
|
||||
kwargs={
|
||||
'folderpath': self.checkpoint_callback.filepath,
|
||||
'experiment': self.experiment
|
||||
}
|
||||
)
|
||||
self.cluster.set_checkpoint_load_function(
|
||||
self.hpc_load,
|
||||
kwargs={
|
||||
'folderpath': self.checkpoint_callback.filepath,
|
||||
'on_gpu': self.on_gpu
|
||||
}
|
||||
)
|
||||
|
||||
def restore_training_state(self, checkpoint):
|
||||
"""
|
||||
Restore trainer state.
|
||||
Model will get its change to update
|
||||
:param checkpoint:
|
||||
:return:
|
||||
"""
|
||||
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
|
||||
self.early_stop_callback.wait = checkpoint['early_stop_callback_wait']
|
||||
self.early_stop_callback.patience = checkpoint['early_stop_callback_patience']
|
||||
self.global_step = checkpoint['global_step']
|
||||
|
||||
# restore the optimizers
|
||||
optimizer_states = checkpoint['optimizer_states']
|
||||
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
|
||||
optimizer.load_state_dict(opt_state)
|
||||
|
||||
# ----------------------------------
|
||||
# PRIVATE OPS
|
||||
# ----------------------------------
|
||||
def hpc_save(self, folderpath, experiment):
|
||||
# save exp to make sure we get all the metrics
|
||||
experiment.save()
|
||||
|
||||
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
|
||||
|
||||
if not os.path.exists(folderpath):
|
||||
os.makedirs(folderpath, exist_ok=True)
|
||||
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number)
|
||||
|
||||
# request what to save from the model
|
||||
checkpoint_dict = self.dump_checkpoint()
|
||||
|
||||
# do the actual save
|
||||
torch.save(checkpoint_dict, filepath)
|
||||
|
||||
def hpc_load(self, folderpath, on_gpu):
|
||||
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
|
||||
|
||||
if on_gpu:
|
||||
checkpoint = torch.load(filepath)
|
||||
else:
|
||||
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
|
||||
|
||||
# load training state
|
||||
self.restore_training_state(checkpoint)
|
||||
|
||||
# load model state
|
||||
self.model.load_model_specific(checkpoint)
|
||||
|
||||
def max_ckpt_in_folder(self, path):
|
||||
files = os.listdir(path)
|
||||
ckpt_vs = []
|
||||
for name in files:
|
||||
name = name.split('ckpt_')[-1]
|
||||
name = re.sub('[^0-9]', '', name)
|
||||
ckpt_vs.append(int(name))
|
||||
|
||||
return max(ckpt_vs)
|
||||
|
||||
|
||||
def load_hparams_from_tags_csv(tags_csv):
|
||||
from argparse import Namespace
|
||||
import pandas as pd
|
||||
|
||||
tags_df = pd.read_csv(tags_csv)
|
||||
dic = tags_df.to_dict(orient='records')
|
||||
|
||||
ns_dict = {row['key']: convert(row['value']) for row in dic}
|
||||
|
||||
ns = Namespace(**ns_dict)
|
||||
return ns
|
||||
|
||||
|
||||
def convert(val):
|
||||
constructors = [int, float, str]
|
||||
|
||||
if type(val) is str:
|
||||
if val.lower() == 'true':
|
||||
return True
|
||||
if val.lower() == 'false':
|
||||
return False
|
||||
|
||||
for c in constructors:
|
||||
try:
|
||||
return c(val)
|
||||
except ValueError:
|
||||
pass
|
||||
return val
|
|
@ -0,0 +1,22 @@
|
|||
from torch import nn
|
||||
from torch import optim
|
||||
|
||||
|
||||
class OptimizerConfig(nn.Module):
|
||||
|
||||
def choose_optimizer(self, optimizer, params, optimizer_params, opt_name_key):
|
||||
if optimizer == 'adam':
|
||||
optimizer = optim.Adam(params, **optimizer_params)
|
||||
if optimizer == 'sparse_adam':
|
||||
optimizer = optim.SparseAdam(params, **optimizer_params)
|
||||
if optimizer == 'sgd':
|
||||
optimizer = optim.SGD(params, **optimizer_params)
|
||||
if optimizer == 'adadelta':
|
||||
optimizer = optim.Adadelta(params, **optimizer_params)
|
||||
|
||||
# transfer opt state if loaded
|
||||
if opt_name_key in self.loaded_optimizer_states_dict:
|
||||
state = self.loaded_optimizer_states_dict[opt_name_key]
|
||||
optimizer.load_state_dict(state)
|
||||
|
||||
return optimizer
|
|
@ -0,0 +1,167 @@
|
|||
import os
|
||||
import torch
|
||||
import math
|
||||
|
||||
from pytorch_lightning.root_module.memory import ModelSummary
|
||||
from pytorch_lightning.root_module.grads import GradInformation
|
||||
from pytorch_lightning.root_module.model_saving import ModelIO, load_hparams_from_tags_csv
|
||||
from pytorch_lightning.root_module.optimization import OptimizerConfig
|
||||
from pytorch_lightning.root_module.hooks import ModelHooks
|
||||
|
||||
|
||||
class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(RootModule, self).__init__()
|
||||
self.hparams = hparams
|
||||
self.on_gpu = hparams.on_gpu
|
||||
self.dtype = torch.FloatTensor
|
||||
self.exp_save_path = None
|
||||
self.current_epoch = 0
|
||||
self.global_step = 0
|
||||
self.loaded_optimizer_states_dict = {}
|
||||
self.fast_dev_run = hparams.fast_dev_run
|
||||
self.overfit = hparams.overfit
|
||||
self.gradient_clip = hparams.gradient_clip
|
||||
self.num = 2
|
||||
|
||||
# computed vars for the dataloaders
|
||||
self._tng_dataloader = None
|
||||
self._val_dataloader = None
|
||||
self._test_dataloader = None
|
||||
|
||||
if self.on_gpu:
|
||||
print('running on gpu...')
|
||||
self.dtype = torch.cuda.FloatTensor
|
||||
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Expand model in into whatever you need.
|
||||
Also need to return the target
|
||||
:param x:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def validation_step(self, data_batch):
|
||||
"""
|
||||
return whatever outputs will need to be aggregated in validation_end
|
||||
:param data_batch:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def validation_end(self, outputs):
|
||||
"""
|
||||
Outputs has the appended output after each validation step
|
||||
:param outputs:
|
||||
:return: dic_with_metrics for tqdm
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, data_batch):
|
||||
"""
|
||||
return loss, dict with metrics for tqdm
|
||||
:param data_batch:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Return array of optimizers
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_tng_log_metrics(self, logs):
|
||||
"""
|
||||
Chance to update metrics to be logged for training step.
|
||||
For example, add music, images, etc... to log
|
||||
:param logs:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def loss(self, *args, **kwargs):
|
||||
"""
|
||||
Expand model_out into your components
|
||||
:param model_out:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def summarize(self):
|
||||
model_summary = ModelSummary(self)
|
||||
print(model_summary)
|
||||
|
||||
def nb_batches(self, dataloader):
|
||||
a = math.ceil(float(len(dataloader.dataset) / self.batch_size))
|
||||
return int(a)
|
||||
|
||||
def freeze(self):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def unfreeze(self):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
@property
|
||||
def tng_dataloader(self):
|
||||
"""
|
||||
Implement a function to load an h5py of this data
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def test_dataloader(self):
|
||||
"""
|
||||
Implement a function to load an h5py of this data
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def val_dataloader(self):
|
||||
"""
|
||||
Implement a function to load an h5py of this data
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_process_position(gpus):
|
||||
try:
|
||||
current_gpu = os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
gpu_ids = gpus.split(';')
|
||||
process_position = gpu_ids.index(current_gpu)
|
||||
return process_position, current_gpu
|
||||
except Exception as e:
|
||||
return 0, 0
|
||||
|
||||
@classmethod
|
||||
def load_from_metrics(cls, weights_path, tags_csv, on_gpu):
|
||||
"""
|
||||
Primary way of loading model from csv weights path
|
||||
:param weights_path:
|
||||
:param tags_csv:
|
||||
:param on_gpu:
|
||||
:return:
|
||||
"""
|
||||
hparams = load_hparams_from_tags_csv(tags_csv)
|
||||
hparams.__setattr__('on_gpu', on_gpu)
|
||||
|
||||
if on_gpu:
|
||||
checkpoint = torch.load(weights_path)
|
||||
else:
|
||||
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
|
||||
|
||||
model = cls(hparams)
|
||||
|
||||
# allow model to load
|
||||
model.load_model_specific(checkpoint)
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
return model
|
|
@ -0,0 +1,215 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from test_tube import HyperOptArgumentParser, Experiment, SlurmCluster
|
||||
from pytorch_lightning.models.trainer import Trainer
|
||||
from pytorch_lightning.utils.arg_parse import add_default_args
|
||||
from time import sleep
|
||||
|
||||
from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint
|
||||
|
||||
SEED = 2334
|
||||
torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
|
||||
# ---------------------
|
||||
# DEFINE MODEL HERE
|
||||
# ---------------------
|
||||
from pytorch_lightning.models.sample_model_template.model_template import ExampleModel1
|
||||
# ---------------------
|
||||
|
||||
AVAILABLE_MODELS = {
|
||||
'model_1': ExampleModel1
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
Allows training by using command line arguments
|
||||
|
||||
Run by:
|
||||
# TYPE YOUR RUN COMMAND HERE
|
||||
"""
|
||||
|
||||
|
||||
def main_local(hparams):
|
||||
main(hparams, None, None)
|
||||
|
||||
|
||||
def main(hparams, cluster, results_dict):
|
||||
"""
|
||||
Main training routine specific for this project
|
||||
:param hparams:
|
||||
:return:
|
||||
"""
|
||||
on_gpu = torch.cuda.is_available()
|
||||
if hparams.disable_cuda:
|
||||
on_gpu = False
|
||||
|
||||
device = 'cuda' if on_gpu else 'cpu'
|
||||
hparams.__setattr__('device', device)
|
||||
hparams.__setattr__('on_gpu', on_gpu)
|
||||
hparams.__setattr__('nb_gpus', torch.cuda.device_count())
|
||||
hparams.__setattr__('inference_mode', hparams.model_load_weights_path is not None)
|
||||
|
||||
# delay each training start to not overwrite logs
|
||||
process_position, current_gpu = TRAINING_MODEL.get_process_position(hparams.gpus)
|
||||
sleep(process_position + 1)
|
||||
|
||||
# init experiment
|
||||
exp = Experiment(
|
||||
name=hparams.tt_name,
|
||||
debug=hparams.debug,
|
||||
save_dir=hparams.tt_save_path,
|
||||
version=hparams.hpc_exp_number,
|
||||
autosave=False,
|
||||
description=hparams.tt_description
|
||||
)
|
||||
|
||||
exp.argparse(hparams)
|
||||
exp.save()
|
||||
|
||||
# build model
|
||||
print('loading model...')
|
||||
model = TRAINING_MODEL(hparams)
|
||||
print('model built')
|
||||
|
||||
# callbacks
|
||||
early_stop = EarlyStopping(
|
||||
monitor=hparams.early_stop_metric,
|
||||
patience=hparams.early_stop_patience,
|
||||
verbose=True,
|
||||
mode=hparams.early_stop_mode
|
||||
)
|
||||
|
||||
model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
|
||||
checkpoint = ModelCheckpoint(
|
||||
filepath=model_save_path,
|
||||
save_function=None,
|
||||
save_best_only=True,
|
||||
verbose=True,
|
||||
monitor=hparams.model_save_monitor_value,
|
||||
mode=hparams.model_save_monitor_mode
|
||||
)
|
||||
|
||||
# configure trainer
|
||||
trainer = Trainer(
|
||||
experiment=exp,
|
||||
on_gpu=on_gpu,
|
||||
cluster=cluster,
|
||||
enable_tqdm=hparams.enable_tqdm,
|
||||
overfit_pct=hparams.overfit,
|
||||
track_grad_norm=hparams.track_grad_norm,
|
||||
fast_dev_run=hparams.fast_dev_run,
|
||||
check_val_every_n_epoch=hparams.check_val_every_n_epoch,
|
||||
accumulate_grad_batches=hparams.accumulate_grad_batches,
|
||||
process_position=process_position,
|
||||
current_gpu_name=current_gpu,
|
||||
checkpoint_callback=checkpoint,
|
||||
early_stop_callback=early_stop,
|
||||
enable_early_stop=hparams.enable_early_stop,
|
||||
max_nb_epochs=hparams.max_nb_epochs,
|
||||
min_nb_epochs=hparams.min_nb_epochs,
|
||||
train_percent_check=hparams.train_percent_check,
|
||||
val_percent_check=hparams.val_percent_check,
|
||||
test_percent_check=hparams.test_percent_check,
|
||||
val_check_interval=hparams.val_check_interval,
|
||||
log_save_interval=hparams.log_save_interval,
|
||||
add_log_row_interval=hparams.add_log_row_interval,
|
||||
lr_scheduler_milestones=hparams.lr_scheduler_milestones
|
||||
)
|
||||
|
||||
# train model
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def get_default_parser(strategy, root_dir):
|
||||
|
||||
possible_model_names = list(AVAILABLE_MODELS.keys())
|
||||
parser = HyperOptArgumentParser(strategy=strategy, add_help=False)
|
||||
add_default_args(parser, root_dir, possible_model_names, SEED)
|
||||
return parser
|
||||
|
||||
|
||||
def get_model_name(args):
|
||||
for i, arg in enumerate(args):
|
||||
if 'model_name' in arg:
|
||||
return args[i+1]
|
||||
|
||||
|
||||
def optimize_on_cluster(hyperparams):
|
||||
# enable cluster training
|
||||
cluster = SlurmCluster(
|
||||
hyperparam_optimizer=hyperparams,
|
||||
log_path=hyperparams.tt_save_path,
|
||||
test_tube_exp_name=hyperparams.tt_name
|
||||
)
|
||||
|
||||
# email for cluster coms
|
||||
cluster.notify_job_status(email='add_email_here', on_done=True, on_fail=True)
|
||||
|
||||
# configure cluster
|
||||
cluster.per_experiment_nb_gpus = hyperparams.per_experiment_nb_gpus
|
||||
cluster.job_time = '48:00:00'
|
||||
cluster.gpu_type = '1080ti'
|
||||
cluster.memory_mb_per_node = 48000
|
||||
|
||||
# any modules for code to run in env
|
||||
cluster.add_command('source activate pytorch_lightning')
|
||||
|
||||
# name of exp
|
||||
job_display_name = hyperparams.tt_name.split('_')[0]
|
||||
job_display_name = job_display_name[0:3]
|
||||
|
||||
# run hopt
|
||||
print('submitting jobs...')
|
||||
cluster.optimize_parallel_cluster_gpu(
|
||||
main,
|
||||
nb_trials=hyperparams.nb_hopt_trials,
|
||||
job_name=job_display_name
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model_name = get_model_name(sys.argv)
|
||||
|
||||
# use default args
|
||||
root_dir = os.path.split(os.path.dirname(sys.modules['__main__'].__file__))[0]
|
||||
parent_parser = get_default_parser(strategy='random_search', root_dir=root_dir)
|
||||
|
||||
# allow model to overwrite or extend args
|
||||
TRAINING_MODEL = AVAILABLE_MODELS[model_name]
|
||||
parser = TRAINING_MODEL.add_model_specific_args(parent_parser)
|
||||
parser.json_config('-c', '--config', default=root_dir + '/run_configs/local.json')
|
||||
hyperparams = parser.parse_args()
|
||||
|
||||
# format GPU layout
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
gpu_ids = hyperparams.gpus.split(';')
|
||||
|
||||
# RUN TRAINING
|
||||
if hyperparams.on_cluster:
|
||||
print('RUNNING ON SLURM CLUSTER')
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_ids)
|
||||
optimize_on_cluster(hyperparams)
|
||||
|
||||
elif hyperparams.single_run_gpu:
|
||||
print(f'RUNNING 1 TRIAL ON GPU. gpu: {gpu_ids[0]}')
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[0]
|
||||
main(hyperparams, None, None)
|
||||
|
||||
elif hyperparams.local or hyperparams.single_run:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
||||
print('RUNNING LOCALLY')
|
||||
main(hyperparams, None, None)
|
||||
|
||||
else:
|
||||
print(f'RUNNING MULTI GPU. GPU ids: {gpu_ids}')
|
||||
hyperparams.optimize_parallel_gpu(
|
||||
main_local,
|
||||
gpu_ids=gpu_ids,
|
||||
nb_trials=hyperparams.nb_hopt_trials,
|
||||
nb_workers=len(gpu_ids)
|
||||
)
|
|
@ -0,0 +1,67 @@
|
|||
def add_default_args(parser, root_dir, possible_model_names, rand_seed):
|
||||
|
||||
# tng, test, val check intervals
|
||||
parser.add_argument('--eval_test_set', dest='eval_test_set', action='store_true', help='true = run test set also')
|
||||
parser.add_argument('--check_val_every_n_epoch', default=1, type=int, help='check val every n epochs')
|
||||
parser.opt_list('--accumulate_grad_batches', default=1, type=int, tunable=False,
|
||||
help='accumulates gradients k times before applying update. Simulates huge batch size')
|
||||
parser.add_argument('--max_nb_epochs', default=200, type=int, help='cap epochs')
|
||||
parser.add_argument('--min_nb_epochs', default=2, type=int, help='min epochs')
|
||||
parser.add_argument('--train_percent_check', default=1.0, type=float, help='how much of tng set to check')
|
||||
parser.add_argument('--val_percent_check', default=1.0, type=float, help='how much of val set to check')
|
||||
parser.add_argument('--test_percent_check', default=1.0, type=float, help='how much of test set to check')
|
||||
|
||||
parser.add_argument('--val_check_interval', default=0.95, type=float, help='how much within 1 epoch to check val')
|
||||
parser.add_argument('--log_save_interval', default=100, type=int, help='how many batches between log saves')
|
||||
parser.add_argument('--add_log_row_interval', default=100, type=int, help='add log every k batches')
|
||||
|
||||
# early stopping
|
||||
parser.add_argument('--disable_early_stop', dest='enable_early_stop', action='store_false')
|
||||
parser.add_argument('--early_stop_metric', default='val_acc', type=str)
|
||||
parser.add_argument('--early_stop_mode', default='min', type=str)
|
||||
parser.add_argument('--early_stop_patience', default=3, type=int, help='number of epochs until stop')
|
||||
|
||||
# gradient handling
|
||||
parser.add_argument('--gradient_clip', default=-1, type=int)
|
||||
parser.add_argument('--track_grad_norm', default=-1, type=int, help='if > 0, will track this grad norm')
|
||||
|
||||
# model saving
|
||||
parser.add_argument('--model_save_path', default=root_dir + '/model_weights')
|
||||
parser.add_argument('--model_save_monitor_value', default='val_acc')
|
||||
parser.add_argument('--model_save_monitor_mode', default='max')
|
||||
|
||||
# model paths
|
||||
parser.add_argument('--model_load_weights_path', default=None, type=str)
|
||||
parser.add_argument('--model_name', default='', help=','.join(possible_model_names))
|
||||
|
||||
# test_tube settings
|
||||
parser.add_argument('-en', '--tt_name', default='r_lib_')
|
||||
parser.add_argument('-td', '--tt_description', default='test research lib')
|
||||
parser.add_argument('--tt_save_path', default=root_dir + '/test_tube_logs', help='logging dir')
|
||||
parser.add_argument('--enable_single_run', dest='single_run', action='store_true')
|
||||
parser.add_argument('--nb_hopt_trials', default=1, type=int)
|
||||
parser.add_argument('--log_stdout', dest='log_stdout', action='store_true')
|
||||
|
||||
# GPU
|
||||
parser.add_argument('--per_experiment_nb_gpus', default=1, type=int)
|
||||
parser.add_argument('--gpus', default='0', type=str)
|
||||
parser.add_argument('--single_run_gpu', dest='single_run_gpu', action='store_true')
|
||||
parser.add_argument('--disable_cuda', dest='disable_cuda', action='store_true')
|
||||
|
||||
# run on hpc
|
||||
parser.add_argument('--on_cluster', dest='on_cluster', action='store_true')
|
||||
|
||||
# FAST training
|
||||
# use these settings to make sure network has no bugs without running a full dataset
|
||||
parser.add_argument('--fast_dev_run', dest='fast_dev_run', default=False, action='store_true', help='runs validation after 1 tng step')
|
||||
parser.add_argument('--enable_tqdm', dest='enable_tqdm', default=False, action='store_true', help='false removes the prog bar')
|
||||
parser.add_argument('--overfit', default=-1, type=float, help='% of dataset to use with this option. float, or -1 for none')
|
||||
|
||||
# debug args
|
||||
parser.add_argument('--random_seed', default=rand_seed, type=int)
|
||||
parser.add_argument('--live', dest='live', action='store_true', help='runs on gpu without cluster')
|
||||
parser.add_argument('--enable_debug', dest='debug', action='store_true', help='enables/disables test tube')
|
||||
parser.add_argument('--enable_local', dest='local', action='store_true', help='enables local tng')
|
||||
|
||||
# optimizer
|
||||
parser.add_argument('--lr_scheduler_milestones', default=None, type=str)
|
|
@ -0,0 +1,107 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class PretrainedEmbedding(torch.nn.Embedding):
|
||||
|
||||
def __init__(self, embedding_path, embedding_dim, task_vocab, freeze=True, *args, **kwargs):
|
||||
"""
|
||||
Loads a prebuilt pytorch embedding from any embedding formated file.
|
||||
Padding=0 by default.
|
||||
|
||||
>>> emb = PretrainedEmbedding(embedding_path='glove.840B.300d.txt',embedding_dim=300, task_vocab={'hello': 1, 'world': 2})
|
||||
>>> data = torch.Tensor([[0, 1], [0, 2]]).long()
|
||||
>>> embedded = emb(data)
|
||||
tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
|
||||
[ 0.2523, 0.1018, -0.6748, ..., 0.1787, -0.5192, 0.3359]],
|
||||
|
||||
[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
|
||||
[-0.0067, 0.2224, 0.2771, ..., 0.0594, 0.0014, 0.0987]]])
|
||||
|
||||
|
||||
:param embedding_path:
|
||||
:param emb_dim:
|
||||
:param task_vocab:
|
||||
:param freeze:
|
||||
:return:
|
||||
"""
|
||||
# count the vocab
|
||||
self.vocab_size = max(task_vocab.values()) + 1
|
||||
super(PretrainedEmbedding, self).__init__(self.vocab_size, embedding_dim, padding_idx=0, *args, **kwargs)
|
||||
|
||||
# load pretrained embeddings
|
||||
new_emb = self.__load_task_specific_embeddings(deepcopy(task_vocab), embedding_path, embedding_dim, freeze)
|
||||
|
||||
# transfer weights
|
||||
self.weight = new_emb.weight
|
||||
|
||||
# apply freeze
|
||||
self.weight.requires_grad = not freeze
|
||||
|
||||
def __load_task_specific_embeddings(self, vocab_words, embedding_path, emb_dim, freeze):
|
||||
"""
|
||||
Iterates embedding file to only pull out task specific embeddings
|
||||
:param vocab_words:
|
||||
:param embedding_path:
|
||||
:param emb_dim:
|
||||
:param freeze:
|
||||
:return:
|
||||
"""
|
||||
|
||||
# holds final embeddings for relevant words
|
||||
embeddings = np.zeros(shape=(self.vocab_size, emb_dim))
|
||||
|
||||
# load embedding line by line and extract relevant embeddings
|
||||
with open(embedding_path, encoding='utf-8') as f:
|
||||
for line in f:
|
||||
tokens = line.split(' ')
|
||||
word = tokens[0]
|
||||
embedding = tokens[1:]
|
||||
embedding[-1] = embedding[-1][:-1] # remove last new line
|
||||
|
||||
if word in vocab_words:
|
||||
vocab_word_i = vocab_words[word]
|
||||
|
||||
# skip words that try to overwrite pad idx
|
||||
if vocab_word_i == 0:
|
||||
del vocab_words[word]
|
||||
continue
|
||||
|
||||
emb_vals = np.asarray([float(x) for x in embedding])
|
||||
embeddings[vocab_word_i] = emb_vals
|
||||
|
||||
# remove vocab word to early terminate
|
||||
del vocab_words[word]
|
||||
|
||||
# early break
|
||||
if len(vocab_words) == 0:
|
||||
break
|
||||
|
||||
# add random vectors for the non-pretrained words
|
||||
# these are vocab words NOT found in the pretrained embeddings
|
||||
for w, i in vocab_words.items():
|
||||
# skip words that try to overwrite pad idx
|
||||
if i == 0:
|
||||
continue
|
||||
|
||||
embedding = np.random.normal(size=emb_dim)
|
||||
embeddings[i] = embedding
|
||||
|
||||
# turn into pt embedding
|
||||
embeddings = torch.FloatTensor(embeddings)
|
||||
embeddings = torch.nn.Embedding.from_pretrained(embeddings, freeze=freeze)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
emb = PretrainedEmbedding(
|
||||
embedding_path='/Users/waf/Developer/NGV/research-fermat/fermat/.vector_cache/glove.840B.300d.txt',
|
||||
embedding_dim=300,
|
||||
task_vocab={'hello': 1, 'world': 2}
|
||||
)
|
||||
|
||||
data = torch.Tensor([[0, 1], [0, 2]]).long()
|
||||
embedded = emb(data)
|
||||
print(embedded)
|
|
@ -0,0 +1,28 @@
|
|||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
np.seterr(divide='ignore', invalid='ignore')
|
||||
|
||||
|
||||
def plot_confusion_matrix(cm,
|
||||
save_path,
|
||||
normalize=False,
|
||||
title='Confusion matrix',
|
||||
ylabel='y',
|
||||
xlabel='x'):
|
||||
"""
|
||||
This function prints and plots the confusion matrix.
|
||||
Normalization can be applied by setting `normalize=True`.
|
||||
"""
|
||||
if normalize:
|
||||
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||
print("Normalized confusion matrix")
|
||||
else:
|
||||
print('Confusion matrix, without normalization')
|
||||
|
||||
fig = plt.figure()
|
||||
plt.matshow(cm)
|
||||
plt.title(title)
|
||||
plt.colorbar()
|
||||
plt.ylabel(ylabel)
|
||||
plt.xlabel(xlabel)
|
||||
plt.savefig(save_path)
|
|
@ -0,0 +1,261 @@
|
|||
import numpy as np
|
||||
import os, shutil
|
||||
|
||||
|
||||
class Callback(object):
|
||||
"""Abstract base class used to build new callbacks.
|
||||
# Properties
|
||||
params: dict. Training parameters
|
||||
(eg. verbosity, batch size, number of epochs...).
|
||||
model: instance of `keras.models.Model`.
|
||||
Reference of the model being trained.
|
||||
The `logs` dictionary that callback methods
|
||||
take as argument will contain keys for quantities relevant to
|
||||
the current batch or epoch.
|
||||
Currently, the `.fit()` method of the `Sequential` model class
|
||||
will include the following quantities in the `logs` that
|
||||
it passes to its callbacks:
|
||||
on_epoch_end: logs include `acc` and `loss`, and
|
||||
optionally include `val_loss`
|
||||
(if validation is enabled in `fit`), and `val_acc`
|
||||
(if validation and accuracy monitoring are enabled).
|
||||
on_batch_begin: logs include `size`,
|
||||
the number of samples in the current batch.
|
||||
on_batch_end: logs include `loss`, and optionally `acc`
|
||||
(if accuracy monitoring is enabled).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.validation_data = None
|
||||
self.model = None
|
||||
|
||||
def set_params(self, params):
|
||||
self.params = params
|
||||
|
||||
def set_model(self, model):
|
||||
self.model = model
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
pass
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
pass
|
||||
|
||||
def on_batch_begin(self, batch, logs=None):
|
||||
pass
|
||||
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
pass
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
pass
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
pass
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
"""Stop training when a monitored quantity has stopped improving.
|
||||
# Arguments
|
||||
monitor: quantity to be monitored.
|
||||
min_delta: minimum change in the monitored quantity
|
||||
to qualify as an improvement, i.e. an absolute
|
||||
change of less than min_delta, will count as no
|
||||
improvement.
|
||||
patience: number of epochs with no improvement
|
||||
after which training will be stopped.
|
||||
verbose: verbosity mode.
|
||||
mode: one of {auto, min, max}. In `min` mode,
|
||||
training will stop when the quantity
|
||||
monitored has stopped decreasing; in `max`
|
||||
mode it will stop when the quantity
|
||||
monitored has stopped increasing; in `auto`
|
||||
mode, the direction is automatically inferred
|
||||
from the name of the monitored quantity.
|
||||
"""
|
||||
|
||||
def __init__(self, monitor='val_loss',
|
||||
min_delta=0.0, patience=0, verbose=0, mode='auto'):
|
||||
super(EarlyStopping, self).__init__()
|
||||
|
||||
self.monitor = monitor
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
self.min_delta = min_delta
|
||||
self.wait = 0
|
||||
self.stopped_epoch = 0
|
||||
|
||||
if mode not in ['auto', 'min', 'max']:
|
||||
print('EarlyStopping mode %s is unknown, fallback to auto mode.' % mode)
|
||||
mode = 'auto'
|
||||
|
||||
if mode == 'min':
|
||||
self.monitor_op = np.less
|
||||
elif mode == 'max':
|
||||
self.monitor_op = np.greater
|
||||
else:
|
||||
if 'acc' in self.monitor:
|
||||
self.monitor_op = np.greater
|
||||
else:
|
||||
self.monitor_op = np.less
|
||||
|
||||
if self.monitor_op == np.greater:
|
||||
self.min_delta *= 1
|
||||
else:
|
||||
self.min_delta *= -1
|
||||
|
||||
self.on_train_begin()
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
# Allow instances to be re-used
|
||||
self.wait = 0
|
||||
self.stopped_epoch = 0
|
||||
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
current = logs.get(self.monitor)
|
||||
stop_training = False
|
||||
if current is None:
|
||||
print('Early stopping conditioned on metric `%s` ''which is not available. Available metrics are: %s' %
|
||||
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
|
||||
)
|
||||
exit(-1)
|
||||
|
||||
if self.monitor_op(current - self.min_delta, self.best):
|
||||
self.best = current
|
||||
self.wait = 0
|
||||
else:
|
||||
self.wait += 1
|
||||
if self.wait >= self.patience:
|
||||
self.stopped_epoch = epoch
|
||||
stop_training = True
|
||||
self.on_train_end()
|
||||
|
||||
return stop_training
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
if self.stopped_epoch > 0 and self.verbose > 0:
|
||||
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
"""Save the model after every epoch.
|
||||
`filepath` can contain named formatting options,
|
||||
which will be filled the value of `epoch` and
|
||||
keys in `logs` (passed in `on_epoch_end`).
|
||||
For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
|
||||
then the model checkpoints will be saved with the epoch number and
|
||||
the validation loss in the filename.
|
||||
# Arguments
|
||||
filepath: string, path to save the model file.
|
||||
monitor: quantity to monitor.
|
||||
verbose: verbosity mode, 0 or 1.
|
||||
save_best_only: if `save_best_only=True`,
|
||||
the latest best model according to
|
||||
the quantity monitored will not be overwritten.
|
||||
mode: one of {auto, min, max}.
|
||||
If `save_best_only=True`, the decision
|
||||
to overwrite the current save file is made
|
||||
based on either the maximization or the
|
||||
minimization of the monitored quantity. For `val_acc`,
|
||||
this should be `max`, for `val_loss` this should
|
||||
be `min`, etc. In `auto` mode, the direction is
|
||||
automatically inferred from the name of the monitored quantity.
|
||||
save_weights_only: if True, then only the model's weights will be
|
||||
saved (`model.save_weights(filepath)`), else the full model
|
||||
is saved (`model.save(filepath)`).
|
||||
period: Interval (number of epochs) between checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, filepath, save_function, monitor='val_loss', verbose=0,
|
||||
save_best_only=False, save_weights_only=False,
|
||||
mode='auto', period=1, prefix=''):
|
||||
super(ModelCheckpoint, self).__init__()
|
||||
self.monitor = monitor
|
||||
self.save_function = save_function
|
||||
self.verbose = verbose
|
||||
self.filepath = filepath
|
||||
self.save_best_only = save_best_only
|
||||
self.save_weights_only = save_weights_only
|
||||
self.period = period
|
||||
self.epochs_since_last_save = 0
|
||||
self.prefix = prefix
|
||||
|
||||
if mode not in ['auto', 'min', 'max']:
|
||||
print('ModelCheckpoint mode %s is unknown, '
|
||||
'fallback to auto mode.' % (mode),
|
||||
RuntimeWarning)
|
||||
mode = 'auto'
|
||||
|
||||
if mode == 'min':
|
||||
self.monitor_op = np.less
|
||||
self.best = np.Inf
|
||||
elif mode == 'max':
|
||||
self.monitor_op = np.greater
|
||||
self.best = -np.Inf
|
||||
else:
|
||||
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
|
||||
self.monitor_op = np.greater
|
||||
self.best = -np.Inf
|
||||
else:
|
||||
self.monitor_op = np.less
|
||||
self.best = np.Inf
|
||||
|
||||
def save_model(self, filepath, overwrite):
|
||||
dirpath = '/'.join(filepath.split('/')[:-1])
|
||||
|
||||
# make paths
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
|
||||
if overwrite:
|
||||
for filename in os.listdir(dirpath):
|
||||
if self.prefix in filename:
|
||||
path_to_delete = os.path.join(dirpath, filename)
|
||||
try:
|
||||
shutil.rmtree(path_to_delete)
|
||||
except OSError:
|
||||
os.remove(path_to_delete)
|
||||
|
||||
# delegate the saving to the model
|
||||
self.save_function(filepath)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
logs = logs or {}
|
||||
self.epochs_since_last_save += 1
|
||||
if self.epochs_since_last_save >= self.period:
|
||||
self.epochs_since_last_save = 0
|
||||
filepath = '{}/{}_ckpt_epoch_{}.ckpt'.format(self.filepath, self.prefix, epoch + 1)
|
||||
if self.save_best_only:
|
||||
current = logs.get(self.monitor)
|
||||
if current is None:
|
||||
print('Can save best model only with %s available, '
|
||||
'skipping.' % (self.monitor), RuntimeWarning)
|
||||
else:
|
||||
if self.monitor_op(current, self.best):
|
||||
if self.verbose > 0:
|
||||
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
|
||||
' saving model to %s'
|
||||
% (epoch + 1, self.monitor, self.best,
|
||||
current, filepath))
|
||||
self.best = current
|
||||
self.save_model(filepath, overwrite=True)
|
||||
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
print('\nEpoch %05d: %s did not improve' %
|
||||
(epoch + 1, self.monitor))
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
|
||||
self.save_model(filepath, overwrite=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
c = EarlyStopping(min_delta=0.9, patience=2, verbose=True)
|
||||
losses = [10, 9, 8, 8, 6, 4.3, 5, 4.4, 2.8, 2.5]
|
||||
for i, loss in enumerate(losses):
|
||||
should_stop = c.on_epoch_end(i, logs={'val_loss': loss})
|
||||
print(loss)
|
||||
if should_stop:
|
||||
break
|
||||
|
Loading…
Reference in New Issue