@@ -31,6 +31,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
3131 let mut runtime: Option < String > = None ;
3232 let mut algorithm: Option < String > = None ;
3333 let mut task: Option < String > = None ;
34+ let mut hyperparams: Option < JsonB > = None ;
3435
3536 Spi :: connect ( |client| {
3637 let result = client
@@ -39,7 +40,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
3940 data,
4041 runtime::TEXT,
4142 algorithm::TEXT,
42- task::TEXT
43+ task::TEXT,
44+ hyperparams
4345 FROM pgml.models
4446 INNER JOIN pgml.files
4547 ON models.id = files.model_id
@@ -66,6 +68,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
6668 runtime = result. get ( 2 ) . expect ( "Runtime for model is corrupted." ) ;
6769 algorithm = result. get ( 3 ) . expect ( "Algorithm for model is corrupted." ) ;
6870 task = result. get ( 4 ) . expect ( "Task for project is corrupted." ) ;
71+ hyperparams = result. get ( 5 ) . expect ( "Hyperparams for model is corrupted." ) ;
6972 }
7073 } ) ;
7174
@@ -83,6 +86,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
8386 let runtime = Runtime :: from_str ( & runtime. unwrap ( ) ) . unwrap ( ) ;
8487 let algorithm = Algorithm :: from_str ( & algorithm. unwrap ( ) ) . unwrap ( ) ;
8588 let task = Task :: from_str ( & task. unwrap ( ) ) . unwrap ( ) ;
89+ let hyperparams = hyperparams. unwrap ( ) ;
8690
8791 debug1 ! (
8892 "runtime = {:?}, algorithm = {:?}, task = {:?}" ,
@@ -94,22 +98,22 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
9498 let bindings: Box < dyn Bindings > = match runtime {
9599 Runtime :: rust => {
96100 match algorithm {
97- Algorithm :: xgboost => crate :: bindings:: xgboost:: Estimator :: from_bytes ( & data) ?,
98- Algorithm :: lightgbm => crate :: bindings:: lightgbm:: Estimator :: from_bytes ( & data) ?,
101+ Algorithm :: xgboost => crate :: bindings:: xgboost:: Estimator :: from_bytes ( & data, & hyperparams ) ?,
102+ Algorithm :: lightgbm => crate :: bindings:: lightgbm:: Estimator :: from_bytes ( & data, & hyperparams ) ?,
99103 Algorithm :: linear => match task {
100- Task :: regression => crate :: bindings:: linfa:: LinearRegression :: from_bytes ( & data) ?,
104+ Task :: regression => crate :: bindings:: linfa:: LinearRegression :: from_bytes ( & data, & hyperparams ) ?,
101105 Task :: classification => {
102- crate :: bindings:: linfa:: LogisticRegression :: from_bytes ( & data) ?
106+ crate :: bindings:: linfa:: LogisticRegression :: from_bytes ( & data, & hyperparams ) ?
103107 }
104108 _ => error ! ( "Rust runtime only supports `classification` and `regression` task types for linear algorithms." ) ,
105109 } ,
106- Algorithm :: svm => crate :: bindings:: linfa:: Svm :: from_bytes ( & data) ?,
110+ Algorithm :: svm => crate :: bindings:: linfa:: Svm :: from_bytes ( & data, & hyperparams ) ?,
107111 _ => todo ! ( ) , //smartcore_load(&data, task, algorithm, &hyperparams),
108112 }
109113 }
110114
111115 #[ cfg( feature = "python" ) ]
112- Runtime :: python => crate :: bindings:: sklearn:: Estimator :: from_bytes ( & data) ?,
116+ Runtime :: python => crate :: bindings:: sklearn:: Estimator :: from_bytes ( & data, & hyperparams ) ?,
113117
114118 #[ cfg( not( feature = "python" ) ) ]
115119 Runtime :: python => {
0 commit comments