#!/usr/bin/python
# Generate SSHFP DNS records (RFC4255) from knownhosts files or ssh-keyscan
# Copyright by Xelerance http://www.xelerance.com/
# Paul Wouters <paul@xelerance.com> and Jake Appelbaum <jacob@appelbaum.net>
# License: GNU GENERAL PUBLIC LICENSE Version 2

import os
import sys
import getopt
import base64
import sha
import commands
import time
# www.dnspython.org
try:
	import dns.resolver
	import dns.query
	import dns.zone
except:
	print "sshfp requires the python-dns package from http://www.pythondns.org/"
	print "Fedora: yum install python-dns"
	print "Debian: apt-get install python-dnspython   (NOT python-dns!)"
	sys.exit()

global all_hosts
global khfile
global version
global hostnames
global trailing
global nameserver
global quiet
global port
global timeout
global algo

def usage():

	print "usage: sshfp -k [-d] [-t algo] [-T timeout] [ knownhosts_file ] [-o output] [ [-a] | host1 [host2 ... ] ]"
	print "       sshfp -s [-d] [-t algo] [-T timeout] [ [-p port] [-a domainname] [-o output] | host1 [host2 .. ] ] [@ns]"
	print "examples:"
	print "          sshfp -k www.xelerance.com"
	print "          sshfp -k ~paul/.ssh/known_hosts -a "
	print "          sshfp -s www.openswan.org"
	print "          sshfp -s -p 2222 xelerance.com"
	print "          sshfp -s -T 30 -t rsa -d -a -o sshfp.txt xelerance.com @ns0.xelerance.net"

def create_sshfp(hostname,keytype,keyblob):
	global trailing
	if keytype == "ssh-rsa":
		keytype = "1"
	else:
		if keytype == "ssh-dss":
			keytype = "2"
		else:
			return ""
	try:
		rawkey = base64.b64decode(keyblob)
	except TypeError:
		print "FAILED on hostname "+hostname+" with keyblob "+keyblob
		return "ERROR"
	fpsha1 = sha.new(rawkey).hexdigest()
	# check for Reverse entries
	reverse = 1
	parts = hostname.split(".",3)
	if parts[0] != hostname:
		for octet in parts:
			if not octet.isdigit():
				reverse = 0
		if reverse:
			hostname = parts[3] + "." + parts[2] + "." + parts[1] + "." + parts[0] + ".in-addr.arpa."
	# we don't know wether we need a trailing dot :(
	# eg if someone did "ssh ns.foo" we don't know if this really is "ns.foo." or "ns.foo" plus resolv.conf domainname
	if trailing and not reverse:
		if hostname[-1:] != ".":
			hostname = hostname + "."
	return hostname + " IN SSHFP " + keytype + " 1 " + fpsha1

def sshfpFromFile(khfile,wantedHosts):
	# ok, let's do it
	known_hosts = os.path.expanduser(khfile)
	try:
		khfp = open(known_hosts)
	except IOError:
		print "Failed to open file "+ known_hosts
		sys.exit()
	entries = khfp.read()
	khfp.close()
	return processRaw(entries,wantedHosts)

def processRaw(entries,wantedHosts):
	global algo
	global all_hosts
	allrecords = []
	for line in entries.split("\n"):
		if line != "" and line[0] != "#" and line[0] != "|" and line !="\n":
			records = line.split(" ")
			hosts = records[0].split(",")
			for hostname in hosts:
				if all_hosts or (hostname in wantedHosts) or (hostname == wantedHosts):
					try:
						keytype = records[1]
						# note that ssh-keyscan and known_hosts don't use the same string
						# for all algos, eg ssh-dss vs -t dsa, match on all but last char 
						if (algo != "dsa,rsa") and keytype[:-1] != "ssh-%s"%algo[:-1]:
							return
						key64blob = records[2]
						record = create_sshfp(hostname,keytype,key64blob)
						if record:
							allrecords.append(record)
					except IndexError:
						pass
	if allrecords != []:
		allrecords.sort()
		# join records, dnssigner wants a newline at end of file, so add one.
		return "\n".join(allrecords)+"\n"
	else:
		return 0

def getRecord(domain,type):
	try:
		answers = dns.resolver.query(domain, type)
	except dns.resolver.NXDOMAIN:
		#print "NXdomain: "+domain
		return 0
	except dns.resolver.NoAnswer:
		#print "NoAnswer: "+domain
		return 0
	for rdata in answers:
		# just return first entry we got, answers[0].target does not work
		if type == "A":
			return rdata
        	if type == "NS":
			return str(rdata.target)
		else:
			print "error in getRecord, unknown type "+type
			sys.exit()

def getAXFRrecord(domain,ns):
	try:
		zone = dns.zone.from_xfr(dns.query.xfr(ns,domain))
	except dns.exception.FormError:
		raise dns.exception.FormError, domain
	else:
		return  zone

def sshfpFromAXFR(domain,nameserver):
	if domain[-1] == " ":
		domain = domain[:-1]
	if " " in domain:
		print "error: space in domain '"+domain+"' can't be right, aborted"
		sys.exit()
	if not nameserver:
		nameserver = getRecord(domain,"NS")
		if not nameserver:
			print "warning: no NS record found for domain "+domain+". trying as host record instead"
			# better then nothing
			return sshfpFromDNS(domain)
	hosts = ""
	#print "nameserver:" + str(ns)
	try:
		# print "trying axfr for "+domain+"@"+nameserver
		axfr = getAXFRrecord(domain,nameserver)
	except dns.exception.FormError, badDomain:
		print "AXFR error: " + nameserver + " - No permission or not authorative for " + badDomain + "; aborting"
		sys.exit()

	for (name, ttl, rdata) in axfr.iterate_rdatas('A'):
		#print "name:" +str(name) +", ttl:"+ str(ttl)+ ", rdata:"+str(rdata)
		if "@" in str(name): 
			hosts = hosts + " " + domain + "."
		else:
			if not str(name) == "localhost":
				hosts = hosts + " " + str(name) + "." + domain + "."
	return sshfpFromDNS(hosts)

def sshfpFromDNS(hosts):
	global quiet
	global port
	if hosts[-1] == " ":
		hosts = hosts[:-1]
	global timeout
	global algo
	cmd = "ssh-keyscan -p %s -T %s -t %s %s" % (port, timeout, algo, hosts)
	if quiet:
		cmd = cmd + " 2>/dev/null"
	tochild, fromchild, childerror = os.popen3(cmd, 'r')
        err = childerror.readlines()
        khdns = "\n".join(fromchild.readlines())
        for e in err:
                if e[0] != "#":
                        print >>sys.stderr, e
	return processRaw(khdns,hosts)
	
def main(argv=None):
	global all_hosts
	global trailing
	global nameserver
	global quiet
	global port
	global timeout
	global algo

	if argv is None:
		argv = sys.argv
	try:
		opts, args = getopt.getopt(argv[1:], "qhdvsT:t:a:o:k:p:", ["quiet","help","trailing-dot","version","scan","timeout:","type:","all:","output:","knownhosts:", "port:"])
	except getopt.error, msg:
		#print >>sys.stderr, err.msg
		print >>sys.stderr, "ERROR parsing options"
		usage()
		sys.exit(2) 

	# parse options
	khfile = ""
	dodns = 0
	dofile = 0
	nameserver = ""
	domain = ""
	output = ""
	quiet = 0
	version = "1.1.3"
	data = ""
	trailing = 0
	timeout = "5"
	algo = "dsa,rsa"
	all_hosts = 0
	port = 22
	hostnames = ()
	#if not opts and not args:
	#	usage()
	#	sys.exit()

	for o, a in opts:
		if o in ("-v", "--version"):
			print "sshfp version: "+version
			print "Authors:\n Paul Wouters <paul@xelerance.com>\n Jake Appelbaum <jacob@appelbaum.net>"
			print "Source : http://www.xelerance.com/software/sshfp/"
			sys.exit()
		if o in ("-h", "--help"):
			usage()
			sys.exit()
		if o in ("-d", "--trailling-dot"):
			trailing = 1
		if o in ("-T", "--timeout"):
			if not a:
				print "error: no timeout specified"
				sys.exit()
			try:
				timeout = str(int(a))
			except:
				print "error: timeout not specified in seconds"
				sys.exit()
		if o in ("-t", "--type"):
			if not a:
				print "error: no type specified"
				sys.exit()
			if (a == "rsa") or (a == "dsa"):
				algo = a
			else:
				print "error: invalid type"
				sys.exit()
		if o in ("-q", "--quiet"):
			quiet = 1
		if o in ("-p", "--port"):
			if a:
				try:
					port = int(a)
					if not quiet and port <> 22:
						print "WARNING: non-standard port numbers are not designated in SSHFP records"
				except:
					print "error: port must be a number"
					sys.exit()
		if o in ("-a", "--all"):
			all_hosts = 1
			if a:
				domain = a
		if o in ("-o", "--output"):
			if not a:
				print "error: no output file specified"
				sys.exit()
			else:
				output = a
		if o in ("-k", "--knownhosts"):
			dofile = 1
			# optional arguments dont work cleanly in python??
			if not a:
				khfile = "~/.ssh/known_hosts"
			else:
				if os.path.isfile(a):
					khfile = a 
				else:
					try:
						arec =  getRecord(a,"A")
						if arec:
							# it's really a hostname argument, not a known_hosts file.
							args.append(a)
							khfile = "~/.ssh/known_hosts"
					except:
						# no file and no domain, prob an arg mistaken as option
						if a[0]=="-":
							khfile = "~/.ssh/known_hosts"
							opts.append(a)
							# I guess we can't append opts for processing within
							# the loop. Guess we need to exec a new sshfp or refactor.
							# catch most commonly used options, eg "sshfp -k -a"
							if a == "-a":
								all_hosts = 1
							if a == "-t":
								trailing = 1
						else:
							print "error: "+a+" is neither a known_hosts file or hostname"
							sys.exit()

		if o in ("-s", "--scan"):
			dodns = 1
			# add any args to -s as arguments
			# currently not possible in getopts call
			if (a):
				args.append(a)

		# print "DEBUG: opts"				
		# print opts
		# print "DEBUG: args"				
		# print args

	if (not dodns and not dofile):
		if not args:
			# use default
			all_hosts = 1
			dofile = 1
			trailing = 1
			if not khfile:
				khfile = "~/.ssh/known_hosts"
		else:
			dodns = 1

	if (dodns and dofile):
			print "use either -k or -s"
			usage()
			sys.exit()
		
	if dodns:

		# filter special case for using @nameserver, verify for misinterpreted options as args
		newargs = ""
		for arg in args:
			if arg != "":
				if arg[0]=="@":
					#print "found ns:"+arg[1:]
					nameserver = arg[1:]
					if not all_hosts:
						print "WARNING: ssh-keyscan does not support @nameserver syntax, ignoring"
				else:
					newargs = newargs + arg +" "
				if arg[0]=="-":
					# shit, misinterpreted option as argument. We'll try to be more clever in the future.
					usage()
					sys.exit()
		if not newargs:
			print "error: No hostnames specified"
			sys.exit()
		if all_hosts:
				data = sshfpFromAXFR(newargs,nameserver)
				if not quiet:
					data = ";\n; Generated by sshfp "+ version +" from " + nameserver + " at "+ time.ctime() +"\n;\n" + data
		else:
			data = sshfpFromDNS(newargs)

	if dofile:
		data = sshfpFromFile(khfile,args)

	if not data:
		sys.exit()

	if output:
		try:
			fp = open(output,"w")
		except IOError:
			print "error: can't open '"+output+"' for writing"
			sys.exit()
        	else:
			fp.write(data)
			fp.close()
	else:
		print data[:-1]

if __name__ == "__main__":
	sys.exit(main())
