# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import pickle
from argparse import ArgumentParser, Namespace
from unittest import mock

import pytest

import tests_pytorch.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import argparse


@mock.patch("argparse.ArgumentParser.parse_args")
def test_default_args(mock_argparse, tmpdir):
    """Tests default argument parser for Trainer."""
    mock_argparse.return_value = Namespace(**Trainer.default_attributes())

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    parser = ArgumentParser(add_help=False)
    args = parser.parse_args()
    args.logger = logger

    args.max_epochs = 5
    trainer = Trainer.from_argparse_args(args)

    assert isinstance(trainer, Trainer)
    assert trainer.max_epochs == 5


@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], []])
def test_add_argparse_args_redefined(cli_args: list):
    """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
    parser = ArgumentParser(add_help=False)
    parser = Trainer.add_argparse_args(parent_parser=parser)

    args = parser.parse_args(cli_args)

    # make sure we can pickle args
    pickle.dumps(args)

    # Check few deprecated args are not in namespace:
    for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"):
        assert depr_name not in args

    trainer = Trainer.from_argparse_args(args=args)
    pickle.dumps(trainer)

    assert isinstance(trainer, Trainer)


@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], []])
def test_add_argparse_args(cli_args: list):
    """Simple test ensuring Trainer.add_argparse_args works."""
    parser = ArgumentParser(add_help=False)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args(cli_args)
    assert Trainer.from_argparse_args(args)

    parser = ArgumentParser(add_help=False)
    parser = Trainer.add_argparse_args(parser, use_argument_group=False)
    args = parser.parse_args(cli_args)
    assert Trainer.from_argparse_args(args)


def test_get_init_arguments_and_types():
    """Asserts a correctness of the `get_init_arguments_and_types` Trainer classmethod."""
    args = argparse.get_init_arguments_and_types(Trainer)
    parameters = inspect.signature(Trainer).parameters
    assert len(parameters) == len(args)
    for arg in args:
        assert parameters[arg[0]].default == arg[2]

    kwargs = {arg[0]: arg[2] for arg in args}
    trainer = Trainer(**kwargs)
    assert isinstance(trainer, Trainer)


@pytest.mark.parametrize("cli_args", [["--callbacks=1", "--logger"], ["--foo", "--bar=1"]])
def test_add_argparse_args_redefined_error(cli_args: list, monkeypatch):
    """Asserts thar an error raised in case of passing not default cli arguments."""

    class _UnkArgError(Exception):
        pass

    def _raise():
        raise _UnkArgError

    parser = ArgumentParser(add_help=False)
    parser = Trainer.add_argparse_args(parent_parser=parser)

    monkeypatch.setattr(parser, "exit", lambda *args: _raise(), raising=True)

    with pytest.raises(_UnkArgError):
        parser.parse_args(cli_args)


@pytest.mark.parametrize(
    ["cli_args", "expected"],
    [
        ("--auto_lr_find --auto_scale_batch_size power", {"auto_lr_find": True, "auto_scale_batch_size": "power"}),
        (
            "--auto_lr_find any_string --auto_scale_batch_size",
            {"auto_lr_find": "any_string", "auto_scale_batch_size": True},
        ),
        ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", {"auto_lr_find": True, "auto_scale_batch_size": False}),
        ("--auto_lr_find t --auto_scale_batch_size ON", {"auto_lr_find": True, "auto_scale_batch_size": True}),
        ("--auto_lr_find 0 --auto_scale_batch_size n", {"auto_lr_find": False, "auto_scale_batch_size": False}),
        (
            "",
            {
                # These parameters are marked as Optional[...] in Trainer.__init__, with None as default.
                # They should not be changed by the argparse interface.
                "min_steps": None,
                "accelerator": None,
                "profiler": None,
            },
        ),
    ],
)
def test_argparse_args_parsing(cli_args, expected):
    """Test multi type argument with bool."""
    cli_args = cli_args.split(" ") if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        parser = ArgumentParser(add_help=False)
        parser = Trainer.add_argparse_args(parent_parser=parser)
        args = Trainer.parse_argparser(parser)

    for k, v in expected.items():
        assert getattr(args, k) == v
    assert Trainer.from_argparse_args(args)


@pytest.mark.parametrize(
    "cli_args,expected",
    [("", False), ("--fast_dev_run=0", False), ("--fast_dev_run=True", True), ("--fast_dev_run 2", 2)],
)
def test_argparse_args_parsing_fast_dev_run(cli_args, expected):
    """Test multi type argument with bool."""
    cli_args = cli_args.split(" ") if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        parser = ArgumentParser(add_help=False)
        parser = Trainer.add_argparse_args(parent_parser=parser)
        args = Trainer.parse_argparser(parser)
    assert args.fast_dev_run is expected


@pytest.mark.parametrize(
    ["cli_args", "expected_parsed"],
    [("", None), ("--accelerator gpu --devices 1", "1"), ("--accelerator gpu --devices 0,", "0,")],
)
def test_argparse_args_parsing_devices(cli_args, expected_parsed, cuda_count_1):
    """Test multi type argument with bool."""
    cli_args = cli_args.split(" ") if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        parser = ArgumentParser(add_help=False)
        parser = Trainer.add_argparse_args(parent_parser=parser)
        args = Trainer.parse_argparser(parser)

    assert args.devices == expected_parsed
    assert Trainer.from_argparse_args(args)


@pytest.mark.parametrize(
    ["cli_args", "extra_args"],
    [
        ({}, {}),
        ({"logger": False}, {}),
        ({"logger": False}, {"logger": True}),
        ({"logger": False}, {"enable_checkpointing": True}),
    ],
)
def test_init_from_argparse_args(cli_args, extra_args):
    unknown_args = dict(unknown_arg=0)

    # unknown args in the argparser/namespace should be ignored
    with mock.patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init:
        trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args)
        expected = dict(cli_args)
        expected.update(extra_args)  # extra args should override any cli arg
        init.assert_called_with(trainer, **expected)

    # passing in unknown manual args should throw an error
    with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"):
        Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args)