@@ -31,6 +31,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
31
31
let mut runtime: Option < String > = None ;
32
32
let mut algorithm: Option < String > = None ;
33
33
let mut task: Option < String > = None ;
34
+ let mut hyperparams: Option < JsonB > = None ;
34
35
35
36
Spi :: connect ( |client| {
36
37
let result = client
@@ -39,7 +40,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
39
40
data,
40
41
runtime::TEXT,
41
42
algorithm::TEXT,
42
- task::TEXT
43
+ task::TEXT,
44
+ hyperparams
43
45
FROM pgml.models
44
46
INNER JOIN pgml.files
45
47
ON models.id = files.model_id
@@ -66,6 +68,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
66
68
runtime = result. get ( 2 ) . expect ( "Runtime for model is corrupted." ) ;
67
69
algorithm = result. get ( 3 ) . expect ( "Algorithm for model is corrupted." ) ;
68
70
task = result. get ( 4 ) . expect ( "Task for project is corrupted." ) ;
71
+ hyperparams = result. get ( 5 ) . expect ( "Hyperparams for model is corrupted." ) ;
69
72
}
70
73
} ) ;
71
74
@@ -83,6 +86,7 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
83
86
let runtime = Runtime :: from_str ( & runtime. unwrap ( ) ) . unwrap ( ) ;
84
87
let algorithm = Algorithm :: from_str ( & algorithm. unwrap ( ) ) . unwrap ( ) ;
85
88
let task = Task :: from_str ( & task. unwrap ( ) ) . unwrap ( ) ;
89
+ let hyperparams = hyperparams. unwrap ( ) ;
86
90
87
91
debug1 ! (
88
92
"runtime = {:?}, algorithm = {:?}, task = {:?}" ,
@@ -94,22 +98,22 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
94
98
let bindings: Box < dyn Bindings > = match runtime {
95
99
Runtime :: rust => {
96
100
match algorithm {
97
- Algorithm :: xgboost => crate :: bindings:: xgboost:: Estimator :: from_bytes ( & data) ?,
98
- Algorithm :: lightgbm => crate :: bindings:: lightgbm:: Estimator :: from_bytes ( & data) ?,
101
+ Algorithm :: xgboost => crate :: bindings:: xgboost:: Estimator :: from_bytes ( & data, & hyperparams ) ?,
102
+ Algorithm :: lightgbm => crate :: bindings:: lightgbm:: Estimator :: from_bytes ( & data, & hyperparams ) ?,
99
103
Algorithm :: linear => match task {
100
- Task :: regression => crate :: bindings:: linfa:: LinearRegression :: from_bytes ( & data) ?,
104
+ Task :: regression => crate :: bindings:: linfa:: LinearRegression :: from_bytes ( & data, & hyperparams ) ?,
101
105
Task :: classification => {
102
- crate :: bindings:: linfa:: LogisticRegression :: from_bytes ( & data) ?
106
+ crate :: bindings:: linfa:: LogisticRegression :: from_bytes ( & data, & hyperparams ) ?
103
107
}
104
108
_ => error ! ( "Rust runtime only supports `classification` and `regression` task types for linear algorithms." ) ,
105
109
} ,
106
- Algorithm :: svm => crate :: bindings:: linfa:: Svm :: from_bytes ( & data) ?,
110
+ Algorithm :: svm => crate :: bindings:: linfa:: Svm :: from_bytes ( & data, & hyperparams ) ?,
107
111
_ => todo ! ( ) , //smartcore_load(&data, task, algorithm, &hyperparams),
108
112
}
109
113
}
110
114
111
115
#[ cfg( feature = "python" ) ]
112
- Runtime :: python => crate :: bindings:: sklearn:: Estimator :: from_bytes ( & data) ?,
116
+ Runtime :: python => crate :: bindings:: sklearn:: Estimator :: from_bytes ( & data, & hyperparams ) ?,
113
117
114
118
#[ cfg( not( feature = "python" ) ) ]
115
119
Runtime :: python => {
0 commit comments