# Copyright 2014 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for oauth2_client and related classes.""" from __future__ import absolute_import import datetime import httplib2 import logging import mox import os import stat import sys import unittest from freezegun import freeze_time from gcs_oauth2_boto_plugin import oauth2_client LOG = logging.getLogger('test_oauth2_client') ACCESS_TOKEN = 'abc123' TOKEN_URI = 'https://provider.example.com/oauth/provider?mode=token' AUTH_URI = 'https://provider.example.com/oauth/provider?mode=authorize' DEFAULT_CA_CERTS_FILE = os.path.abspath( os.path.join('gslib', 'data', 'cacerts.txt')) IS_WINDOWS = 'win32' in str(sys.platform).lower() class MockDateTime(object): def __init__(self): self.mock_now = None def utcnow(self): # pylint: disable=invalid-name return self.mock_now class MockOAuth2ServiceAccountClient(oauth2_client.OAuth2ServiceAccountClient): """Mock service account client for testing OAuth2 with service accounts.""" def __init__(self, client_id, private_key, password, auth_uri, token_uri, datetime_strategy): super(MockOAuth2ServiceAccountClient, self).__init__( client_id, private_key, password, auth_uri=auth_uri, token_uri=token_uri, datetime_strategy=datetime_strategy, ca_certs_file=DEFAULT_CA_CERTS_FILE) self.Reset() def Reset(self): self.fetched_token = False def FetchAccessToken(self): self.fetched_token = True return oauth2_client.AccessToken( ACCESS_TOKEN, GetExpiry(self.datetime_strategy, 3600), datetime_strategy=self.datetime_strategy) class MockOAuth2UserAccountClient(oauth2_client.OAuth2UserAccountClient): """Mock user account client for testing OAuth2 with user accounts.""" def __init__(self, token_uri, client_id, client_secret, refresh_token, auth_uri, datetime_strategy): super(MockOAuth2UserAccountClient, self).__init__( token_uri, client_id, client_secret, refresh_token, auth_uri=auth_uri, datetime_strategy=datetime_strategy, ca_certs_file=DEFAULT_CA_CERTS_FILE) self.Reset() def Reset(self): self.fetched_token = False def FetchAccessToken(self): self.fetched_token = True return oauth2_client.AccessToken( ACCESS_TOKEN, GetExpiry(self.datetime_strategy, 3600), datetime_strategy=self.datetime_strategy) def GetExpiry(datetime_strategy, length_in_seconds): token_expiry = (datetime_strategy.utcnow() + datetime.timedelta(seconds=length_in_seconds)) return token_expiry def CreateMockUserAccountClient(mock_datetime): return MockOAuth2UserAccountClient( TOKEN_URI, 'clid', 'clsecret', 'ref_token_abc123', AUTH_URI, mock_datetime) def CreateMockServiceAccountClient(mock_datetime): return MockOAuth2ServiceAccountClient( 'clid', 'private_key', 'password', AUTH_URI, TOKEN_URI, mock_datetime) class OAuth2AccountClientTest(unittest.TestCase): """Unit tests for OAuth2UserAccountClient and OAuth2ServiceAccountClient.""" def setUp(self): self.tempdirs = [] self.mock_datetime = MockDateTime() self.start_time = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) self.mock_datetime.mock_now = self.start_time def testGetAccessTokenUserAccount(self): self.client = CreateMockUserAccountClient(self.mock_datetime) self._RunGetAccessTokenTest() def testGetAccessTokenServiceAccount(self): self.client = CreateMockServiceAccountClient(self.mock_datetime) self._RunGetAccessTokenTest() def _RunGetAccessTokenTest(self): """Tests access token gets with self.client.""" access_token_1 = 'abc123' self.assertFalse(self.client.fetched_token) token_1 = self.client.GetAccessToken() # There's no access token in the cache; verify that we fetched a fresh # token. self.assertTrue(self.client.fetched_token) self.assertEquals(access_token_1, token_1.token) self.assertEquals(self.start_time + datetime.timedelta(minutes=60), token_1.expiry) # Advance time by less than expiry time, and fetch another token. self.client.Reset() self.mock_datetime.mock_now = ( self.start_time + datetime.timedelta(minutes=55)) token_2 = self.client.GetAccessToken() # Since the access token wasn't expired, we get the cache token, and there # was no refresh request. self.assertEquals(token_1, token_2) self.assertEquals(access_token_1, token_2.token) self.assertFalse(self.client.fetched_token) # Advance time past expiry time, and fetch another token. self.client.Reset() self.mock_datetime.mock_now = ( self.start_time + datetime.timedelta(minutes=55, seconds=1)) self.client.datetime_strategy = self.mock_datetime token_3 = self.client.GetAccessToken() # This should have resulted in a refresh request and a fresh access token. self.assertTrue(self.client.fetched_token) self.assertEquals( self.mock_datetime.mock_now + datetime.timedelta(minutes=60), token_3.expiry) class AccessTokenTest(unittest.TestCase): """Unit tests for access token functions.""" def testShouldRefresh(self): """Tests that token.ShouldRefresh returns the correct value.""" mock_datetime = MockDateTime() start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) expiry = start + datetime.timedelta(minutes=60) token = oauth2_client.AccessToken( 'foo', expiry, datetime_strategy=mock_datetime) mock_datetime.mock_now = start self.assertFalse(token.ShouldRefresh()) mock_datetime.mock_now = start + datetime.timedelta(minutes=54) self.assertFalse(token.ShouldRefresh()) mock_datetime.mock_now = start + datetime.timedelta(minutes=55) self.assertFalse(token.ShouldRefresh()) mock_datetime.mock_now = start + datetime.timedelta( minutes=55, seconds=1) self.assertTrue(token.ShouldRefresh()) mock_datetime.mock_now = start + datetime.timedelta( minutes=61) self.assertTrue(token.ShouldRefresh()) mock_datetime.mock_now = start + datetime.timedelta(minutes=58) self.assertFalse(token.ShouldRefresh(time_delta=120)) mock_datetime.mock_now = start + datetime.timedelta( minutes=58, seconds=1) self.assertTrue(token.ShouldRefresh(time_delta=120)) def testShouldRefreshNoExpiry(self): """Tests token.ShouldRefresh with no expiry time.""" mock_datetime = MockDateTime() start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) token = oauth2_client.AccessToken( 'foo', None, datetime_strategy=mock_datetime) mock_datetime.mock_now = start self.assertFalse(token.ShouldRefresh()) mock_datetime.mock_now = start + datetime.timedelta( minutes=472) self.assertFalse(token.ShouldRefresh()) def testSerialization(self): """Tests token serialization.""" expiry = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) token = oauth2_client.AccessToken('foo', expiry) serialized_token = token.Serialize() LOG.debug('testSerialization: serialized_token=%s', serialized_token) token2 = oauth2_client.AccessToken.UnSerialize(serialized_token) self.assertEquals(token, token2) class FileSystemTokenCacheTest(unittest.TestCase): """Unit tests for FileSystemTokenCache.""" def setUp(self): self.cache = oauth2_client.FileSystemTokenCache() self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) self.token_1 = oauth2_client.AccessToken('token1', self.start_time) self.token_2 = oauth2_client.AccessToken( 'token2', self.start_time + datetime.timedelta(seconds=492)) self.key = 'token1key' def tearDown(self): try: os.unlink(self.cache.CacheFileName(self.key)) except: # pylint: disable=bare-except pass def testPut(self): self.cache.PutToken(self.key, self.token_1) # Assert that the cache file exists and has correct permissions. if not IS_WINDOWS: self.assertEquals( 0600, stat.S_IMODE(os.stat(self.cache.CacheFileName(self.key)).st_mode)) def testPutGet(self): """Tests putting and getting various tokens.""" # No cache file present. self.assertEquals(None, self.cache.GetToken(self.key)) # Put a token self.cache.PutToken(self.key, self.token_1) cached_token = self.cache.GetToken(self.key) self.assertEquals(self.token_1, cached_token) # Put a different token self.cache.PutToken(self.key, self.token_2) cached_token = self.cache.GetToken(self.key) self.assertEquals(self.token_2, cached_token) def testGetBadFile(self): f = open(self.cache.CacheFileName(self.key), 'w') f.write('blah') f.close() self.assertEquals(None, self.cache.GetToken(self.key)) def testCacheFileName(self): """Tests configuring the cache with a specific file name.""" cache = oauth2_client.FileSystemTokenCache( path_pattern='/var/run/ccache/token.%(uid)s.%(key)s') if IS_WINDOWS: uid = '_' else: uid = os.getuid() self.assertEquals('/var/run/ccache/token.%s.abc123' % uid, cache.CacheFileName('abc123')) cache = oauth2_client.FileSystemTokenCache( path_pattern='/var/run/ccache/token.%(key)s') self.assertEquals('/var/run/ccache/token.abc123', cache.CacheFileName('abc123')) class RefreshTokenTest(unittest.TestCase): """Unit tests for refresh tokens.""" def setUp(self): self.mock_datetime = MockDateTime() self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) self.mock_datetime.mock_now = self.start_time self.client = CreateMockUserAccountClient(self.mock_datetime) def testUniqeId(self): cred_id = self.client.CacheKey() self.assertEquals('0720afed6871f12761fbea3271f451e6ba184bf5', cred_id) def testGetAuthorizationHeader(self): self.assertEquals('Bearer %s' % ACCESS_TOKEN, self.client.GetAuthorizationHeader()) class FakeResponse: def __init__(self, status): self._status = status @property def status(self): return self._status class OAuth2GCEClientTest(unittest.TestCase): """Unit tests for OAuth2GCEClient.""" def setUp(self): self.mox = mox.Mox() self.mox.StubOutClassWithMocks(httplib2, 'Http') self.mock_http = httplib2.Http() def tearDown(self): self.mox.UnsetStubs() @freeze_time('2014-03-26 01:01:01') def testFetchAccessToken(self): token = 'my_token' self.mock_http.request( oauth2_client.META_TOKEN_URI, method='GET', body=None, headers=oauth2_client.META_HEADERS).AndReturn(( FakeResponse(200), '{"access_token":"%(TOKEN)s",' '"expires_in": %(EXPIRES_IN)d}' % { 'TOKEN': token, 'EXPIRES_IN': 42 })) self.mox.ReplayAll() client = oauth2_client.OAuth2GCEClient() self.assertEqual( str(client.FetchAccessToken()), 'AccessToken(token=%s, expiry=2014-03-26 01:01:43Z)' % token) self.mox.VerifyAll() def testIsGCENotFound(self): self.mock_http.request(oauth2_client.METADATA_SERVER).AndReturn(( FakeResponse(404), '')) self.mox.ReplayAll() self.assertFalse(oauth2_client._IsGCE()) self.mox.VerifyAll() def testIsGCEServerNotFound(self): self.mock_http.request(oauth2_client.METADATA_SERVER).AndRaise( httplib2.ServerNotFoundError) self.mox.ReplayAll() self.assertFalse(oauth2_client._IsGCE()) self.mox.VerifyAll() def testIsGCETrue(self): self.mock_http.request(oauth2_client.METADATA_SERVER).AndReturn(( FakeResponse(200), '')) self.mox.ReplayAll() self.assertTrue(oauth2_client._IsGCE()) self.mox.VerifyAll() if __name__ == '__main__': unittest.main()