From 84ac5a30abee62efd72e6e881b3fb0cf28feedc9 Mon Sep 17 00:00:00 2001 From: Andris Raugulis Date: Fri, 7 Oct 2016 19:55:31 +0300 Subject: [PATCH] Decouple AuditConf from Output. --- ssh-audit.py | 155 ++++++++++++++++++++--------------------- test/test_auditconf.py | 50 +++++++++++++ 2 files changed, 125 insertions(+), 80 deletions(-) create mode 100644 test/test_auditconf.py diff --git a/ssh-audit.py b/ssh-audit.py index b5d22cd..d31d226 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -31,8 +31,6 @@ VERSION = 'v1.5.1.dev' def usage(err=None): p = os.path.basename(sys.argv[0]) - out.batch = False - out.minlevel = 'info' out.head('# {0} {1}, moo@arthepsy.eu'.format(p, VERSION)) if err is not None: out.fail('\n' + err) @@ -49,43 +47,78 @@ def usage(err=None): class AuditConf(object): - def __init__(self): - self.__host = None - self.__port = 22 - self.__ssh1 = False - self.__ssh2 = False + def __init__(self, host=None, port=22): + self.host = host + self.port = port + self.ssh1 = True + self.ssh2 = True + self.batch = False + self.colors = True + self.verbose = False + self.minlevel = 'info' - @property - def host(self): - return self.__host + def __setattr__(self, name, value): + valid = False + if name in ['ssh1', 'ssh2', 'batch', 'colors', 'verbose']: + valid, value = True, True if value else False + elif name == 'port': + valid, port = True, utils.parse_int(value) + if port < 1 or port > 65535: + raise ValueError('invalid port: {0}'.format(value)) + value = port + elif name in ['minlevel']: + if value not in ('info', 'warn', 'fail'): + raise ValueError('invalid level: {0}'.format(value)) + valid = True + elif name == 'host': + valid = True + if valid: + object.__setattr__(self, name, value) - @host.setter - def host(self, v): - self.__host = v - - @property - def port(self): - return self.__port - - @port.setter - def port(self, v): - self.__port = v - - @property - def ssh1(self): - return self.__ssh1 - - @ssh1.setter - def ssh1(self, v): - self.__ssh1 = v - - @property - def ssh2(self): - return self.__ssh2 - - @ssh2.setter - def ssh2(self, v): - self.__ssh2 = v + @classmethod + def from_cmdline(cls, args, usage_cb): + conf = cls() + try: + sopts = 'h12bnvl:' + lopts = ['help', 'ssh1', 'ssh2', 'batch', + 'no-colors', 'verbose', 'level='] + opts, args = getopt.getopt(args, sopts, lopts) + except getopt.GetoptError as err: + usage_cb(str(err)) + conf.ssh1, conf.ssh2 = False, False + for o, a in opts: + if o in ('-h', '--help'): + usage_cb() + elif o in ('-1', '--ssh1'): + conf.ssh1 = True + elif o in ('-2', '--ssh2'): + conf.ssh2 = True + elif o in ('-b', '--batch'): + conf.batch = True + conf.verbose = True + elif o in ('-n', '--no-colors'): + conf.colors = False + elif o in ('-v', '--verbose'): + conf.verbose = True + elif o in ('-l', '--level'): + if a not in ('info', 'warn', 'fail'): + usage_cb('level {0} is not valid'.format(a)) + conf.minlevel = a + if len(args) == 0: + usage_cb() + s = args[0].split(':') + host, port = s[0].strip(), 22 + if len(s) > 1: + port = utils.parse_int(s[1]) + if not host: + usage_cb('host is empty') + if port <= 0 or port > 65535: + usage_cb('port {0} is not valid'.format(s[1])) + conf.host = host + conf.port = port + if not (conf.ssh1 or conf.ssh2): + conf.ssh1, conf.ssh2 = True, True + return conf class Output(object): @@ -1563,49 +1596,11 @@ class Utils(object): return 0 -def parse_args(): - conf = AuditConf() - try: - sopts = 'h12bnvl:' - lopts = ['help', 'ssh1', 'ssh2', 'batch', 'no-colors', 'verbose', 'level='] - opts, args = getopt.getopt(sys.argv[1:], sopts, lopts) - except getopt.GetoptError as err: - usage(str(err)) - for o, a in opts: - if o in ('-h', '--help'): - usage() - elif o in ('-1', '--ssh1'): - conf.ssh1 = True - elif o in ('-2', '--ssh2'): - conf.ssh2 = True - elif o in ('-b', '--batch'): - out.batch = True - out.verbose = True - elif o in ('-n', '--no-colors'): - out.colors = False - elif o in ('-v', '--verbose'): - out.verbose = True - elif o in ('-l', '--level'): - if a not in ('info', 'warn', 'fail'): - usage('level ' + a + ' is not valid') - out.minlevel = a - if len(args) == 0: - usage() - s = args[0].split(':') - host, port = s[0].strip(), 22 - if len(s) > 1: - port = utils.parse_int(s[1]) - if not host or port <= 0: - usage('port {0} is not valid'.format(port)) - conf.host = host - conf.port = port - if not (conf.ssh1 or conf.ssh2): - conf.ssh1 = True - conf.ssh2 = True - return conf - - def audit(conf, sshv=None): + out.batch = conf.batch + out.colors = conf.colors + out.verbose = conf.verbose + out.minlevel = conf.minlevel s = SSH.Socket(conf.host, conf.port) if sshv is None: sshv = 2 if conf.ssh2 else 1 @@ -1646,5 +1641,5 @@ def audit(conf, sshv=None): utils = Utils.wrap() if __name__ == '__main__': out = Output() - conf = parse_args() + conf = AuditConf.from_cmdline(sys.argv[1:], usage) audit(conf) diff --git a/test/test_auditconf.py b/test/test_auditconf.py new file mode 100644 index 0000000..7b21551 --- /dev/null +++ b/test/test_auditconf.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import pytest + + +class TestAuditConf(object): + @pytest.fixture(autouse=True) + def init(self, ssh_audit): + self.AuditConf = ssh_audit.AuditConf + + def test_audit_conf_defaults(self): + conf = self.AuditConf() + assert conf.host is None + assert conf.port == 22 + assert conf.ssh1 is True + assert conf.ssh2 is True + assert conf.batch is False + assert conf.colors is True + assert conf.verbose is False + assert conf.minlevel == 'info' + + def test_audit_conf_booleans(self): + conf = self.AuditConf() + for p in ['ssh1', 'ssh2', 'batch', 'colors', 'verbose']: + for v in [True, 1]: + setattr(conf, p, v) + assert getattr(conf, p) is True + for v in [False, 0]: + setattr(conf, p, v) + assert getattr(conf, p) is False + + def test_audit_conf_port(self): + conf = self.AuditConf() + for port in [22, 2222]: + conf.port = port + assert conf.port == port + for port in [-1, 0, 65536, 99999]: + with pytest.raises(ValueError) as excinfo: + conf.port = port + excinfo.match(r'.*invalid port.*') + + def test_audit_conf_minlevel(self): + conf = self.AuditConf() + for level in ['info', 'warn', 'fail']: + conf.minlevel = level + assert conf.minlevel == level + for level in ['head', 'good', 'unknown', None]: + with pytest.raises(ValueError) as excinfo: + conf.minlevel = level + excinfo.match(r'.*invalid level.*')