# -*- coding: utf-8 -*-

# Author: Natalia B. Bidart <natalia.bidart@canonical.com>
#
# Copyright 2011 Canonical Ltd.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 3, as published
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranties of
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.

"""Tests for the oauth_headers helper function."""

from twisted.trial.unittest import TestCase

from ubuntu_sso.utils import oauth, oauth_headers
from ubuntu_sso.tests import TOKEN


class FakedOAuthRequest(object):
    """Replace the OAuthRequest class."""

    params = {}

    def __init__(self):
        self.sign_request = lambda *args, **kwargs: None
        self.to_header = lambda *args, **kwargs: {}

    def from_consumer_and_token(oauth_consumer, **kwargs):
        """Fake the method storing the params for check."""
        FakedOAuthRequest.params.update(kwargs)
        return FakedOAuthRequest()
    from_consumer_and_token = staticmethod(from_consumer_and_token)


class SignWithCredentialsTestCase(TestCase):
    """Test suite for the oauth_headers method."""

    url = u'http://example.com'

    def build_header(self, url, http_method='GET'):
        """Build an Oauth header for comparison."""
        consumer = oauth.OAuthConsumer(TOKEN['consumer_key'],
                                       TOKEN['consumer_secret'])
        token = oauth.OAuthToken(TOKEN['token'],
                                 TOKEN['token_secret'])
        get_request = oauth.OAuthRequest.from_consumer_and_token
        oauth_req = get_request(oauth_consumer=consumer, token=token,
                                http_method=http_method, http_url=url)
        oauth_req.sign_request(oauth.OAuthSignatureMethod_HMAC_SHA1(),
                               consumer, token)
        return oauth_req.to_header()

    def dictify_header(self, header):
        """Convert an OAuth header into a dict."""
        result = {}
        fields = header.split(', ')
        for field in fields:
            key, value = field.split('=')
            result[key] = value.strip('"')

        return result

    def assert_header_equal(self, expected, actual):
        """Is 'expected' equals to 'actual'?"""
        expected = self.dictify_header(expected['Authorization'])
        actual = self.dictify_header(actual['Authorization'])
        for header in (expected, actual):
            header.pop('oauth_nonce')
            header.pop('oauth_timestamp')
            header.pop('oauth_signature')

        self.assertEqual(expected, actual)

    def assert_method_called(self, path, query_str='', http_method='GET'):
        """Assert that the url build by joining 'paths' was called."""
        expected = (self.url, path, query_str)
        expected = ''.join(expected).encode('utf8')
        expected = self.build_header(expected, http_method=http_method)
        actual = oauth_headers(url=self.url + path, credentials=TOKEN)
        self.assert_header_equal(expected, actual)

    def test_call(self):
        """Calling 'get' triggers an OAuth signed GET request."""
        path = u'/test/'
        self.assert_method_called(path)

    def test_quotes_path(self):
        """Calling 'get' quotes the path."""
        path = u'/test me more, sí!/'
        self.assert_method_called(path)

    def test_adds_parameters_to_oauth_request(self):
        """The query string from the path is used in the oauth request."""
        self.patch(oauth, 'OAuthRequest', FakedOAuthRequest)

        path = u'/test/something?foo=bar'
        oauth_headers(url=self.url + path, credentials=TOKEN)

        self.assertIn('parameters', FakedOAuthRequest.params)
        self.assertEqual(FakedOAuthRequest.params['parameters'],
                         {'foo': 'bar'})
