#!/usr/bin/python

from ldaptor.protocols.ldap import ldapclient, distinguishedname, ldapconnector
from ldaptor.protocols import pureber, pureldap
from ldaptor import usage, ldapfilter
from socket import inet_aton, inet_ntoa
import sys
from twisted.internet import protocol, defer, reactor

def my_aton_octets(ip):
    s=inet_aton(ip)
    octets=map(None, s)
    n=0L
    for o in octets:
	n=n<<8
	n+=ord(o)
    return n

def my_aton_numbits(num):
    n=0L
    while num>0:
	n>>=1
	n |= 2**31
	num-=1
    return n

def my_aton(ip):
    try:
	i=int(ip)
    except ValueError:
	return my_aton_octets(ip)
    else:
	return my_aton_numbits(i)

def my_ntoa(n):
    s=(
	chr((n>>24)&0xFF)
	+ chr((n>>16)&0xFF)
	+ chr((n>>8)&0xFF)
	+ chr(n&0xFF)
       )
    ip=inet_ntoa(s)
    return ip

def printIPAddress(name, ip):
    print 'A'+name+'.%|86400|'+ip

def printPTR(name, ip):
    octets = ip.split('.')
    octets.reverse()
    octets.append('in-addr.arpa.')
    print 'P'+('.'.join(octets))+'|86400|'+name+'.%'

class HostIPAddress:
    def __init__(self, host, ipAddress):
	self.host=host
	self.ipAddress=ipAddress

    def printZone(self, domain):
	print '#  '+self.host.dn
	printIPAddress(self.host.name+'.'+domain, self.ipAddress)
	printPTR(self.host.name+'.'+domain, self.ipAddress)

    def __repr__(self):
	return (self.__class__.__name__
		+'('
		+'host=%s, ' % id(self.host)
		+'ipAddress=%s' % repr(self.ipAddress)
		+')')

class Host:
    def __init__(self, dn, name, ipAddresses):
	self.dn=dn
	self.name=name
	self.ipAddresses=[HostIPAddress(self, ip) for ip in ipAddresses]

    def __repr__(self):
	return (self.__class__.__name__
		+'('
		+'dn=%s, ' % repr(self.dn)
		+'name=%s, ' % repr(self.name)
		+'ipAddresses=%s' % repr(self.ipAddresses)
		+')')

class Net:
    def __init__(self, dn, name, address, mask):
	self.dn=dn
	self.name=name
	self.address=address
	self.mask=mask
	self.hosts=[]

    def isInNet(self, ipAddress):
	net = my_aton(self.address)
	mask = my_aton(self.mask)
	ip = my_aton(ipAddress)
	if ip&mask == net:
	    return 1
	return 0

    def addHost(self, host):
	assert self.isInNet(host.ipAddress)
	self.hosts.append(host)

    def printZone(self):
	print '#'+self.dn
	printIPAddress(self.name, self.address)
	printPTR(self.name, self.address)
	printIPAddress('netmask.'+self.name, self.mask)
	ip = my_aton(self.address)
	mask = my_aton(self.mask)
	broadcast = my_ntoa(ip|~mask)
	printIPAddress('broadcast.'+self.name, broadcast)
	printPTR('broadcast.'+self.name, broadcast)

	print '# hosts begin'
	for host in self.hosts:
	    host.printZone(self.name)
	print '# hosts end'
	print

    def __repr__(self):
	return (self.__class__.__name__
		+'('
		+'dn=%s, ' % repr(self.dn)
		+'name=%s, ' % repr(self.name)
		+'address=%s, ' % repr(self.address)
		+'mask=%s' % repr(self.mask)
		+')')

class SearchHosts(ldapclient.LDAPSearch):
    def __init__(self, deferred, client, base, filter):
	self.entries = []
	filt=pureldap.LDAPFilter_equalityMatch(attributeDesc=pureldap.LDAPAttributeDescription('objectClass'),
					       assertionValue=pureber.BEROctetString('ipHost'))
	if filter:
	    filt = pureldap.LDAPFilter_and(value=(filter, filt))
	ldapclient.LDAPSearch.__init__(self, deferred, client,
				       baseObject=base,
				       filter=filt,
				       attributes=['ipHostNumber',
						   'cn'])
	deferred.addCallbacks(
	    callback=lambda searchHosts: searchHosts.entries,
	    errback=lambda x: x,
	    )

    def handle_entry(self, objectName, attributes):
	args = {}
	for k,vs in attributes:
	    k=str(k)
	    args[k]=vs

	assert len(args['cn'])==1, \
	       "object %s attribute 'cn' has multiple values: %s" \
	       % (objectName, args['cn'])

	self.entries.append(Host(str(objectName),
				 str(args['cn'][0]),
				 map(str, args['ipHostNumber'])))

class SearchNets(ldapclient.LDAPSearch):
    def __init__(self, deferred, client, base, filter):
	self.entries = []
	filt=pureldap.LDAPFilter_and(value=(
	    pureldap.LDAPFilter_present('cn'),
	    pureldap.LDAPFilter_present('ipNetworkNumber'),
	    pureldap.LDAPFilter_present('ipNetmaskNumber'),
	    ))
	if filter:
	    filt = pureldap.LDAPFilter_and(value=(filter, filt))
	ldapclient.LDAPSearch.__init__(self, deferred, client,
				       baseObject=base,
				       filter=filt,
				       attributes=['cn',
						   'ipNetworkNumber',
						   'ipNetmaskNumber',
						   ])
	deferred.addCallbacks(
	    callback=lambda searchNets: searchNets.entries,
	    errback=lambda x: x,
	    )

    def handle_entry(self, objectName, attributes):
	args = {}
	for k,vs in attributes:
	    k=str(k)
	    args[k]=map(str, vs)

	assert len(args['cn'])==1, \
	       "object %s attribute 'cn' has multiple values: %s" \
	       % (objectName, args['cn'])
	assert len(args['ipNetworkNumber'])==1, \
	       "object %s attribute 'ipNetworkNumber' has multiple values: %s" \
	       % (objectName, args['ipNetworkNumber'])
	assert len(args['ipNetmaskNumber'])==1, \
	       "object %s attribute 'ipNetmaskNumber' has multiple values: %s" \
	       % (objectName, args['ipNetmaskNumber'])

	self.entries.append(Net(str(objectName),
				str(args['cn'][0]),
				str(args['ipNetworkNumber'][0]),
				str(args['ipNetmaskNumber'][0])))

class Search(ldapclient.LDAPClient):
    def connectionMade(self):
	d=self.bind()
	d.addCallback(self._handle_bind_success)
	d.addErrback(defer.logError)
	d.chainDeferred(self.factory.deferred)

    def _handle_bind_success(self, x):
	d1=defer.Deferred()
	SearchNets(d1, self, self.factory.base, self.factory.filt)
	d1.addCallbacks(callback=self.haveNets,
			errback=defer.logError)
	return d1

    def haveNets(self, nets):
	self.nets = nets
	d=defer.Deferred()
	SearchHosts(d, self, self.factory.base, self.factory.filt)
	d.addCallbacks(callback=self.haveHosts,
		       errback=defer.logError)
	return d

    def haveHosts(self, hosts):
	for host in hosts:
	    for hostIP in host.ipAddresses:
		parent=None
		for net in self.nets:
		    if net.isInNet(hostIP.ipAddress):
			parent=net
			break

		if parent:
		    parent.addHost(hostIP)
		else:
		    sys.stderr.write("IP address %s is in no net, discarding.\n" % hostIP)

	for net in self.nets:
	    net.printZone()

class SearchFactory(protocol.ClientFactory):
    protocol = Search

    def __init__(self, base, filt, deferred):
	self.base=base
	self.filt=filt
	self.deferred=deferred

    def clientConnectionFailed(self, connector, reason):
	self.deferred.errback(reason)

exitStatus=0

def error(fail):
    print >>sys.stderr, 'fail:', fail.getErrorMessage()
    global exitStatus
    exitStatus=1

def main(base, serviceLocationOverride, filter_text):
    from twisted.python import log
    log.startLogging(sys.stderr, setStdout=0)

    if filter_text is not None:
	filt = ldapfilter.parseFilter(filter_text)
    else:
	filt = None

    d=defer.Deferred()
    s=SearchFactory(base, filt, d)
    d.addErrback(error)
    d.addBoth(lambda x: reactor.stop())
    dn = distinguishedname.DistinguishedName(stringValue=base)
    c=ldapconnector.LDAPConnector(reactor, dn, s, overrides=serviceLocationOverride)
    c.connect()
    reactor.run()
    sys.exit(exitStatus)

class MyOptions(usage.Options, usage.Options_service_location, usage.Options_base):
    """LDAPtor maradns zone file exporter"""
    def parseArgs(self, filter=None):
	self.opts['filter'] = filter

if __name__ == "__main__":
    import sys
    try:
	config = MyOptions()
	config.parseOptions()
    except usage.UsageError, ue:
	sys.stderr.write('%s: %s\n' % (sys.argv[0], ue))
	sys.exit(1)

    main(config.opts['base'],
         config.opts['service-location'],
         config.opts['filter'])
