#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#ifdef HAVE_TLS

#include "IMTLS.hh"
#include "IMUtil.hh"
#include "IMLog.hh"
#include <unistd.h>

#include <openssl/ssl.h>
#include <openssl/err.h>

bool IMTLS::initialized = false;
IMTLS* IMTLS::ptls = NULL;

class IMTLSImpl : public IMTLS
{
    enum {
        NEVER_CHECK,
        ALLOW,
        TRY,
        REQUIRED
    };

    int verify_func;
    int verify_client;
    int verify_depth;
    string cert_file;
    string key_file;
    string ca_file;
    string ca_path;
    SSL_CTX *ctx;

    // lock
    static pthread_mutex_t locks[CRYPTO_NUM_LOCKS];

    static int verify(
        int ok,
        X509_STORE_CTX *ctx
    );
    static int verify_notfailed(
        int ok,
        X509_STORE_CTX *ctx
    );
    static RSA* get_rsa_cb(
        SSL *ssl,
        int is_export,
        int key_length
    );
    static void lock_callback(
        int mode,
        int type,
        const char *file,
        int line
    );
    static unsigned long id_callback();

  public:
    IMTLSImpl();
    ~IMTLSImpl();

    static void tls_error();

    virtual bool set_certificate_file(
        const string& filename
    );
    virtual bool set_certificate_key_file(
        const string &filename
    );
    virtual bool set_cacertificate_file(
        const string &filename
    );
    virtual bool set_cacertificate_path(
        const string &filename
    );
    virtual bool set_verify_client(
        const string &value
    );
    virtual bool set_verify_depth(
        const string &value
    );

    virtual bool setup();

    virtual IMSocketTrans *create_trans(
        int fd,
        int x_trans
    );
};

class IMSocketTransTLS : public IMSocketTrans
{
    SSL *ssl;
    bool err;
  public:
    virtual int send(
        const void *,
        size_t n
    );
    virtual int recv(
        void *,
        size_t n
    );
    bool error()
    { return err; }

    bool check_peer();

    IMSocketTransTLS(SSL_CTX *ctx, int fd);
    ~IMSocketTransTLS();
};


IMTLS*
IMTLS::construct()
{
    if (initialized) {
        return get_instance();
    }
    IMTLS *tls = new IMTLSImpl();
    tls->register_singleton();
    initialized = true;
}

bool
IMTLSImpl::set_verify_depth(
    const string &value
)
{
    char *ptr;
    int num;

    num = strtol(value.c_str(), &ptr, 10);
    if (*ptr != '\0') {
        LOG_WARNING("SSLVerifyDepth: (%s) is not an integer value.", value.c_str());
        return false;
    }
    verify_depth = num;
    return true;
}

bool
IMTLSImpl::set_verify_client(
    const string &value
)
{
    if(value.compare("demand")) {
        verify_client = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
        verify_func = 1;
    } else if (value.compare("allow")) {
        verify_client = SSL_VERIFY_PEER;
    } else if (value.compare("never")) {
        verify_client = SSL_VERIFY_NONE;
    } else if (value.compare("try")) {
        verify_client = SSL_VERIFY_PEER;
        verify_func = 1;
    } else {
        LOG_WARNING("SSLVerifyClient should be one of (never|allow|try|demand)");
        return false;
    }
    return true;
}

bool
IMTLSImpl::set_certificate_file(
    const string &filename
)
{
    cert_file = filename;
    if(!SSL_CTX_use_certificate_file(ctx, filename.c_str(), SSL_FILETYPE_PEM)) {
        tls_error();
    }
    return true;
}

bool
IMTLSImpl::set_certificate_key_file(
    const string &filename
)
{
    key_file = filename;
    if (!SSL_CTX_use_PrivateKey_file(ctx, filename.c_str(), SSL_FILETYPE_PEM)) {
        tls_error();
    }
    return true;
}

bool
IMTLSImpl::set_cacertificate_file(
    const string &filename
)
{
    ca_file = filename;
    return true;
}

bool
IMTLSImpl::set_cacertificate_path(
    const string &dirname
)
{
    ca_path = dirname;
    return true;
}

void
IMTLSImpl::tls_error()
{
    unsigned long err;
    char buf[BUFSIZ];

    while ((err = ERR_get_error()) != 0) {
        ERR_error_string(err, buf);
        LOG_WARNING("TLS error: %s", buf);
    }
}

unsigned long
IMTLSImpl::id_callback()
{
    unsigned long ret;

    // dirty hack
    ret = (unsigned long) pthread_self();
    return ret;
}

void
IMTLSImpl::lock_callback(
    int mode,
    int type,
    const char *file,
    int line
)
{
    if (mode & CRYPTO_LOCK) {
        pthread_mutex_lock(&locks[type]);
    } else {
        pthread_mutex_unlock(&locks[type]);
    }
}

int
IMTLSImpl::verify(
    int ok,
    X509_STORE_CTX *ctx
)
{
    char buf[256];
    X509 *cert;
    int err;
    int depth;

    cert = X509_STORE_CTX_get_current_cert(ctx);
    err = X509_STORE_CTX_get_error(ctx);
    depth = X509_STORE_CTX_get_error_depth(ctx);

    X509_NAME_oneline(X509_get_subject_name(cert), buf, 256);

    LOG_DEBUG("%s(depth %d):%s", buf, depth, X509_verify_cert_error_string(err));
    return ok;
}

int
IMTLSImpl::verify_notfailed(
    int ok,
    X509_STORE_CTX *ctx
)
{
    verify(ok, ctx);
    return 1;
}

RSA*
IMTLSImpl::get_rsa_cb(
    SSL *ssl,
    int is_export,
    int key_length
)
{
    RSA *rsa;

    rsa = RSA_generate_key(key_length, RSA_F4, NULL, NULL);

    return rsa;
}

bool
IMTLSImpl::setup()
{

    // when user specifies a certificate file, but doesn't specify a private key file,
    // use the certificate file as a private key file.
    if (!cert_file.empty() && key_file.empty()) {
        set_certificate_key_file(cert_file);
    }

    // user specifies a CA file or CA dir.
    if (verify_depth > 0) SSL_CTX_set_verify_depth(ctx, verify_depth);
    if (!ca_file.empty() || !ca_path.empty()) {
        if (!SSL_CTX_load_verify_locations(ctx, 
                                ca_file.empty() ? NULL : ca_file.c_str(),
                                ca_path.empty() ? NULL : ca_path.c_str()) ||
            !SSL_CTX_set_default_verify_paths(ctx)) {
            tls_error();
        }
        STACK_OF(X509_NAME) *ca_list = NULL;
        // get CAs from the CA file.
        if (!ca_file.empty()) {
            ca_list = SSL_load_client_CA_file (ca_file.c_str());
            if (!ca_list) {
                tls_error();
            }
        }
        // set CAs from the specified dir.
        if (!ca_path.empty()) {
           if (!ca_list) ca_list = sk_X509_NAME_new_null();
           if (!SSL_add_dir_cert_subjects_to_stack(ca_list, ca_path.c_str())) {
               tls_error();
           }
        }
        if (ca_list) {
            // finally, set CAs to the context.
            SSL_CTX_set_client_CA_list(ctx, ca_list);
        }
    }

    if(!SSL_CTX_set_cipher_list(ctx, SSL_DEFAULT_CIPHER_LIST)) {
         tls_error();
    }

    if (!cert_file.empty() && !key_file.empty()) {
        if (!SSL_CTX_check_private_key (ctx)) {
            tls_error();
        }
    }

    SSL_CTX_set_tmp_rsa_callback (ctx, &IMTLSImpl::get_rsa_cb);

    SSL_CTX_set_verify(ctx, verify_client, verify_func ? &IMTLSImpl::verify : &IMTLSImpl::verify_notfailed);
    return true;
}

IMSocketTrans *
IMTLSImpl::create_trans(
    int x_fd,
    int x_trans
)
{
    if (x_trans == IMSocketAddress::NORMAL)
        return new IMSocketTrans(x_fd);
    if (x_trans == IMSocketAddress::TLS) {
        IMSocketTransTLS *tls = new IMSocketTransTLS(ctx, x_fd);
        if (tls->error()) {
            tls_error();
            delete tls;
            return NULL;
        }
        // check peer's certificate
        tls->check_peer();
        return tls;
    }
    return NULL;
}

IMTLS::~IMTLS()
{
}

pthread_mutex_t IMTLSImpl::locks[CRYPTO_NUM_LOCKS];

IMTLSImpl::IMTLSImpl() :
verify_client(SSL_VERIFY_PEER), verify_depth(1), verify_func(0)
{
    SSL_load_error_strings();
    SSL_library_init();
    ctx = SSL_CTX_new(SSLv23_server_method());

    int i;
    // initialize mutexes
    for (i = 0; i < CRYPTO_NUM_LOCKS; i++) {
        pthread_mutex_init(&locks[i], NULL);
    }
    // set callback  for locking
    CRYPTO_set_locking_callback(&IMTLSImpl::lock_callback);
    CRYPTO_set_id_callback(&IMTLSImpl::id_callback);
}

IMTLSImpl::~IMTLSImpl()
{
    if (ctx) {
        SSL_CTX_free(ctx);
        ctx = NULL;
    }
    EVP_cleanup();
    ERR_remove_state(0);
    ERR_free_strings();
}

int
IMSocketTransTLS::send(
    const void *p,
    size_t n
)
{
    return SSL_write(ssl, p, n);
}

int
IMSocketTransTLS::recv(
    void *p,
    size_t n
)
{
    return SSL_read(ssl, p, n);
}

IMSocketTransTLS::IMSocketTransTLS(
    SSL_CTX *ctx,
    int x_fd
) : IMSocketTrans(x_fd), err(false)
{
    ssl = SSL_new(ctx);
    SSL_set_fd (ssl, get_fd());
    if(SSL_accept(ssl) <= 0) {
        err = true;
    }
}

IMSocketTransTLS::~IMSocketTransTLS()
{
   if (ssl) {
       int fd = SSL_get_fd (ssl);
       SSL_shutdown (ssl);
       close (fd);
       SSL_free(ssl);
   }
}

bool
IMSocketTransTLS::check_peer()
{
   X509 *peer;
  
   if ((peer = SSL_get_peer_certificate(ssl)) != NULL) {
       if (SSL_get_verify_result(ssl) != X509_V_OK) {
         IMTLSImpl::tls_error();
       }
       X509_free (peer);
   }

}

#endif
