Skip to content

Commit fa91b55

Browse files
authored
Merge pull request #14 from postgresml/montana/build
update for 2021 edition
2 parents d851631 + 4f0f358 commit fa91b55

File tree

7 files changed

+51
-67
lines changed

7 files changed

+51
-67
lines changed

Cargo.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@ homepage = "https://github.com/davechallis/rust-xgboost"
88
description = "Machine learning using XGBoost"
99
documentation = "https://docs.rs/xgboost"
1010
readme = "README.md"
11+
edition = "2021"
1112

1213
[dependencies]
1314
xgboost-sys = { path = "xgboost-sys" }
1415
libc = "0.2"
15-
derive_builder = "0.12"
16+
derive_builder = "0.20"
1617
log = "0.4"
17-
tempfile = "3.9"
18-
indexmap = "2.1"
18+
tempfile = "3.15"
19+
indexmap = "2.7"
1920

2021
[features]
2122
cuda = ["xgboost-sys/cuda"]

src/booster.rs

Lines changed: 28 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use dmatrix::DMatrix;
2-
use error::XGBError;
1+
use crate::dmatrix::DMatrix;
2+
use crate::error::XGBError;
33
use libc;
44
use std::collections::{BTreeMap, HashMap};
55
use std::io::{self, BufRead, BufReader, Write};
@@ -13,7 +13,7 @@ use tempfile;
1313
use xgboost_sys;
1414

1515
use super::XGBResult;
16-
use parameters::{BoosterParameters, TrainingParameters};
16+
use crate::parameters::{BoosterParameters, TrainingParameters};
1717

1818
pub type CustomObjective = fn(&[f32], &DMatrix) -> (Vec<f32>, Vec<f32>);
1919

@@ -148,29 +148,8 @@ impl Booster {
148148
dmats
149149
};
150150

151-
let mut bst = Booster::new_with_cached_dmats(&params.booster_params, &cached_dmats)?;
152-
// load distributed code checkpoint from rabit
153-
let mut version = bst.load_rabit_checkpoint()?;
154-
debug!("Loaded Rabit checkpoint: version={}", version);
155-
assert!(unsafe { xgboost_sys::RabitGetWorldSize() != 1 || version == 0 });
156-
let start_iteration = version / 2;
157-
for i in start_iteration..params.boost_rounds as i32 {
158-
// distributed code: need to resume to this point
159-
// skip first update if a recovery step
160-
if version % 2 == 0 {
161-
if let Some(objective_fn) = params.custom_objective_fn {
162-
debug!("Boosting in round: {}", i);
163-
bst.update_custom(params.dtrain, objective_fn)?;
164-
} else {
165-
debug!("Updating in round: {}", i);
166-
bst.update(params.dtrain, i)?;
167-
}
168-
let _ = bst.save_rabit_checkpoint()?;
169-
version += 1;
170-
}
171-
172-
assert!(unsafe { xgboost_sys::RabitGetWorldSize() == 1 || version == xgboost_sys::RabitVersionNumber() });
173-
151+
let bst = Booster::new_with_cached_dmats(&params.booster_params, &cached_dmats)?;
152+
for i in 0..params.boost_rounds as i32 {
174153
if let Some(eval_sets) = params.evaluation_sets {
175154
let mut dmat_eval_results = bst.eval_set(eval_sets, i)?;
176155

@@ -203,10 +182,6 @@ impl Booster {
203182
}
204183
println!();
205184
}
206-
207-
// do checkpoint after evaluation, in case evaluation also updates booster.
208-
let _ = bst.save_rabit_checkpoint();
209-
version += 1;
210185
}
211186

212187
Ok(bst)
@@ -365,13 +340,16 @@ impl Booster {
365340
let mut out_len = 0;
366341
let mut out = ptr::null_mut();
367342
xgb_call!(xgboost_sys::XGBoosterGetAttrNames(self.handle, &mut out_len, &mut out))?;
368-
369-
let out_ptr_slice = unsafe { slice::from_raw_parts(out, out_len as usize) };
370-
let out_vec = out_ptr_slice
371-
.iter()
372-
.map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() })
373-
.collect();
374-
Ok(out_vec)
343+
if out_len > 0 {
344+
let out_ptr_slice = unsafe { slice::from_raw_parts(out, out_len as usize) };
345+
let out_vec = out_ptr_slice
346+
.iter()
347+
.map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() })
348+
.collect();
349+
Ok(out_vec)
350+
} else {
351+
Ok(Vec::new())
352+
}
375353
}
376354

377355
/// Predict results for given data.
@@ -517,7 +495,7 @@ impl Booster {
517495
Err(err) => return Err(XGBError::new(err.to_string())),
518496
};
519497

520-
let file_path = tmp_dir.path().join("fmap.txt");
498+
let file_path = tmp_dir.path().join("fmap.json");
521499
let mut file: File = match File::create(&file_path) {
522500
Ok(f) => f,
523501
Err(err) => return Err(XGBError::new(err.to_string())),
@@ -551,24 +529,18 @@ impl Booster {
551529
&mut out_dump_array
552530
))?;
553531

554-
let out_ptr_slice = unsafe { slice::from_raw_parts(out_dump_array, out_len as usize) };
555-
let out_vec: Vec<String> = out_ptr_slice
556-
.iter()
557-
.map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() })
558-
.collect();
532+
if out_len > 0 {
533+
let out_ptr_slice = unsafe { slice::from_raw_parts(out_dump_array, out_len as usize) };
534+
let out_vec: Vec<String> = out_ptr_slice
535+
.iter()
536+
.map(|str_ptr| unsafe { ffi::CStr::from_ptr(*str_ptr).to_str().unwrap().to_owned() })
537+
.collect();
559538

560-
assert_eq!(out_len as usize, out_vec.len());
561-
Ok(out_vec.join("\n"))
562-
}
563-
564-
pub(crate) fn load_rabit_checkpoint(&self) -> XGBResult<i32> {
565-
let mut version = 0;
566-
xgb_call!(xgboost_sys::XGBoosterLoadRabitCheckpoint(self.handle, &mut version))?;
567-
Ok(version)
568-
}
569-
570-
pub(crate) fn save_rabit_checkpoint(&self) -> XGBResult<()> {
571-
xgb_call!(xgboost_sys::XGBoosterSaveRabitCheckpoint(self.handle))
539+
assert_eq!(out_len as usize, out_vec.len());
540+
Ok(out_vec.join("\n"))
541+
} else {
542+
Ok(String::new())
543+
}
572544
}
573545

574546
pub fn set_param(&mut self, name: &str, value: &str) -> XGBResult<()> {
@@ -721,7 +693,7 @@ impl fmt::Display for FeatureType {
721693
#[cfg(test)]
722694
mod tests {
723695
use super::*;
724-
use parameters::{self, learning, tree};
696+
use crate::parameters::{self, learning, tree};
725697

726698
fn read_train_matrix() -> XGBResult<DMatrix> {
727699
DMatrix::load(r#"{"uri": "xgboost-sys/xgboost/demo/data/agaricus.txt.train?format=libsvm"}"#)
@@ -739,7 +711,6 @@ mod tests {
739711
assert!(res.is_ok());
740712
}
741713

742-
743714
#[test]
744715
fn get_set_attr() {
745716
let mut booster = load_test_booster();

src/dmatrix.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,11 @@ impl DMatrix {
314314
&mut out_dptr
315315
))?;
316316

317-
Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) })
317+
if out_len > 0 {
318+
Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_float, out_len as usize) })
319+
} else {
320+
Err(XGBError::new("error"))
321+
}
318322
}
319323

320324
fn set_float_info(&mut self, field: &str, array: &[f32]) -> XGBResult<()> {
File renamed without changes.

xgboost-sys/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ license = "MIT"
88
repository = "https://github.com/davechallis/rust-xgboost"
99
description = "Native bindings to the xgboost library"
1010
readme = "README.md"
11+
edition = "2021"
1112

1213
[dependencies]
1314
libc = "0.2"
1415

1516
[build-dependencies]
16-
bindgen = "0.69"
17+
bindgen = "0.71"
1718
cmake = "0.1"
1819

1920
[features]

xgboost-sys/build.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,21 @@ fn main() {
2525
dst.define("BUILD_STATIC_LIB", "ON").define("CMAKE_CXX_STANDARD", "17");
2626

2727
// CMake
28+
let mut dst = Config::new(&xgb_root);
29+
let mut dst = dst.define("BUILD_STATIC_LIB", "ON");
30+
2831
#[cfg(feature = "cuda")]
29-
dst.define("USE_CUDA", "ON")
32+
let mut dst = dst
33+
.define("USE_CUDA", "ON")
3034
.define("BUILD_WITH_CUDA", "ON")
3135
.define("BUILD_WITH_CUDA_CUB", "ON");
3236

3337
#[cfg(target_os = "macos")]
3438
{
3539
let path = PathBuf::from("/opt/homebrew/"); // check for m1 vs intel config
3640
if let Ok(_dir) = std::fs::read_dir(&path) {
37-
dst.define("CMAKE_C_COMPILER", "/opt/homebrew/opt/llvm/bin/clang")
41+
dst = dst
42+
.define("CMAKE_C_COMPILER", "/opt/homebrew/opt/llvm/bin/clang")
3843
.define("CMAKE_CXX_COMPILER", "/opt/homebrew/opt/llvm/bin/clang++")
3944
.define("OPENMP_LIBRARIES", "/opt/homebrew/opt/llvm/lib")
4045
.define("OPENMP_INCLUDES", "/opt/homebrew/opt/llvm/include");
@@ -54,9 +59,11 @@ fn main() {
5459

5560
#[cfg(feature = "cuda")]
5661
let bindings = bindings.clang_arg("-I/usr/local/cuda/include");
57-
let bindings = bindings.generate().expect("Unable to generate bindings.");
62+
let bindings = bindings
63+
.generate()
64+
.expect("Unable to generate bindings.");
5865

59-
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
66+
let out_path = PathBuf::from(out_dir);
6067
bindings
6168
.write_to_file(out_path.join("bindings.rs"))
6269
.expect("Couldn't write bindings.");

xgboost-sys/xgboost

Submodule xgboost updated 797 files

0 commit comments

Comments
 (0)