split commands into their own functions; improve logging and error reporting
[stack/cam.git] / cam / main.py
index a925df3..22b9758 100755 (executable)
@@ -9,14 +9,14 @@ from cam import config
 
 
 USAGE = '''cam [<OPTIONS>] <COMMAND> [<ARG>...]
-CAM v2.0 - (c)2012 by <ale@incal.net> 
-A Certification Authority manager for complex situations.
+CAM v%(version)s - (c)2012-2014 by <ale@incal.net> 
+Minimal X509 Certification Authority management tool.
 
 Known commands:
 
-  init [<RSA_CRT> [<DSA_CRT>]]
+  init [<RSA_CRT>]
     Initialize the environment and create a new CA certificate
-    (you can also import your own existing certificates)
+    (you can also import an existing certificate)
 
   gen <TAG>...
     Create (or re-create) the certificates corresponding
@@ -31,7 +31,7 @@ Known commands:
   fp [<TAG>...]
     Print SHA1/MD5 fingerprints of certificates
 
-  files <TAG>
+  files <TAG>...
     Dump all the certificate-related files of this TAG
 
   check 
@@ -39,11 +39,14 @@ Known commands:
     certificates are about to expire (controlled by the 'warning_days'
     parameter in the 'global' section of the configuration)
 
-The configuration file consists of a ini-style file, with one 'ca'
+The configuration file consists of a ini-style file, with a 'ca'
 section that specifies global CA parameters, and more sections for
-each tag with certificate-specific information. See the examples for
-more details on how to write your own configuration.
-'''
+each tag with certificate-specific information. See the documentation
+for more details on how to write your own configuration.
+
+Run `cam --help' to get a list of available command-line options.
+
+''' % {'version': '2.1'}
 
 
 def find_cert(certs, name):
@@ -53,76 +56,118 @@ def find_cert(certs, name):
     raise Exception('Certificate "%s" not found' % name)
 
 
+def cmd_init(global_config, ca, certs, args):
+    ca.create()
+
+
+def cmd_gen(global_config, ca, certs, args):
+    if len(args) < 1:
+        print 'Nothing to do.'
+    for tag in args:
+        ca.generate(find_cert(certs, tag))
+
+
+def cmd_gencrl(global_config, ca, certs, args):
+    ca.gencrl()
+
+
+def cmd_files(global_config, ca, certs, args):
+    if len(args) < 1:
+        print 'Nothing to do.'
+    for tag in args:
+        c = find_cert(certs, tag)
+        print c.public_key_file
+        print c.private_key_file
+
+
+def cmd_list(global_config, ca, certs, args):
+    now = time.time()
+    for cert in sorted(certs, key=lambda x: x.name):
+        expiry = cert.get_expiration_date()
+        state = 'OK'
+        expiry_str = ''
+        if not expiry:
+            state = 'MISSING'
+        else:
+            if expiry < now:
+                state = 'EXPIRED'
+            expiry_str = time.strftime('%Y/%m/%d', time.gmtime(expiry))
+        print cert.name, cert.cn, state, expiry_str
+
+
+def cmd_fingerprint(global_config, ca, certs, args):
+    if len(args) > 0:
+        certs = [find_cert(certs, x) for x in args]
+    for cert in certs:
+        print cert.name, cert.cn
+        print '  SHA1:', cert.get_fingerprint('sha1')
+        print '  MD5:', cert.get_fingerprint('md5')
+
+
+def cmd_check(global_config, ca, certs, args):
+    now = time.time()
+    warning_time = 86400 * int(global_config.get('warning_days', 15))
+    retval = 0
+    for cert in certs:
+        exp = cert.get_expiration_date()
+        if exp and (exp - now) < warning_time:
+            print '%s (%s) is about to expire.' % (cert.name, cert.cn)
+            retval = 1
+    return retval
+
+
+cmd_table = {
+    'init': cmd_init,
+    'gen': cmd_gen,
+    'gencrl': cmd_gencrl,
+    'files': cmd_files,
+    'list': cmd_list,
+    'fp': cmd_fingerprint,
+    'fingerprint': cmd_fingerprint,
+    'check': cmd_check,
+}
+
+
 def main():
     parser = optparse.OptionParser(usage=USAGE)
     parser.add_option('-d', '--debug', dest='debug', help='Be verbose',
                       action='store_true')
     parser.add_option('-c', '--config', dest='config', help='Config file')
     opts, args = parser.parse_args()
+
+    if len(args) > 0 and args[0] == 'help':
+        parser.print_help()
+        return 0
     if not opts.config:
         parser.error('Must specify --config')
     if len(args) < 1:
         parser.error('Must specify a command')
 
-    logging.basicConfig()
-    logging.getLogger().setLevel(opts.debug and logging.DEBUG or logging.INFO)
-
-    global_config, ca, certs = config.read_config(opts.config)
-
-    cmd, args = args[0], args[1:]
+    logging.basicConfig(
+        format='cam: %(levelname)s: %(message)s',
+        level=logging.DEBUG if opts.debug else logging.INFO)
 
     try:
-        if cmd == 'init':
-            ca.create()
-        elif cmd == 'gen':
-            if len(args) != 1:
-                parser.error('Wrong number of arguments')
-            ca.generate(find_cert(certs, args[0]))
-        elif cmd == 'gencrl':
-            ca.gencrl()
-        elif cmd == 'files':
-            if len(args) != 1:
-                parser.error('Wrong number of arguments')
-            c = find_cert(certs, args[0])
-            print c.public_key_file
-            print c.private_key_file
-        elif cmd == 'list':
-            now = time.time()
-            for cert in sorted(certs, key=lambda x: x.name):
-                expiry = cert.get_expiration_date()
-                state = 'OK'
-                expiry_str = ''
-                if not expiry:
-                    state = 'MISSING'
-                else:
-                    if expiry < now:
-                        state = 'EXPIRED'
-                    expiry_str = time.strftime('%Y/%m/%d', time.gmtime(expiry))
-                print cert.name, cert.cn, state, expiry_str
-        elif cmd == 'fp' or cmd == 'fingerprint':
-            if len(args) > 0:
-                certs = [find_cert(certs, x) for x in args]
-            for cert in certs:
-                print cert.name, cert.cn
-                print '  SHA1:', cert.get_fingerprint('sha1')
-                print '  MD5:', cert.get_fingerprint('md5')
-        elif cmd == 'check':
-            now = time.time()
-            warning_time = 86400 * int(global_config.get('warning_days', 15))
-            for cert in certs:
-                exp = cert.get_expiration_date()
-                if exp and (exp - now) < warning_time:
-                    print '%s (%s) is about to expire.' % (cert.name, cert.cn)
+        global_config, ca, certs = config.read_config(opts.config)
+        try:
+            cmd, args = args[0], args[1:]
+            if cmd not in cmd_table:
+                parser.error('unknown command "%s"' % cmd)
+            cmdfn = cmd_table[cmd]
+            return cmdfn(global_config, ca, certs, args)
+        finally:
+            ca.close()
+    except Exception as e:
+        if opts.debug:
+            logging.exception(e)
         else:
-            parser.error('unknown command "%s"' % cmd)
-    finally:
-        ca.close()
+            logging.error(e)
+        return 1
 
 
 def main_wrapper():
     try:
-        main()
-        return 0
+        return main()
     except Exception, e:
         logging.exception('uncaught exception')
         return 1