1
- use dmatrix:: DMatrix ;
2
- use error:: XGBError ;
1
+ use crate :: dmatrix:: DMatrix ;
2
+ use crate :: error:: XGBError ;
3
3
use libc;
4
4
use std:: collections:: { BTreeMap , HashMap } ;
5
5
use std:: io:: { self , BufRead , BufReader , Write } ;
@@ -13,7 +13,7 @@ use tempfile;
13
13
use xgboost_sys;
14
14
15
15
use super :: XGBResult ;
16
- use parameters:: { BoosterParameters , TrainingParameters } ;
16
+ use crate :: parameters:: { BoosterParameters , TrainingParameters } ;
17
17
18
18
pub type CustomObjective = fn ( & [ f32 ] , & DMatrix ) -> ( Vec < f32 > , Vec < f32 > ) ;
19
19
@@ -148,29 +148,8 @@ impl Booster {
148
148
dmats
149
149
} ;
150
150
151
- let mut bst = Booster :: new_with_cached_dmats ( & params. booster_params , & cached_dmats) ?;
152
- // load distributed code checkpoint from rabit
153
- let mut version = bst. load_rabit_checkpoint ( ) ?;
154
- debug ! ( "Loaded Rabit checkpoint: version={}" , version) ;
155
- assert ! ( unsafe { xgboost_sys:: RabitGetWorldSize ( ) != 1 || version == 0 } ) ;
156
- let start_iteration = version / 2 ;
157
- for i in start_iteration..params. boost_rounds as i32 {
158
- // distributed code: need to resume to this point
159
- // skip first update if a recovery step
160
- if version % 2 == 0 {
161
- if let Some ( objective_fn) = params. custom_objective_fn {
162
- debug ! ( "Boosting in round: {}" , i) ;
163
- bst. update_custom ( params. dtrain , objective_fn) ?;
164
- } else {
165
- debug ! ( "Updating in round: {}" , i) ;
166
- bst. update ( params. dtrain , i) ?;
167
- }
168
- let _ = bst. save_rabit_checkpoint ( ) ?;
169
- version += 1 ;
170
- }
171
-
172
- assert ! ( unsafe { xgboost_sys:: RabitGetWorldSize ( ) == 1 || version == xgboost_sys:: RabitVersionNumber ( ) } ) ;
173
-
151
+ let bst = Booster :: new_with_cached_dmats ( & params. booster_params , & cached_dmats) ?;
152
+ for i in 0 ..params. boost_rounds as i32 {
174
153
if let Some ( eval_sets) = params. evaluation_sets {
175
154
let mut dmat_eval_results = bst. eval_set ( eval_sets, i) ?;
176
155
@@ -203,10 +182,6 @@ impl Booster {
203
182
}
204
183
println ! ( ) ;
205
184
}
206
-
207
- // do checkpoint after evaluation, in case evaluation also updates booster.
208
- let _ = bst. save_rabit_checkpoint ( ) ;
209
- version += 1 ;
210
185
}
211
186
212
187
Ok ( bst)
@@ -365,13 +340,16 @@ impl Booster {
365
340
let mut out_len = 0 ;
366
341
let mut out = ptr:: null_mut ( ) ;
367
342
xgb_call ! ( xgboost_sys:: XGBoosterGetAttrNames ( self . handle, & mut out_len, & mut out) ) ?;
368
-
369
- let out_ptr_slice = unsafe { slice:: from_raw_parts ( out, out_len as usize ) } ;
370
- let out_vec = out_ptr_slice
371
- . iter ( )
372
- . map ( |str_ptr| unsafe { ffi:: CStr :: from_ptr ( * str_ptr) . to_str ( ) . unwrap ( ) . to_owned ( ) } )
373
- . collect ( ) ;
374
- Ok ( out_vec)
343
+ if out_len > 0 {
344
+ let out_ptr_slice = unsafe { slice:: from_raw_parts ( out, out_len as usize ) } ;
345
+ let out_vec = out_ptr_slice
346
+ . iter ( )
347
+ . map ( |str_ptr| unsafe { ffi:: CStr :: from_ptr ( * str_ptr) . to_str ( ) . unwrap ( ) . to_owned ( ) } )
348
+ . collect ( ) ;
349
+ Ok ( out_vec)
350
+ } else {
351
+ Ok ( Vec :: new ( ) )
352
+ }
375
353
}
376
354
377
355
/// Predict results for given data.
@@ -517,7 +495,7 @@ impl Booster {
517
495
Err ( err) => return Err ( XGBError :: new ( err. to_string ( ) ) ) ,
518
496
} ;
519
497
520
- let file_path = tmp_dir. path ( ) . join ( "fmap.txt " ) ;
498
+ let file_path = tmp_dir. path ( ) . join ( "fmap.json " ) ;
521
499
let mut file: File = match File :: create ( & file_path) {
522
500
Ok ( f) => f,
523
501
Err ( err) => return Err ( XGBError :: new ( err. to_string ( ) ) ) ,
@@ -551,24 +529,18 @@ impl Booster {
551
529
& mut out_dump_array
552
530
) ) ?;
553
531
554
- let out_ptr_slice = unsafe { slice:: from_raw_parts ( out_dump_array, out_len as usize ) } ;
555
- let out_vec: Vec < String > = out_ptr_slice
556
- . iter ( )
557
- . map ( |str_ptr| unsafe { ffi:: CStr :: from_ptr ( * str_ptr) . to_str ( ) . unwrap ( ) . to_owned ( ) } )
558
- . collect ( ) ;
532
+ if out_len > 0 {
533
+ let out_ptr_slice = unsafe { slice:: from_raw_parts ( out_dump_array, out_len as usize ) } ;
534
+ let out_vec: Vec < String > = out_ptr_slice
535
+ . iter ( )
536
+ . map ( |str_ptr| unsafe { ffi:: CStr :: from_ptr ( * str_ptr) . to_str ( ) . unwrap ( ) . to_owned ( ) } )
537
+ . collect ( ) ;
559
538
560
- assert_eq ! ( out_len as usize , out_vec. len( ) ) ;
561
- Ok ( out_vec. join ( "\n " ) )
562
- }
563
-
564
- pub ( crate ) fn load_rabit_checkpoint ( & self ) -> XGBResult < i32 > {
565
- let mut version = 0 ;
566
- xgb_call ! ( xgboost_sys:: XGBoosterLoadRabitCheckpoint ( self . handle, & mut version) ) ?;
567
- Ok ( version)
568
- }
569
-
570
- pub ( crate ) fn save_rabit_checkpoint ( & self ) -> XGBResult < ( ) > {
571
- xgb_call ! ( xgboost_sys:: XGBoosterSaveRabitCheckpoint ( self . handle) )
539
+ assert_eq ! ( out_len as usize , out_vec. len( ) ) ;
540
+ Ok ( out_vec. join ( "\n " ) )
541
+ } else {
542
+ Ok ( String :: new ( ) )
543
+ }
572
544
}
573
545
574
546
pub fn set_param ( & mut self , name : & str , value : & str ) -> XGBResult < ( ) > {
@@ -721,7 +693,7 @@ impl fmt::Display for FeatureType {
721
693
#[ cfg( test) ]
722
694
mod tests {
723
695
use super :: * ;
724
- use parameters:: { self , learning, tree} ;
696
+ use crate :: parameters:: { self , learning, tree} ;
725
697
726
698
fn read_train_matrix ( ) -> XGBResult < DMatrix > {
727
699
DMatrix :: load ( r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"# )
@@ -739,7 +711,6 @@ mod tests {
739
711
assert ! ( res. is_ok( ) ) ;
740
712
}
741
713
742
-
743
714
#[ test]
744
715
fn get_set_attr ( ) {
745
716
let mut booster = load_test_booster ( ) ;
0 commit comments