Source code for sfepy.base.multiproc_mpi

"""
Multiprocessing functions.
"""
import logging
import os
import six

try:
    from mpi4py import MPI
    mpi_comm = MPI.COMM_WORLD
    mpi_rank = mpi_comm.Get_rank()
    mpi_status = MPI.Status()
    use_multiprocessing = mpi_comm.Get_size() > 1
except:
    use_multiprocessing = False


global_multiproc_dict = {}
mpi_master = 0


[docs] class MPILogFile(object): def __init__(self, comm, filename, mode): self._file = MPI.File.Open(comm, filename, mode) self._file.Set_atomicity(True)
[docs] def write(self, msg): try: msg = msg.encode() except AttributeError: pass self._file.Write_shared(msg)
[docs] def sync(self): self._file.Sync()
[docs] def close(self): self.sync() self._file.Close()
[docs] class MPIFileHandler(logging.StreamHandler): "MPI file class for logging process communication." def __init__(self, filename, mode=MPI.MODE_WRONLY, comm=MPI.COMM_WORLD): self.baseFilename = os.path.abspath(filename) self.mode = mode self.comm = comm super(MPIFileHandler, self).__init__(self._open()) def _open(self): stream = MPILogFile(self.comm, self.baseFilename, self.mode) return stream
[docs] def close(self): if self.stream: self.stream.close() self.stream = None
[docs] def emit(self, record): msg = self.format(record) self.stream.write('{}{}'.format(msg, self.terminator)) self.flush()
[docs] def set_logging_level(log_level='info'): if log_level == 'debug': logger.setLevel(logging.DEBUG) else: logger.setLevel(logging.INFO)
[docs] def get_logger(log_filename='multiproc_mpi.log'): """Get the MPI logger which log information into a shared file.""" open(log_filename, 'w').close() # empy log file log_id = 'master' if mpi_rank == 0 else 'slave%d' % mpi_comm.rank logger = logging.getLogger(log_id) logger.setLevel(logging.INFO) mh = MPIFileHandler(log_filename) formatter = logging.Formatter('%(asctime)s - %(levelname)s' + ' - %(name)s: %(message)s') mh.setFormatter(formatter) logger.addHandler(mh) return logger
logger = get_logger()
[docs] def enum(*sequential): enums = dict(zip(sequential, range(len(sequential)))) reverse = dict((value, key) for key, value in six.iteritems(enums)) enums['name'] = reverse return type('Enum', (), enums)
tags = enum('READY', 'DONE', 'START', 'CONTINUE', 'LOCK', 'UNLOCK', 'SET_DICT', 'SET_DICT_IMMUTABLE', 'GET_DICT', 'DICT_VAL', 'SET_DICT_STATUS', 'GET_DICT_KEYS', 'DICT_KEYS', 'GET_DICT_LEN', 'DICT_LEN', 'GET_DICT_IN', 'DICT_IN', 'GET_QUEUE', 'PUT_QUEUE', 'QUEUE_VAL')
[docs] def cpu_count(): """Get the number of MPI nodes.""" return mpi_comm.Get_size()
[docs] def get_int_value(name, init_value=0): """Get the remote integer value.""" key = '__int_%s' % name rdict = get_dict(key, mutable=True) if mpi_rank == mpi_master: val = RemoteInt(rdict, init_value) else: val = RemoteInt(rdict) return val
[docs] def get_queue(name): """Get the queue.""" key = '__queue_%s' % name if mpi_rank == mpi_master: if key not in global_multiproc_dict: queue = RemoteQueueMaster(name) global_multiproc_dict[key] = queue else: queue = global_multiproc_dict[key] queue.clean() else: queue = RemoteQueue(name) return queue
[docs] def get_dict(name, mutable=False, clear=False, soft_set=False): """Get the remote dictionary.""" if mpi_rank == mpi_master: if name in global_multiproc_dict: if clear: global_multiproc_dict[name].clear() logger.info("cleaning dict %s" % name) else: logger.info("using existing dict %s" % name) return global_multiproc_dict[name] else: rdict = RemoteDictMaster(name, mutable=mutable, soft_set=soft_set) global_multiproc_dict[name] = rdict logger.info("new dict %s" % name) return rdict else: return RemoteDict(name, mutable=mutable)
[docs] class RemoteInt(object): """Remote intiger class, data saved in RemoteDict."""
[docs] class IntDesc(object): def __get__(self, instance, owner=None): return instance.remote_dict['value'] def __set__(self, instance, val): instance.remote_dict['value'] = val
value = IntDesc() def __init__(self, remote_dict, value=None): self.remote_dict = remote_dict if value is not None: self.value = value
[docs] class RemoteQueueMaster(list): """Remote queue class - master side.""" def __init__(self, name, mode='fifo', *args): list.__init__(self, *args) self.name = name self.mode = mode
[docs] @staticmethod def get_gdict_key(name): return '__queue_%s' % name
[docs] def get(self): if len(self) == 0: return None else: if self.mode == 'fifo': out = self.pop(0) elif self.mode == 'lifo': out = self.pop(-1) return out
[docs] def put(self, value): self.append(value)
[docs] def clean(self): del self[:]
[docs] def remote_put(self, value, slave): self.put(value)
[docs] def remote_get(self, slave): mpi_comm.isend(self.get(), dest=slave, tag=tags.QUEUE_VAL) logger.debug('sent %s to %d (%s)' % (tags.name[tags.QUEUE_VAL], slave, self.name))
[docs] class RemoteQueue(object): """Remote queue class - slave side.""" def __init__(self, name): self.name = name
[docs] def get(self): mpi_comm.isend(self.name, dest=mpi_master, tag=tags.GET_QUEUE) logger.debug('sent %s to %d (%s)' % (tags.name[tags.GET_QUEUE], mpi_master, self.name)) value = mpi_comm.recv(source=mpi_master, tag=tags.QUEUE_VAL) logger.debug('received %s from %d' % (tags.name[tags.QUEUE_VAL], mpi_master)) return value
[docs] def put(self, value): mpi_comm.isend((self.name, value), dest=mpi_master, tag=tags.PUT_QUEUE) logger.debug('sent %s to %d (%s)' % (tags.name[tags.PUT_QUEUE], mpi_master, self.name))
[docs] class RemoteDictMaster(dict): """Remote dictionary class - master side.""" def __init__(self, name, mutable=False, soft_set=False, *args): dict.__init__(self, *args) self.name = name self.mutable = mutable self.immutable_soft_set = soft_set
[docs] def remote_set(self, data, slave, mutable=False): key, value = data if not self.mutable and key in self: if self.immutable_soft_set: mpi_comm.isend(False, dest=slave, tag=tags.SET_DICT_STATUS) else: msg = "imutable dict '%s'! key '%s' already in global dict"\ % (self.name, key) logger.error(msg) mpi_comm.isend(False, dest=slave, tag=tags.SET_DICT_STATUS) raise(KeyError) else: self.__setitem__(key, value) logger.debug('set master dict (%s[%s])' % (self.name, key)) mpi_comm.isend(True, dest=slave, tag=tags.SET_DICT_STATUS)
[docs] def remote_get(self, key, slave): if key in self: mpi_comm.isend(self.__getitem__(key), dest=slave, tag=tags.DICT_VAL) logger.debug('sent %s to %d (%s[%s])' % (tags.name[tags.DICT_VAL], slave, self.name, key)) else: mpi_comm.isend(None, dest=slave, tag=tags.DICT_VAL) logger.error('RemoteDict KeyError (%s[%s])' % (self.name, key)) raise(KeyError)
[docs] def remote_get_keys(self, slave): mpi_comm.isend(self.keys(), dest=slave, tag=tags.DICT_KEYS) logger.debug('sent %s to %d (%s)' % (tags.name[tags.DICT_KEYS], slave, self.name))
[docs] def remote_get_len(self, slave): mpi_comm.isend(self.__len__(), dest=slave, tag=tags.DICT_LEN) logger.debug('sent %s to %d (%s)' % (tags.name[tags.DICT_LEN], slave, self.name))
[docs] def remote_get_in(self, key, slave): mpi_comm.isend(self.__contains__(key), dest=slave, tag=tags.DICT_IN) logger.debug('sent %s to %d (%s)' % (tags.name[tags.DICT_IN], slave, self.name))
[docs] class RemoteDict(object): """Remote dictionary class - slave side.""" def __init__(self, name, mutable=False): self._dict = {} self.mutable = mutable self.name = name def _setitem(self, key, value, tag): mpi_comm.isend((self.name, key, value), dest=mpi_master, tag=tag) logger.debug('sent %s to %d (%s[%s])' % (tags.name[tag], mpi_master, self.name, key)) stat = mpi_comm.recv(source=mpi_master, tag=tags.SET_DICT_STATUS) logger.debug('recevied %s from %d' % (tags.name[tags.SET_DICT_STATUS], mpi_master)) return stat def __setitem__(self, key, value): if self.mutable: self._setitem(key, value, tags.SET_DICT) self._dict[key] = value else: if key in self._dict: msg = "imutable dict '%s'! key '%s' already in local dict"\ % (self.name, key) logger.error(msg) raise(KeyError) else: stat = self._setitem(key, value, tags.SET_DICT_IMMUTABLE) if stat: self._dict[key] = value else: msg = \ "imutable dict '%s'! key '%s' already in global dict"\ % (self.name, key) logger.error(msg) raise(KeyError) def __getitem__(self, key): if key in self._dict and not self.mutable: logger.debug('get value from local dict! (%s[%s])' % (self.name, key)) else: mpi_comm.isend((self.name, key), dest=mpi_master, tag=tags.GET_DICT) logger.debug('sent %s to %d (%s)' % (tags.name[tags.GET_DICT], mpi_master, key)) data = mpi_comm.recv(source=mpi_master, tag=tags.DICT_VAL) logger.debug('received %s from %d' % (tags.name[tags.DICT_VAL], mpi_master)) if data is not None: self._dict[key] = data else: logger.error('RemoteDict KeyError (%s[%s])' % (self.name, key)) raise(KeyError) return self._dict[key] def __len__(self): mpi_comm.isend((self.name,), dest=mpi_master, tag=tags.GET_DICT_LEN) length = mpi_comm.recv(source=mpi_master, tag=tags.DICT_LEN) return length def __contains__(self, key): if key in self._dict and not self.mutable: is_in = True else: mpi_comm.isend((self.name, key), dest=mpi_master, tag=tags.GET_DICT_IN) is_in = mpi_comm.recv(source=mpi_master, tag=tags.DICT_IN) return is_in
[docs] def keys(self): mpi_comm.isend((self.name,), dest=mpi_master, tag=tags.GET_DICT_KEYS) keys = mpi_comm.recv(source=mpi_master, tag=tags.DICT_KEYS) return keys
[docs] def get(self, key, default=None): if key in self.keys(): return self.__getitem__(key) else: return default
[docs] def update(self, other): for k in other.keys(): self.__setitem__(k, other[k])
[docs] def is_remote_dict(d): """Return True if 'd' is RemoteDict or RemoteDictMaster instance.""" return isinstance(d, RemoteDict) or isinstance(d, RemoteDictMaster)
[docs] class RemoteLock(object): """Remote lock class - lock and unlock restricted access to the master.""" def __init__(self): self.locked = False
[docs] def acquire(self): mpi_comm.isend(None, dest=mpi_master, tag=tags.LOCK) self.locked = True
[docs] def release(self): mpi_comm.isend(None, dest=mpi_master, tag=tags.UNLOCK) self.locked = False
[docs] def wait_for_tag(wtag, num=1): ndone = num start = MPI.Wtime() while ndone > 0: mpi_comm.recv(source=MPI.ANY_SOURCE, tag=wtag, status=mpi_status) tag = mpi_status.Get_tag() source = mpi_status.Get_source() logger.debug('received %s from %d (%.03fs)' % (tags.name[tag], source, MPI.Wtime() - start)) if tag == wtag: ndone -= 1
[docs] def slave_get_task(name=''): """Start the slave nodes.""" mpi_comm.isend(mpi_rank, dest=mpi_master, tag=tags.READY) logger.debug('%s ready' % name) task, data = mpi_comm.bcast(None, root=mpi_master) logger.info('%s received task %s' % (name, task)) return task, data
[docs] def slave_task_done(task=''): """Stop the slave nodes.""" mpi_comm.isend(mpi_rank, dest=mpi_master, tag=tags.DONE) logger.info('%s stopped' % task)
[docs] def get_slaves(): """Get the list of slave nodes""" slaves = list(range(mpi_comm.Get_size())) slaves.remove(mpi_master) return slaves
[docs] def master_send_task(task, data): """Send task to all slaves.""" slaves = get_slaves() wait_for_tag(tags.READY, len(slaves)) logger.info('all nodes are ready for task %s' % task) mpi_comm.bcast((task, data), root=mpi_master)
[docs] def master_send_continue(): """Send 'continue' to all slaves.""" for ii in get_slaves(): mpi_comm.send(None, dest=ii, tag=tags.CONTINUE) logger.info('slave nodes - continue')
[docs] def master_loop(): """Run the master loop - wait for requests from slaves.""" logger.info('main loop started') master_send_task('calculate', None) ndone = len(get_slaves()) source = MPI.ANY_SOURCE while ndone > 0: data = mpi_comm.recv(source=source, tag=MPI.ANY_TAG, status=mpi_status) tag = mpi_status.Get_tag() slave = mpi_status.Get_source() logger.debug('received %s from %d' % (tags.name[tag], slave)) if tag == tags.DONE: ndone -= 1 elif tag == tags.LOCK: source = slave elif tag == tags.UNLOCK: source = MPI.ANY_SOURCE elif tag == tags.SET_DICT: global_multiproc_dict[data[0]].remote_set(data[1:], slave, mutable=True) elif tag == tags.SET_DICT_IMMUTABLE: global_multiproc_dict[data[0]].remote_set(data[1:], slave) elif tag == tags.GET_DICT: global_multiproc_dict[data[0]].remote_get(data[1], slave) elif tag == tags.GET_DICT_KEYS: global_multiproc_dict[data[0]].remote_get_keys(slave) elif tag == tags.GET_DICT_LEN: global_multiproc_dict[data[0]].remote_get_len(slave) elif tag == tags.GET_DICT_IN: global_multiproc_dict[data[0]].remote_get_in(data[1], slave) elif tag == tags.GET_QUEUE: qkey = RemoteQueueMaster.get_gdict_key(data) global_multiproc_dict[qkey].remote_get(slave) elif tag == tags.PUT_QUEUE: qkey = RemoteQueueMaster.get_gdict_key(data[0]) global_multiproc_dict[qkey].remote_put(data[1], slave) logger.info('main loop finished')