Kindai-OCR/transformer/lit_bttr.py

183 lines
5.2 KiB
Python
Raw Normal View History

2023-07-11 06:44:23 +00:00
import zipfile
import pytorch_lightning as pl
import torch
import torch.optim as optim
from torch import FloatTensor, LongTensor
from bttr.datamodule import Batch, vocab
from bttr.model.bttr import BTTR
from bttr.utils import ExpRateRecorder, Hypothesis, ce_loss, to_bi_tgt_out
class LitBTTR(pl.LightningModule):
def __init__(
self,
d_model: int,
# encoder
growth_rate: int,
num_layers: int,
# decoder
nhead: int,
num_decoder_layers: int,
dim_feedforward: int,
dropout: float,
# beam search
beam_size: int,
max_len: int,
alpha: float,
# training
learning_rate: float,
patience: int,
):
super().__init__()
self.save_hyperparameters()
self.bttr = BTTR(
d_model=d_model,
growth_rate=growth_rate,
num_layers=num_layers,
nhead=nhead,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
)
self.exprate_recorder = ExpRateRecorder()
def forward(
self, img: FloatTensor, img_mask: LongTensor, tgt: LongTensor
) -> FloatTensor:
"""run img and bi-tgt
Parameters
----------
img : FloatTensor
[b, 1, h, w]
img_mask: LongTensor
[b, h, w]
tgt : LongTensor
[2b, l]
Returns
-------
FloatTensor
[2b, l, vocab_size]
"""
return self.bttr(img, img_mask, tgt)
def beam_search(
self,
img: FloatTensor,
beam_size: int = 10,
max_len: int = 200,
alpha: float = 1.0,
) -> str:
"""for inference, one image at a time
Parameters
----------
img : FloatTensor
[1, h, w]
beam_size : int, optional
by default 10
max_len : int, optional
by default 200
alpha : float, optional
by default 1.0
Returns
-------
str
LaTex string
"""
assert img.dim() == 3
img_mask = torch.zeros_like(img, dtype=torch.long) # squeeze channel
hyps = self.bttr.beam_search(img.unsqueeze(0), img_mask, beam_size, max_len)
best_hyp = max(hyps, key=lambda h: h.score / (len(h) ** alpha))
return vocab.indices2label(best_hyp.seq)
def training_step(self, batch: Batch, _):
tgt, out = to_bi_tgt_out(batch.indices, self.device)
out_hat = self(batch.imgs, batch.mask, tgt)
loss = ce_loss(out_hat, out)
self.log("train_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
return loss
def validation_step(self, batch: Batch, _):
tgt, out = to_bi_tgt_out(batch.indices, self.device)
out_hat = self(batch.imgs, batch.mask, tgt)
loss = ce_loss(out_hat, out)
self.log(
"val_loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
sync_dist=True,
)
hyps = self.bttr.beam_search(
batch.imgs, batch.mask, self.hparams.beam_size, self.hparams.max_len
)
best_hyp = max(hyps, key=lambda h: h.score / (len(h) ** self.hparams.alpha))
print('prediction ', best_hyp.seq)
print('groundtruth ', batch.indices[0])
self.exprate_recorder(best_hyp.seq, batch.indices[0])
self.log(
"val_ExpRate",
self.exprate_recorder,
prog_bar=True,
on_step=False,
on_epoch=True,
)
def test_step(self, batch: Batch, _):
hyps = self.bttr.beam_search(
batch.imgs, batch.mask, self.hparams.beam_size, self.hparams.max_len
)
best_hyp = max(hyps, key=lambda h: h.score / (len(h) ** self.hparams.alpha))
self.exprate_recorder(best_hyp.seq, batch.indices[0])
return batch.img_bases[0], vocab.indices2label(best_hyp.seq)
def test_epoch_end(self, test_outputs) -> None:
exprate = self.exprate_recorder.compute()
print(f"ExpRate: {exprate}")
print(f"length of total file: {len(test_outputs)}")
with zipfile.ZipFile("result.zip", "w") as zip_f:
for img_base, pred in test_outputs:
content = f"%{img_base}\n${pred}$".encode()
with zip_f.open(f"{img_base}.txt", "w") as f:
f.write(content)
def configure_optimizers(self):
optimizer = optim.Adadelta(
self.parameters(),
lr=self.hparams.learning_rate,
eps=1e-6,
weight_decay=1e-4,
)
reduce_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.1,
patience=self.hparams.patience // self.trainer.check_val_every_n_epoch,
)
scheduler = {
"scheduler": reduce_scheduler,
"monitor": "val_ExpRate",
"interval": "epoch",
"frequency": self.trainer.check_val_every_n_epoch,
"strict": True,
}
return {"optimizer": optimizer, "lr_scheduler": scheduler}