183 lines
5.2 KiB
Python
183 lines
5.2 KiB
Python
|
import zipfile
|
||
|
|
||
|
import pytorch_lightning as pl
|
||
|
import torch
|
||
|
import torch.optim as optim
|
||
|
from torch import FloatTensor, LongTensor
|
||
|
|
||
|
from bttr.datamodule import Batch, vocab
|
||
|
from bttr.model.bttr import BTTR
|
||
|
from bttr.utils import ExpRateRecorder, Hypothesis, ce_loss, to_bi_tgt_out
|
||
|
|
||
|
|
||
|
class LitBTTR(pl.LightningModule):
|
||
|
def __init__(
|
||
|
self,
|
||
|
d_model: int,
|
||
|
# encoder
|
||
|
growth_rate: int,
|
||
|
num_layers: int,
|
||
|
# decoder
|
||
|
nhead: int,
|
||
|
num_decoder_layers: int,
|
||
|
dim_feedforward: int,
|
||
|
dropout: float,
|
||
|
# beam search
|
||
|
beam_size: int,
|
||
|
max_len: int,
|
||
|
alpha: float,
|
||
|
# training
|
||
|
learning_rate: float,
|
||
|
patience: int,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.save_hyperparameters()
|
||
|
|
||
|
self.bttr = BTTR(
|
||
|
d_model=d_model,
|
||
|
growth_rate=growth_rate,
|
||
|
num_layers=num_layers,
|
||
|
nhead=nhead,
|
||
|
num_decoder_layers=num_decoder_layers,
|
||
|
dim_feedforward=dim_feedforward,
|
||
|
dropout=dropout,
|
||
|
)
|
||
|
|
||
|
self.exprate_recorder = ExpRateRecorder()
|
||
|
|
||
|
def forward(
|
||
|
self, img: FloatTensor, img_mask: LongTensor, tgt: LongTensor
|
||
|
) -> FloatTensor:
|
||
|
"""run img and bi-tgt
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
img : FloatTensor
|
||
|
[b, 1, h, w]
|
||
|
img_mask: LongTensor
|
||
|
[b, h, w]
|
||
|
tgt : LongTensor
|
||
|
[2b, l]
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
FloatTensor
|
||
|
[2b, l, vocab_size]
|
||
|
"""
|
||
|
return self.bttr(img, img_mask, tgt)
|
||
|
|
||
|
def beam_search(
|
||
|
self,
|
||
|
img: FloatTensor,
|
||
|
beam_size: int = 10,
|
||
|
max_len: int = 200,
|
||
|
alpha: float = 1.0,
|
||
|
) -> str:
|
||
|
"""for inference, one image at a time
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
img : FloatTensor
|
||
|
[1, h, w]
|
||
|
beam_size : int, optional
|
||
|
by default 10
|
||
|
max_len : int, optional
|
||
|
by default 200
|
||
|
alpha : float, optional
|
||
|
by default 1.0
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
str
|
||
|
LaTex string
|
||
|
"""
|
||
|
assert img.dim() == 3
|
||
|
img_mask = torch.zeros_like(img, dtype=torch.long) # squeeze channel
|
||
|
hyps = self.bttr.beam_search(img.unsqueeze(0), img_mask, beam_size, max_len)
|
||
|
best_hyp = max(hyps, key=lambda h: h.score / (len(h) ** alpha))
|
||
|
return vocab.indices2label(best_hyp.seq)
|
||
|
|
||
|
def training_step(self, batch: Batch, _):
|
||
|
tgt, out = to_bi_tgt_out(batch.indices, self.device)
|
||
|
out_hat = self(batch.imgs, batch.mask, tgt)
|
||
|
|
||
|
loss = ce_loss(out_hat, out)
|
||
|
self.log("train_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
|
||
|
|
||
|
return loss
|
||
|
|
||
|
def validation_step(self, batch: Batch, _):
|
||
|
tgt, out = to_bi_tgt_out(batch.indices, self.device)
|
||
|
out_hat = self(batch.imgs, batch.mask, tgt)
|
||
|
|
||
|
loss = ce_loss(out_hat, out)
|
||
|
self.log(
|
||
|
"val_loss",
|
||
|
loss,
|
||
|
on_step=False,
|
||
|
on_epoch=True,
|
||
|
prog_bar=True,
|
||
|
sync_dist=True,
|
||
|
)
|
||
|
|
||
|
hyps = self.bttr.beam_search(
|
||
|
batch.imgs, batch.mask, self.hparams.beam_size, self.hparams.max_len
|
||
|
)
|
||
|
best_hyp = max(hyps, key=lambda h: h.score / (len(h) ** self.hparams.alpha))
|
||
|
|
||
|
print('prediction ', best_hyp.seq)
|
||
|
print('groundtruth ', batch.indices[0])
|
||
|
|
||
|
self.exprate_recorder(best_hyp.seq, batch.indices[0])
|
||
|
self.log(
|
||
|
"val_ExpRate",
|
||
|
self.exprate_recorder,
|
||
|
prog_bar=True,
|
||
|
on_step=False,
|
||
|
on_epoch=True,
|
||
|
)
|
||
|
|
||
|
def test_step(self, batch: Batch, _):
|
||
|
hyps = self.bttr.beam_search(
|
||
|
batch.imgs, batch.mask, self.hparams.beam_size, self.hparams.max_len
|
||
|
)
|
||
|
best_hyp = max(hyps, key=lambda h: h.score / (len(h) ** self.hparams.alpha))
|
||
|
self.exprate_recorder(best_hyp.seq, batch.indices[0])
|
||
|
|
||
|
return batch.img_bases[0], vocab.indices2label(best_hyp.seq)
|
||
|
|
||
|
def test_epoch_end(self, test_outputs) -> None:
|
||
|
exprate = self.exprate_recorder.compute()
|
||
|
print(f"ExpRate: {exprate}")
|
||
|
|
||
|
print(f"length of total file: {len(test_outputs)}")
|
||
|
with zipfile.ZipFile("result.zip", "w") as zip_f:
|
||
|
for img_base, pred in test_outputs:
|
||
|
content = f"%{img_base}\n${pred}$".encode()
|
||
|
with zip_f.open(f"{img_base}.txt", "w") as f:
|
||
|
f.write(content)
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
optimizer = optim.Adadelta(
|
||
|
self.parameters(),
|
||
|
lr=self.hparams.learning_rate,
|
||
|
eps=1e-6,
|
||
|
weight_decay=1e-4,
|
||
|
)
|
||
|
|
||
|
reduce_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||
|
optimizer,
|
||
|
mode="min",
|
||
|
factor=0.1,
|
||
|
patience=self.hparams.patience // self.trainer.check_val_every_n_epoch,
|
||
|
)
|
||
|
scheduler = {
|
||
|
"scheduler": reduce_scheduler,
|
||
|
"monitor": "val_ExpRate",
|
||
|
"interval": "epoch",
|
||
|
"frequency": self.trainer.check_val_every_n_epoch,
|
||
|
"strict": True,
|
||
|
}
|
||
|
|
||
|
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|