/*  Copyright (c) 2005 Romain BONDUE
    This file is part of RutilT.

    RutilT is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    RutilT is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with RutilT; if not, write to the Free Software
    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
*/
/** \file RTDrivers.cxx
    \author Romain BONDUE
    \date 30/08/2005 */
#include <cstring>
#include <sstream>

extern "C"{
#include <sys/ioctl.h> // SIOCDEVPRIVATE
}

#include "RTDrivers.h"
#include "ErrorsCode.h"



namespace
{
    class CScanResult
    {
      protected :
        struct SSSid
        {
            unsigned long Length;
            unsigned char Text [32];

        }; // SSSid

        struct SConfigFH
        {
            unsigned long m1;
            unsigned long m2;
            unsigned long m3;
            unsigned long m4;

        }; // SConfigFH

        struct SConfig
        {
            unsigned long m1;
            unsigned long m2;
            unsigned long m3;
            unsigned long Frequency;
            SConfigFH m5;

        }; // SConfig

        struct SEncryptionD
        {
            unsigned char Id;
            unsigned char Length;
            unsigned char Data [32];

        }; // SEncryptionD

        struct SEncryptionOffset
        {
            unsigned char m1 [8];
            unsigned short m2;
            unsigned short m3;

        }; // SEncryptionOffset


      private :
        static const unsigned NbRates = 16;

        unsigned long m_Length;
        char m_MacAddress [6];
        unsigned char m_Reserved [2];
        SSSid m_SSID;
        unsigned long m_Privacy;
        long m_SignalStrength;
        int m_NetworkType;
        SConfig m_Config;
        int m_NetworkInfrastructureType;
        unsigned char m_SupportedRates [NbRates];
        unsigned long m_IELength;
        unsigned char m_pIEs [sizeof (SEncryptionOffset) +
                                                  sizeof (SEncryptionD) * 4];


      public :
        CScanResult () throw() {memset (this, 0, sizeof (CScanResult));}

        unsigned GetLength () const throw() {return m_Length;}

        nsWireless::CMacAddress GetAPMacAddress () const throw (std::bad_alloc)
        {
            return nsWireless::CMacAddress (m_MacAddress);

        } // GetAPMacAddress()

        nsWireless::Mode_e GetMode () const throw()
        {
            switch (m_NetworkInfrastructureType)
            {
              case 0 : return nsWireless::AdHoc;
              case 1 : return nsWireless::Managed;
              default : return nsWireless::Auto;
            }

        } // GetMode()

            // Most things here are magic, ask Ralink...
        nsWireless::CEncryptionD GetEncryptionD () const throw()
        {
            nsWireless::CEncryptionD Descriptor;
            if (m_Privacy)
            {       // Check if it's WPA :
                const SEncryptionD* const pDescriptor
                                (reinterpret_cast<const SEncryptionD* const>
                                        (m_pIEs + sizeof (SEncryptionOffset)));
                if (pDescriptor->Id == 221 && pDescriptor->Length >= 16 &&
                    *reinterpret_cast<const unsigned long*> (pDescriptor->Data)
                                                            == 0x01F25000)
                {
                    switch (pDescriptor->Data [9])
                    {
                      case 2 :
                        Descriptor.SetEncrypt (nsWireless::TKIP);
                      break;

                      case 3 : // AES-WRAP
                      case 4 : // AES-CCMP
                        Descriptor.SetEncrypt (nsWireless::AES);
                      //break;
                    }
                    if (m_NetworkInfrastructureType)
                        Descriptor.SetAuth (nsWireless::WPAPSK);
                    else
                        Descriptor.SetAuth (nsWireless::WPANONE);
                }
                else Descriptor.SetEncrypt (nsWireless::WEP);
                    /* TODO We don't know if it's "shared" or not.
                     *      The standard GetEncryptionD() could be called to
                     *      know. Dig more in Ralink stuffs? */
            }
            return Descriptor;

        } // GetEncryptionD()

        std::string GetSSID () const throw()
        {
            return std::string (reinterpret_cast<const char*> (m_SSID.Text), 0,
                                m_SSID.Length);
        
        } // GetSSID()

        double GetFrequency () const throw()
        {
            return m_Config.Frequency / 1000000.0;

        } // GetFrequency()

        nsWireless::CQuality GetQuality () const throw()
        {
            return nsWireless::CQuality (0, m_SignalStrength);

        } // GetQuality()

        unsigned GetTxRate () const throw()
        {
            return 0;

        } // GetTxRate()

    }; // CScanResult


    template<unsigned MaxSize>
    class CScanResultsTab
    {
      public :
        CScanResultsTab () throw() : m_Size (MaxSize)
        {
            memset (m_Data, 0, sizeof (CScanResult) * MaxSize);

        } // CScanResulTab()

        const CScanResult* Get () const throw()
        {
            return m_Data;

        } // Get()

        unsigned Size () const throw() {return m_Size;}


      private :
        unsigned long m_Size;
        CScanResult m_Data [MaxSize];

    }; // CScanResultsTab


    struct SPrivateConfig
    {
        SPrivateConfig () throw() {memset (this, 0, sizeof (SPrivateConfig));}

        unsigned long EnableTxBurst;
        unsigned long EnableTurboRate;
            // Ralink : 0-AUTO, 1-always ON, 2-always OFF.
        unsigned long UseBGProtection;
            // Ralink : 0-no use, 1-use 9-use short slot time when applicable.
        unsigned long UseShortSlotTime;
        unsigned long UseOfdmRatesIn11gAdhoc;


      private :
        unsigned long m1;
        unsigned long m2;
        unsigned long m3;

    }; // SPrivateConfig


    enum DriverWirelessMode_e {BGMixed, B, A, ABGMixed};

    enum DriverEncrypt_e {WEP, TKIP = 4, AES = 6};

        // Not sure about WPANone...
    enum DriverAuth_e {Open, Shared, WPA = 3, WPAPSK, WPANone};

} // anonymous namespace



void nsWireless::CRTDriver::GetScanResult (int ListOid,
                                           std::vector<CCell>& CellVec) const
                                throw (nsErrors::CException, std::bad_alloc)
{
    CScanResultsTab<16> Results;
    m_Data.data.pointer = reinterpret_cast< ::caddr_t> (&Results);
    m_Data.data.length = sizeof (Results);
    m_Data.data.flags = ListOid;
    Ioctl (m_PrivateIoctl,
           "Can't get scanning results through special ioctl.");

    const CScanResult* pResult (Results.Get());
    for (unsigned i (0) ; i < Results.Size() ; ++i)
    {
        unsigned Channel (0);
        try
        {
            Channel = GetMatchingFreq (0,
                                       pResult->GetFrequency()).GetChannel();
        }
        catch (const nsErrors::CException& Exc)
        {
            if (Exc.GetCode() != nsErrors::NoMatchingFreq) throw;
        }
        CellVec.push_back (CCell (pResult->GetAPMacAddress(),
                                  pResult->GetMode(), pResult->GetSSID(),
                                  pResult->GetEncryptionD(), Channel,
                                  pResult->GetQuality(),
                                  pResult->GetTxRate()));
        pResult = reinterpret_cast<const CScanResult*>
                                    (reinterpret_cast<const char*> (pResult) +
                                                        pResult->GetLength());
    }

} // GetScanResult()


nsWireless::CRT2400Driver::CRT2400Driver (const std::string& DeviceName)
                                                throw (nsErrors::CException)
    : CRTDriver (DeviceName, SIOCDEVPRIVATE) {}


void nsWireless::CRT2400Driver::Scan () throw (nsErrors::CSystemExc)
{
    m_Data.data.pointer = 0;
    m_Data.data.length = 0;
    m_Data.data.flags = 8002;
    Ioctl (SIOCDEVPRIVATE, "Can't trigger scan through special ioctl.");

} // Scan()


void nsWireless::CRT2500Driver::GetScanResult (std::vector<CCell>& CellVec)
                            const throw (nsErrors::CException, std::bad_alloc)
{
    const int ScanResultMagicNumber (0x116); // Another Ralink magic number.
    try{CWE17Driver::GetScanResult (CellVec);}
    catch (...)
    {
        CRTDriver::GetScanResult (ScanResultMagicNumber, CellVec);
        return;
    }
    try
    {
        std::vector<CCell> RalinkCellVec;
        CRTDriver::GetScanResult (ScanResultMagicNumber, RalinkCellVec);
        for (unsigned i (0) ; i < RalinkCellVec.size() ; ++i)
            if (RalinkCellVec [i].GetEncryptionD().GetAuth() == WPAPSK ||
                RalinkCellVec [i].GetEncryptionD().GetAuth() == WPANONE)
                    // FIXME Are both results always in the same order ?
                CellVec [i].GetEncryptionD() =
                                            RalinkCellVec [i].GetEncryptionD();
    }
    catch (...) {} // We have some results anyway.

} // GetScanResult()


nsWireless::CEncryptionD nsWireless::CRT2500Driver::GetEncryption ()
                                throw (nsErrors::CSystemExc, std::bad_alloc)
{
    CEncryptionD Descriptor (CWE17Driver::GetEncryption());
    if (Descriptor.GetEncrypt() != None)
    {
            // FIXME is this code really working?
        DriverEncrypt_e Encrypt (::WEP);
        m_Data.data.pointer = reinterpret_cast< ::caddr_t> (Encrypt);
        m_Data.data.length = sizeof (Encrypt);
        m_Data.data.flags = 0x11A; // Magic...
        Ioctl (GetPrivateIoctl(), "Can't get cipher through special ioctl.");
        DriverAuth_e Auth (::Open);
        m_Data.data.pointer = reinterpret_cast< ::caddr_t> (Auth);
        m_Data.data.length = sizeof (Auth);
        m_Data.data.flags = 0x117; // Magic...
        Ioctl (GetPrivateIoctl(),
               "Can't get authentication type through special ioctl.");
        switch (Encrypt)
        {
          case ::TKIP : Descriptor.SetEncrypt (nsWireless::TKIP);
          break;

          case ::AES : Descriptor.SetEncrypt (nsWireless::AES);
          //break;

          default : ; // To avoid warnings.
        }
        switch (Auth)
        {
          case ::WPANone :
            Descriptor.SetAuth (nsWireless::WPANONE);
          break;

          case ::WPAPSK :
            Descriptor.SetAuth (nsWireless::WPAPSK);
          //break;

          default : ; // To avoid warnings.
        }
    }
    return Descriptor;

} // GetEncryption()


bool nsWireless::CRT2500Driver::GetRfmontxFromDriver () const
                                                throw (nsErrors::CSystemExc)
{
        /* Get Rfmontx state :
           We set a buffer where the returned data can be written. Length is
           set to 0 because we don't want to change anything through
           this call. */
    char Buffer ('0');
    m_Data.data.pointer = reinterpret_cast< ::caddr_t> (&Buffer);
    m_Data.data.length = 0; // sizeof (Buffer) => EINVAL
    m_Data.data.flags = 0;
    Ioctl (m_RfmontxIoctl, "Can't get injection mode through private ioctl.");
    return Buffer == '1';
 
} // GetRfmontx()


void nsWireless::CRT2500Driver::SetRfmontx (bool B)
                                                throw (nsErrors::CSystemExc)
{
    char Buffer (B ? 1 : 0);
    m_Data.data.pointer = reinterpret_cast< ::caddr_t> (&Buffer);
    m_Data.data.length = sizeof (Buffer);
    m_Data.data.flags = 0;
    Ioctl (m_RfmontxIoctl, "Can't set injection mode.");
    m_Flags |= Rfmontx;

} // SetRfmontx()


void nsWireless::CRT2500Driver::SetEncryption (const CEncryptionD& Descriptor)
                                throw (nsErrors::CSystemExc, std::bad_alloc)
{
    if (Descriptor.GetAuth() == Open || Descriptor.GetAuth() == Shared)
        CWE17Driver::SetEncryption (Descriptor);
    else
    {
        const char* const AuthName
                                (GetAuthName (Descriptor.GetAuth()).c_str());
        SetIoctl ("AuthMode", AuthName, "Can't set authentication mode : ");
        SetIoctl ("EncrypType", GetEncryptName (Descriptor.GetEncrypt()),
                  "Can't set encryption type : ");
        SetIoctl (AuthName, std::string (Descriptor.GetKey().Get(), 0,
                                         Descriptor.GetKey().Size()),
                  "Can't set key : ");
    }

} // SetEncryption()


nsWireless::CRT2500Driver::CRT2500Driver (const std::string& DeviceName)
                                                throw (nsErrors::CException)
    : CRTDriver (DeviceName, SIOCIWFIRSTPRIV + 1),
      m_SetIoctl (InvalidIoctl), m_RfmontxIoctl (InvalidIoctl), m_Flags (0),
      m_BGProtection (AutoProtection), m_TxPreamble (AutoPreamble)
{
    const unsigned TabArgsSize (16);
    ::iw_priv_args TabArgs [TabArgsSize];
    for (unsigned i (GetPrivateIoctls (TabArgs, TabArgsSize)) ; i ; )
        if (!strcmp (TabArgs [--i].name, "set"))
            m_SetIoctl = TabArgs [i].cmd;
        else if (!strcmp (TabArgs [i].name, "rfmontx"))
            m_RfmontxIoctl = TabArgs [i].cmd;

    if (m_SetIoctl == InvalidIoctl)
        throw nsErrors::CException ("Can't find \"set\" private ioctl.",
                                    nsErrors::RT2500SetIoctlNotFound);
        // Get STA config :
    SPrivateConfig PrivateConfig;
    m_Data.data.pointer = reinterpret_cast< ::caddr_t> (&PrivateConfig);
    m_Data.data.length = sizeof (PrivateConfig);
    m_Data.data.flags = 0x217; // Magic Ralink number : STAConfigOid.
    Ioctl (GetPrivateIoctl(), "Can't get STA config through special ioctl.");
    if (PrivateConfig.EnableTxBurst) m_Flags |= TxBurst;
    if (PrivateConfig.EnableTurboRate) m_Flags |= TurboRate;
    m_BGProtection = BGProtection_e (PrivateConfig.UseBGProtection);
    if (!PrivateConfig.UseShortSlotTime)
        m_TxPreamble = Long;
    else if (PrivateConfig.UseShortSlotTime == 1)
        m_TxPreamble = Short;
    // else it's AutoPreamble, the default value for m_TxPreamble.
        /* FIXME Set the preamble doesn't work, but it doesn't work with
         *       RaConfig too. Simply remove it? */
    if (PrivateConfig.UseOfdmRatesIn11gAdhoc) m_Flags |= AdHocOFDM;

        // Get wireless mode :
    DriverWirelessMode_e Mode;
    m_Data.data.pointer = reinterpret_cast< ::caddr_t> (&Mode);
    m_Data.data.length = sizeof (Mode);
    m_Data.data.flags = 0x212;// Magic Ralink number : PhyModeOid.
    Ioctl (GetPrivateIoctl(),
           "Cant get the wireless mode through special ioctl.");
    if (Mode == B) m_Flags |= IsB_Only;

    if (IsRfmontxSupported() && GetRfmontxFromDriver())
        m_Flags |= Rfmontx;

} // CRT2500Driver()


void nsWireless::CRT2500Driver::SetIoctl (const char* Command,
                                          const std::string& Parameter,
                                          const std::string& ErrorMsg) const
                                                throw (nsErrors::CSystemExc)
{
    const unsigned CstBufferSize (256); // Should be enough.
    char Buffer [CstBufferSize + 2]; // '=' and '\0'.
    m_Data.data.pointer = reinterpret_cast< ::caddr_t> (Buffer);
    m_Data.data.length = 0;
    m_Data.data.flags = 0;
 
    char* BufferIter (Buffer);
    while (m_Data.data.length < CstBufferSize && *Command)
    {
        *BufferIter++ = *Command++;
        ++m_Data.data.length;
    }
    *BufferIter = '=';
    ++m_Data.data.length;
    std::string::const_iterator ParameterIter (Parameter.begin());
    while (m_Data.data.length < CstBufferSize &&
           ParameterIter != Parameter.end())
    {
        *++BufferIter = *ParameterIter++;
        ++m_Data.data.length;
    }
    *++BufferIter = '\0';
    ++m_Data.data.length;

    Ioctl (m_SetIoctl, ErrorMsg + Parameter);

} // SetIoctl()


void nsWireless::CRT2500Driver::SetIoctl (const char* Command, int Value,
                                          const std::string& ErrorMsg) const
                                                throw (nsErrors::CSystemExc)
{
    std::ostringstream Os;
    Os << Value;
    SetIoctl (Command, Os.str().c_str(), ErrorMsg);

} // SetIoctl()


    // STAConfigOid : 0x50D     PhyModeOid : 0x50C
nsWireless::CRT2570Driver::CRT2570Driver (const std::string& DeviceName)
                                                throw (nsErrors::CException)
    : CWE17Driver (DeviceName), m_AuthIoctl (InvalidIoctl),
      m_EncIoctl (InvalidIoctl), m_KeyIoctl (InvalidIoctl),
      m_AdHocModeIoctl (InvalidIoctl), m_PrismHeaderIoctl (InvalidIoctl)
{
    const unsigned TabArgsSize (16);
    ::iw_priv_args TabArgs [TabArgsSize];

    for (unsigned i (GetPrivateIoctls (TabArgs, TabArgsSize)) ; i ; )
        if (!strcmp (TabArgs [--i].name, "auth"))
            m_AuthIoctl = TabArgs [i].cmd;
        else if (!strcmp (TabArgs [i].name, "enc"))
            m_EncIoctl = TabArgs [i].cmd;
        else if (!strcmp (TabArgs [i].name, "wpapsk"))
            m_KeyIoctl = TabArgs [i].cmd;
        else if (!strcmp (TabArgs [i].name, "adhocmode"))
            m_AdHocModeIoctl = TabArgs [i].cmd;
        else if (!strcmp (TabArgs [i].name, "rfmontx"))
            m_RfmontxIoctl = TabArgs [i].cmd;
        else if (!strcmp (TabArgs [i].name, "forceprismheader"))
            m_PrismHeaderIoctl = TabArgs [i].cmd;

    if (m_AdHocModeIoctl == InvalidIoctl)
        throw nsErrors::CException ("Can't find \"adhocmode\" private ioctl.",
                                    nsErrors::RT2570AdHocModeIoctlNotFound);
    if (m_EncIoctl == InvalidIoctl)
        throw nsErrors::CException ("Can't find \"enc\" private ioctl.",
                                    nsErrors::RT2570EncIoctlNotFound);
    if (m_AuthIoctl == InvalidIoctl)
        throw nsErrors::CException ("Can't find \"auth\" private ioctl.",
                                    nsErrors::RT2570AuthIoctlNotFound);
    if (m_KeyIoctl == InvalidIoctl)
        throw nsErrors::CException ("Can't find \"key\" private ioctl.",
                                    nsErrors::RT2570KeyIoctlNotFound);

} // CRT2570Driver()


void nsWireless::CRT2570Driver::SetEncryption (const CEncryptionD& Descriptor)
                                throw (nsErrors::CSystemExc, std::bad_alloc)
{
    if (Descriptor.GetAuth() == Open || Descriptor.GetAuth() == Shared)
        CWE17Driver::SetEncryption (Descriptor);
    else
    {
        /* TODO Check if there's an order that works better, this is the
                same order as rt2500 currently. */
            // WPAPSK : 2       WPANONE : 4
        // FIXME Check if WPAPSK == 2 and WPANONE == 4.
        PrivateIoctl (m_AuthIoctl, Descriptor.GetAuth() == WPAPSK ? 2 : 4,
                      "Can't set wpa authentication.");
        PrivateIoctl (m_EncIoctl, 3, "Can't set tkip encryption.");
        PrivateIoctl (m_KeyIoctl, Descriptor.GetKey().GetStr(),
                      "Can't set wpa key.");
    }

} // SetEncryption()


void nsWireless::CRT2570Driver::PrivateIoctl (int IoctlCode,
                        const std::string& Value, const std::string& ErrorMsg)
                                                throw (nsErrors::CSystemExc)
{
    m_Data.data.pointer = reinterpret_cast< ::caddr_t>
                                        (const_cast<char*> (Value.c_str()));
    m_Data.data.length = Value.size();
    m_Data.data.flags = 0;
    Ioctl (IoctlCode, ErrorMsg);

} // PrivateIoctl()


void nsWireless::CRT2570Driver::GetSupportedRates (std::vector<int>& RatesVec)
                                                                const throw()
{
    CWE17Driver::GetSupportedRates (RatesVec);
    if (RatesVec.empty())
    {   // The driver doesn't report values but support rate changing :
        RatesVec.push_back (54000);
        RatesVec.push_back (48000);
        RatesVec.push_back (36000);
        RatesVec.push_back (24000);
        RatesVec.push_back (18000);
        RatesVec.push_back (11000);
        RatesVec.push_back (2000);
        RatesVec.push_back (1000);
    }

} // GetSupportedRates()
