Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,39 @@ def get_output_path(self):
"""Returns the algo output paths to find the algo scripts and configs."""
return self.output_path

def state_dict(self) -> dict:
"""
Return state for serialization.

Returns:
A dictionary containing the BundleAlgo state to serialize.

Note:
template_path is excluded as it is determined dynamically at load time
based on which path successfully imports the Algo class.
"""
return {
"data_stats_files": self.data_stats_files,
"data_list_file": self.data_list_file,
"mlflow_tracking_uri": self.mlflow_tracking_uri,
"mlflow_experiment_name": self.mlflow_experiment_name,
"output_path": self.output_path,
"name": self.name,
"best_metric": self.best_metric,
"fill_records": self.fill_records,
"device_setting": self.device_setting,
}

def load_state_dict(self, state: dict) -> None:
"""
Restore state from a dictionary.

Args:
state: A dictionary containing the state to restore.
"""
for key, value in state.items():
setattr(self, key, value)


# path to download the algo_templates
default_algo_zip = (
Expand Down
21 changes: 14 additions & 7 deletions monai/apps/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os

from monai.apps.auto3dseg.bundle_gen import BundleAlgo
from monai.auto3dseg import algo_from_pickle, algo_to_pickle
from monai.auto3dseg import algo_from_json, algo_to_json
from monai.utils.enums import AlgoKeys

__all__ = ["import_bundle_algo_history", "export_bundle_algo_history", "get_name_from_algo_id"]
Expand Down Expand Up @@ -42,11 +42,18 @@ def import_bundle_algo_history(
if not os.path.isdir(write_path):
continue

obj_filename = os.path.join(write_path, "algo_object.pkl")
if not os.path.isfile(obj_filename): # saved mode pkl
# Prefer JSON format, fall back to legacy pickle
json_filename = os.path.join(write_path, "algo_object.json")
pkl_filename = os.path.join(write_path, "algo_object.pkl")

if os.path.isfile(json_filename):
obj_filename = json_filename
elif os.path.isfile(pkl_filename):
obj_filename = pkl_filename
else:
continue

algo, algo_meta_data = algo_from_pickle(obj_filename, template_path=template_path)
algo, algo_meta_data = algo_from_json(obj_filename, template_path=template_path)

best_metric = algo_meta_data.get(AlgoKeys.SCORE, None)
if best_metric is None:
Expand All @@ -57,7 +64,7 @@ def import_bundle_algo_history(

is_trained = best_metric is not None

if (only_trained and is_trained) or not only_trained:
if is_trained or not only_trained:
history.append(
{AlgoKeys.ID: name, AlgoKeys.ALGO: algo, AlgoKeys.SCORE: best_metric, AlgoKeys.IS_TRAINED: is_trained}
)
Expand All @@ -67,14 +74,14 @@ def import_bundle_algo_history(

def export_bundle_algo_history(history: list[dict[str, BundleAlgo]]) -> None:
"""
Save all the BundleAlgo in the history to algo_object.pkl in each individual folder
Save all the BundleAlgo in the history to algo_object.json in each individual folder.

Args:
history: a List of Bundle. Typically, the history can be obtained from BundleGen get_history method
"""
for algo_dict in history:
algo = algo_dict[AlgoKeys.ALGO]
algo_to_pickle(algo, template_path=algo.template_path)
algo_to_json(algo, template_path=algo.template_path)


def get_name_from_algo_id(id: str) -> str:
Expand Down
2 changes: 2 additions & 0 deletions monai/auto3dseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from .operations import Operations, SampleOperations, SummaryOperations
from .seg_summarizer import SegSummarizer
from .utils import (
algo_from_json,
algo_from_pickle,
algo_to_json,
algo_to_pickle,
concat_multikeys_to_dict,
concat_val_to_np,
Expand Down
25 changes: 25 additions & 0 deletions monai/auto3dseg/algo_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,31 @@ def get_output_path(self, *args, **kwargs):
"""Returns the algo output paths for scripts location"""
pass

def state_dict(self) -> dict:
"""
Return state for serialization.

Subclasses should override this method to return a dictionary of
attributes that need to be serialized. This follows the PyTorch
convention for state management.

Returns:
A dictionary containing the state to serialize.
"""
return {}

def load_state_dict(self, state: dict) -> None:
"""
Restore state from a dictionary.

Subclasses should override this method to restore their state
from the dictionary returned by state_dict().

Args:
state: A dictionary containing the state to restore.
"""
pass


class AlgoGen(Randomizable):
"""
Expand Down
Loading
Loading