Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions src/migration_bench/eval/final_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,17 +320,43 @@ def run_batch_eval_parallel(predictions, max_workers=8, **kwargs) -> int:
pred_data = [(pred, kwargs) for pred in predictions]

count = 0
pool = mp.Pool(max_workers)
manager = mp.Manager()
return_dict = manager.dict()

try:
results = pool.map(_process_single_prediction, pred_data)
count = sum(results)
for i in range(0, len(pred_data), max_workers):
batch = pred_data[i:i + max_workers]
processes = []

for idx, data in enumerate(batch):
def worker(data, return_dict, idx):
result = _process_single_prediction(data)
return_dict[idx] = result

p = mp.Process(target=worker, args=(data, return_dict, i + idx))
p.start()
processes.append((p, i + idx))

# Wait for each process with timeout
for p, idx in processes:
p.join(timeout=1800) # 30 minutes
if p.is_alive():
logging.warning("Process %d timed out after 30 minutes, terminating...", idx)
p.terminate()
p.join() # Clean up
return_dict[idx] = False

# Collect results from this batch
for _, idx in processes:
if idx in return_dict:
count += return_dict[idx]

except KeyboardInterrupt:
logging.info("Interrupted by user, shutting down...")
pool.terminate() # Kill immediately
pool.join() # Clean up
finally:
pool.close() # No more tasks
pool.join() # Wait for completion if not terminated
for p, _ in processes:
if p.is_alive():
p.terminate()
p.join()

logging.info(
"[batch-parallel] Final eval result: Success = %d out of %d.",
Expand Down