################## Organize Your Code ################## Any raw PyTorch can be converted to Fabric with zero refactoring required, giving maximum flexibility in how you want to organize your projects. However, when developing a project in a team or sharing the code publicly, it can be beneficial to conform to a standard format of how core pieces of the code are organized. This is what the `LightningModule `_ was made for! Here is how you can neatly separate the research code (model, loss, optimization, etc.) from the "trainer" code (training loop, checkpointing, logging, etc.). ---- ************************************************* Step 1: Move your code into LightningModule hooks ************************************************* Take these main ingredients and put them in a LightningModule: - The PyTorch model(s) as an attribute (e.g. ``self.model``) - The forward, including loss computation, goes into ``training_step()`` - Setup of optimizer(s) goes into ``configure_optimizers()`` - Setup of the training data loader goes into ``train_dataloader()`` .. code-block:: python import lightning as L class LitModel(L.LightningModule): def __init__(self): super().__init__() self.model = ... def training_step(self, batch, batch_idx): # Main forward, loss computation, and metrics goes here x, y = batch y_hat = self.model(x) loss = self.loss_fn(y, y_hat) acc = self.accuracy(y, y_hat) ... return loss def configure_optimizers(self): # Return one or several optimizers return torch.optim.Adam(self.parameters(), ...) def train_dataloader(self): # Return your dataloader for training return DataLoader(...) def on_train_start(self): # Do something at the beginning of training ... def any_hook_you_like(self, *args, **kwargs): ... This is a minimal LightningModule, but there are `many other useful hooks `_ you can use. ---- **************************************** Step 2: Call hooks from your Fabric code **************************************** In your Fabric training loop, you can now call the hooks of the LightningModule interface. It is up to you to call everything at the right place. .. code-block:: python import lightning as L fabric = L.Fabric(...) # Instantiate the LightningModule model = LitModel() # Get the optimizer(s) from the LightningModule optimizer = model.configure_optimizers() # Get the training data loader from the LightningModule train_dataloader = model.train_dataloader() # Set up objects model, optimizer = fabric.setup(model, optimizer) train_dataloader = fabric.setup_dataloaders(train_dataloader) # Call the hooks at the right time model.on_train_start() model.train() for epoch in range(num_epochs): for i, batch in enumerate(dataloader): optimizer.zero_grad() loss = model.training_step(batch, i) fabric.backward(loss) optimizer.step() # Control when hooks are called if condition: model.any_hook_you_like() Your code is now modular. You can switch out the entire LightningModule implementation for another one, and you don't need to touch the training loop: .. code-block:: diff # Instantiate the LightningModule - model = LitModel() + model = DopeModel() ...