When I use pool.manager.namespace to share a pandas dataframe, and each target function will call .sample(5000) to this dataframe, EOF error occurs.
def get_sample(i):
print("start round {}".format(i))
sample = sharedData.data.sample(5000, random_state=i)
if __name__=='__main__':
with mp.Pool(cpu_count(logical=False)) as pool0:
results = pool0.map(load_data, paths)
sharedData.data = pd.concat(results, axis=0, copy=False)
genes = sharedData.data.columns
pool0.close()
pool0.join()
del results
"""sampling"""
with mp.Pool(cpu_count(logical=True)) as pool:
print("start sampling, total round = {}".format(1000))
r = pool.map_async(get_sample, [j for j in range(1000)], error_callback=my_error)
results2 = r.get()
pool.close()
pool.join()
which has traceback:
start round 145
round35 returns output
round18 returns output
rount161 returns output
start round 704
start round 720
start round 736
start round 752
start round 768
start round 784
start round 800
start round 816
start round 832
start round 848
start round 864
start round 880
start round 896
start round 912
start round 928
start round 944
start round 960
start round 976
start round 992
from error_callback:
multiprocessing.pool.RemoteTraceback:
multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/usc/python/3.6.0/lib/python3.6/multiprocessing/pool.py", line 119, in worker
result = (True, func(*args, **kwds))
File "/usr/usc/python/3.6.0/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
return list(map(*args))
File "sampling2temp.py", line 38, in get_sample_ys
sample = sharedData.data.sample(5000, random_state=i)
File "/usr/usc/python/3.6.0/lib/python3.6/multiprocessing/managers.py", line 1060, in __getattr__
return callmethod('__getattribute__', (key,))
File "/usr/usc/python/3.6.0/lib/python3.6/multiprocessing/managers.py", line 757, in _callmethod
kind, result = conn.recv()
File "/usr/usc/python/3.6.0/lib/python3.6/multiprocessing/connection.py", line 250, in recv
buf = self._recv_bytes()
File "/usr/usc/python/3.6.0/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
buf = self._recv(4)
File "/usr/usc/python/3.6.0/lib/python3.6/multiprocessing/connection.py", line 383, in _recv
raise EOFError
EOFError
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "sampling2temp.py", line 105, in <module>
results2 = r.get()
File "/usr/usc/python/3.6.0/lib/python3.6/multiprocessing/pool.py", line 608, in get
raise self._value
EOFError
It seems like the tasks 704 to 992 doesn't return any outputs at all then the Manager process shut down. So when one of the running task read data from manager.namespace.data, it receive EOF.
By the way, if I change sample(5000) to sample(2500) and change the size of Manager.Namespace.data from 2127096024 bytes to 1738281624 bytes, there's no EOF problem. Is that because each worker use too much memory?
A multiprocessing.Connection receiver throws EOFError if all of the associated sender Connections have been closed.
It looks like multiprocessing.Manager is using multiprocessing.Connection under the hood based on the stack trace. Since it doesn't look like your code is prematurely terminating the manager process, I think that the problem must be that the manager process is hitting an exception and terminating before you are done with it. Since reducing the sample size seems to fix the problem, it's possible the Manager process gets killed off by the OOM killer for using too much memory - you can check if that was the case by using the command suggested on that linked article:
dmesg | egrep -i "killed process"
You'd expect to see something like this:
host kernel: Out of Memory: Killed process 1234 (python).
Related
I have a following problem. I am running a parallel task. I am getting this error:
Traceback (most recent call last):
File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "eclat_model.py", line 127, in do_work
function(*args, work_queue, valid_list)
File "eclat_model.py", line 115, in eclat_parallel_helper
valid_list.extend(next_vectors)
File "<string>", line 2, in extend
File "/usr/lib/python3.8/multiprocessing/managers.py", line 834, in _callmethod
conn.send((self._id, methodname, args, kwds))
File "/usr/lib/python3.8/multiprocessing/connection.py", line 206, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/usr/lib/python3.8/multiprocessing/connection.py", line 404, in _send_bytes
self._send(header)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 368, in _send
n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Relevant functions in eclat_model.py look like this:
def eclat_parallel_helper(index, bit_vectors, min_support, work_queue, valid_list):
next_vectors = []
for j in range(index + 1, len(bit_vectors)):
item_vector = bit_vectors[index][0] | bit_vectors[j][0]
transaction_vector = bit_vectors[index][1] & bit_vectors[j][1]
support = get_vector_support(transaction_vector)
if support >= min_support:
next_vectors.append((item_vector, transaction_vector, support))
if len(next_vectors) > 0:
valid_list.extend(next_vectors)
for i in range(len(next_vectors)):
work_queue.put((eclat_parallel_helper, (i, next_vectors, min_support)))
def do_work(work_queue, valid_list, not_done):
# work queue entries have the form (function, args)
while not_done.value:
try:
function, args = work_queue.get_nowait()
except QueueEmptyError:
continue
function(*args, work_queue, valid_list)
work_queue.task_done()
work_queue.close()
EDIT:
Multiprocessing part of the code is as follows: bit_vectors is a list of lists, where each entry is of the form
[items, transactions, support], where items is a bit vector encoding which items appear in the itemset, vector is a bit vector encoding which transactions the itemset appears in, and support is the number of transactions in which the itemset occurs.
from multiprocessing import Process, JoinableQueue, Manager, Value, cpu_count
def eclat_parallel(bit_vectors, min_support):
not_done = Value('i', 1)
manager = Manager()
valid_list = manager.list()
work_queue = JoinableQueue()
for i in range(len(bit_vectors)):
work_queue.put((eclat_parallel_helper, (i, bit_vectors, min_support)))
processes = []
for i in range(cpu_count()):
p = Process(target=do_work, args=(work_queue, valid_list, not_done), daemon=True)
p.start()
processes.append(p)
work_queue.join()
not_done.value = 0
work_queue.close()
valid_itemset_vectors = bit_vectors
for element in valid_list:
valid_itemset_vectors.append(element)
for p in processes:
p.join()
return valid_itemset_vectors
What does this error mean, please? Am I appending too many elements into next_vectors list?
I had the same issue, in my case just added a delay (time.sleep(0.01)) to solve it.
The problem is that the individual processes are too fast on queue that causes the error.
Hello fellow programmers!
I am trying to implement multiprocessing in a class, to reduce processing time of a program.
This is an abbreviation of the program:
import multiprocessing as mp
from functools import partial
class PlanningMachines():
def __init__(self, machines, number_of_objectives, topology=False, episodes=None):
....
def calculate_total_node_THD_func_real_data_with_topo(self):
self.consider_topology = True
func_part = partial(self.worker_function, consider_topology=self.consider_topology,
list_of_machines=self.list_of_machines, next_state=self.next_state, phase=phase, grid_topo=self.grid_topo,
total_THD_for_all_timesteps_with_topo=total_THD_for_all_timesteps_with_topo,
smallest_harmonic=smallest_harmonic, pol2cart=self.pol2cart, cart2pol=self.cart2pol,
total_THD_for_all_timesteps=total_THD_for_all_timesteps, harmonics_state_phase=harmonics_state_phase,
episode=self.episode, episodes=self.episodes, time_=self.time_, steplength=self.steplength,
longest_measurement=longest_measurement)
with mp.Pool() as mpool:
mpool.map(func_part, range(0, longest_measurement))
def worker_function(measurement=None, consider_topology=None, list_of_machines=None, next_state=None, phase=None,
grid_topo=None, total_THD_for_all_timesteps_with_topo=None, smallest_harmonic=None, pol2cart=None,
cart2pol=None, total_THD_for_all_timesteps=None, harmonics_state_phase=None, episode=None,
episodes=None, time_=None, steplength=None, longest_measurement=None):
.....
As you might know, one way of implementing parallel processing is using multiprocessing.Pool().map:
with mp.Pool() as mpool:
mpool.map(func_part, range(0, longest_measurement))
This function requires a worker_function which can be "packed" with functools.partial:
func_part = partial(self.worker_function, consider_topology=self.consider_topology,
list_of_machines=self.list_of_machines, next_state=self.next_state, phase=phase, grid_topo=self.grid_topo,
total_THD_for_all_timesteps_with_topo=total_THD_for_all_timesteps_with_topo,
smallest_harmonic=smallest_harmonic, pol2cart=self.pol2cart, cart2pol=self.cart2pol,
total_THD_for_all_timesteps=total_THD_for_all_timesteps, harmonics_state_phase=harmonics_state_phase,
episode=self.episode, episodes=self.episodes, time_=self.time_, steplength=self.steplength,
longest_measurement=longest_measurement)
The Error is thrown when I try to execute mpool.map(func_part, range(0, longest_measurement)):
multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "C:\Users\Artur\Anaconda\lib\multiprocessing\pool.py", line 121, in worker
result = (True, func(*args, **kwds))
File "C:\Users\Artur\Anaconda\lib\multiprocessing\pool.py", line 44, in mapstar
return list(map(*args))
TypeError: worker_function() got multiple values for argument 'consider_topology'
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "C:/Users/Artur/Desktop/RL_framework/train.py", line 87, in <module>
main()
File "C:/Users/Artur/Desktop/RL_framework/train.py", line 77, in main
duration = cf.training(episodes, env, agent, filename, topology=topology, multi_processing=multi_processing, CPUs_used=CPUs_used)
File "C:\Users\Artur\Desktop\RL_framework\help_functions\custom_functions.py", line 166, in training
save_interval = parallel_training(range(episodes), env, agent, log_data_qvalues, log_data, filename, CPUs_used)
File "C:\Users\Artur\Desktop\RL_framework\help_functions\custom_functions.py", line 54, in paral
lel_training
next_state, reward = env.step(action, state) # given the action, the environment gives back the next_state and the reward for the transaction for all objectives seperately
File "C:\Users\Artur\Desktop\RL_framework\help_functions\environment_machines.py", line 127, in step
self.calculate_total_node_THD_func_real_data_with_topo() # THD_plant calculation with considering grid topo
File "C:\Users\Artur\Desktop\RL_framework\help_functions\environment_machines.py", line 430, in calculate_total_node_THD_func_real_data_with_topo
mpool.map(func_part, range(longest_measurement))
File "C:\Users\Artur\Anaconda\lib\multiprocessing\pool.py", line 268, in map
return self._map_async(func, iterable, mapstar, chunksize).get()
File "C:\Users\Artur\Anaconda\lib\multiprocessing\pool.py", line 657, in get
raise self._value
TypeError: worker_function() got multiple values for argument 'consider_topology'
Process finished with exit code 1
How can consider_topology have multiple values if it is passed right before the worker_function:
self.consider_topology = True
I hope I could describe the my issue well enough for you to understand. Thank you in return.
The problem I think is that your worker_function should be a static method.
What happens now is that you provide all values except the measurement variable in the partial call. You do this since this is the one value you are changing I'm guessing.
However since it is a class method it provides an instance of itself automatically as the first argument as well. You did not define self as the first argument of worker_function and now the class instance is inputted as your measurement input. The range(0, longest_measurement) you provide the map call is then inserted as the second input variable. Now since consider_topology is the second input parameter the function sees two values supplied for it, 1 the value in the partial call, and 2 the map call.
I am writing a bootstrap algorithm using parallel loops and pandas. The problem i experience is that a merge command inside the parallel loop causes a "ValueError: buffer source array is read-only" error - but only if i use the full dataset to merge (120k lines). Any subset with less than 12k lines will work just fine and so i infer it is not a problem of the syntax. What can i do?
Current pandas version is 0.24.2 and cython is 0.29.7.
_RemoteTraceback Traceback (most recent call last)
_RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/ubuntu/.local/lib/python3.6/site-packages/joblib/externals/loky/process_executor.py", line 418, in _process_worker
r = call_item()
File "/home/ubuntu/.local/lib/python3.6/site-packages/joblib/externals/loky/process_executor.py", line 272, in __call__
return self.fn(*self.args, **self.kwargs)
File "/home/ubuntu/.local/lib/python3.6/site-packages/joblib/_parallel_backends.py", line 567, in __call__
return self.func(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.6/site-packages/joblib/parallel.py", line 225, in __call__
for func, args, kwargs in self.items]
File "/home/ubuntu/.local/lib/python3.6/site-packages/joblib/parallel.py", line 225, in <listcomp>
for func, args, kwargs in self.items]
File "<ipython-input-72-cdb83eaf594c>", line 12, in bootstrap
File "/home/ubuntu/.local/lib/python3.6/site-packages/pandas/core/frame.py", line 6868, in merge
copy=copy, indicator=indicator, validate=validate)
File "/home/ubuntu/.local/lib/python3.6/site-packages/pandas/core/reshape/merge.py", line 48, in merge
return op.get_result()
File "/home/ubuntu/.local/lib/python3.6/site-packages/pandas/core/reshape/merge.py", line 546, in get_result
join_index, left_indexer, right_indexer = self._get_join_info()
File "/home/ubuntu/.local/lib/python3.6/site-packages/pandas/core/reshape/merge.py", line 756, in _get_join_info
right_indexer) = self._get_join_indexers()
File "/home/ubuntu/.local/lib/python3.6/site-packages/pandas/core/reshape/merge.py", line 735, in _get_join_indexers
how=self.how)
File "/home/ubuntu/.local/lib/python3.6/site-packages/pandas/core/reshape/merge.py", line 1130, in _get_join_indexers
llab, rlab, shape = map(list, zip(* map(fkeys, left_keys, right_keys)))
File "/home/ubuntu/.local/lib/python3.6/site-packages/pandas/core/reshape/merge.py", line 1662, in _factorize_keys
rlab = rizer.factorize(rk)
File "pandas/_libs/hashtable.pyx", line 111, in pandas._libs.hashtable.Int64Factorizer.factorize
File "stringsource", line 653, in View.MemoryView.memoryview_cwrapper
File "stringsource", line 348, in View.MemoryView.memoryview.__cinit__
ValueError: buffer source array is read-only
"""
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
<ipython-input-73-652c1db5701b> in <module>()
1 num_cores = multiprocessing.cpu_count()
----> 2 results = Parallel(n_jobs=num_cores, prefer='processes', verbose = 5)(delayed(bootstrap)() for i in range(n_trials))
3 #pd.DataFrame(results[0])
~/.local/lib/python3.6/site-packages/joblib/parallel.py in __call__(self, iterable)
932
933 with self._backend.retrieval_context():
--> 934 self.retrieve()
935 # Make sure that we get a last message telling us we are done
936 elapsed_time = time.time() - self._start_time
~/.local/lib/python3.6/site-packages/joblib/parallel.py in retrieve(self)
831 try:
832 if getattr(self._backend, 'supports_timeout', False):
--> 833 self._output.extend(job.get(timeout=self.timeout))
834 else:
835 self._output.extend(job.get())
~/.local/lib/python3.6/site-packages/joblib/_parallel_backends.py in wrap_future_result(future, timeout)
519 AsyncResults.get from multiprocessing."""
520 try:
--> 521 return future.result(timeout=timeout)
522 except LokyTimeoutError:
523 raise TimeoutError()
/usr/lib/python3.6/concurrent/futures/_base.py in result(self, timeout)
430 raise CancelledError()
431 elif self._state == FINISHED:
--> 432 return self.__get_result()
433 else:
434 raise TimeoutError()
/usr/lib/python3.6/concurrent/futures/_base.py in __get_result(self)
382 def __get_result(self):
383 if self._exception:
--> 384 raise self._exception
385 else:
386 return self._result
ValueError: buffer source array is read-only
and the code is
def bootstrap():
df_resample_ids = skl.utils.resample(ob_ids)
df_resample_ids = pd.DataFrame(df_resample_ids).sort_values(by="0").reset_index(drop=True)
df_resample_ids.columns = [ob_id_field]
df_resample = pd.DataFrame(df_resample_ids.merge(df, on = ob_id_field))
return df_resample
num_cores = multiprocessing.cpu_count()
results = Parallel(n_jobs=num_cores, prefer='processes', verbose = 5)(delayed(bootstrap)() for i in range(n_trials))
The algo will create resampled/replaced IDs from an ID variable and use the merge command to create a new dataset based on the resampled IDs and the original dataset stored in df. If i cut out a subset of the original dataset (anywhere) leaving less than ~12k lines, then the parallel loop will finish without an error and do as expected.
As requested, below is a new snippet to re-create the data structures and mirror the principal approach i am currently working on:
import pandas as pd
import sklearn as skl
import multiprocessing
from joblib import Parallel, delayed
df = pd.DataFrame(np.random.randn(200000, 24), columns=list('ABCDDEFGHIJKLMNOPQRSTUVW'))
df["ID"] = df.index.drop_duplicates().tolist()
ob_ids = df.index.drop_duplicates().tolist()
def bootstrap2():
df_resample_ids = skl.utils.resample(ob_ids)
df_resample_ids = pd.DataFrame(df_resample_ids).sort_values(by=0).reset_index(drop=True)
df_resample_ids.columns = ['ID']
df_resample = pd.DataFrame(df1.merge(df_resample_ids, on = 'ID'))
result = df_resample
return result
num_cores = multiprocessing.cpu_count()
results = Parallel(n_jobs=num_cores, prefer='processes', verbose = 5)(delayed(bootstrap2)() for i in range(n_trials))
However, i notice that when the data is completely made up of np.random numbers, the loop goes through without an error. The dtypes of the original dataframe are:
start_rtg int64
end_rtg float64
days_diff float64
ultimate_customer_system_id int64
How can i avoid the read-only error?
posting an answer to my question as i found that one of the variables was of int64 datatype. when i converted all variables to float64, the error disappeared. so it is an issue that is restricted to certain datatypes only...
cheers
stephan
I'm play with multiprocessing in Python. I'm trying to determine what happends if a workers raise an exception so I wrote the following code:
def a(num):
if(num == 2):
raise Exception("num can't be 2")
print(num)
p = Pool()
p.map(a, [2, 1, 3, 4, 5, 6, 7, 100, 100000000000000, 234, 234, 5634, 0000])
output
3
4
5
7
6
100
100000000000000
234
234
5634
0
multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/lib/python3.5/multiprocessing/pool.py", line 119, in worker
result = (True, func(*args, **kwds))
File "/usr/lib/python3.5/multiprocessing/pool.py", line 44, in mapstar
return list(map(*args))
File "<stdin>", line 3, in a
Exception: Error, num can't be 2
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/lib/python3.5/multiprocessing/pool.py", line 260, in map
return self._map_async(func, iterable, mapstar, chunksize).get()
File "/usr/lib/python3.5/multiprocessing/pool.py", line 608, in get
raise self._value
Exception: Error, num can't be 2
If you can see the numbers that was printed "2" is not there but Why is not number 1 also there?
Note: I'm using Python 3.5.2 on Ubuntu
By default, Pool creates a number of workers equal to your number of cores. When one of those worker processes dies, it may leave work that has been assigned to it undone. It also may leave output in a buffer that never gets flushed.
The pattern with .map() is to handle exceptions in the workers and return some suitable error value, since the results of .map() are supposed to be one-to-one with the input.
from multiprocessing import Pool
def a(num):
try:
if(num == 2):
raise Exception("num can't be 2")
print(num, flush=True)
return num
except Exception as e:
print('failed', flush=True)
return e
p = Pool()
n=100
results = p.map(a, range(n))
print("missing numbers: ", tuple(i for i in range(n) if i not in results))
Here's another question with good information about how exceptions propagate in multiprocessing.map workers.
I'm having a lot of success using Dask and Distributed to develop data analysis pipelines. One thing that I'm still looking forward to improving, however, is the way I handle exceptions.
Right now if, I write the following
def my_function (value):
return 1 / value
results = (dask.bag
.from_sequence(range(-10, 10))
.map(my_function))
print(results.compute())
... then on running the program I get a long, long list of tracebacks (one per worker, I'm guessing). The most relevant segment being
distributed.utils - ERROR - division by zero
Traceback (most recent call last):
File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/distributed/utils.py", line 193, in f
result[0] = yield gen.maybe_future(func(*args, **kwargs))
File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/gen.py", line 1015, in run
value = future.result()
File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/concurrent.py", line 237, in result
raise_exc_info(self._exc_info)
File "<string>", line 3, in raise_exc_info
File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/gen.py", line 1021, in run
yielded = self.gen.throw(*exc_info)
File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/distributed/client.py", line 1473, in _get
result = yield self._gather(packed)
File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/gen.py", line 1015, in run
value = future.result()
File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/concurrent.py", line 237, in result
raise_exc_info(self._exc_info)
File "<string>", line 3, in raise_exc_info
File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/gen.py", line 1021, in run
yielded = self.gen.throw(*exc_info)
File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/distributed/client.py", line 923, in _gather
st.traceback)
File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/six.py", line 685, in reraise
raise value.with_traceback(tb)
File "/mnt/lustrefs/work/aurelien.mazurie/test_dask/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/dask/bag/core.py", line 1411, in reify
File "test.py", line 9, in my_function
return 1 / value
ZeroDivisionError: division by zero
Here, of course, a visual inspection will tell me that the error was dividing a number by zero. What I'm wondering is if there is a better way to track these errors. For example, I cannot seem to be able to catch the exception itself:
import dask.bag
import distributed
try:
dask_scheduler = "127.0.0.1:8786"
dask_client = distributed.Client(dask_scheduler)
def my_function (value):
return 1 / value
results = (dask.bag
.from_sequence(range(-10, 10))
.map(my_function))
#dask_client.persist(results)
print(results.compute())
except Exception as e:
print("error: %s" % e)
EDIT: Note that in my example I'm using distributed, not just dask. There is a dask-scheduler listening on port 8786 with four dask-worker processes registered to it.
This code will produce the exact same output as above, meaning that I'm not actually catching the exception with my try/except block.
Now, since we're talking of distributed tasks across a cluster it is obviously non trivial to propagate exceptions back to me. Is there any guideline to do so? Right now my solution is to have functions return both a result and an optional error message, then process the results and error messages separately:
def my_function (value):
try:
return {"result": 1 / value, "error": None}
except ZeroDivisionError:
return {"result": None, "error": "boom!"}
results = (dask.bag
.from_sequence(range(-10, 10))
.map(my_function))
dask_client.persist(results)
errors = (results
.pluck("error")
.filter(lambda x: x is not None)
.compute())
print(errors)
results = (results
.pluck("result")
.filter(lambda x: x is not None)
.compute())
print(results)
This works, but I'm wondering if I'm sandblasting the soup cracker here. EDIT: Another option would be to use something like a Maybe monad, but once again I'd like to know if I'm overthinking it.
Dask automatically packages up exceptions that occurred remotely and reraises them locally. Here is what I get when I run your example
In [1]: from dask.distributed import Client
In [2]: client = Client('localhost:8786')
In [3]: import dask.bag
In [4]: try:
...: def my_function (value):
...: return 1 / value
...:
...: results = (dask.bag
...: .from_sequence(range(-10, 10))
...: .map(my_function))
...:
...: print(results.compute())
...:
...: except Exception as e:
...: import pdb; pdb.set_trace()
...: print("error: %s" % e)
...:
distributed.utils - ERROR - division by zero
> <ipython-input-4-17aa5fbfb732>(13)<module>()
-> print("error: %s" % e)
(Pdb) pp e
ZeroDivisionError('division by zero',)
You could wrap your function like so:
def exception_handler(orig_func):
def wrapper(*args,**kwargs):
try:
return orig_func(*args,**kwargs)
except:
import sys
sys.exit(1)
return wrapper
You could use a decorator or do:
wrapped = exception_handler(my_function)
dask_client.map(wrapper, range(100))
This seems to automatically rebalance tasks if a worker fails. But I don't know how to remove the failed worker from the pool.