diff --git a/tests/server.py b/tests/server.py index fbd7e37..e7542dc 100644 --- a/tests/server.py +++ b/tests/server.py @@ -208,12 +208,10 @@ class ServerTestCase(AsyncTestCase): cls._server = yield from cls.start_server() sock = cls._server.sockets[0] - cls._client_host, _ = yield from cls.loop.getnameinfo(('127.0.0.1', 0)) cls._server_addr = '127.0.0.1' cls._server_port = sock.getsockname()[1] - host = '[%s]:%d,%s ' % (cls._server_addr, cls._server_port, - cls._client_host) + host = '[%s]:%d,localhost ' % (cls._server_addr, cls._server_port) with open('known_hosts', 'w') as known_hosts: known_hosts.write(host) diff --git a/tests/test_auth_keys.py b/tests/test_auth_keys.py index 1d625ef..72a49f7 100644 --- a/tests/test_auth_keys.py +++ b/tests/test_auth_keys.py @@ -13,13 +13,13 @@ """Unit tests for matching against authorized_keys file""" import unittest -from unittest.mock import patch import asyncssh -from .util import TempDirTestCase, x509_available +from .util import TempDirTestCase, patch_getnameinfo, x509_available +@patch_getnameinfo class _TestAuthorizedKeys(TempDirTestCase): """Unit tests for auth_keys module""" @@ -69,36 +69,22 @@ class _TestAuthorizedKeys(TempDirTestCase): def match_keys(self, tests, x509=False): """Match against authorized keys""" - def getnameinfo(sockaddr, flags): - """Mock reverse DNS lookup of client address""" - - # pylint: disable=unused-argument - - host, port = sockaddr - - if host == '127.0.0.1': - return ('localhost', port) - else: - return sockaddr - - with patch('socket.getnameinfo', getnameinfo): - for keys, matches in tests: - auth_keys = self.build_keys(keys, x509) - for (msg, keynum, client_addr, - cert_principals, match) in matches: - with self.subTest(msg, x509=x509): - if x509: - result, trusted_cert = auth_keys.validate_x509( - self.imported_certlist[keynum], client_addr) - if (trusted_cert and trusted_cert.subject != - self.imported_certlist[keynum].subject): - result = None - else: - result = auth_keys.validate( - self.imported_keylist[keynum], client_addr, - cert_principals, keynum == 1) - - self.assertEqual(result is not None, match) + for keys, matches in tests: + auth_keys = self.build_keys(keys, x509) + for (msg, keynum, client_addr, cert_principals, match) in matches: + with self.subTest(msg, x509=x509): + if x509: + result, trusted_cert = auth_keys.validate_x509( + self.imported_certlist[keynum], client_addr) + if (trusted_cert and trusted_cert.subject != + self.imported_certlist[keynum].subject): + result = None + else: + result = auth_keys.validate( + self.imported_keylist[keynum], client_addr, + cert_principals, keynum == 1) + + self.assertEqual(result is not None, match) def test_matches(self): """Test authorized keys matching""" diff --git a/tests/test_connection_auth.py b/tests/test_connection_auth.py index 3da8a5b..ff3e3cc 100644 --- a/tests/test_connection_auth.py +++ b/tests/test_connection_auth.py @@ -23,8 +23,8 @@ from asyncssh.packet import String from asyncssh.public_key import CERT_TYPE_USER, CERT_TYPE_HOST from .server import Server, ServerTestCase -from .util import asynctest, gss_available, patch_gss, make_certificate -from .util import x509_available +from .util import asynctest, gss_available, patch_getnameinfo, patch_gss +from .util import make_certificate, x509_available class _FailValidateHostSSHServerConnection(asyncssh.SSHServerConnection): @@ -455,6 +455,7 @@ class _TestGSSFQDN(ServerTestCase): yield from conn.wait_closed() +@patch_getnameinfo class _TestHostBasedAuth(ServerTestCase): """Unit tests for host-based authentication""" @@ -579,7 +580,7 @@ class _TestHostBasedAuth(ServerTestCase): """Test stripping of trailing dot from client host""" with (yield from self.connect(username='user', client_host_keys='skey', - client_host=self._client_host + '.', + client_host='localhost.', client_username='user')) as conn: pass @@ -667,6 +668,7 @@ class _TestHostBasedAsyncServerAuth(_TestHostBasedAuth): client_username='user') +@patch_getnameinfo class _TestLimitedHostBasedSignatureAlgs(ServerTestCase): """Unit tests for limited host key signature algorithms""" diff --git a/tests/util.py b/tests/util.py index 42bb596..4d92ec3 100644 --- a/tests/util.py +++ b/tests/util.py @@ -84,6 +84,24 @@ def asynctest35(func): return async_wrapper +def patch_getnameinfo(cls): + """Decorator for patching socket.getnameinfo""" + + def getnameinfo(sockaddr, flags): + """Mock reverse DNS lookup of client address""" + + # pylint: disable=unused-argument + + host, port = sockaddr + + if host == '127.0.0.1': + return ('localhost', port) + else: + return sockaddr + + return patch('socket.getnameinfo', getnameinfo)(cls) + + def patch_gss(cls): """Decorator for patching GSSAPI classes"""