#!/usr/bin/lua5.4

--[[
Certificate Authority tool for Dynamic Multipoint VPN

Copyright (c) 2014-2025 Kaarle Ritvanen
Copyright (c) 2015 Timo Teräs
Copyright (c) 2017 Natanael Copa

See LICENSE file for license details
]]--

conf_file = io.open(os.getenv('DMVPN_CA_CONF') or '/etc/dmvpn-ca.conf')
config = require('lyaml').load(conf_file:read('*a'))
conf_file:close()

asn1 = require('asn1')
rfc3779 = require('asn1.rfc3779')
rfc5280 = require('asn1.rfc5280')

bit32 = require('bit32')
dmvpn = require('dmvpn')

pkcs12 = require('openssl.pkcs12')
pkey = require('openssl.pkey')
x509 = require('openssl.x509')
x509an = require('openssl.x509.altname')
x509ch = require('openssl.x509.chain')
x509crl = require('openssl.x509.crl')
x509ext = require('openssl.x509.extension')
x509n = require('openssl.x509.name')

stat = require('posix.sys.stat')
unistd = require('posix.unistd')

stringy = require('stringy')

dev = io.open('/dev/urandom')
s = dev:read(4)
dev:close()
assert(s:len() == 4)
seed = 0
for i=1,4 do seed = seed * 256 + s:byte(i) end
math.randomseed(seed)


function set_config_defaults(config, defaults)
	for k, v in pairs(defaults) do
		if not config[k] then config[k] = v
		elseif type(v) == 'table' then
			set_config_defaults(config[k], v)
		end
	end
end

set_config_defaults(
	config,
	{
		db={file='/var/lib/misc/dmvpn-ca.sqlite3'},
		cert={
			lifetime=365 * 24 * 60 * 60,
			renewal=30 * 24 * 60 * 60,
			['hash-alg']='SHA256',
			key={type='EC', curve='secp384r1'}
		},
		ca={dn='O=DMVPN'},
		hub={
			dn='$ROOT,OU=Hubs,CN=$NAME',
			['default-name']='Hub-$ID',
			subnets={'10.0.0.0/8'}
		},
		spoke={
			dn='$ROOT,OU=$SITE,CN=$NAME',
			['default-name']='VPNc-$ID'
		},
		crl={},
		['password-length']=16,
		['default-asn']=65000
	}
)

for _, ct in ipairs{'ca', 'hub', 'spoke', 'crl'} do
	set_config_defaults(config[ct], config.cert)
end

if config.db['encrypt-keys'] == true then
	config.db['encrypt-keys'] = 'aes128'
end


now = os.time()

stat.umask(bit32.bor(stat.S_IRWXG, stat.S_IRWXO))


function user_error(msg) error{false, msg} end
function syntax_error(msg) error{true, msg} end


function connection()
	if not sql then
		sql = require('luasql.sqlite3').sqlite3()
		conn = sql:connect(config.db.file)
		if not conn then
			error('Cannot open database file: '..config.db.file)
		end
		conn:setautocommit(false)
	end
	return conn
end

function execute(statement)
	local res, err = connection():execute(statement)
	if res then return res end
	error(err)
end

function escape(s)
	if not s then return 'NULL' end
	return type(s) == 'string' and "'"..connection():escape(s).."'" or s
end

function value_list(values)
	local res = {}
	for k, v in pairs(values) do table.insert(res, k..' = '..escape(v)) end
	return res
end

function insert(tbl, values)
	local columns = {}
	local escaped = {}
	for k, v in pairs(values) do
		table.insert(columns, k)
		table.insert(escaped, escape(v))
	end
	execute(
		'INSERT INTO '..tbl..'('..table.concat(columns, ',')..
			') VALUES ('..table.concat(escaped, ',')..')'
	)
end

function where(filter)
	if not filter then return '' end
	if type(filter) == 'table' then
		filter = value_list(filter)
		if not next(filter) then return '' end
		filter = table.concat(filter, ' AND ')
	end
	return ' WHERE '..filter
end

function update(tbl, values, filter)
	execute(
		'UPDATE '..tbl..' SET '..
			table.concat(value_list(values), ', ')..where(filter)
	)
end

function delete(tbl, filter) execute('DELETE FROM '..tbl..where(filter)) end

function select_many(what, from, filter, mode, order)
	local cur = execute(
		'SELECT '..what..' FROM '..from..where(filter)..
			(order and ' ORDER BY '..order or '')
	)
	local gc = setmetatable({}, {__gc=function() cur:close() end})
	return function() return cur:fetch(mode and {}, mode) end
end

function select_certs(filter, order)
	return select_many('*', 'certificate', filter, 'a', order)
end

function select_one(...)
	local res
	for row in select_many(...) do
		assert(res == nil)
		res = row
	end
	return res
end

function check_exists(tbl, filter, msg, active)
	if active ~= nil then
		filter.active = active and '1' or '0'
	end
	if not select_one('0', tbl, filter) then user_error(msg) end
end

function next_key(tbl, key, filter)
	return (select_one('MAX('..key..')', tbl, filter) or 0) + 1
end

function select_cert(serial)
	return select_one('*', 'certificate', {serial=serial}, 'a')
end


function toint(s, min, max, desc)
	local i = tonumber(s)
	if i and i == math.floor(i) and i >= min and (not max or i <= max) then
		return i
	end
	if desc then user_error('Invalid '..desc..': '..s) end
end

function toid(s) return toint(s, 1, nil, 'ID') end


function detect_afi(s)
	x509an.new():add('IP', s)
	return s:find(':') and 2 or 1
end

function detect_prefix_afi(s)
	local function fail() user_error('Invalid network address: '..s) end

	local comps = stringy.split(s, '/')
	if #comps ~= 2 then fail() end

	local afi = detect_afi(comps[1])
	if not (afi and toint(comps[2], 0, ({32, 128})[afi])) then fail() end

	return afi
end


passwords = {}
function get_password(new, id)
	if not id then id = 'default' end
	if not passwords[id] then passwords[id] = dmvpn.get_password(new) end
	return passwords[id]
end

function decrypt_key(key)
	return pkey.new(
		key, {format='PEM', type='private', password=get_password}
	)
end

function load_cert(row)
	return x509.new(row.data, 'PEM'), decrypt_key(row.privateKey)
end

function load_ca_cert()
	if not ca_cert then ca_cert, ca_key = load_cert(select_cert(0)) end
	return ca_cert, ca_key
end

function encrypt_key(key, new_pw, pw_id)
	return key:toPEM{
		type='private',
		cipher=config.db['encrypt-keys'] or nil,
		password=function() return get_password(new_pw, pw_id) end
	}
end

function sign(object, hash_alg, cert, key)
	if not cert then cert, key = load_ca_cert() end
	object:setIssuer(cert:getSubject())

	object:addExtension(
		x509ext.new(
			'authorityKeyIdentifier',
			'DER',
			rfc5280.AuthorityKeyIdentifier.encode{
				keyIdentifier=rfc5280.KeyIdentifier.decode(
					cert:getExtension(
						'subjectKeyIdentifier'
					):getData()
				)
			}
		)
	)

	return object:sign(key, hash_alg)
end

function issue_cert(attrs, func)
	local key = pkey.new(attrs.params.key)

	local ca = attrs.usage and attrs.usage.keyCertSign
	attrs.serial = ca and 0 or next_key('certificate', 'serial')

	local cert = x509.new()
	cert:setVersion(3)
	cert:setSerial(attrs.serial)
	cert:setPublicKey(key)

	local dn = x509n.new()
	for _, rdn in ipairs(stringy.split(attrs.dn, ',')) do
		local attr, value = string.match(rdn, '^%s*(%a+)=(.*)$')
		if not attr then user_error('Invalid RDN: '..rdn) end
		dn:add(attr, value)
	end
	cert:setSubject(dn)

	cert:setBasicConstraints{CA=attrs.usage}
	cert:setBasicConstraintsCritical(true)

	local issued = cert:getLifetime()
	local expires = issued + attrs.params.lifetime
	cert:setLifetime(issued, expires)

	attrs.issued = issued
	attrs.expires = expires
	attrs.privateKey = (ca or not attrs.usage) and encrypt_key(key, ca) or
		key:toPEM{type='private'}

	cert:addExtension(
		x509ext.new(
			'subjectKeyIdentifier',
			'DER',
			rfc5280.KeyIdentifier.encode(cert:getPublicKeyDigest())
		)
	)

	if attrs.usage then
		cert:addExtension(
			x509ext.new(
				'keyUsage',
				'DER',
				rfc5280.KeyUsage.encode(attrs.usage)
			)
		)
		attrs.usage = nil
	end

	local crl_dp = config.crl['dist-point']
	if crl_dp then
		cert:addExtension(
			x509ext.new(
				'crlDistributionPoints',
				'DER',
				rfc5280.CRLDistributionPoints.encode{
					{
						distributionPoint={
							fullName={{uniformResourceIdentifier=crl_dp}}
						}
					}
				}
			)
		)
	end

	if func then func(cert, attrs) end

	sign(cert, attrs.params['hash-alg'], ca and cert, ca and key)

	attrs.data = tostring(cert)
	attrs.dn = tostring(dn)
	attrs.params = nil
	attrs.sname = nil
	attrs.vname = nil
	insert('certificate', attrs)
	return attrs
end

function issue_ca_cert(usage)
	return issue_cert{dn=config.ca.dn, params=config.ca, usage=usage}
end

function export_cert(cert)
	local password = {}
	for i=1,config['password-length'] do
		local r = math.random(0, 61)
		if r > 35 then r = r + 61
		elseif r > 9 then r = r + 55
		else r = r + 48 end
		password[i] = r
	end
	password = string.char(table.unpack(password))

	local chain = x509ch.new()
	chain:add(load_ca_cert())

	if config.db['encrypt-keys'] then chain:add(load_crl_cert()) end

	chain:add(x509.new(cert.data, 'PEM'))

	local file = io.open(
		('%s_%s.%s.pfx'):format(cert.site, cert.vpnc, password), 'w'
	)
	file:write(
		tostring(
			pkcs12.new{
				key=decrypt_key(cert.privateKey),
				certs=chain,
				password=password
			}
		)
	)
	file:close()
end


function add_header(tbl, columns)
	if tbl[1] then table.insert(tbl, 1, columns) end
	return tbl
end


function format_cert_info(certs)
	local res = {}

	local function format_ts(timestamp) return
		os.date('%x %X', timestamp)
	end

	for _, cert in ipairs(certs) do
		local exp_info
		if cert.revoked then
			exp_info = 'Revoked '..format_ts(cert.revoked)
		else
			exp_info = 'Expire'..
				(cert.expires < now and 'd' or 's')..' '..
				format_ts(cert.expires)
		end
		table.insert(
			res,
			{cert.serial, cert.dn, format_ts(cert.issued), exp_info}
		)
	end

	return add_header(res, {'serial', 'dn', 'issued', 'validity'})
end

function print_cert(cert)
	print(x509.new(cert.data, 'PEM'):text{'ext_parse'})
end

function is_valid(cert, margin)
	return not cert.revoked and now < cert.expires - (margin or 0)
end

function revoke(filter)
	local revoked = {}
	for cert in select_certs(filter) do
		if is_valid(cert) then
			update(
				'certificate',
				{revoked=now},
				{serial=cert.serial}
			)
			cert.revoked = now
			table.insert(revoked, cert)
		end
	end
	if #revoked > 0 then update('crl', {expires=now}) end
	return revoked
end


function scan_next(desc)
	if #arg == 0 then
		if desc then syntax_error('Missing '..desc) end
		return
	end
	local res = arg[1]
	table.remove(arg, 1)
	return res
end

function scan_addr()
	local addr = scan_next('IP address')
	detect_afi(addr)
	return addr
end

function scan_subnet()
	local addr = scan_next('network address')
	detect_prefix_afi(addr)
	return addr
end

function scan_name_token()
	local token = scan_next('attribute')
	if token ~= 'name' then syntax_error('Invalid attribute: '..token) end
end

function scan_param(choices, desc, optional)
	if type(choices) == 'string' then choices = {choices} end

	local k = scan_next(not optional and desc)
	if not k then return end
	for _, v in ipairs(choices) do
		if k == v then return k, scan_next(k..' specification') end
	end
	syntax_error('Invalid '..desc..': '..k)
end


function scan_hub(active)
	local id = toid(scan_next('hub number'))
	check_exists(
		'vpnc', {site='', id=id}, 'Invalid hub number: '..id, active
	)
	return id
end


function scan_site_code() return scan_next('site code'):upper() end

function validate_site(code, active)
	code = code:upper()
	check_exists('site', {code=code}, 'Invalid site code: '..code, active)
	return code
end

function scan_site(active) return validate_site(scan_site_code(), active) end

function scan_site_selector(options)
	options = options or {}
	local _, code = scan_param(
		options.param or 'site', 'selector', options.required == false
	)
	if code then return validate_site(code, options.active) end
end


function scan_site_filter(options)
	options = options or {}
	options.required = false

	local code = scan_site_selector(options)
	if code then return {code=code} end
end

function site_filter(filter, column)
	if not filter then return end
	if not column then column = 'site' end
	return {[column]=filter.code}
end


function scan_vpnc(options)
	options = options or {}

	if options.multiple and arg[1] == 'hubs' then return {site=''} end

	local k, v = scan_param({'hub', 'site'}, 'selector', options.multiple)
	if not k then return end

	local res, obj

	if k == 'hub' then
		res = {site='', id=v}
		obj = 'hub'
	else
		local _, vpnc = scan_param(
			options.id_attr or 'vpnc', 'selector', true
		)
		if not (vpnc or options.multiple) then vpnc = 1 end

		res = {site=validate_site(v, options.active), id=vpnc}
		obj = 'VPNc'
	end

	if res.id then
		check_exists(
			'vpnc',
			res,
			'Invalid '..obj..' number: '..res.id,
			options.active
		)
	end

	return res
end

function vpnc_filter(filter, site_column, id_column)
	if not filter then return end
	if not site_column then site_column = 'site' end
	if not id_column then id_column = 'vpnc' end
	return {[site_column]=filter.site, [id_column]=filter.id}
end


function scan_finished()
	if #arg > 0 then syntax_error('Invalid parameter: '..scan_next()) end
end


function display_active(value) return ({['0']='no', ['1']='yes'})[value] end

function show(tbl, columns, filter, mangle)
	local output = {columns}

	for row in select_many(table.concat(columns, ','), tbl, filter, 'a') do
		if row.active then row.active = display_active(row.active) end
		if mangle then mangle(row) end

		local line = {}
		for _, column in ipairs(columns) do
			table.insert(line, row[column] or '')
		end
		table.insert(output, line)
	end

	return output
end

function show_site(filter)
	return show('site', {'code', 'name', 'asn', 'active'}, filter)
end

function show_vpnc(site, vpnc)
	local filter = site and site_filter(site, 'v.site') or "v.site > ''"
	if vpnc then filter['v.id'] = vpnc end

	local columns = {
		'v.id', 'v.name', 'v.active', 'a1.address', 'a2.address'
	}
	local titles = {'vpnc', 'name', 'active','ipv4Address', 'ipv6Address'}
	local i = 2
	if not site then
		table.insert(columns, 1, 'v.site')
		table.insert(titles, 1, 'site')
		i = 3
	elseif site.code == '' then titles[1] = 'hub' end

	local output = {}

	for row in select_many(
		table.concat(columns, ', '),
		[[
			vpnc v
				LEFT JOIN greAddress a1
					ON v.site = a1.site AND
						v.id = a1.vpnc AND a1.family = 1
				LEFT JOIN greAddress a2
					ON v.site = a2.site AND
						v.id = a2.vpnc AND a2.family = 2
		]],
		filter,
		'n'
	) do
		if not row[i] then row[i] = '' end
		row[i + 1] = display_active(row[i + 1])
		table.insert(output, row)
	end

	return add_header(output, titles)
end


function add_subnet(site, addr)
	insert(
		'subnet',
		{site=site, family=detect_prefix_afi(addr), address=addr}
	)
end


function deactivate_vpnc(site, id)
	local filter = {site=site, id=id}
	local vf = vpnc_filter(filter)
	local revoked = revoke(vf)
	delete('greAddress', vf)
	update('vpnc', {active='0'}, filter)
	return format_cert_info(revoked)
end


function add_site(attrs)
	if not attrs.asn then attrs.asn = config['default-asn'] end
	insert('site', attrs)
end


function get_cert_by_serial(serial)
	scan_finished()

	serial = toint(serial, 0, nil, 'serial number')
	local cert = select_cert(serial)

	if not cert then user_error('Invalid serial number') end
	return cert
end

function get_certs()
	if arg[1] == 'serial' then
		local _, serial = scan_param('serial', 'serial number')
		return function(cert, v)
			return not v and cert or nil
		end, get_cert_by_serial(serial)
	end

	if arg[1] == 'crl' then
		return select_certs('site IS NULL AND serial > 0')
	end

	return select_certs(
		vpnc_filter(scan_vpnc{multiple=true}) or 'serial > 0'
	)
end


function try_load_crl_cert()
	if not config.db['encrypt-keys'] then
		return table.unpack{load_ca_cert()}
	end

	if not crl_cert then
		for row in select_certs(
			'site IS NULL and serial > 0', 'serial DESC'
		) do
			if not crl_cert then
				local cert, key = load_cert(row)
				if is_valid(row) then
					crl_cert = cert
					crl_key = key
				end
			end
		end
	end
	return crl_cert, crl_key
end

function load_crl_cert()
	local cert, key = try_load_crl_cert()
	if not cert then user_error('No valid CRL signing key found') end
	return cert, key
end

function generate_crl()
	local crl = x509crl.new()
	crl:setVersion(2)

	local old_serial = select_one('serial', 'crl')
	local new_serial = (old_serial or 0) + 1

	crl:addExtension(
		x509ext.new(
			'crlNumber', 'DER', rfc5280.CRLNumber.encode(new_serial)
		)
	)

	local timestamp = crl:getLastUpdate()
	local expires = timestamp + config.crl.lifetime
	crl:setNextUpdate(expires)

	for cert in select_certs() do
		if cert.expires > timestamp and cert.revoked then
			crl:add(cert.serial, cert.revoked)
		end
	end

	sign(crl, config.crl['hash-alg'], table.unpack{load_crl_cert()})

	insert('crl', {serial=new_serial, expires=expires, data=tostring(crl)})
	if old_serial then delete('crl', {serial=old_serial}) end

	return crl
end

function get_crl()
	local row = select_one('expires, data', 'crl', nil, 'n')
	return row and now < row[1] - config.crl.renewal and x509crl.new(row[2])
		or generate_crl()
end


function print_table(tbl)
	local colwidth = {}
	for _, row in ipairs(tbl) do
		for i, col in ipairs(row) do
			colwidth[i] = math.max(
				colwidth[i] or 0, string.len(col)
			)
		end
	end
	for _, row in ipairs(tbl) do
		for i = 1,#row do
			if i > 1 then io.write('  ') end
			io.write(row[i])
			if i < #row then
				for _ = 1,colwidth[i] - string.len(row[i]) do
					io.write(' ')
				end
			end
		end
		io.write('\n')
	end
end

function simple_confirm(func)
	if not unistd.isatty(0) then return end
	func()
	io.write('\nContinue (y/n)? ')
	io.stdout:flush()
	local input = io.read()
	if not input or input:lower() ~= 'y' then os.exit() end
	io.write('\n')
end

function confirm(object, action, tbl)
	if #tbl == 0 then user_error('No such '..object..' to be '..action) end
	simple_confirm(
		function()
			io.write(
				'The following '..object..' will be '..action..
					':\n\n'
			)
			print_table(tbl)
		end
	)
end


cert_select_syntax = '[serial <num>|hubs|hub <id>|site <abbr> [vpnc <id>]|crl]'

commands = {
	site={
		add={
			'<abbr>',
			function()
				local site = scan_site_code()
				scan_finished()

				add_site{code=site}
			end
		},
		set={
			'{asn|name} <value> <abbr>',
			function()
				local k, v = scan_param(
					{'asn', 'name'}, 'attribute'
				)
				local site = scan_site()
				scan_finished()

				if k == 'asn' then
					v = toint(
						v,
						0,
						math.pow(2, 16) - 1,
						'AS number'
					)
				end

				update('site', {[k]=v}, {code=site})
			end
		},
		unset={
			'name <abbr>',
			function()
				scan_name_token()
				local site = scan_site()
				scan_finished()

				update('site', {name=false}, {code=site})
			end
		},
		show={
			'[code <abbr>]',
			function()
				local site = scan_site_filter{param='code'}
				scan_finished()

				return show_site(site or "code > ''")
			end
		},
		deactivate={
			'<abbr>',
			function()
				local site = scan_site(true)
				scan_finished()
				confirm(
					'site',
					'deactivated',
					show_site{code=site}
				)

				delete('subnet', {site=site})
				local output = deactivate_vpnc(site)
				update('site', {active='0'}, {code=site})
				return output
			end
		},
		reactivate={
			'<abbr>',
			function()
				local site = scan_site(false)
				scan_finished()

				update('site', {active='1'}, {code=site})
			end
		}
	},
	subnet={
		add={
			'<addr> site <abbr>',
			function()
				local addr = scan_subnet()
				local site = scan_site_selector{active=true}
				scan_finished()

				add_subnet(site, addr)
			end
		},
		show={
			'[site <abbr>]',
			function()
				local site = scan_site_filter{active=true}
				scan_finished()

				return show(
					'subnet',
					{'site', 'address'},
					site and site_filter(site) or
						"site > ''"
				)
			end
		},
		del={
			'<addr> [site <abbr>]',
			function()
				local addr = scan_subnet()
				local filter = site_filter(
					scan_site_filter{active=true}
				)
				scan_finished()

				if not filter then filter = {} end
				filter.address = addr
				check_exists(
					'subnet',
					filter,
					'Address does not exist: '..addr
				)

				delete('subnet', filter)
			end
		}
	},
	vpnc={
		create={
			'site <abbr> [id <num>]',
			function()
				local site = scan_site_selector{active=true}
				local _, id = scan_param('id', 'VPNc ID', true)
				scan_finished()

				insert(
					'vpnc',
					{
						site=site,
						id=id and toid(id) or
							next_key(
								'vpnc',
								'id',
								{site=site}
							)
					}
				)
			end
		},
		set={
			'name <value> site <abbr> id <num>',
			function()
				local _, name = scan_param('name', 'attribute')
				local vpnc = scan_vpnc{id_attr='id'}
				scan_finished()

				update('vpnc', {name=name}, vpnc)
			end
		},
		unset={
			'name site <abbr> id <num>',
			function()
				scan_name_token()
				local vpnc = scan_vpnc{id_attr='id'}
				scan_finished()

				update('vpnc', {name=false}, vpnc)
			end
		},
		show={
			'[site <abbr>]',
			function()
				local site = scan_site_filter()
				scan_finished()

				return show_vpnc(site)
			end
		},
		deactivate={
			'<id> site <abbr>',
			function()
				local id = toid(scan_next('VPNc number'))
				local site = scan_site_selector{active=true}
				scan_finished()
				confirm(
					'VPNc',
					'deactivated',
					show_vpnc({code=site}, id)
				)

				return deactivate_vpnc(site, id)
			end
		},
		reactivate={
			'<id> site <abbr>',
			function()
				local id = toid(scan_next('VPNc number'))
				local site = scan_site_selector{active=true}
				scan_finished()

				local filter = {site=site, id=id}
				check_exists(
					'vpnc',
					filter,
					'Invalid VPNc number: '..id,
					false
				)
				update('vpnc', {active='1'}, filter)
			end
		}
	},
	['gre-addr']={
		add={
			'<addr> {hub <id>|site <abbr> [vpnc <id>]}',
			function()
				local addr = scan_addr()
				local vpnc = scan_vpnc{active=true}
				scan_finished()

				insert(
					'greAddress',
					{
						site=vpnc.site,
						vpnc=vpnc.id,
						family=detect_afi(addr),
						address=addr
					}
				)
			end
		},
		del={
			'<addr>',
			function()
				local addr = scan_addr()
				scan_finished()

				delete('greAddress', {address=addr})
			end
		}
	},
	hub={
		create={
			'[<id>]',
			function()
				local id = scan_next()
				scan_finished()

				insert(
					'vpnc',
					{
						site='',
						id=id and toid(id) or
							next_key(
								'vpnc',
								'id',
								{site=''}
							)
					}
				)
			end
		},
		set={
			'name <value> <id>',
			function()
				local _, name = scan_param('name', 'attribute')
				local id = scan_hub()
				scan_finished()

				update('vpnc', {name=name}, {site='', id=id})
			end
		},
		unset={
			'name <id>',
			function()
				scan_name_token()
				local id = scan_hub()
				scan_finished()

				update('vpnc', {name=false}, {site='', id=id})
			end
		},
		show={
			'',
			function()
				scan_finished()

				return show_vpnc{code=''}
			end
		},
		deactivate={
			'<id>',
			function()
				local id = scan_hub(true)
				scan_finished()
				confirm(
					'hub',
					'deactivated',
					show_vpnc({code=''}, id)
				)

				return deactivate_vpnc('', id)
			end
		},
		reactivate={
			'<id>',
			function()
				local id = scan_hub(false)
				scan_finished()

				update('vpnc', {active='1'}, {site='', id=id})
			end
		}
	},
	['root-cert']={
		generate={
			'',
			function()
				scan_finished()

				if stat.stat(config.db.file) then
					simple_confirm(
						function()
							io.write('Current configuration and all keys will be lost.\n')
						end
					)
				end

				os.remove(config.db.file)
				for _, statement in ipairs(
					{
						[[
							CREATE TABLE site (
								code VARCHAR(16) NOT NULL PRIMARY KEY,
								asn INTEGER NOT NULL,
								name VARCHAR(32),
								active CHAR(1) DEFAULT '1'
							)
						]],
						[[
							CREATE TABLE subnet (
								site VARCHAR(16) NOT NULL REFERENCES site(code),
								family INTEGER NOT NULL,
								address CHAR(14) NOT NULL,
								PRIMARY KEY(site, family, address)
							)
						]],
						[[
							CREATE TABLE vpnc (
								site VARCHAR(16) NOT NULL REFERENCES site(code),
								id INTEGER NOT NULL,
								name VARCHAR(32),
								active CHAR(1) DEFAULT '1',
								PRIMARY KEY(site, id)
							)
						]],
						[[
							CREATE TABLE greAddress (
								site VARCHAR(16) NOT NULL,
								vpnc INTEGER NOT NULL,
								family INTEGER NOT NULL,
								address CHAR(14) NOT NULL,
								PRIMARY KEY(site, vpnc, family),
								FOREIGN KEY(site, vpnc) REFERENCES vpnc(site, id),
								UNIQUE(family, address)
							)
						]],
						[[
							CREATE TABLE certificate (
								serial INTEGER NOT NULL PRIMARY KEY,
								site VARCHAR(16),
								vpnc INTEGER,
								dn VARCHAR(128) NOT NULL,
								issued DATETIME NOT NULL,
								expires DATETIME NOT NULL,
								revoked DATETIME,
								privateKey TEXT NOT NULL,
								data TEXT NOT NULL,
								FOREIGN KEY(site, vpnc) REFERENCES vpnc(site, id)
							)
						]],
						[[
							CREATE TABLE crl (
								serial INTEGER NOT NULL PRIMARY KEY,
								expires DATETIME NOT NULL,
								data TEXT NOT NULL
							)
						]]
					}
				) do execute(statement) end

				add_site{code=''}
				for _, subnet in ipairs(config.hub.subnets) do
					add_subnet('', subnet)
				end

				issue_ca_cert{
					keyCertSign=true,
					cRLSign=not config.db['encrypt-keys']
				}
			end
		},
		show={
			'',
			function()
				scan_finished()
				print_cert(get_cert_by_serial(0))
			end
		}
	},
	cert={
		generate={
			'[hubs|hub <id>|site <abbr> [vpnc <id>]|crl]',
			function()
				local crl = arg[1] == 'crl'
				local vpnc = not crl and vpnc_filter(
					scan_vpnc{active=true, multiple=true},
					's.code',
					'v.id'
				)

				local dns = {}
				local gen_crl_cert
				if not vpnc then
					gen_crl_cert = not try_load_crl_cert()
					if gen_crl_cert then
						table.insert(
							dns, {config.ca.dn}
						)
					end
				end

				local subjects = {}

				if not crl then
					local filter = vpnc or {}
					filter['s.active'] = '1'
					filter['v.active'] = '1'

					for row in select_many(
						's.code, v.id, s.asn, s.name, v.name',
						'site s INNER JOIN vpnc v ON s.code = v.site',
						filter,
						'n'
					) do
						local attrs = {
							site=row[1],
							vpnc=row[2],
							asn=row[3],
							sname=row[4],
							vname=row[5]
						}

						attrs.params = config[
							attrs.site == '' and
								'hub' or 'spoke'
						]

						local function insert()
							attrs.dn = attrs.params.dn:gsub(
								'%$(%u+)',
								{
									ROOT=config.ca.dn,
									SITE=attrs.sname or
										attrs.site,
									NAME=attrs.vname or
										attrs.params[
											'default-name'
										]:gsub(
											'$ID',
											attrs.vpnc
										)
								}
							)

							table.insert(
								subjects, attrs
							)
							table.insert(
								dns,
								{tostring(attrs.dn)}
							)
						end

						if vpnc then insert()

						else
							local valid
							for cert in select_certs{
								site=row[1],
								vpnc=row[2]
							} do
								if is_valid(
									cert,
									attrs.params.renewal
								) then
									valid = true
								end
							end
							if not valid then
								insert()
							end
						end
					end
				end

				if #dns == 0 then return end
				confirm('certificates', 'issued', dns)

				local issued = {}

				if gen_crl_cert then
					table.insert(
						issued,
						issue_ca_cert{cRLSign=true}
					)
				end

				for _, attrs in ipairs(subjects) do
					local asn = attrs.asn
					attrs.asn = nil

					local cert = issue_cert(
						attrs,
						function(cert, attrs)

							local function get_subnets()
								return select_many(
									'family, address',
									'subnet',
									{site=attrs.site},
									'a'
								)
							end

							local function get_gre_addrs()
								return select_many(
									'family, address',
									'greAddress',
									{site=attrs.site, vpnc=attrs.vpnc},
									'a'
								)
							end

							cert:addExtension(
								x509ext.new(
									dmvpn.OID_IS_HUB,
									'critical,DER',
									asn1.boolean.encode(attrs.site == '')
								)
							)

							local hosts = config.hub.hosts
							if hosts then
								cert:addExtension(
									x509ext.new(
										dmvpn.OID_HUB_HOSTS,
										'DER',
										asn1.sequence_of(asn1.ia5string).encode(hosts)
									)
								)
							end

							local net_config = {}
							local pr_config = {}
							for subnet in get_subnets() do
								local f = subnet.family
								if not pr_config[f] then
									pr_config[f] = {}
									table.insert(
										net_config,
										{
											addressFamily={afi=f},
											ipAddressChoice={
												addressesOrRanges=pr_config[f]
											}
										}
									)
								end
								table.insert(
									pr_config[f],
									{addressPrefix=subnet.address}
								)
							end
							if net_config[1] then
								cert:addExtension(
									x509ext.new(
										'sbgp-ipAddrBlock',
										'critical,DER',
										rfc3779.IPAddrBlocks.encode(net_config)
									)
								)
							end

							local san
							for ga in get_gre_addrs() do
								if not san then
									san = x509an.new()
								end
								san:add(
									'IP',
									ga.address
								)
							end
							if san then
								cert:setSubjectAlt(san)
							end

							cert:addExtension(
								x509ext.new(
									'sbgp-autonomousSysNum',
									'critical,DER',
									rfc3779.ASIdentifiers.encode{
										asnum={asIdsOrRanges={{id=asn}}}
									}
								)
							)
						end
					)
					export_cert(cert)
					table.insert(issued, cert)
				end

				return format_cert_info(issued)
			end
		},
		list={
			cert_select_syntax,
			function()
				local certs = {}
				for cert in get_certs() do
					table.insert(certs, cert)
				end
				return format_cert_info(certs)
			end
		},
		show={
			cert_select_syntax,
			function()
				for cert in get_certs() do print_cert(cert) end
			end
		},
		revoke={
			cert_select_syntax,
			function()
				local certs = {}
				for cert, selector in get_certs() do
					local valid = is_valid(cert)
					if selector == 'serial' and
						not valid then

						user_error('Certificate already expired or revoked')
					end
					if valid then
						table.insert(certs, cert)
					end
				end
				confirm(
					'certificates',
					'revoked',
					format_cert_info(certs)
				)

				local revoked = {}
				for _, cert in ipairs(certs) do
					table.insert(
						revoked,
						revoke{serial=cert.serial}[1]
					)
				end
				return format_cert_info(revoked)
			end
		},
		export={
			'serial <num>',
			function()
				local _, serial = scan_param(
					'serial', 'selector'
				)
				local cert = get_cert_by_serial(serial)
				if not cert.site then
					user_error(
						'Cannot export CA certificate'
					)
				end
				export_cert(cert)
			end
		}
	},
	crl={
		generate={
			'',
			function()
				scan_finished()
				io.write(tostring(generate_crl()))
			end
		},
		show={
			'',
			function()
				scan_finished()
				io.write(get_crl():text())
			end
		},
		export={
			'',
			function()
				scan_finished()
				io.write(tostring(get_crl()))
			end
		}
	},
	password={
		set={
			'',
			function()
				for row in select_many(
					'serial, privateKey',
					'certificate',
					'serial = 0 OR site IS NOT NULL',
					'n'
				) do
					update(
						'certificate',
						{
							privateKey=encrypt_key(
								decrypt_key(row[2]),
								true,
								'new'
							)
						},
						{serial=row[1]}
					)
				end
			end
		}
	}
}

function scan_command(obj_type)
	local prefix = 'dmvpn-ca '..(obj_type and obj_type..' ' or '')
	local cmd = scan_next()
	local choices = obj_type and commands[obj_type] or commands
	local res = choices[cmd]
	if res then
		if obj_type then return prefix..cmd..' '..res[1], res[2] end
		return scan_command(cmd)
	end
	print('Available commands:')
	for k, v in pairs(choices) do
		print('\t'..prefix..k..(obj_type and ' '..v[1] or ''))
	end
end

syntax, func = scan_command()

if syntax then
	status, output = xpcall(
		func,
		function(err)
			if type(err) == 'string' then
				return err..'\n'..debug.traceback()
			end
			return err[2]..(err[1] and '\n\nSyntax: '..syntax or '')
		end
	)
	if status then
		if output then print_table(output) end
		if conn then conn:commit() end
	else print(output) end
end

if sql then
	if conn then
		collectgarbage()
		conn:close()
	end
	sql:close()
end

os.exit(status and 0 or 1)
