I am writing a PySpark implementation of an algorithm that is iterative in nature. Part of the algorithm involves iterating a strategy until no more improvements can be made (i.e., a local maximum has been greedily reached).
The function optimize returns a three-column dataframe that looks as follows:
| id | current_value | best_value |
|---|---|---|
| 0 | 1 | 1 |
| 1 | 0 | 1 |
This function is used in a while loop until current_value and best_value are identical (meaning that no more optimizations can be made).
# Init while loop
iterate = True
# Start iterating until optimization yields same result as before
while iterate:
# Create (or overwrite) `df`
df = optimizeAll(df2) # Uses `df2` as input
df.persist().count()
# Check stopping condition
iterate = df.where('current_value != best_value').count() > 0
# Update `df2` with latest results
if iterate:
df2 = df2.join(other=df, on='id', how='left') # <- Should I persist this?
This function runs very quickly when I pass it the inputs manually. However, I have noticed that the time it takes for the function to run increases exponentially as it iterates. That is, the first iteration runs in milliseconds, the second one in seconds and eventually it takes up to 10 minutes per pass.
This question suggests that if df isn't cached, the while loop will start running from scratch on every iteration. Is this true?
If so, which objects should I persist? I know that persisting df will be triggered by the count when defining iterate. However, df2 has no action, so even if I persist it, will it make the while loop start from scratch every time? Likewise, should I unpersist either table at some point in the loop?