# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import torch
import torch.nn.functional as F
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy

import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule, LightningModule, seed_everything
from pytorch_lightning.callbacks import EarlyStopping

PATH_LEGACY = os.path.dirname(__file__)


class SklearnDataset(Dataset):
    def __init__(self, x, y, x_type, y_type):
        self.x = x
        self.y = y
        self._x_type = x_type
        self._y_type = y_type

    def __getitem__(self, idx):
        return torch.tensor(self.x[idx], dtype=self._x_type), torch.tensor(self.y[idx], dtype=self._y_type)

    def __len__(self):
        return len(self.y)


class SklearnDataModule(LightningDataModule):
    def __init__(self, sklearn_dataset, x_type, y_type, batch_size: int = 128):
        super().__init__()
        self.batch_size = batch_size
        self._x, self._y = sklearn_dataset
        self._split_data()
        self._x_type = x_type
        self._y_type = y_type

    def _split_data(self):
        self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(
            self._x, self._y, test_size=0.20, random_state=42
        )
        self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(
            self.x_train, self.y_train, test_size=0.40, random_state=42
        )

    def train_dataloader(self):
        return DataLoader(
            SklearnDataset(self.x_train, self.y_train, self._x_type, self._y_type),
            shuffle=True,
            batch_size=self.batch_size,
        )

    def val_dataloader(self):
        return DataLoader(
            SklearnDataset(self.x_valid, self.y_valid, self._x_type, self._y_type), batch_size=self.batch_size
        )

    def test_dataloader(self):
        return DataLoader(
            SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), batch_size=self.batch_size
        )


class ClassifDataModule(SklearnDataModule):
    def __init__(self, num_features=24, length=6000, num_classes=3, batch_size=128):
        data = make_classification(
            n_samples=length,
            n_features=num_features,
            n_classes=num_classes,
            n_clusters_per_class=2,
            n_informative=int(num_features / num_classes),
            random_state=42,
        )
        super().__init__(data, x_type=torch.float32, y_type=torch.long, batch_size=batch_size)


class ClassificationModel(LightningModule):
    def __init__(self, num_features=24, num_classes=3, lr=0.01):
        super().__init__()
        self.save_hyperparameters()

        self.lr = lr
        for i in range(3):
            setattr(self, f"layer_{i}", nn.Linear(num_features, num_features))
            setattr(self, f"layer_{i}a", torch.nn.ReLU())
        setattr(self, "layer_end", nn.Linear(num_features, num_classes))

        self.train_acc = Accuracy()
        self.valid_acc = Accuracy()
        self.test_acc = Accuracy()

    def forward(self, x):
        x = self.layer_0(x)
        x = self.layer_0a(x)
        x = self.layer_1(x)
        x = self.layer_1a(x)
        x = self.layer_2(x)
        x = self.layer_2a(x)
        x = self.layer_end(x)
        logits = F.softmax(x, dim=1)
        return logits

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return [optimizer], []

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.train_acc(logits, y), prog_bar=True)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False)
        self.log("val_acc", self.valid_acc(logits, y), prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        self.log("test_loss", F.cross_entropy(logits, y), prog_bar=False)
        self.log("test_acc", self.test_acc(logits, y), prog_bar=True)


def main_train(dir_path, max_epochs: int = 20):
    seed_everything(42)
    stopping = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.005)
    trainer = pl.Trainer(
        default_root_dir=dir_path,
        gpus=int(torch.cuda.is_available()),
        precision=(16 if torch.cuda.is_available() else 32),
        checkpoint_callback=True,
        callbacks=[stopping],
        min_epochs=3,
        max_epochs=max_epochs,
        accumulate_grad_batches=2,
        deterministic=True,
    )

    dm = ClassifDataModule()
    model = ClassificationModel()
    trainer.fit(model, datamodule=dm)
    res = trainer.test(model, datamodule=dm)
    assert res[0]["test_loss"] <= 0.7
    assert res[0]["test_acc"] >= 0.85
    assert trainer.current_epoch < (max_epochs - 1)


if __name__ == "__main__":
    path_dir = os.path.join(PATH_LEGACY, "checkpoints", str(pl.__version__))
    main_train(path_dir)