# Written by Bram Cohen
# Modified by Cameron Dale
# see LICENSE.txt for license information
#
# $Id: Connecter.py 266 2007-08-18 02:06:35Z camrdale-guest $

"""For maintaining connections to peers.

@type logger: C{logging.Logger}
@var logger: the logger to send all log messages to for this module
@type CHOKE: C{char}
@var CHOKE: the code for choke messages
@type UNCHOKE: C{char}
@var UNCHOKE: the code for unchoke messages
@type INTERESTED: C{char}
@var INTERESTED: the code for interested messages
@type NOT_INTERESTED: C{char}
@var NOT_INTERESTED: the code for not interested messages
@type HAVE: C{char}
@var HAVE: the code for have messages
@type BITFIELD: C{char}
@var BITFIELD: the code for bitfield messages
@type REQUEST: C{char}
@var REQUEST: the code for request messages
@type PIECE: C{char}
@var PIECE: the code for piece messages
@type CANCEL: C{char}
@var CANCEL: the code for cancel messages

"""

from DebTorrent.bitfield import Bitfield
from DebTorrent.clock import clock
from binascii import b2a_hex
import struct
import logging

logger = logging.getLogger('DebTorrent.BT1.Connecter')

CHOKE = chr(0)
UNCHOKE = chr(1)
INTERESTED = chr(2)
NOT_INTERESTED = chr(3)
# index
HAVE = chr(4)
# index, bitfield
BITFIELD = chr(5)
# index, begin, length
REQUEST = chr(6)
# index, begin, piece
PIECE = chr(7)
# index, begin, piece
CANCEL = chr(8)

class Connection:
    """A connection to an individual peer.
    
    @type connection: L{Encrypter.Connection}
    @ivar connection: the connection
    @type connecter: L{Connecter}
    @ivar connecter: the collection of all connections
    @type ccount: C{int}
    @ivar ccount: the number of the connection
    @type got_anything: C{boolean}
    @ivar got_anything: whether a message has ever been received on the connection
    @type next_upload: L{Encrypter.Connection}
    @ivar next_upload: the connection that will next be allowed to upload
    @type outqueue: C{list}
    @ivar outqueue: the queue of messages to send on the connection that are
        waiting for the current piece to finish sending
    @type partial_message: C{string}
    @ivar partial_message: the remaining data in the current piece being sent
    @type upload: C{Uploader.Upload}
    @ivar upload: the Uploader instance to use for the connection
    @type download: C{Downloader.Downloader}
    @ivar download: the Downloader instance to use for the connection
    @type send_choke_queued: C{boolean}
    @ivar send_choke_queued: whether to suppress the next L{CHOKE} message
    @type just_unchoked: C{int}
    @ivar just_unchoked: the time of a recent L{UNCHOKE}, if it was the first
    
    """
    
    def __init__(self, connection, connecter, ccount):
        """Initialize the class.
        
        @type connection: L{Encrypter.Connection}
        @param connection: the connection
        @type connecter: L{Connecter}
        @param connecter: the collection of all connections
        @type ccount: C{int}
        @param ccount: the number of the connection
        
        """
        
        self.connection = connection
        self.connecter = connecter
        self.ccount = ccount
        self.got_anything = False
        self.next_upload = None
        self.outqueue = []
        self.partial_message = None
        self.upload = None
        self.download = None
        self.send_choke_queued = False
        self.just_unchoked = None

    def get_ip(self, real=False):
        """Get the IP address of the connection.
        
        @type real: C{boolean}
        @param real: whether to check that the IP is the real one
            (optional, defaults to False)
        
        """
        
        return self.connection.get_ip(real)

    def get_id(self):
        """Get the Peer ID of the connection.
        
        @rtype: C{string}
        @return: the ID of the connection
        
        """
        
        return self.connection.get_id()

    def get_readable_id(self):
        """Get a human readable version of the ID of the connection.
        
        @rtype: C{string}
        @return: the ID of the connection
        
        """
        
        return self.connection.get_readable_id()

    def close(self):
        """Close the connection."""
        logger.debug(self.get_ip()+': connection closed')
        self.connection.close()

    def is_locally_initiated(self):
        """Check whether the connection was established by the client.
        
        @rtype: C{boolean}
        @return: whether the connection was established by the client
        
        """
        
        return self.connection.is_locally_initiated()

    def is_encrypted(self):
        """Check whether the connection is encrypted.
        
        @rtype: C{boolean}
        @return: whether the connection is encrypted
        
        """
        
        return self.connection.is_encrypted()

    def send_interested(self):
        """Send the L{INTERESTED} message."""
        self._send_message(INTERESTED)

    def send_not_interested(self):
        """Send the L{NOT_INTERESTED} message."""
        self._send_message(NOT_INTERESTED)

    def send_choke(self):
        """Send the L{CHOKE} message."""
        if self.partial_message:
            self.send_choke_queued = True
        else:
            self._send_message(CHOKE)
            self.upload.choke_sent()
            self.just_unchoked = 0

    def send_unchoke(self):
        """Send the L{UNCHOKE} message."""
        if self.send_choke_queued:
            self.send_choke_queued = False
            logger.info(self.get_ip()+': UNCHOKE SUPPRESSED')
        else:
            self._send_message(UNCHOKE)
            if ( self.partial_message or self.just_unchoked is None
                 or not self.upload.interested or self.download.active_requests ):
                self.just_unchoked = 0
            else:
                self.just_unchoked = clock()

    def send_request(self, index, begin, length):
        """Send the L{REQUEST} message.
        
        @type index: C{int}
        @param index: the piece to request some of
        @type begin: C{int}
        @param begin: the starting offset within the piece
        @type length: C{int}
        @param length: the length of the part of the piece to get
        
        """
        
        self._send_message(REQUEST + struct.pack('>iii', index, begin, length))
        logger.debug(self.get_ip()+': sent request '+str(index)+', '+str(begin)+'-'+str(begin+length))

    def send_cancel(self, index, begin, length):
        """Send the L{CANCEL} message.
        
        Cancels a previously sent L{REQUEST} message.
        
        @type index: C{int}
        @param index: the piece that was requested
        @type begin: C{int}
        @param begin: the starting offset within the piece
        @type length: C{int}
        @param length: the length of the part of the piece to get
        
        """
        
        self._send_message(CANCEL + struct.pack('>iii', index, begin, length))
        logger.debug(self.get_ip()+': sent cancel '+str(index)+', '+str(begin)+'-'+str(begin+length))

    def send_bitfield(self, bitfield):
        """Send the L{BITFIELD} message.
        
        @type bitfield: C{string}
        @param bitfield: the bitfield to send
        
        """
        
        self._send_message(BITFIELD + bitfield)

    def send_have(self, index):
        """Send the L{HAVE} message.
        
        @type index: C{int}
        @param index: the piece index to indicate having
        
        """
        
        self._send_message(HAVE + struct.pack('>i', index))

    def send_keepalive(self):
        """Send a keepalive message."""
        self._send_message('')

    def _send_message(self, s):
        """Actually send the message.
        
        @type s: C{string}
        @param s: the message to send
        
        """
        
        if s:
            logger.debug(self.get_ip()+': SENDING MESSAGE '+str(ord(s[0]))+' ('+str(len(s))+')')
        else:
            logger.debug(self.get_ip()+': SENDING MESSAGE keepalive (0)')
        s = struct.pack('>i', len(s))+s
        if self.partial_message:
            self.outqueue.append(s)
        else:
            self.connection.send_message_raw(s)

    def send_partial(self, bytes):
        """Send a L{PIECE} message containing part of a piece.
        
        @type bytes: C{int}
        @param bytes: the number of bytes of piece data to send
        @rtype: C{int}
        @return: the actual number of bytes sent
        
        """
        
        if self.connection.closed:
            return 0
        if self.partial_message is None:
            s = self.upload.get_upload_chunk()
            if s is None:
                return 0
            index, begin, piece = s
            self.partial_message = ''.join((
                            struct.pack('>i', len(piece) + 9), PIECE,
                            struct.pack('>ii', index, begin), piece.tostring() ))
            logger.debug(self.get_ip()+': sending chunk '+str(index)+', '+str(begin)+'-'+str(begin+len(piece)))

        if bytes < len(self.partial_message):
            self.connection.send_message_raw(self.partial_message[:bytes])
            self.partial_message = self.partial_message[bytes:]
            return bytes

        q = [self.partial_message]
        self.partial_message = None
        if self.send_choke_queued:
            self.send_choke_queued = False
            self.outqueue.append(struct.pack('>i', 1)+CHOKE)
            self.upload.choke_sent()
            self.just_unchoked = 0
        q.extend(self.outqueue)
        self.outqueue = []
        q = ''.join(q)
        self.connection.send_message_raw(q)
        return len(q)

    def get_upload(self):
        """Get the L{Uploader.Upload} instance for this connection.
        
        @rtype: L{Uploader.Upload}
        @return: the Upload instance
        
        """
        
        return self.upload

    def get_download(self):
        """Get the L{Downloader.Downloader} instance for this connection.
        
        @rtype: L{Downloader.Downloader}
        @return: the Downloader instance
        
        """
        
        return self.download

    def set_download(self, download):
        """Set the L{Downloader.Downloader} instance for this connection.
        
        @type download: L{Downloader.Downloader}
        @param download: the Downloader instance
        
        """
        
        self.download = download

    def backlogged(self):
        """Check whether the connection is ready to send.
        
        @rtype: C{boolean}
        @return: whether the connection is backlogged
        
        """
        
        return not self.connection.is_flushed()

    def got_request(self, i, p, l):
        """Process a request from a peer for a part of a piece.
        
        @type i: C{int}
        @param i: the piece index
        @type p: C{int}
        @param p: the position to start at
        @type l: C{int}
        @param l: the length to send
        
        """
        
        logger.debug(self.get_ip()+': got request '+str(i)+', '+str(p)+'-'+str(p+l))
        self.upload.got_request(i, p, l)
        if self.just_unchoked:
            self.connecter.ratelimiter.ping(clock() - self.just_unchoked)
            self.just_unchoked = 0
    



class Connecter:
    """A collection of all connections to peers.
    
    @type downloader: L{Downloader.Downloader}
    @ivar downloader: the Downloader instance to use
    @type make_upload: C{method}
    @ivar make_upload: the method to create a new L{Uploader.Upload}
    @type choker: L{Choker.Choker}
    @ivar choker: the Choker instance to use
    @type numpieces: C{int}
    @ivar numpieces: the number of pieces in the download
    @type config: C{dictionary}
    @ivar config: the configration information
    @type ratelimiter: L{RateLimiter.RateLimiter}
    @ivar ratelimiter: the RateLimiter instance to use
    @type rate_capped: C{boolean}
    @ivar rate_capped: not used
    @type sched: C{method}
    @ivar sched: the method to call to schedule future actions (not used)
    @type totalup: L{Debtorrent.CurrentRateMeasure.Measure}
    @ivar totalup: the Measure instance to use
    @type connections: C{dictionary}
    @ivar connections: the collection of connections that are open
    @type external_connection_made: C{int}
    @ivar external_connection_made: greater than 0 if there have been external connections
    @type ccount: C{int}
    @ivar ccount: the largest connection number used
    
    """
    
    def __init__(self, make_upload, downloader, choker, numpieces,
            totalup, config, ratelimiter, sched = None):
        """
        
        @type make_upload: C{method}
        @param make_upload: the method to create a new L{Uploader.Upload}
        @type downloader: L{Downloader.Downloader}
        @param downloader: the Downloader instance to use
        @type choker: L{Choker.Choker}
        @param choker: the Choker instance to use
        @type numpieces: C{int}
        @param numpieces: the number of pieces in the download
        @type totalup: L{Debtorrent.CurrentRateMeasure.Measure}
        @param totalup: the Measure instance to use
        @type config: C{dictionary}
        @param config: the configration information
        @type ratelimiter: L{RateLimiter.RateLimiter}
        @param ratelimiter: the RateLimiter instance to use
        @type sched: C{method}
        @param sched: the method to call to schedule future actions
            (optional, default is None)
        
        """
        
        self.downloader = downloader
        self.make_upload = make_upload
        self.choker = choker
        self.numpieces = numpieces
        self.config = config
        self.ratelimiter = ratelimiter
        self.rate_capped = False
        self.sched = sched
        self.totalup = totalup
        self.rate_capped = False
        self.connections = {}
        self.external_connection_made = 0
        self.ccount = 0

    def how_many_connections(self):
        """Get the number of currently open connections.
        
        @rtype: C{int}
        @return: the number of open connections
        
        """
        
        return len(self.connections)

    def connection_made(self, connection):
        """Make a new connection.
        
        @type connection: L{Encrypter.Connection}
        @param connection: the new connection to make
        @rtype: L{Connection}
        @return: the new connection
        
        """
        
        self.ccount += 1
        c = Connection(connection, self, self.ccount)
        logger.debug(c.get_ip()+': connection made')
        self.connections[connection] = c
        c.upload = self.make_upload(c, self.ratelimiter, self.totalup)
        c.download = self.downloader.make_download(c)
        self.choker.connection_made(c)
        return c

    def connection_lost(self, connection):
        """Process a lost connection.
        
        @type connection: L{Encrypter.Connection}
        @param connection: the connection that was lost
        
        """
        
        c = self.connections[connection]
        logger.debug(c.get_ip()+': connection lost')
        del self.connections[connection]
        if c.download:
            c.download.disconnected()
        self.choker.connection_lost(c)

    def connection_flushed(self, connection):
        """Process a flushed connection.
        
        @type connection: L{Encrypter.Connection}
        @param connection: the connection that was flushed
        
        """
        
        conn = self.connections[connection]
        if conn.next_upload is None and (conn.partial_message is not None
               or len(conn.upload.buffer) > 0):
            self.ratelimiter.queue(conn)
            
    def got_piece(self, i):
        """Alert all the open connections that a piece was received.
        
        @type i: C{int}
        @param i: the piece index that was received
        
        """
        
        for co in self.connections.values():
            co.send_have(i)

    def got_message(self, connection, message):
        """Process a received message on a connection.
        
        @type connection: L{Encrypter.Connection}
        @param connection: the connection that the message was received on
        @type message: C{string}
        @param message: the message that was received
        
        """
        
        c = self.connections[connection]
        t = message[0]
        logger.debug(c.get_ip()+': message received '+str(ord(t))+' ('+str(len(message))+')')
        if t == BITFIELD and c.got_anything:
            logger.warning(c.get_ip()+': misplaced bitfield, closing connection')
            connection.close()
            return
        c.got_anything = True
        if (t in [CHOKE, UNCHOKE, INTERESTED, NOT_INTERESTED] and 
                len(message) != 1):
            logger.warning(c.get_ip()+': bad message length, closing connection')
            connection.close()
            return
        if t == CHOKE:
            c.download.got_choke()
        elif t == UNCHOKE:
            c.download.got_unchoke()
        elif t == INTERESTED:
            if not c.download.have.complete():
                c.upload.got_interested()
        elif t == NOT_INTERESTED:
            c.upload.got_not_interested()
        elif t == HAVE:
            if len(message) != 5:
                logger.warning(c.get_ip()+': bad message length, closing connection')
                connection.close()
                return
            i = struct.unpack('>i', message[1:])[0]
            if i >= self.numpieces:
                logger.warning(c.get_ip()+': bad piece number, closing connection')
                connection.close()
                return
            if c.download.got_have(i):
                c.upload.got_not_interested()
        elif t == BITFIELD:
            try:
                b = Bitfield(self.numpieces, message[1:])
            except ValueError:
                logger.warning(c.get_ip()+': bad bitfield, closing connection')
                connection.close()
                return
            if c.download.got_have_bitfield(b):
                c.upload.got_not_interested()
        elif t == REQUEST:
            if len(message) != 13:
                logger.warning(c.get_ip()+': bad message length, closing connection')
                connection.close()
                return
            i = struct.unpack('>i', message[1:5])[0]
            if i >= self.numpieces:
                logger.warning(c.get_ip()+': bad piece number, closing connection')
                connection.close()
                return
            c.got_request(i, struct.unpack('>i', message[5:9])[0], 
                struct.unpack('>i', message[9:])[0])
        elif t == CANCEL:
            if len(message) != 13:
                logger.warning(c.get_ip()+': bad message length, closing connection')
                connection.close()
                return
            i = struct.unpack('>i', message[1:5])[0]
            if i >= self.numpieces:
                logger.warning(c.get_ip()+': bad piece number, closing connection')
                connection.close()
                return
            c.upload.got_cancel(i, struct.unpack('>i', message[5:9])[0], 
                struct.unpack('>i', message[9:])[0])
        elif t == PIECE:
            if len(message) <= 9:
                logger.warning(c.get_ip()+': bad message length, closing connection')
                connection.close()
                return
            i = struct.unpack('>i', message[1:5])[0]
            if i >= self.numpieces:
                logger.warning(c.get_ip()+': bad piece number, closing connection')
                connection.close()
                return
            if c.download.got_piece(i, struct.unpack('>i', message[5:9])[0], message[9:]):
                self.got_piece(i)
        else:
            logger.warning(c.get_ip()+': unknown message type, closing connection')
            connection.close()
