Unable to use Prefect with stateful flows #15806
-
I am currently bumping against an apparent limitation with stateful flows. Consider the following example: import numpy as np
import threading
from prefect import flow, task
from prefect.futures import as_completed
from prefect_dask import DaskTaskRunner
from prefect.task_runners import ThreadPoolTaskRunner
@task
def work(arg):
print(f"Starting task in a function with argument: {arg}")
return arg
class MainObj:
def __init__(self):
# This object should not be pickled
self.large_object = np.random.random((10000, 10000))
self.lock = threading.Lock()
self.counter = 0
#@flow(task_runner=ThreadPoolTaskRunner)
@flow(task_runner=DaskTaskRunner())
def run(self):
futures = [work.submit(i) for i in range(3)]
for fut in as_completed(futures):
res = fut.result()
self.process_result(res)
def process_result(self, res):
print(f"Task finished with {res}")
self.counter += res
if __name__ == "__main__":
obj = MainObj()
print(f"Counter is {obj.counter}")
obj.run()
print(f"Counter is {obj.counter}") Here, the flow needs to update the The code works with
However, it fails with
This is because If so, then it's reasonable that the code doesn't work: the task runner cannot update the
Is there any way to use Prefect in this case? Any insight would be very much appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 6 replies
-
would something like this work for your case? example# /// script
# dependencies = [
# "numpy",
# "prefect-dask",
# ]
# ///
import time
import dask.array as da
from prefect_dask import DaskTaskRunner
from prefect import flow, task
from prefect.futures import as_completed
@task
def process_chunk(chunk_data, chunk_id: int):
print(f"Processing chunk {chunk_id} with shape: {chunk_data.shape}")
time.sleep(1) # Simulate work
# Calculate mean absolute value - naturally some will be above 0.5
result = abs(chunk_data).mean().compute()
print(f"Chunk {chunk_id} result: {result:.4f}")
return {"chunk_id": chunk_id, "result": result}
@task
def process_followup(previous_result: float):
print(f"Starting followup for result: {previous_result}")
time.sleep(3) # Longer sleep to make it obvious
print(f"Completed followup for result: {previous_result}")
return previous_result * 0.5
@flow(task_runner=DaskTaskRunner(), log_prints=True)
def process_large_dataset():
large_array = da.random.random((3000, 3000), chunks=(1000, 1000)) * 2 - 1
print(
f"Created array with shape {large_array.shape} and chunks {large_array.chunks}"
)
futures = []
for i in range(large_array.numblocks[0]):
chunk = large_array[i * 1000 : (i + 1) * 1000, :]
futures.append(process_chunk.submit(chunk, i))
# Track all results including followups
all_results = {}
followup_futures = [] # Separate list for followups
print("Starting to process results...")
for future in as_completed(futures):
result = future.result()
all_results[result["chunk_id"]] = result["result"]
print(f"Processed chunk {result['chunk_id']}")
if result["result"] > 0.5:
print(f"Spawning followup task for chunk {result['chunk_id']}")
followup_future = process_followup.submit(result["result"])
followup_futures.append(followup_future)
print(
f"Main processing complete. Waiting for {len(followup_futures)} followup tasks..."
)
# Explicitly wait for followups
for future in as_completed(followup_futures):
result = future.result()
print(f"Completed followup with result: {result:.4f}")
return all_results
if __name__ == "__main__":
print(sorted(process_large_dataset().items()))
here we are:
feel free to let me know if there's some nuance where prefect specifically is getting in the way |
Beta Was this translation helpful? Give feedback.
-
I recently stumbled upon Covalent, which works in the desired manner: import numpy as np
import threading
import time
import covalent as ct
@ct.electron
def work(arg):
print(f"Starting task in a function with argument: {arg}")
time.sleep(1)
return arg
class MainObj:
def __init__(self):
# This object should not be pickled
self.large_object = np.random.random((10000, 10000))
self.lock = threading.Lock()
self.counter = 0
def run(self):
dask_executor = (
ct.executor.DaskExecutor()
) # Not actually necessary, since Dask is also the default executor
futures = [
ct.dispatch(ct.lattice(work, executor=dask_executor))(i) for i in range(3)
]
for fut in futures:
wresult = ct.get_result(dispatch_id=fut, wait=True)
res = wresult.get_node_result(node_id=0)["output"].get_deserialized()
self.process_result(res)
def process_result(self, res):
print(f"Task finished with {res}")
self.counter += res
if __name__ == "__main__":
obj = MainObj()
print(f"Counter is {obj.counter}")
obj.run()
print(f"Counter is {obj.counter}") The main difference here is that |
Beta Was this translation helpful? Give feedback.
hmm, I'm not sure I understand your intended use of
threading.Lock
in this contextif the above represents what you're trying to do, what about this? i.e just call your dask work (instance method or not) from a flow
docs on this