diff --git a/src/ssh_audit/ssh_socket.py b/src/ssh_audit/ssh_socket.py index 9e6ca38..5b8169c 100644 --- a/src/ssh_audit/ssh_socket.py +++ b/src/ssh_audit/ssh_socket.py @@ -75,24 +75,26 @@ class SSH_Socket(ReadBuf, WriteBuf): self.client_port = None def _resolve(self) -> Iterable[Tuple[int, Tuple[Any, ...]]]: + """Resolves a hostname into a list of IPs + Raises + ------ + socket.gaierror [Errno -2] + If the hostname cannot be resolved. + """ # If __ip_version_preference has only one entry, then it means that ONLY that IP version should be used. if len(self.__ip_version_preference) == 1: family = socket.AF_INET if self.__ip_version_preference[0] == 4 else socket.AF_INET6 else: family = socket.AF_UNSPEC - try: - stype = socket.SOCK_STREAM - r = socket.getaddrinfo(self.__host, self.__port, family, stype) + stype = socket.SOCK_STREAM + r = socket.getaddrinfo(self.__host, self.__port, family, stype) - # If the user has a preference for using IPv4 over IPv6 (or vice-versa), then sort the list returned by getaddrinfo() so that the preferred address type comes first. - if len(self.__ip_version_preference) == 2: - r = sorted(r, key=lambda x: x[0], reverse=(self.__ip_version_preference[0] == 6)) # pylint: disable=superfluous-parens - for af, socktype, _proto, _canonname, addr in r: - if socktype == socket.SOCK_STREAM: - yield af, addr - except socket.error as e: - self.__outputbuffer.fail('[exception] {}'.format(e)).write() - sys.exit(exitcodes.CONNECTION_ERROR) + # If the user has a preference for using IPv4 over IPv6 (or vice-versa), then sort the list returned by getaddrinfo() so that the preferred address type comes first. + if len(self.__ip_version_preference) == 2: + r = sorted(r, key=lambda x: x[0], reverse=(self.__ip_version_preference[0] == 6)) # pylint: disable=superfluous-parens + for af, socktype, _proto, _canonname, addr in r: + if socktype == socket.SOCK_STREAM: + yield af, addr # Listens on a server socket and accepts one connection (used for # auditing client connections). @@ -152,18 +154,18 @@ class SSH_Socket(ReadBuf, WriteBuf): def connect(self) -> Optional[str]: '''Returns None on success, or an error string.''' err = None - for af, addr in self._resolve(): - s = None - try: + s = None + try: + for af, addr in self._resolve(): s = socket.socket(af, socket.SOCK_STREAM) s.settimeout(self.__timeout) self.__outputbuffer.d(("Connecting to %s:%d..." % ('[%s]' % addr[0] if Utils.is_ipv6_address(addr[0]) else addr[0], addr[1])), write_now=True) s.connect(addr) self.__sock = s return None - except socket.error as e: - err = e - self._close_socket(s) + except socket.error as e: + err = e + self._close_socket(s) if err is None: errm = 'host {} has no DNS records'.format(self.__host) else: