diff --git a/proxy/common/flags.py b/proxy/common/flags.py index 7f92fd58..12318c97 100644 --- a/proxy/common/flags.py +++ b/proxy/common/flags.py @@ -518,7 +518,7 @@ class Flags: } for plugin_ in plugins: if isinstance(plugin_, type): - module_name = None + module_name = '__main__' klass = plugin_ else: plugin = text_(plugin_.strip()) @@ -536,7 +536,8 @@ class Flags: while next(iterator) is not abc.ABC: pass base_klass = next(iterator) - p[bytes_(base_klass.__name__)].append(klass) + if klass not in p[bytes_(base_klass.__name__)]: + p[bytes_(base_klass.__name__)].append(klass) logger.info('Loaded plugin %s.%s', module_name, klass.__name__) return p diff --git a/tests/common/test_flags.py b/tests/common/test_flags.py index 2cc10fd0..93774fc3 100644 --- a/tests/common/test_flags.py +++ b/tests/common/test_flags.py @@ -10,26 +10,29 @@ """ import unittest +from typing import List, Dict + from proxy.common.flags import Flags from proxy.http.proxy import HttpProxyPlugin from proxy.plugin import CacheResponsesPlugin from proxy.plugin import FilterByUpstreamHostPlugin + class TestFlags(unittest.TestCase): - def assert_plugins(self, expected): + def assert_plugins(self, expected: Dict[str, List[type]]) -> None: for k in expected: self.assertIn(k.encode(), self.flags.plugins) for p in expected[k]: self.assertIn(p, self.flags.plugins[k.encode()]) self.assertEqual(len([o for o in self.flags.plugins[k.encode()] if o == p]), 1) - def test_load_plugin_from_bytes(self): + def test_load_plugin_from_bytes(self) -> None: self.flags = Flags.initialize([], plugins=[ b'proxy.plugin.CacheResponsesPlugin', ]) self.assert_plugins({'HttpProxyBasePlugin': [CacheResponsesPlugin]}) - def test_load_plugins_from_bytes(self): + def test_load_plugins_from_bytes(self) -> None: self.flags = Flags.initialize([], plugins=[ b'proxy.plugin.CacheResponsesPlugin', b'proxy.plugin.FilterByUpstreamHostPlugin', @@ -39,13 +42,13 @@ class TestFlags(unittest.TestCase): FilterByUpstreamHostPlugin, ]}) - def test_load_plugin_from_args(self): + def test_load_plugin_from_args(self) -> None: self.flags = Flags.initialize([ '--plugins', 'proxy.plugin.CacheResponsesPlugin', ]) self.assert_plugins({'HttpProxyBasePlugin': [CacheResponsesPlugin]}) - def test_load_plugins_from_args(self): + def test_load_plugins_from_args(self) -> None: self.flags = Flags.initialize([ '--plugins', 'proxy.plugin.CacheResponsesPlugin,proxy.plugin.FilterByUpstreamHostPlugin', ]) @@ -54,13 +57,13 @@ class TestFlags(unittest.TestCase): FilterByUpstreamHostPlugin, ]}) - def test_load_plugin_from_class(self): + def test_load_plugin_from_class(self) -> None: self.flags = Flags.initialize([], plugins=[ CacheResponsesPlugin, ]) self.assert_plugins({'HttpProxyBasePlugin': [CacheResponsesPlugin]}) - def test_load_plugins_from_class(self): + def test_load_plugins_from_class(self) -> None: self.flags = Flags.initialize([], plugins=[ CacheResponsesPlugin, FilterByUpstreamHostPlugin, @@ -70,7 +73,7 @@ class TestFlags(unittest.TestCase): FilterByUpstreamHostPlugin, ]}) - def test_load_plugins_from_bytes_and_class(self): + def test_load_plugins_from_bytes_and_class(self) -> None: self.flags = Flags.initialize([], plugins=[ CacheResponsesPlugin, b'proxy.plugin.FilterByUpstreamHostPlugin', @@ -80,8 +83,7 @@ class TestFlags(unittest.TestCase): FilterByUpstreamHostPlugin, ]}) - @unittest.expectedFailure - def test_unique_plugin_from_bytes(self): + def test_unique_plugin_from_bytes(self) -> None: self.flags = Flags.initialize([], plugins=[ b'proxy.http.proxy.HttpProxyPlugin', ]) @@ -89,8 +91,7 @@ class TestFlags(unittest.TestCase): HttpProxyPlugin, ]}) - @unittest.expectedFailure - def test_unique_plugin_from_args(self): + def test_unique_plugin_from_args(self) -> None: self.flags = Flags.initialize([ '--plugins', 'proxy.http.proxy.HttpProxyPlugin', ]) @@ -98,8 +99,7 @@ class TestFlags(unittest.TestCase): HttpProxyPlugin, ]}) - @unittest.expectedFailure - def test_unique_plugin_from_class(self): + def test_unique_plugin_from_class(self) -> None: self.flags = Flags.initialize([], plugins=[ HttpProxyPlugin, ]) @@ -107,5 +107,6 @@ class TestFlags(unittest.TestCase): HttpProxyPlugin, ]}) + if __name__ == '__main__': unittest.main()