/*
 * Drizzle Client & Protocol Library
 *
 * Copyright (C) 2008 Eric Day (eday@oddments.org)
 * All rights reserved.
 *
 * Use and distribution licensed under the BSD license.  See
 * the COPYING file in this directory for full text.
 */

/**
 * @file
 * @brief Handshake definitions
 */

#include "common.h"

/*
 * Client definitions
 */

drizzle_return_t drizzle_server_handshake_read(drizzle_con_st *con)
{
  if (DRIZZLE_STATE_NONE(con))
  {
    DRIZZLE_STATE_PUSH(con, drizzle_state_server_handshake_read)
    DRIZZLE_STATE_PUSH(con, drizzle_state_packet_read)
  }

  return drizzle_state_loop(con);
}

drizzle_return_t drizzle_client_handshake_write(drizzle_con_st *con)
{
  if (DRIZZLE_STATE_NONE(con))
  {
    DRIZZLE_STATE_PUSH(con, drizzle_state_write)
    DRIZZLE_STATE_PUSH(con, drizzle_state_client_handshake_write)
  }

  return drizzle_state_loop(con);
}

/*
 * Server definitions
 */

drizzle_return_t drizzle_server_handshake_write(drizzle_con_st *con)
{
  if (DRIZZLE_STATE_NONE(con))
  {
    DRIZZLE_STATE_PUSH(con, drizzle_state_write)
    DRIZZLE_STATE_PUSH(con, drizzle_state_server_handshake_write)
  }

  return drizzle_state_loop(con);
}

drizzle_return_t drizzle_client_handshake_read(drizzle_con_st *con)
{
  if (DRIZZLE_STATE_NONE(con))
  {
    DRIZZLE_STATE_PUSH(con, drizzle_state_client_handshake_read)
    DRIZZLE_STATE_PUSH(con, drizzle_state_packet_read)
  }

  return drizzle_state_loop(con);
}

/*
 * Internal state functions.
 */

drizzle_return_t drizzle_state_server_handshake_read(drizzle_con_st *con)
{
  uint8_t *ptr;

  PDEBUG("drizzle_state_server_handshake_read", "%5zu %5zu", con->buffer_size,
         con->packet_size)

  /* Assume the entire handshake packet will fit in the buffer. */
  if (con->buffer_size < con->packet_size)
  {
    DRIZZLE_STATE_PUSH(con, drizzle_state_read)
    return DRIZZLE_RETURN_OK;
  }

  if (con->packet_size < 46)
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_server_handshake_read",
                      "bad packet size:>=46:%zu", con->packet_size)
    return DRIZZLE_RETURN_BAD_HANDSHAKE_PACKET;
  }

  con->protocol_version= con->buffer_ptr[0];
  con->buffer_ptr++;

  if (con->protocol_version != 10)
  {
    /* This is a special case where the server determines that authentication
       will be impossible and denies any attempt right away. */
    if (con->protocol_version == 255)
    {
      DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_server_handshake_read",
                        "%.*s", (int32_t)con->packet_size - 3,
                        con->buffer_ptr + 2)
      return DRIZZLE_RETURN_AUTH_FAILED;
    }

    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_server_handshake_read",
                      "protocol version not supported:%d",
                      con->protocol_version)
    return DRIZZLE_RETURN_PROTOCOL_NOT_SUPPORTED;
  }

  /* Look for null-terminated server version string. */
  ptr= memchr(con->buffer_ptr, 0, con->buffer_size - 1);
  if (ptr == NULL)
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_server_handshake_read",
                      "server version string not found")
    return DRIZZLE_RETURN_BAD_HANDSHAKE_PACKET;
  }

  if (con->packet_size != (46 + (size_t)(ptr - con->buffer_ptr)))
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_server_handshake_read",
                      "bad packet size:%zu:%zu",
                      (46 + (size_t)(ptr - con->buffer_ptr)), con->packet_size)
    return DRIZZLE_RETURN_BAD_HANDSHAKE_PACKET;
  }

  strncpy(con->server_version, (char *)con->buffer_ptr,
          DRIZZLE_MAX_SERVER_VERSION_SIZE);
  con->server_version[DRIZZLE_MAX_SERVER_VERSION_SIZE - 1]= 0;
  con->buffer_ptr+= ((ptr - con->buffer_ptr) + 1);

  con->thread_id= (uint32_t)DRIZZLE_GET_BYTE4(con->buffer_ptr);
  con->buffer_ptr+= 4;

  con->scramble= con->scramble_buffer;
  memcpy(con->scramble, con->buffer_ptr, 8);
  /* Skip scramble and filler. */
  con->buffer_ptr+= 9;

  /* Even though drizzle_capabilities is more than 2 bytes, the protocol only
     allows for 2. This means some capabilities are not possible during this
     handshake step. The options beyond 2 bytes are for client response only. */
  con->capabilities= (drizzle_capabilities_t)DRIZZLE_GET_BYTE2(con->buffer_ptr);
  con->buffer_ptr+= 2;

  if (con->options & DRIZZLE_CON_MYSQL &&
      !(con->capabilities & DRIZZLE_CAPABILITIES_PROTOCOL_41))
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_server_handshake_read",
                      "protocol version not supported, must be MySQL 4.1+")
    return DRIZZLE_RETURN_PROTOCOL_NOT_SUPPORTED;
  }

  con->charset= con->buffer_ptr[0];
  con->buffer_ptr+= 1;

  con->status= DRIZZLE_GET_BYTE2(con->buffer_ptr);
  /* Skip status and filler. */
  con->buffer_ptr+= 15;

  memcpy(con->scramble + 8, con->buffer_ptr, 12);
  con->buffer_ptr+= 13;

  con->buffer_size-= con->packet_size;
  if (con->buffer_size != 0)
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_server_handshake_read",
                      "unexpected data after packet:%zu", con->buffer_size)
    return DRIZZLE_RETURN_UNEXPECTED_DATA;
  }

  con->buffer_ptr= con->buffer;

  DRIZZLE_STATE_POP(con)

  if (!(con->options & DRIZZLE_CON_RAW_PACKET))
  {
    DRIZZLE_STATE_PUSH(con, drizzle_state_handshake_result_read)
    DRIZZLE_STATE_PUSH(con, drizzle_state_packet_read)
    DRIZZLE_STATE_PUSH(con, drizzle_state_write)
    DRIZZLE_STATE_PUSH(con, drizzle_state_client_handshake_write)
  }

  return DRIZZLE_RETURN_OK;
}

drizzle_return_t drizzle_state_server_handshake_write(drizzle_con_st *con)
{
  uint8_t *ptr;

  /* Calculate max packet size. */
  con->packet_size= 1   /* Protocol version */
                  + strlen(con->server_version) + 1
                  + 4   /* Thread ID */
                  + 8   /* Scramble */
                  + 1   /* NULL */
                  + 2   /* Capabilities */
                  + 1   /* Language */
                  + 2   /* Status */
                  + 13  /* Unused */
                  + 12  /* Scramble */
                  + 1;  /* NULL */

  /* Assume the entire handshake packet will fit in the buffer. */
  if ((con->packet_size + 4) > DRIZZLE_MAX_BUFFER_SIZE)
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_server_handshake_write",
                      "buffer too small:%zu", con->packet_size + 4)
    return DRIZZLE_RETURN_INTERNAL_ERROR;
  }

  ptr= con->buffer_ptr;

  /* Store packet size and packet number first. */
  DRIZZLE_SET_BYTE3(ptr, con->packet_size)
  ptr[3]= 0;
  con->packet_number= 1;
  ptr+= 4;

  ptr[0]= con->protocol_version;
  ptr++;

  memcpy(ptr, con->server_version, strlen(con->server_version));
  ptr+= strlen(con->server_version);

  ptr[0]= 0;
  ptr++;

  DRIZZLE_SET_BYTE4(ptr, con->thread_id)
  ptr+= 4;

  if (con->scramble == NULL)
    memset(ptr, 0, 8);
  else
    memcpy(ptr, con->scramble, 8);
  ptr+= 8;

  ptr[0]= 0;
  ptr++;

  if (con->options & DRIZZLE_CON_MYSQL)
    con->capabilities|= DRIZZLE_CAPABILITIES_PROTOCOL_41;

  /* We can only send two bytes worth, this is a protocol limitation. */
  DRIZZLE_SET_BYTE2(ptr, con->capabilities)
  ptr+= 2;

  ptr[0]= con->charset;
  ptr++;

  DRIZZLE_SET_BYTE2(ptr, con->status)
  ptr+= 2;

  memset(ptr, 0, 13);
  ptr+= 13;

  if (con->scramble == NULL)
    memset(ptr, 0, 12);
  else
    memcpy(ptr, con->scramble + 8, 12);
  ptr+= 12;

  ptr[0]= 0;
  ptr++;

  con->buffer_size+= (4 + con->packet_size);

  /* Make sure we packed it correctly. */
  if ((size_t)(ptr - con->buffer_ptr) != (4 + con->packet_size))
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_server_handshake_write",
                      "error packing server handshake:%zu:%zu",
                      (size_t)(ptr - con->buffer_ptr), 4 + con->packet_size)
    return DRIZZLE_RETURN_INTERNAL_ERROR;
  }

  DRIZZLE_STATE_POP(con)
  return DRIZZLE_RETURN_OK;
}

drizzle_return_t drizzle_state_client_handshake_read(drizzle_con_st *con)
{
  size_t real_size;
  uint8_t *ptr;
  uint8_t scramble_size;

  PDEBUG("drizzle_state_client_handshake_read", "%5zu %5zu", con->buffer_size,
         con->packet_size)

  /* Assume the entire handshake packet will fit in the buffer. */
  if (con->buffer_size < con->packet_size)
  {
    DRIZZLE_STATE_PUSH(con, drizzle_state_read)
    return DRIZZLE_RETURN_OK;
  }

  /* This is the minimum packet size. */
  if (con->packet_size < 34)
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_client_handshake_read",
                      "bad packet size:>=34:%zu", con->packet_size)
    return DRIZZLE_RETURN_BAD_HANDSHAKE_PACKET;
  }

  real_size= 34;

  con->capabilities= DRIZZLE_GET_BYTE4(con->buffer_ptr);
  con->buffer_ptr+= 4;

  if (con->options & DRIZZLE_CON_MYSQL &&
      !(con->capabilities & DRIZZLE_CAPABILITIES_PROTOCOL_41))
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_client_handshake_read",
                      "protocol version not supported, must be MySQL 4.1+")
    return DRIZZLE_RETURN_PROTOCOL_NOT_SUPPORTED;
  }

  con->max_packet_size= (uint32_t)DRIZZLE_GET_BYTE4(con->buffer_ptr);
  con->buffer_ptr+= 4;

  con->charset= con->buffer_ptr[0];
  con->buffer_ptr+= 1;

  /* Skip unused. */
  con->buffer_ptr+= 23;

  /* Look for null-terminated user string. */
  ptr= memchr(con->buffer_ptr, 0, con->buffer_size - 32);
  if (ptr == NULL)
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_client_handshake_read",
                      "user string not found")
    return DRIZZLE_RETURN_BAD_HANDSHAKE_PACKET;
  }

  if (con->buffer_ptr == ptr)
  {
    con->user[0]= 0;
    con->buffer_ptr++;
  }
  else
  {
    real_size+= (size_t)(ptr - con->buffer_ptr);
    if (con->packet_size < real_size)
    {
      DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_client_handshake_read",
                        "bad packet size:>=%zu:%zu", real_size,
                        con->packet_size)
      return DRIZZLE_RETURN_BAD_HANDSHAKE_PACKET;
    }

    strncpy(con->user, (char *)con->buffer_ptr, DRIZZLE_MAX_USER_SIZE);
    con->user[DRIZZLE_MAX_USER_SIZE - 1]= 0;
    con->buffer_ptr+= ((ptr - con->buffer_ptr) + 1);
  }

  scramble_size= con->buffer_ptr[0];
  con->buffer_ptr+= 1;

  if (scramble_size == 0)
    con->scramble= NULL;
  else
  {
    if (scramble_size != DRIZZLE_MAX_SCRAMBLE_SIZE)
    {
      DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_client_handshake_read",
                        "wrong scramble size")
      return DRIZZLE_RETURN_BAD_HANDSHAKE_PACKET;
    }

    real_size+= scramble_size;
    con->scramble= con->scramble_buffer;
    memcpy(con->scramble, con->buffer_ptr, DRIZZLE_MAX_SCRAMBLE_SIZE);

    con->buffer_ptr+= DRIZZLE_MAX_SCRAMBLE_SIZE;
  }

  /* Look for null-terminated db string. */
  if ((34 + strlen(con->user) + scramble_size) == con->packet_size)
    con->db[0]= 0;
  else
  {
    ptr= memchr(con->buffer_ptr, 0, con->buffer_size -
                                    (34 + strlen(con->user) + scramble_size));
    if (ptr == NULL)
    {
      DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_client_handshake_read",
                        "db string not found")
      return DRIZZLE_RETURN_BAD_HANDSHAKE_PACKET;
    }

    real_size+= ((size_t)(ptr - con->buffer_ptr) + 1);
    if (con->packet_size != real_size)
    {
      DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_client_handshake_read",
                        "bad packet size:%zu:%zu", real_size, con->packet_size)
      return DRIZZLE_RETURN_BAD_HANDSHAKE_PACKET;
    }

    if (con->buffer_ptr == ptr)
    {
      con->db[0]= 0;
      con->buffer_ptr++;
    }
    else
    {
      strncpy(con->db, (char *)con->buffer_ptr, DRIZZLE_MAX_DB_SIZE);
      con->db[DRIZZLE_MAX_DB_SIZE - 1]= 0;
      con->buffer_ptr+= ((ptr - con->buffer_ptr) + 1);
    }
  }

  con->buffer_size-= con->packet_size;
  if (con->buffer_size != 0)
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_client_handshake_read",
                      "unexpected data after packet:%zu", con->buffer_size)
    return DRIZZLE_RETURN_UNEXPECTED_DATA;
  }

  con->buffer_ptr= con->buffer;

  DRIZZLE_STATE_POP(con)
  return DRIZZLE_RETURN_OK;
}

drizzle_return_t drizzle_state_client_handshake_write(drizzle_con_st *con)
{
  uint8_t *ptr;
  drizzle_capabilities_t capabilities;
  drizzle_return_t ret;

  /* Calculate max packet size. */
  con->packet_size= 4   /* Capabilities */
                  + 4   /* Max packet size */
                  + 1   /* Charset */
                  + 23  /* Unused */
                  + strlen(con->user) + 1
                  + 1   /* Scramble size */
                  + DRIZZLE_MAX_SCRAMBLE_SIZE
                  + strlen(con->db) + 1;

  /* Assume the entire handshake packet will fit in the buffer. */
  if ((con->packet_size + 4) > DRIZZLE_MAX_BUFFER_SIZE)
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_client_handshake_write",
                      "buffer too small:%zu", con->packet_size + 4)
    return DRIZZLE_RETURN_INTERNAL_ERROR;
  }

  ptr= con->buffer_ptr;

  /* Store packet size at the end since it may change. */
  ptr[3]= con->packet_number;
  con->packet_number++;
  ptr+= 4;

  if (con->options & DRIZZLE_CON_MYSQL)
    con->capabilities|= DRIZZLE_CAPABILITIES_PROTOCOL_41;

  capabilities= con->capabilities & DRIZZLE_CAPABILITIES_CLIENT;
  capabilities&= ~(DRIZZLE_CAPABILITIES_COMPRESS | DRIZZLE_CAPABILITIES_SSL);
  if (con->db[0] == 0)
    capabilities&= ~DRIZZLE_CAPABILITIES_CONNECT_WITH_DB;

  DRIZZLE_SET_BYTE4(ptr, capabilities)
  ptr+= 4;

  DRIZZLE_SET_BYTE4(ptr, con->max_packet_size)
  ptr+= 4;

  ptr[0]= con->charset;
  ptr++;

  memset(ptr, 0, 23);
  ptr+= 23;

  ptr= drizzle_pack_auth(con, ptr, &ret);
  if (ret != DRIZZLE_RETURN_OK)
    return ret;

  con->buffer_size+= (4 + con->packet_size);

  /* Make sure we packed it correctly. */
  if ((size_t)(ptr - con->buffer_ptr) != (4 + con->packet_size))
  {
    DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_client_handshake_write",
                      "error packing client handshake:%zu:%zu",
                      (size_t)(ptr - con->buffer_ptr), 4 + con->packet_size)
    return DRIZZLE_RETURN_INTERNAL_ERROR;
  }

  /* Store packet size now. */
  DRIZZLE_SET_BYTE3(con->buffer_ptr, con->packet_size)

  DRIZZLE_STATE_POP(con)
  return DRIZZLE_RETURN_OK;
}

drizzle_return_t drizzle_state_handshake_result_read(drizzle_con_st *con)
{
  drizzle_return_t ret;
  drizzle_result_st result;

  if (drizzle_result_create(con, &result) == NULL)
    return DRIZZLE_RETURN_MEMORY;

  con->result= &result;

  ret= drizzle_state_result_read(con);
  if (DRIZZLE_STATE_NONE(con))
  {
    if (ret == DRIZZLE_RETURN_OK)
    {
      if (drizzle_result_eof(&result))
      {
        DRIZZLE_ERROR_SET(con->drizzle, "drizzle_state_handshake_result_read",
                          "old insecure authentication mechanism not supported")
        ret= DRIZZLE_RETURN_AUTH_FAILED;
      }
      else
        con->options|= DRIZZLE_CON_READY;
    }
  }

  drizzle_result_free(&result);

  if (ret == DRIZZLE_RETURN_ERROR_CODE)
    return DRIZZLE_RETURN_HANDSHAKE_FAILED;

  return ret;
}
