diff --git a/ssh-audit.py b/ssh-audit.py index 229a006..90a99c3 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -77,36 +77,64 @@ class Kex(object): @classmethod def parse(cls, payload): kex = cls() - buf = io.BytesIO(payload) - kex.cookie = buf.read(16) - kex.kex_algorithms = read_list(buf) - kex.key_algorithms = read_list(buf) - kex.client.encryption = read_list(buf) - kex.server.encryption = read_list(buf) - kex.client.mac = read_list(buf) - kex.server.mac = read_list(buf) - kex.client.compression = read_list(buf) - kex.server.compression = read_list(buf) - kex.client.languages = read_list(buf) - kex.server.languages = read_list(buf) - kex.follows = read_bool(buf) - kex.unused = read_int(buf) + buf = ReadBuf(payload) + kex.cookie = buf.read_raw(16) + kex.kex_algorithms = buf.read_list() + kex.key_algorithms = buf.read_list() + kex.client.encryption = buf.read_list() + kex.server.encryption = buf.read_list() + kex.client.mac = buf.read_list() + kex.server.mac = buf.read_list() + kex.client.compression = buf.read_list() + kex.server.compression = buf.read_list() + kex.client.languages = buf.read_list() + kex.server.languages = buf.read_list() + kex.follows = buf.read_bool() + kex.unused = buf.read_int() return kex -def read_int(buf): - return struct.unpack('>I', buf.read(4))[0] +class ReadBuf(object): + def __init__(self, data = None): + self._buf = io.BytesIO(data) if data else io.BytesIO() + self._len = len(data) if data else 0 + + @property + def unread_len(self): + return self._len - self._buf.tell() + + def read_raw(self, size): + return self._buf.read(size) + + def read_line(self): + return self._buf.readline().rstrip().decode('utf-8') + + def read_int(self): + return struct.unpack('>I', self._buf.read(4))[0] + + def read_bool(self): + return struct.unpack('b', self._buf.read(1))[0] != 0 + + def read_list(self): + list_size = self.read_int() + return self._buf.read(list_size).decode().split(',') -def read_bool(buf): - return struct.unpack('b', buf.read(1))[0] != 0 +class SockBuf(ReadBuf): + def __init__(self, s): + super(SockBuf, self).__init__() + self.__sock = s + + def recv(self, size = 2048): + data = self.__sock.recv(size) + pos = self._buf.tell() + self._buf.seek(0, 2) + self._buf.write(data) + self._len += len(data) + self._buf.seek(pos, 0) -def read_list(buf): - list_size = read_int(buf) - return buf.read(list_size).decode().split(',') def get_ssh_ver(v): return 'available since OpenSSH {0}'.format(v) - WARN_OPENSSH72_LEGACY = 'removed (in client) since OpenSSH 7.2, legacy algorithm' WARN_OPENSSH70_LEGACY = 'removed since OpenSSH 7.0, legacy algorithm' FAIL_OPENSSH70_WEAK = 'removed (in server) and disabled (in client) since OpenSSH 7.0, weak algorithm' @@ -246,9 +274,9 @@ def process_kex(kex): process_algorithms('mac', kex.server.mac, maxlen) out.sep() -def read_ssh_packet(s): +def read_ssh_packet(sbuf): block_size = 8 - header = s.recv(block_size) + header = sbuf.read_raw(block_size) packet_size = struct.unpack('>I', header[:4])[0] rest = header[4:] lrest = len(rest) @@ -257,7 +285,7 @@ def read_ssh_packet(s): if (packet_size - lrest) % block_size != 0: out.fail('[exception] invalid ssh packet (block size)') sys.exit(1) - buf = s.recv(packet_size - lrest) + buf = sbuf.read_raw(packet_size - lrest) packet = rest[2:] + buf[0:packet_size - lrest] payload = packet[0:packet_size - padding] return packet_type, payload @@ -291,13 +319,17 @@ def main(): try: s = socket.create_connection((host, port), SOCK_CONN_TIMEOUT) s.settimeout(SOCK_READ_TIMEOUT) - banner = s.recv(1024).strip() - out.head('# general') - out.good('[info] banner: ' + banner.decode()) - if banner.decode().startswith('SSH-1.99-'): - out.fail('[fail] protocol SSH1 enabled') + sbuf = SockBuf(s) s.send(SSH_BANNER.encode() + b'\r\n') - packet_type, payload = read_ssh_packet(s) + sbuf.recv() + banner = sbuf.read_line() + out.head('# general') + out.good('[info] banner: ' + banner) + if banner.startswith('SSH-1.99-'): + out.fail('[fail] protocol SSH1 enabled') + if sbuf.unread_len == 0: + sbuf.recv() + packet_type, payload = read_ssh_packet(sbuf) if packet_type != 20: out.fail('[exception] did not receive MSG_KEXINIT (20), instead received unknown message ({0})'.format(packet_type)) sys.exit(1)