Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -253,22 +253,14 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
<< "Weight size should equal to number of groups.";
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
auto it =
thrust::upper_bound(thrust::seq,
d_group_ptr.cbegin(), d_group_ptr.cend(),
ridx + base_rowid) - 1;
bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it);
d_temp_weights[idx] = weights[group];
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid);
d_temp_weights[idx] = weights[group_idx];
});
} else {
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
d_temp_weights[idx] = weights[ridx + base_rowid];
});
}
Expand Down
107 changes: 28 additions & 79 deletions src/common/hist_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,8 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
thrust::make_constant_iterator(0lu),
[=]__device__(size_t idx) -> float {
auto ridx = batch.GetElement(idx).row_idx;
auto it = thrust::upper_bound(thrust::seq,
d_group_ptr.cbegin(), d_group_ptr.cend(),
ridx) - 1;
bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it);
return weights[group];
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
return weights[group_idx];
});
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
weight_iter + begin, weight_iter + end,
Expand Down Expand Up @@ -277,98 +274,50 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
}

template <typename AdapterT>
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
float missing,
size_t sketch_batch_num_elements = 0) {
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, adapter->NumRows());
CHECK(adapter->NumRows() != data::kAdapterUnknownSize);
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize);

adapter->BeforeFirst();
adapter->Next();
auto& batch = adapter->Value();
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
adapter->NumRows(), adapter->NumColumns(), std::numeric_limits<size_t>::max(),
adapter->DeviceIdx(),
num_cuts_per_feature, false);

// Enforce single batch
CHECK(!adapter->Next());

HistogramCuts cuts;
SketchContainer sketch_container(num_bins, adapter->NumColumns(),
adapter->NumRows(), adapter->DeviceIdx());

for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
auto const& batch = adapter->Value();
ProcessSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(),
begin, end, missing, &sketch_container, num_cuts_per_feature);
}

sketch_container.MakeCuts(&cuts);
return cuts;
}

/*
* \brief Perform sketching on GPU.
*
* \param batch A batch from adapter.
* \param num_bins Bins per column.
* \param info Metainfo used for sketching.
* \param missing Floating point value that represents invalid value.
* \param sketch_container Container for output sketch.
* \param sketch_batch_num_elements Number of element per-sliding window, use it only for
* testing.
*/
template <typename Batch>
void AdapterDeviceSketch(Batch batch, int num_bins,
MetaInfo const& info,
float missing, SketchContainer* sketch_container,
size_t sketch_batch_num_elements = 0) {
size_t num_rows = batch.NumRows();
size_t num_cols = batch.NumCols();
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
int32_t device = sketch_container->DeviceIdx();
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, false);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessSlidingWindow(batch, device, num_cols,
begin, end, missing, sketch_container, num_cuts_per_feature);
}
}

/*
* \brief Perform weighted sketching on GPU.
*
* When weight in info is empty, this function is equivalent to unweighted version.
*/
template <typename Batch>
void AdapterDeviceSketchWeighted(Batch batch, int num_bins,
MetaInfo const& info,
float missing, SketchContainer* sketch_container,
size_t sketch_batch_num_elements = 0) {
if (info.weights_.Size() == 0) {
return AdapterDeviceSketch(batch, num_bins, missing, sketch_container, sketch_batch_num_elements);
}

size_t num_rows = batch.NumRows();
size_t num_cols = batch.NumCols();
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
int32_t device = sketch_container->DeviceIdx();
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, true);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessWeightedSlidingWindow(batch, info,
num_cuts_per_feature,
CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end,
sketch_container);
bool weighted = info.weights_.Size() != 0;

if (weighted) {
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, true);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessWeightedSlidingWindow(batch, info,
num_cuts_per_feature,
CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end,
sketch_container);
}
} else {
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, false);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessSlidingWindow(batch, device, num_cols,
begin, end, missing, sketch_container, num_cuts_per_feature);
}
}
}
} // namespace common
Expand Down
17 changes: 7 additions & 10 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,13 @@ class CutsBuilder {
explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {}
virtual ~CutsBuilder() = default;

static uint32_t SearchGroupIndFromRow(
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) {
using KIt = std::vector<bst_uint>::const_iterator;
KIt res = std::lower_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid);
// Cannot use CHECK_NE because it will try to print the iterator.
bool const found = res != group_ptr.cend() - 1;
if (!found) {
LOG(FATAL) << "Row " << base_rowid << " does not lie in any group!";
}
uint32_t group_ind = std::distance(group_ptr.cbegin(), res);
static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
size_t const base_rowid) {
CHECK_LT(base_rowid, group_ptr.back())
<< "Row: " << base_rowid << " is not found in any group.";
auto it =
std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid);
bst_group_t group_ind = it - group_ptr.cbegin() - 1;
return group_ind;
}

Expand Down
24 changes: 0 additions & 24 deletions src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,30 +486,6 @@ class QuantileSketchTemplate {
this->data = dmlc::BeginPtr(space);
}
}
/*!
* \brief set the space to be merge of all Summary arrays
* \param begin beginning position in the summary array
* \param end ending position in the Summary array
*/
inline void SetMerge(const Summary *begin,
const Summary *end) {
CHECK(begin < end) << "can not set combine to empty instance";
size_t len = end - begin;
if (len == 1) {
this->Reserve(begin[0].size);
this->CopyFrom(begin[0]);
} else if (len == 2) {
this->Reserve(begin[0].size + begin[1].size);
this->SetMerge(begin[0], begin[1]);
} else {
// recursive merge
SummaryContainer lhs, rhs;
lhs.SetCombine(begin, begin + len / 2);
rhs.SetCombine(begin + len / 2, end);
this->Reserve(lhs.size + rhs.size);
this->SetCombine(lhs, rhs);
}
}
/*!
* \brief do elementwise combination of summary array
* this[i] = combine(this[i], src[i]) for each i
Expand Down
25 changes: 0 additions & 25 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -228,31 +228,6 @@ void WriteNullValues(EllpackPageImpl* dst, int device_idx,
});
}

template <typename AdapterT>
EllpackPageImpl::EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread,
int max_bin, common::Span<size_t> row_counts_span,
size_t row_stride) {
common::HistogramCuts cuts =
common::AdapterDeviceSketch(adapter, max_bin, missing);
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
auto& batch = adapter->Value();

*this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride,
adapter->NumRows());
CopyDataToEllpack(batch, this, adapter->DeviceIdx(), missing);
WriteNullValues(this, adapter->DeviceIdx(), row_counts_span);
}

#define ELLPACK_SPECIALIZATION(__ADAPTER_T) \
template EllpackPageImpl::EllpackPageImpl( \
__ADAPTER_T* adapter, float missing, bool is_dense, int nthread, int max_bin, \
common::Span<size_t> row_counts_span, \
size_t row_stride);

ELLPACK_SPECIALIZATION(data::CudfAdapter)
ELLPACK_SPECIALIZATION(data::CupyAdapter)


template <typename AdapterBatch>
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
bool is_dense, int nthread,
Expand Down
6 changes: 0 additions & 6 deletions src/data/ellpack_page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,6 @@ class EllpackPageImpl {
*/
explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm);

template <typename AdapterT>
explicit EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense, int nthread,
int max_bin,
common::Span<size_t> row_counts_span,
size_t row_stride);

template <typename AdapterBatch>
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread,
common::Span<size_t> row_counts_span,
Expand Down
4 changes: 2 additions & 2 deletions src/data/iterative_device_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
auto* p_sketch = &sketch_containers.back();
proxy->Info().weights_.SetDevice(device);
Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
proxy->Info(), missing, p_sketch);
common::AdapterDeviceSketch(value, batch_param_.max_bin,
proxy->Info(), missing, p_sketch);
});

auto batch_rows = num_rows();
Expand Down
7 changes: 6 additions & 1 deletion tests/cpp/common/test_hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,12 @@ TEST(CutsBuilder, SearchGroupInd) {
group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5);
ASSERT_EQ(group_ind, 2);

EXPECT_ANY_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17));
p_mat->Info().Validate(-1);
EXPECT_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17),
dmlc::Error);

std::vector<bst_uint> group_ptr {0, 1, 2};
CHECK_EQ(CutsBuilder::SearchGroupIndFromRow(group_ptr, 1), 1);
}

TEST(SparseCuts, SingleThreadedBuild) {
Expand Down
Loading