""" open/DurusWorks/durus/client_storage.py """ from durus.error import DurusKeyError, ProtocolError from durus.error import ReadConflictError, ConflictError, WriteConflictError from durus.serialize import split_durus_ids from durus.storage import Storage from durus.storage_server import DEFAULT_PORT, DEFAULT_HOST from durus.storage_server import SocketAddress, StorageServer from durus.storage_server import STATUS_OKAY, STATUS_KEYERROR, STATUS_INVALID from durus.utils import int4_to_str, read, write, join_bytes, write_all from durus.utils import read_int4, write_int4, write_int4_str, iteritems from durus.utils import as_bytes class ClientStorage (Storage): def __init__(self, host=DEFAULT_HOST, port=DEFAULT_PORT, address=None): self.address = SocketAddress.new(address or (host, port)) self.s = self.address.get_connected_socket() assert self.s, "Could not connect to %s" % self.address self.durus_id_pool = [] self.durus_id_pool_size = 32 self.begin() protocol = StorageServer.protocol assert len(protocol) == 4 write_all(self.s, 'V', protocol) server_protocol = read(self.s, 4) if server_protocol != protocol: raise ProtocolError("Protocol version mismatch.") def __str__(self): return "ClientStorage(%s)" % self.address def new_durus_id(self): if not self.durus_id_pool: batch = self.durus_id_pool_size write(self.s, 'M%s' % chr(batch)) self.durus_id_pool = split_durus_ids(read(self.s, 8 * batch)) self.durus_id_pool.reverse() assert len(self.durus_id_pool) == len(set(self.durus_id_pool)) durus_id = self.durus_id_pool.pop() assert durus_id not in self.durus_id_pool self.transaction_new_durus_ids.append(durus_id) return durus_id def load(self, durus_id): write_all(self.s, 'L', durus_id) return self._get_load_response(durus_id) def _get_load_response(self, durus_id): status = read(self.s, 1) if status == STATUS_OKAY: pass elif status == STATUS_INVALID: raise ReadConflictError([durus_id]) elif status == STATUS_KEYERROR: raise DurusKeyError(durus_id) else: raise ProtocolError('status=%r, durus_id=%r' % (status, durus_id)) n = read_int4(self.s) record = read(self.s, n) return record def begin(self): self.records = {} self.transaction_new_durus_ids = [] def store(self, durus_id, record): assert len(durus_id) == 8 assert durus_id not in self.records self.records[durus_id] = record def end(self, handle_invalidations=None): write(self.s, 'C') n = read_int4(self.s) durus_id_list = [] if n != 0: packed_durus_ids = read(self.s, n*8) durus_id_list = split_durus_ids(packed_durus_ids) try: handle_invalidations(durus_id_list) except ConflictError: self.transaction_new_durus_ids.reverse() self.durus_id_pool.extend(self.transaction_new_durus_ids) assert len(self.durus_id_pool) == len(set(self.durus_id_pool)) self.begin() # clear out records and transaction_new_durus_ids. write_int4(self.s, 0) # Tell server we are done. raise tdata = [] for durus_id, record in iteritems(self.records): tdata.append(int4_to_str(8 + len(record))) tdata.append(as_bytes(durus_id)) tdata.append(record) tdata = join_bytes(tdata) write_int4_str(self.s, tdata) self.records.clear() if len(tdata) > 0: status = read(self.s, 1) if status == STATUS_OKAY: pass elif status == STATUS_INVALID: raise WriteConflictError() else: raise ProtocolError('server returned invalid status %r' % status) def sync(self): write(self.s, 'S') n = read_int4(self.s) if n == 0: packed_durus_ids = '' else: packed_durus_ids = read(self.s, n*8) return split_durus_ids(packed_durus_ids) def pack(self): write(self.s, 'P') status = read(self.s, 1) if status != STATUS_OKAY: raise ProtocolError('server returned invalid status %r' % status) def bulk_load(self, durus_ids): durus_id_str = join_bytes(durus_ids) num_durus_ids, remainder = divmod(len(durus_id_str), 8) assert remainder == 0, remainder write_all(self.s, 'B', int4_to_str(num_durus_ids), durus_id_str) records = [self._get_load_response(durus_id) for durus_id in durus_ids] for record in records: yield record def close(self): write(self.s, '.') # Closes the server side. self.s.close()