/** @file uphpmvault.cc

    written by Marc Singer
    8 Jun 2008

    Copyright (C) 2008 Marc Singer

    This program is free software; you can redistribute it and/or
    modify it under the terms of the GNU General Public License
    version 2 as published by the Free Software Foundation.  Please
    refer to the file debian/copyright for further details.

    @brief Upload recovery firmware to an HP Media Vault (MV2).  The
    U-BOOT boot loader on the Media Vault will enter a recovery mode
    if the firmware fails to load or if the device is powered on while
    the reset button is pressed.  This program watches for 'beacon'
    broadcast message UDP packets from Media Vault devices on the
    local Ethernet network.  It will upload a firmware image using
    TFTP [1] to the first found MediaVault.

    [1] http://www.faqs.org/rfcs/rfc1350.html

    NOTES
    -----

    o The address received from the beacon packet is not the address
      we use to send the TFTP commands.  Instead, we use the address
      in the UDP header since that will be most reliable.  Still, we
      require that the IP address in the beacon be non-zero.

    o The main loop should wait for beacon packets for a little more
      than 5 seconds before initiating a recovery.  If there are
      several devices waiting for recovery, we'll have seen them all.
      If there is more than one, a command line switch will
      distinguish between them.

    o poll vs. epoll.  We're using poll because we aren't a high
      performance application and the semantics of poll are simpler
      for our application.  With epoll, we'd have to read from the
      beacon socket until we saw an EWOULDBLOCK before entering epoll.
      With poll, we can simply call poll and wait.

    o Need to add a check for an application timeout if we don't
      receive ACKs or when there are no Beacons.

    o Note that the number of blocks sent will include an empty block
      if the file is an even multiple 512 bytes, per the RFC.

    o There is no real retry logic.  We should at least be able to
      cope with dropped acknowledgements.

    o alloca is used in favor of stack arrays of characters to
      guarantee alignment.

*/

#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdint.h>
#include <fcntl.h>
#include <errno.h>
#include <memory.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <poll.h>
#include <argp.h>

#include "dumpw.h"
#include "exception.h"
#include "oprintf.h"

const char* argp_program_version = "uphpmvault 0.5";

struct arguments {
  arguments () : szImageFile (NULL), verbose (false), quiet (false) {}

  const char* szImageFile;
  bool verbose;;
  bool quiet;
};

struct argp_option options[] = {
//  { "verbose",		'v', 0, 0, "Verbose output, when available" },
  { "quiet",		'q', 0, 0, "Suppress output" },
  { 0 }
};


error_t arg_parser (int key, char* arg, struct argp_state* state);

struct argp argp = {
  options, arg_parser,
  "FILE",
"  Upload recovery firmware to HP MediaVault (MV2) devices awaiting recovery.\n"
  "\v"
  "This program listens for broadcast beacon packets on any network\n"
  "interface from HP MediaVault devices in recovery mode.  On receipt of\n"
  "a valid beacon message, the application will upload the specified\n"
  "firmware image file to the device and then terminate.\n"
  "\n"
  "The format of the recovery image is a pair of U-BOOT image files,\n"
  "concatenated with padding between them such that the second image\n"
  "starts exactly 2MiB from the beginning of the combined image file.\n"
  "The first U-BOOT image is a Linux kernel and the second is an initrd\n"
  "or initramfs.\n"
};


error_t arg_parser (int key, char* arg, struct argp_state* state)
{
  struct arguments& args = *(struct arguments*) state->input;

  switch (key) {

  case 'v':
    args.verbose = true;
    break;

  case 'q':
    args.quiet = true;
    break;

  case ARGP_KEY_ARG:
    if (args.szImageFile)
      argp_error (state, "only one image file is permitted");
    args.szImageFile = arg;
    break;
  case ARGP_KEY_END:
    if (!args.szImageFile)
      argp_error (state, "you must specify an image file");
    break;
  default:
    return ARGP_ERR_UNKNOWN;
  }
  return 0;
};

/** Enumeration of the system state. */

enum {
  Idle = 0,			///< Waiting for a beacon
  BeaconReceived,		///< Beacon detected
  WRQSent,			///< TFTP Write RQ sent
  Acked,			///< ACK received to last command
  DataSent,			///< Data packet sent
  Done,				///< Upload complete
};

#define PACKET_TYPE_RECOVERY_MODE (0xde)
#define PORT_BEACON		(8488)
#define VERSION_BEACON		(1) ///< Only one version of Beacon recognized
#define STATE_RECOVERY		(1) ///< Only one state exists

/** Payload of a Beacon message. */

struct Beacon {
  // header
  uint8_t packet_type;		///< Beacon packet type
  uint8_t version;
  // payload
  uint16_t state;
  uint8_t ip_addr[4];		///< IP address of device (redundant)
  uint8_t mac_addr[6];		///< MAC address of device (redundant)
  uint8_t szName[16];		///< Name of device manufacturer
  uint8_t szModel[50];		///< Model name of device

  Beacon () {
    bzero (this, sizeof (*this)); }

  bool is (void) {
    return version == VERSION_BEACON
      && packet_type == PACKET_TYPE_RECOVERY_MODE
      && state == STATE_RECOVERY; }
  bool is_mediavault2 (void) {
    return strcasecmp ((const char*) szName, "HP P2 NAS") == 0; }
  bool is_valid (void) {
    return ip_addr[0] | ip_addr[1] | ip_addr[2] | ip_addr[3]; }

} __attribute__((packed));

std::ostream& operator<< (std::ostream& o, const Beacon& b) {
  o << oprintf ("%-16.16s %02x:%02x:%02x:%02x:%02x:%02x (%d.%d.%d.%d)",
                b.szName,
                b.mac_addr[0], b.mac_addr[1], b.mac_addr[2],
                b.mac_addr[3], b.mac_addr[4], b.mac_addr[5],
                b.ip_addr[0], b.ip_addr[1], b.ip_addr[2], b.ip_addr[3]);
  return o;
}

#define PORT_TFTP		(69)

#define TFTP_OP_RRQ		1
#define TFTP_OP_WRQ		2
#define TFTP_OP_DATA		3
#define TFTP_OP_ACK		4
#define TFTP_OP_ERROR		5
#define TFTP_OP_OACK		6

#define TFTP_ERROR_NONE		0
#define TFTP_ERROR_FILENOTFOUND	1
#define TFTP_ERROR_ACCESSERROR	2
#define TFTP_ERROR_DISKFULL	3
#define TFTP_ERROR_ILLEGALOP	4
#define TFTP_ERROR_UNKNOWNID	5
#define TFTP_ERROR_FILEEXISTS	6
#define TFTP_ERROR_NOUSER	7
#define TFTP_ERROR_OPTIONTERMINATE 8

#define US_BETWEEN_RETRIES	(100*1000)
#define C_RETRIES		(5)
#define CB_DATA_PAYLOAD		(512) 		// Default DATA payload size
#define MS_TIMEOUT_POLL		(200)

struct message_tftp {
  uint16_t opcode;
  uint8_t data[];
} __attribute__((packed));


/** Berkeley UDP socket wrapper class.  It can open a listening socket
    for either a specific port, or an OS selected port. */

struct Socket {
  int m_port;
  int m_s;

  Socket ()         : m_port (0),    m_s (-1) {}
  Socket (int port) : m_port (port), m_s (-1) {}

  void open_udp (void) {
    struct sockaddr_in addr;
    addr.sin_family = AF_INET;
    addr.sin_addr.s_addr = htonl (INADDR_ANY);
    addr.sin_port = htons (m_port);
    m_s = ::socket (AF_INET, SOCK_DGRAM, 0);
    if (m_s == -1)
      throw Exception ("unable to create socket");
    int options = ::fcntl (m_s, F_GETFL, 0);
    if (::fcntl (m_s, F_SETFL, options | O_NONBLOCK) != 0)
      throw Exception ("unable to make socket non-blocking");
    if (::bind (m_s, (struct sockaddr*) &addr, sizeof (addr)) != 0)
      throw Exception ("unable to bind socket");
//    printf ("m_port %d m_s %d\n", m_port, m_s);
		// Recover port if we've asked OS to pick one
    if (m_port == 0) {
      socklen_t cbAddr = sizeof (addr);
      if (::getsockname (m_s, (struct sockaddr*) &addr, &cbAddr) != 0)
	throw Exception ("unable to get socket address");
      m_port = ntohs (addr.sin_port); }}
};


/** Simple TFTP serving class.  The base class is the UDP socket
    class to handle opening the listening socket.  There are methods
    to send the important TFTP protocol elements. */

struct TFTP : Socket {
  struct sockaddr m_addr;

  TFTP () { }
  TFTP (int port) : Socket (port) { }

  void bind (struct sockaddr addr) {
    m_addr = addr;
    ((struct sockaddr_in*) &m_addr)->sin_port = htons (PORT_TFTP); }

  void set_port (int port) {
    ((struct sockaddr_in*) &m_addr)->sin_port = htons (port); }

  /** Send write request for named file.  The file type will be
      'octet'. */

  void wrq (const char* szFile) {
    static const char szMode[] = "octet";
    int cbMessage = sizeof (struct message_tftp) + strlen (szFile) + 1
      + strlen (szMode) + 1;
    char* rgb = (char*) alloca (cbMessage);
    ssize_t cb;
    struct message_tftp& m = *(struct message_tftp*) rgb;
    m.opcode = htons (TFTP_OP_WRQ);
    cb = sizeof (m);
    cb += snprintf (rgb + cb, cbMessage - cb, "%s", szFile) + 1;
    cb += snprintf (rgb + cb, cbMessage - cb, "%s", szMode) + 1;
//    dumpw (rgb, cb, 0, 0);
    for (int retries = C_RETRIES; retries--; ) {
      ssize_t cbSent = sendto (m_s, rgb, cb, 0, &m_addr, sizeof (m_addr));
//      printf ("sendto: %d -> %d\n", cb, cbSent);
      if (cb == cbSent)
	break;
      if (retries) {
	usleep (US_BETWEEN_RETRIES);
	continue;
      }
      throw ResultException (errno, "error sending WRQ on %d (%s)",
			     m_s, strerror (errno));
    }
  }

  /** Send write data.  The block index starts with 1.  The size of
      the buffer to send may reference all of the data remaining in
      the file.  This function will only send messages of the
      negotiated size, probably 512 bytes.  The return value is the
      number of bytes transmited. */

  size_t wdata (int block, const char* rgb, size_t cb) {
    size_t cbSend = cb;
    if (cbSend > CB_DATA_PAYLOAD)
      cbSend = CB_DATA_PAYLOAD;
    size_t cbPacket = sizeof (message_tftp) + 2 + cbSend;
    char* rgbPacket = (char*) alloca (cbPacket);
    message_tftp& msg = *(message_tftp*) rgbPacket;
    msg.opcode = htons (TFTP_OP_DATA);
    *(unsigned short*) msg.data = htons (block);
    memcpy (msg.data + 2, rgb, cbSend);
//    printf ("sending data block %d %d bytes\n", block, cbSend);
//    dumpw (rgbPacket, 16, 0, 0);
    for (int retries = C_RETRIES; retries--; ) {
      ssize_t cbSent = sendto (m_s, rgbPacket, cbPacket, 0,
			       &m_addr, sizeof (m_addr));
//      printf ("sendto: %d %d\n", sizeof (rgbPacket), cbSent);
      if (cbSent == (ssize_t) cbPacket)
	break;
      if (retries) {
	usleep (US_BETWEEN_RETRIES);
	continue;
      }
      throw ResultException (errno, "error sending %d DATA on %d (%d '%s')",
			     cbSend, m_s, errno, strerror (errno));
    }
    return cbSend;
  }
};


/** UDP/TFTP server loop and state machine.  It waits for broadcast
    beacon messages and uploads the recovery image when it finds
    receives a valid beacon.

    The cBlocks calculation rounds up and may add a full extra block
    if the file to send is an even multiple of the block size, per the
    RFC. */

void server (struct arguments& args, void* pvFile, size_t cbFile)
{
  Socket sBeacon = Socket (PORT_BEACON);
  sBeacon.open_udp ();
  TFTP tftp;
  Beacon b;
  tftp.open_udp ();

  struct pollfd fds[] = {
    { sBeacon.m_s, POLLIN, 0 },
    { tftp.m_s,    POLLIN, 0 },
  };
  int result;

  int state = Idle;
  int block = 0;
  int cBlocks = (cbFile + CB_DATA_PAYLOAD)/CB_DATA_PAYLOAD;
  size_t ibFile = 0;

  struct sockaddr addr;
  socklen_t cbAddr = sizeof (addr);

  while (state != Done) {
    char rgb[1500];
    int cb;

    switch (state) {

      // Read from our sockets
    case Idle:
    case WRQSent:
    case DataSent:
      if ((result = poll (fds, sizeof (fds)/sizeof (*fds),
                          MS_TIMEOUT_POLL)) > 0) {
		// Read Beacon
        if (fds[0].revents & POLLIN) {
          cb = recvfrom (sBeacon.m_s, rgb, sizeof (rgb), 0,
                         &addr, &cbAddr);
          if (cb > 0 && !b.is ()) {
            memcpy (&b, rgb, sizeof (b));
            bool recover
              = state == Idle && b.is_mediavault2 () && b.is_valid ();
            if (!args.quiet)
              std::cout << b << (recover ? "" : " [ignoring]") << std::endl;
            if (state == Idle && b.is_mediavault2 () && b.is_valid ()) {
              if (!args.quiet) {
                printf ("Recovering %02x:%02x:%02x:%02x:%02x:%02x\n",
                        b.mac_addr[0], b.mac_addr[1], b.mac_addr[2],
                        b.mac_addr[3], b.mac_addr[4], b.mac_addr[5]);
                printf ("Sending file of %d bytes in %d blocks of %d bytes\n",
                        cbFile, cBlocks, CB_DATA_PAYLOAD);
              }
              tftp.bind (addr);
              state = BeaconReceived;
            }
          }
        }
	        // Read ACKs
        if (fds[1].revents & POLLIN) {
          if ((cb = recvfrom (tftp.m_s, rgb, sizeof (rgb), 0,
                              &addr, &cbAddr) > 0)) {
            uint16_t opcode;
            memcpy (&opcode, &(*(message_tftp*) rgb).opcode, sizeof (opcode));
            opcode = ntohs (opcode);
            uint16_t blockAck;
            memcpy (&blockAck, &(*(message_tftp*) rgb).data, sizeof (blockAck));
            blockAck = ntohs (blockAck);
            if (opcode == TFTP_OP_ACK && blockAck == block) {
              if (state == WRQSent)
                tftp.set_port (ntohs (((sockaddr_in*)&addr)->sin_port));
              state = Acked;
              ++block;
            }
          }
        }
      }
      break;

      // We've seen a beacon, send the write request
    case BeaconReceived:
      tftp.wrq ("uImage");      // The name, itself, is unimportant
      state = WRQSent;
      break;

      // We've been ACKd, send another block or terminate
    case Acked:
      if (block - 1 >= cBlocks) {
	state = Done;
	break;
      }
      ibFile += tftp.wdata (block, (const char*) pvFile + ibFile,
                            cbFile - ibFile);
      state = DataSent;
      break;
    }
  }
}

int main (int argc, char** argv)
{
  struct arguments args;
  argp_parse (&argp, argc, argv, 0, 0, &args);

  int fh = ::open (args.szImageFile, O_RDONLY);
  if (fh == -1) {
    printf ("unable to open file %s\n", args.szImageFile);
    exit (2);
  }
  struct stat stat;
  ::fstat (fh, &stat);
  size_t cb = stat.st_size;
  void* pv = ::mmap (0, cb, PROT_READ, MAP_FILE | MAP_SHARED, fh, 0);

  if (!args.quiet) {
    printf ("Recovery image %s is %d bytes\n", args.szImageFile, cb);
    printf ("Waiting for broadcast beacon packets.\n");
    printf ("Press ^C to cancel\n");
  }

  try {
    server (args, pv, cb);
  }
  catch (const Exception& e) {
    printf ("exception: %s\n", e.sz);
  }
  catch (...) {
    printf ("exception: <unrecognized>\n");
  }

  return 0;
}
