# Copyright 2010 Google Inc. # # Permission is hereby granted, free of charge, to any person obtaining a # copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, dis- # tribute, sublicense, and/or sell copies of the Software, and to permit # persons to whom the Software is furnished to do so, subject to the fol- # lowing conditions: # # The above copyright notice and this permission notice shall be included # in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- # ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT # SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. """ Provides basic mocks of core storage service classes, for unit testing: ACL, Key, Bucket, Connection, and StorageUri. We implement a subset of the interfaces defined in the real boto classes, but don't handle most of the optional params (which we indicate with the constant "NOT_IMPL"). """ import copy import boto import base64 import re from hashlib import md5 from boto.utils import compute_md5 from boto.utils import find_matching_headers from boto.utils import merge_headers_by_name from boto.s3.prefix import Prefix from boto.compat import six NOT_IMPL = None class MockAcl(object): def __init__(self, parent=NOT_IMPL): pass def startElement(self, name, attrs, connection): pass def endElement(self, name, value, connection): pass def to_xml(self): return '' class MockKey(object): def __init__(self, bucket=None, name=None): self.bucket = bucket self.name = name self.data = None self.etag = None self.size = None self.closed = True self.content_encoding = None self.content_language = None self.content_type = None self.last_modified = 'Wed, 06 Oct 2010 05:11:54 GMT' self.BufferSize = 8192 def __repr__(self): if self.bucket: return '' % (self.bucket.name, self.name) else: return '' % self.name def get_contents_as_string(self, headers=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, torrent=NOT_IMPL, version_id=NOT_IMPL): return self.data def get_contents_to_file(self, fp, headers=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, torrent=NOT_IMPL, version_id=NOT_IMPL, res_download_handler=NOT_IMPL): fp.write(self.data) def get_file(self, fp, headers=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, torrent=NOT_IMPL, version_id=NOT_IMPL, override_num_retries=NOT_IMPL): fp.write(self.data) def _handle_headers(self, headers): if not headers: return if find_matching_headers('Content-Encoding', headers): self.content_encoding = merge_headers_by_name('Content-Encoding', headers) if find_matching_headers('Content-Type', headers): self.content_type = merge_headers_by_name('Content-Type', headers) if find_matching_headers('Content-Language', headers): self.content_language = merge_headers_by_name('Content-Language', headers) # Simplistic partial implementation for headers: Just supports range GETs # of flavor 'Range: bytes=xyz-'. def open_read(self, headers=None, query_args=NOT_IMPL, override_num_retries=NOT_IMPL): if self.closed: self.read_pos = 0 self.closed = False if headers and 'Range' in headers: match = re.match('bytes=([0-9]+)-$', headers['Range']) if match: self.read_pos = int(match.group(1)) def close(self, fast=NOT_IMPL): self.closed = True def read(self, size=0): self.open_read() if size == 0: data = self.data[self.read_pos:] self.read_pos = self.size else: data = self.data[self.read_pos:self.read_pos+size] self.read_pos += size if not data: self.close() return data def set_contents_from_file(self, fp, headers=None, replace=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL, md5=NOT_IMPL, res_upload_handler=NOT_IMPL): self.data = fp.read() self.set_etag() self.size = len(self.data) self._handle_headers(headers) def set_contents_from_stream(self, fp, headers=None, replace=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL, reduced_redundancy=NOT_IMPL, query_args=NOT_IMPL, size=NOT_IMPL): self.data = '' chunk = fp.read(self.BufferSize) while chunk: self.data += chunk chunk = fp.read(self.BufferSize) self.set_etag() self.size = len(self.data) self._handle_headers(headers) def set_contents_from_string(self, s, headers=NOT_IMPL, replace=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL, md5=NOT_IMPL, reduced_redundancy=NOT_IMPL): self.data = copy.copy(s) self.set_etag() self.size = len(s) self._handle_headers(headers) def set_contents_from_filename(self, filename, headers=None, replace=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL, md5=NOT_IMPL, res_upload_handler=NOT_IMPL): fp = open(filename, 'rb') self.set_contents_from_file(fp, headers, replace, cb, num_cb, policy, md5, res_upload_handler) fp.close() def copy(self, dst_bucket_name, dst_key, metadata=NOT_IMPL, reduced_redundancy=NOT_IMPL, preserve_acl=NOT_IMPL): dst_bucket = self.bucket.connection.get_bucket(dst_bucket_name) return dst_bucket.copy_key(dst_key, self.bucket.name, self.name, metadata) @property def provider(self): provider = None if self.bucket and self.bucket.connection: provider = self.bucket.connection.provider return provider def set_etag(self): """ Set etag attribute by generating hex MD5 checksum on current contents of mock key. """ m = md5() if not isinstance(self.data, bytes): m.update(self.data.encode('utf-8')) else: m.update(self.data) hex_md5 = m.hexdigest() self.etag = hex_md5 def compute_md5(self, fp): """ :type fp: file :param fp: File pointer to the file to MD5 hash. The file pointer will be reset to the beginning of the file before the method returns. :rtype: tuple :return: A tuple containing the hex digest version of the MD5 hash as the first element and the base64 encoded version of the plain digest as the second element. """ tup = compute_md5(fp) # Returned values are MD5 hash, base64 encoded MD5 hash, and file size. # The internal implementation of compute_md5() needs to return the # file size but we don't want to return that value to the external # caller because it changes the class interface (i.e. it might # break some code) so we consume the third tuple value here and # return the remainder of the tuple to the caller, thereby preserving # the existing interface. self.size = tup[2] return tup[0:2] class MockBucket(object): def __init__(self, connection=None, name=None, key_class=NOT_IMPL): self.name = name self.keys = {} self.acls = {name: MockAcl()} # default object ACLs are one per bucket and not supported for keys self.def_acl = MockAcl() self.subresources = {} self.connection = connection self.logging = False def __repr__(self): return 'MockBucket: %s' % self.name def copy_key(self, new_key_name, src_bucket_name, src_key_name, metadata=NOT_IMPL, src_version_id=NOT_IMPL, storage_class=NOT_IMPL, preserve_acl=NOT_IMPL, encrypt_key=NOT_IMPL, headers=NOT_IMPL, query_args=NOT_IMPL): new_key = self.new_key(key_name=new_key_name) src_key = self.connection.get_bucket( src_bucket_name).get_key(src_key_name) new_key.data = copy.copy(src_key.data) new_key.size = len(new_key.data) return new_key def disable_logging(self): self.logging = False def enable_logging(self, target_bucket_prefix): self.logging = True def get_logging_config(self): return {"Logging": {}} def get_versioning_status(self, headers=NOT_IMPL): return False def get_acl(self, key_name='', headers=NOT_IMPL, version_id=NOT_IMPL): if key_name: # Return ACL for the key. return self.acls[key_name] else: # Return ACL for the bucket. return self.acls[self.name] def get_def_acl(self, key_name=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): # Return default ACL for the bucket. return self.def_acl def get_subresource(self, subresource, key_name=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): if subresource in self.subresources: return self.subresources[subresource] else: return '' def new_key(self, key_name=None): mock_key = MockKey(self, key_name) self.keys[key_name] = mock_key self.acls[key_name] = MockAcl() return mock_key def delete_key(self, key_name, headers=NOT_IMPL, version_id=NOT_IMPL, mfa_token=NOT_IMPL): if key_name not in self.keys: raise boto.exception.StorageResponseError(404, 'Not Found') del self.keys[key_name] def get_all_keys(self, headers=NOT_IMPL): return six.itervalues(self.keys) def get_key(self, key_name, headers=NOT_IMPL, version_id=NOT_IMPL): # Emulate behavior of boto when get_key called with non-existent key. if key_name not in self.keys: return None return self.keys[key_name] def list(self, prefix='', delimiter='', marker=NOT_IMPL, headers=NOT_IMPL): prefix = prefix or '' # Turn None into '' for prefix match. # Return list instead of using a generator so we don't get # 'dictionary changed size during iteration' error when performing # deletions while iterating (e.g., during test cleanup). result = [] key_name_set = set() for k in six.itervalues(self.keys): if k.name.startswith(prefix): k_name_past_prefix = k.name[len(prefix):] if delimiter: pos = k_name_past_prefix.find(delimiter) else: pos = -1 if (pos != -1): key_or_prefix = Prefix( bucket=self, name=k.name[:len(prefix)+pos+1]) else: key_or_prefix = MockKey(bucket=self, name=k.name) if key_or_prefix.name not in key_name_set: key_name_set.add(key_or_prefix.name) result.append(key_or_prefix) return result def set_acl(self, acl_or_str, key_name='', headers=NOT_IMPL, version_id=NOT_IMPL): # We only handle setting ACL XML here; if you pass a canned ACL # the get_acl call will just return that string name. if key_name: # Set ACL for the key. self.acls[key_name] = MockAcl(acl_or_str) else: # Set ACL for the bucket. self.acls[self.name] = MockAcl(acl_or_str) def set_def_acl(self, acl_or_str, key_name=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): # We only handle setting ACL XML here; if you pass a canned ACL # the get_acl call will just return that string name. # Set default ACL for the bucket. self.def_acl = acl_or_str def set_subresource(self, subresource, value, key_name=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): self.subresources[subresource] = value class MockProvider(object): def __init__(self, provider): self.provider = provider def get_provider_name(self): return self.provider class MockConnection(object): def __init__(self, aws_access_key_id=NOT_IMPL, aws_secret_access_key=NOT_IMPL, is_secure=NOT_IMPL, port=NOT_IMPL, proxy=NOT_IMPL, proxy_port=NOT_IMPL, proxy_user=NOT_IMPL, proxy_pass=NOT_IMPL, host=NOT_IMPL, debug=NOT_IMPL, https_connection_factory=NOT_IMPL, calling_format=NOT_IMPL, path=NOT_IMPL, provider='s3', bucket_class=NOT_IMPL): self.buckets = {} self.provider = MockProvider(provider) def create_bucket(self, bucket_name, headers=NOT_IMPL, location=NOT_IMPL, policy=NOT_IMPL, storage_class=NOT_IMPL): if bucket_name in self.buckets: raise boto.exception.StorageCreateError( 409, 'BucketAlreadyOwnedByYou', "Your previous request to create the named bucket " "succeeded and you already own it.") mock_bucket = MockBucket(name=bucket_name, connection=self) self.buckets[bucket_name] = mock_bucket return mock_bucket def delete_bucket(self, bucket, headers=NOT_IMPL): if bucket not in self.buckets: raise boto.exception.StorageResponseError( 404, 'NoSuchBucket', 'no such bucket') del self.buckets[bucket] def get_bucket(self, bucket_name, validate=NOT_IMPL, headers=NOT_IMPL): if bucket_name not in self.buckets: raise boto.exception.StorageResponseError(404, 'NoSuchBucket', 'Not Found') return self.buckets[bucket_name] def get_all_buckets(self, headers=NOT_IMPL): return six.itervalues(self.buckets) # We only mock a single provider/connection. mock_connection = MockConnection() class MockBucketStorageUri(object): delim = '/' def __init__(self, scheme, bucket_name=None, object_name=None, debug=NOT_IMPL, suppress_consec_slashes=NOT_IMPL, version_id=None, generation=None, is_latest=False): self.scheme = scheme self.bucket_name = bucket_name self.object_name = object_name self.suppress_consec_slashes = suppress_consec_slashes if self.bucket_name and self.object_name: self.uri = ('%s://%s/%s' % (self.scheme, self.bucket_name, self.object_name)) elif self.bucket_name: self.uri = ('%s://%s/' % (self.scheme, self.bucket_name)) else: self.uri = ('%s://' % self.scheme) self.version_id = version_id self.generation = generation and int(generation) self.is_version_specific = (bool(self.generation) or bool(self.version_id)) self.is_latest = is_latest if bucket_name and object_name: self.versionless_uri = '%s://%s/%s' % (scheme, bucket_name, object_name) def __repr__(self): """Returns string representation of URI.""" return self.uri def acl_class(self): return MockAcl def canned_acls(self): return boto.provider.Provider('aws').canned_acls def clone_replace_name(self, new_name): return self.__class__(self.scheme, self.bucket_name, new_name) def clone_replace_key(self, key): return self.__class__( key.provider.get_provider_name(), bucket_name=key.bucket.name, object_name=key.name, suppress_consec_slashes=self.suppress_consec_slashes, version_id=getattr(key, 'version_id', None), generation=getattr(key, 'generation', None), is_latest=getattr(key, 'is_latest', None)) def connect(self, access_key_id=NOT_IMPL, secret_access_key=NOT_IMPL): return mock_connection def create_bucket(self, headers=NOT_IMPL, location=NOT_IMPL, policy=NOT_IMPL, storage_class=NOT_IMPL): return self.connect().create_bucket(self.bucket_name) def delete_bucket(self, headers=NOT_IMPL): return self.connect().delete_bucket(self.bucket_name) def get_versioning_config(self, headers=NOT_IMPL): self.get_bucket().get_versioning_status(headers) def has_version(self): return (issubclass(type(self), MockBucketStorageUri) and ((self.version_id is not None) or (self.generation is not None))) def delete_key(self, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL, mfa_token=NOT_IMPL): self.get_bucket().delete_key(self.object_name) def disable_logging(self, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): self.get_bucket().disable_logging() def enable_logging(self, target_bucket, target_prefix, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): self.get_bucket().enable_logging(target_bucket) def get_logging_config(self, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): return self.get_bucket().get_logging_config() def equals(self, uri): return self.uri == uri.uri def get_acl(self, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): return self.get_bucket().get_acl(self.object_name) def get_def_acl(self, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): return self.get_bucket().get_def_acl(self.object_name) def get_subresource(self, subresource, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): return self.get_bucket().get_subresource(subresource, self.object_name) def get_all_buckets(self, headers=NOT_IMPL): return self.connect().get_all_buckets() def get_all_keys(self, validate=NOT_IMPL, headers=NOT_IMPL): return self.get_bucket().get_all_keys(self) def list_bucket(self, prefix='', delimiter='', headers=NOT_IMPL, all_versions=NOT_IMPL): return self.get_bucket().list(prefix=prefix, delimiter=delimiter) def get_bucket(self, validate=NOT_IMPL, headers=NOT_IMPL): return self.connect().get_bucket(self.bucket_name) def get_key(self, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): return self.get_bucket().get_key(self.object_name) def is_file_uri(self): return False def is_cloud_uri(self): return True def names_container(self): return bool(not self.object_name) def names_singleton(self): return bool(self.object_name) def names_directory(self): return False def names_provider(self): return bool(not self.bucket_name) def names_bucket(self): return self.names_container() def names_file(self): return False def names_object(self): return not self.names_container() def is_stream(self): return False def new_key(self, validate=NOT_IMPL, headers=NOT_IMPL): bucket = self.get_bucket() return bucket.new_key(self.object_name) def set_acl(self, acl_or_str, key_name='', validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): self.get_bucket().set_acl(acl_or_str, key_name) def set_def_acl(self, acl_or_str, key_name=NOT_IMPL, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): self.get_bucket().set_def_acl(acl_or_str) def set_subresource(self, subresource, value, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL): self.get_bucket().set_subresource(subresource, value, self.object_name) def copy_key(self, src_bucket_name, src_key_name, metadata=NOT_IMPL, src_version_id=NOT_IMPL, storage_class=NOT_IMPL, preserve_acl=NOT_IMPL, encrypt_key=NOT_IMPL, headers=NOT_IMPL, query_args=NOT_IMPL, src_generation=NOT_IMPL): dst_bucket = self.get_bucket() return dst_bucket.copy_key(new_key_name=self.object_name, src_bucket_name=src_bucket_name, src_key_name=src_key_name) def set_contents_from_string(self, s, headers=NOT_IMPL, replace=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL, md5=NOT_IMPL, reduced_redundancy=NOT_IMPL): key = self.new_key() key.set_contents_from_string(s) def set_contents_from_file(self, fp, headers=None, replace=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL, md5=NOT_IMPL, size=NOT_IMPL, rewind=NOT_IMPL, res_upload_handler=NOT_IMPL): key = self.new_key() return key.set_contents_from_file(fp, headers=headers) def set_contents_from_stream(self, fp, headers=NOT_IMPL, replace=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL, reduced_redundancy=NOT_IMPL, query_args=NOT_IMPL, size=NOT_IMPL): dst_key.set_contents_from_stream(fp) def get_contents_to_file(self, fp, headers=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, torrent=NOT_IMPL, version_id=NOT_IMPL, res_download_handler=NOT_IMPL, response_headers=NOT_IMPL): key = self.get_key() key.get_contents_to_file(fp) def get_contents_to_stream(self, fp, headers=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL, version_id=NOT_IMPL): key = self.get_key() return key.get_contents_to_file(fp)