Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
}
}

int XGBRegisterLogCallback(void (*callback)(const char*)) {
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
API_BEGIN();
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
registry->Register(callback);
API_END();
}

int XGDMatrixCreateFromFile(const char *fname,
int silent,
DMatrixHandle *out) {
XGB_DLL int XGDMatrixCreateFromFile(const char *fname,
int silent,
DMatrixHandle *out) {
API_BEGIN();
bool load_row_split = false;
if (rabit::IsDistributed()) {
Expand All @@ -60,7 +60,7 @@ int XGDMatrixCreateFromFile(const char *fname,
}

XGB_DLL int XGDMatrixCreateFromDataIter(
void *data_handle, // a Java interator
void *data_handle, // a Java iterator
XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp
const char *cache_info, DMatrixHandle *out) {
API_BEGIN();
Expand All @@ -69,7 +69,8 @@ XGB_DLL int XGDMatrixCreateFromDataIter(
if (cache_info != nullptr) {
scache = cache_info;
}
xgboost::data::IteratorAdapter adapter(data_handle, callback);
xgboost::data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
XGBoostBatchCSR> adapter(data_handle, callback);
*out = new std::shared_ptr<DMatrix> {
DMatrix::Create(
&adapter, std::numeric_limits<float>::quiet_NaN(),
Expand Down
3 changes: 2 additions & 1 deletion src/c_api/c_api_error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* \brief C error handling
*/
#include <dmlc/thread_local.h>
#include "xgboost/c_api.h"
#include "./c_api_error.h"

struct XGBAPIErrorEntry {
Expand All @@ -12,7 +13,7 @@ struct XGBAPIErrorEntry {

using XGBAPIErrorStore = dmlc::ThreadLocalStore<XGBAPIErrorEntry>;

const char *XGBGetLastError() {
XGB_DLL const char *XGBGetLastError() {
return XGBAPIErrorStore::Get()->last_error.c_str();
}

Expand Down
1 change: 0 additions & 1 deletion src/c_api/c_api_error.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <xgboost/c_api.h>

/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
Expand Down
3 changes: 2 additions & 1 deletion src/data/adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/span.h"
#include "xgboost/c_api.h"

#include "array_interface.h"
#include "../c_api/c_api_error.h"

namespace xgboost {
Expand Down Expand Up @@ -496,6 +496,7 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {

/*! \brief Data iterator that takes callback to return data, used in JVM package for
* accepting data iterator. */
template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
public:
IteratorAdapter(DataIterHandle data_handle,
Expand Down
11 changes: 7 additions & 4 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "dmlc/io.h"
#include "xgboost/data.h"
#include "xgboost/c_api.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/logging.h"
#include "xgboost/version_config.h"
Expand Down Expand Up @@ -533,7 +534,7 @@ DMatrix* DMatrix::Load(const std::string& uri,

template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size ) {
const std::string& cache_prefix, size_t page_size) {
if (cache_prefix.length() == 0) {
// Data split mode is fixed to be row right now.
return new data::SimpleDMatrix(adapter, missing, nthread);
Expand Down Expand Up @@ -563,9 +564,11 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
template DMatrix* DMatrix::Create<data::FileAdapter>(
data::FileAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::IteratorAdapter>(
data::IteratorAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix *
DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
XGBoostBatchCSR> *adapter,
float missing, int nthread, const std::string &cache_prefix,
size_t page_size);

SparsePage SparsePage::GetTranspose(int num_columns) const {
SparsePage transpose;
Expand Down
2 changes: 1 addition & 1 deletion src/data/data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "array_interface.h"
#include "../common/device_helpers.cuh"
#include "device_adapter.cuh"
#include "device_dmatrix.h"
#include "simple_dmatrix.h"

namespace xgboost {

Expand Down
16 changes: 12 additions & 4 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
* \brief the input data structure for gradient boosting
* \author Tianqi Chen
*/
#include "./simple_dmatrix.h"
#include <xgboost/data.h>
#include <vector>
#include <limits>
#include <algorithm>

#include "xgboost/data.h"
#include "xgboost/c_api.h"

#include "simple_dmatrix.h"
#include "./simple_batch_iterator.h"
#include "../common/random.h"
#include "adapter.h"
Expand Down Expand Up @@ -195,7 +201,9 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(IteratorAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>
*adapter,
float missing, int nthread);
} // namespace data
} // namespace xgboost
22 changes: 11 additions & 11 deletions tests/cpp/data/test_adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
}

// A mock for JVM data iterator.
class DataIterForTest {
class CSRIterForTest {
std::vector<float> data_ {1, 2, 3, 4, 5};
std::vector<std::remove_pointer<decltype(std::declval<XGBoostBatchCSR>().index)>::type>
feature_idx_ {0, 1, 0, 1, 1};
Expand Down Expand Up @@ -100,16 +100,16 @@ class DataIterForTest {
size_t Iter() const { return iter_; }
};

size_t constexpr DataIterForTest::kCols;
size_t constexpr CSRIterForTest::kCols;

int SetDataNextForTest(DataIterHandle data_handle,
XGBCallbackSetData *set_function,
DataHolderHandle set_function_handle) {
int CSRSetDataNextForTest(DataIterHandle data_handle,
XGBCallbackSetData *set_function,
DataHolderHandle set_function_handle) {
size_t constexpr kIters { 2 };
auto iter = static_cast<DataIterForTest *>(data_handle);
auto iter = static_cast<CSRIterForTest *>(data_handle);
if (iter->Iter() < kIters) {
auto batch = iter->Next();
batch.columns = DataIterForTest::kCols;
batch.columns = CSRIterForTest::kCols;
set_function(set_function_handle, batch);
return 1;
} else {
Expand All @@ -118,15 +118,15 @@ int SetDataNextForTest(DataIterHandle data_handle,
}

TEST(Adapter, IteratorAdaper) {
DataIterForTest iter;
data::IteratorAdapter adapter{&iter, SetDataNextForTest};
CSRIterForTest iter;
data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
XGBoostBatchCSR> adapter{&iter, CSRSetDataNextForTest};
constexpr size_t kRows { 6 };

std::unique_ptr<DMatrix> data {
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)
};
ASSERT_EQ(data->Info().num_col_, DataIterForTest::kCols);
ASSERT_EQ(data->Info().num_col_, CSRIterForTest::kCols);
ASSERT_EQ(data->Info().num_row_, kRows);
}

} // namespace xgboost