Skip to content

Commit 856bebf

Browse files
author
Colin Lin
committed
Update Segment Tree
1 parent 3ac3673 commit 856bebf

File tree

6 files changed

+111
-77
lines changed

6 files changed

+111
-77
lines changed

Algon/dstruct/SegmentTree.h

Lines changed: 78 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,96 @@
11
#pragma once
22
#include <functional>
3-
#include <iterator>
43
#include <vector>
54
#include "../utility/Exception.h"
65
#include "../common.h"
7-
#include "../utility/Util.h"
86
namespace colinli {
97
namespace algon {
108

11-
using std::vector;
9+
using std::vector;
1210

13-
template<typename TValue, typename TKey>
14-
class SegmentTree
15-
{
16-
public:
17-
typedef std::function<TValue(TKey)> Transformer;
18-
typedef std::function<TValue(TValue, TValue)> Reducer;
19-
SegmentTree(vector<TKey>& collection, Reducer reducer, Transformer trans) {
20-
vector_ = collection;
21-
reducer_ = reducer;
22-
trans_ = trans;
23-
root_ = build_tree(0, collection.size(), reducer, trans);
24-
}
25-
SegmentTree(vector<TKey>&& collection, Reducer reducer, Transformer trans) {
26-
vector_ = collection;
27-
reducer_ = reducer;
28-
trans_ = trans;
29-
root_ = build_tree(0, collection.size(), reducer, trans);
30-
}
11+
#define LCHILD(i) (((i)<<1)+1)
12+
#define RCHILD(i) (((i)<<1)+2)
3113

32-
struct Node
33-
{
34-
Node* left;
35-
Node* right;
36-
TValue value;
37-
explicit Node(TValue v) :value(v), left(nullptr), right(nullptr){}
38-
explicit Node() :left(nullptr), right(nullptr), value(TValue()){}
39-
};
40-
private:
41-
Node* root_;
42-
vector<TKey> vector_;
43-
Transformer trans_;
44-
Reducer reducer_;
14+
template <typename TValue, typename TKey>
15+
class SegmentTree
16+
{
17+
public:
18+
typedef std::function<TValue(TKey)> Transformer;
19+
typedef std::function<TValue(TValue, TValue)> Reducer;
4520

46-
Node* build_tree(size_t begin,size_t end) {
47-
if (begin == end) {
48-
return NULL;
49-
}
50-
auto mid = begin + (end - begin) / 2;
51-
Node* node = new Node(trans_(vector_.at(mid)));
52-
node->left = build_tree(begin, mid);
53-
node->right = build_tree( std::next(mid), end);
54-
if (node->left && node->right) {
55-
node->value = reducer_(node->left->value, node->right->value);
56-
}
57-
return node;
58-
}
21+
SegmentTree(vector<TKey>& collection, Reducer reducer, Transformer trans, TValue initVal):
22+
vector_(collection),
23+
N(collection.size()),
24+
trans_(trans),
25+
reducer_(reducer),
26+
initVal_(initVal) {
27+
node_values_ = new TValue[(N << 2) + 1];
28+
build_tree(0, 0, N-1);
29+
}
30+
31+
void Update(size_t pos, TKey newkey) {
32+
update(0, 0, N - 1, pos, newkey);
33+
}
34+
35+
TValue Query(size_t begin, size_t end) {
36+
return query(0, 0, N - 1, begin, end);
37+
}
38+
39+
virtual ~SegmentTree() {
40+
delete[] node_values_;
41+
}
42+
43+
private:
44+
vector<TKey> vector_;
45+
size_t N;
46+
Transformer trans_;
47+
Reducer reducer_;
48+
TValue initVal_;
49+
TValue* node_values_;
5950

60-
void update(Node* node ,size_t begin, size_t end, size_t pos, TKey newkey) {
61-
if (begin == end || node == NULL) {
62-
return;
63-
}
64-
if (++begin == end) {
65-
node->value = trans_(newkey);
66-
}
67-
else {
68-
auto mid = begin + (begin - end) / 2;
69-
if ( pos <= mid ) {
70-
update(node->left, begin, mid, pos, newkey);
71-
update(node->right, mid+1 , end, pos, newkey);
72-
CHECK_THROW(node->left && node->right, NullArgumentError);
73-
node->value = reducer_(node->left->value, node->right->value);
74-
}
75-
}
51+
// build substree from x[begin...end] (inclusive range)
52+
void build_tree(size_t node_idx, size_t begin, size_t end) {
53+
//leaf node
54+
if (begin == end) {
55+
node_values_[node_idx] = trans_(vector_.at(begin));
56+
return;
7657
}
77-
};
58+
auto mid = begin + (end - begin) / 2;
59+
build_tree(LCHILD(node_idx), begin, mid);
60+
build_tree(RCHILD(node_idx), mid + 1, end);
61+
node_values_[node_idx] = reducer_(node_values_[LCHILD(node_idx)], node_values_[RCHILD(node_idx)]);
62+
}
7863

64+
void update(size_t node_idx, size_t begin, size_t end, size_t pos, TKey newkey) {
65+
if (begin == end) {
66+
node_values_[node_idx] = trans_(newkey);
67+
return;
68+
}
69+
auto mid = begin + (end - begin) / 2;
70+
if (pos <= mid) {
71+
update(LCHILD(node_idx), begin, mid, pos, newkey);
72+
}
73+
else {
74+
update(RCHILD(node_idx), mid + 1, end, pos, newkey);
75+
}
76+
node_values_[node_idx] = reducer_(node_values_[LCHILD(node_idx)], node_values_[RCHILD(node_idx)]);
77+
}
7978

79+
TValue query(size_t node_idx, size_t node_begin, size_t node_end, size_t q_begin, size_t q_end) {
80+
if (node_begin >= q_begin && node_end <= q_end) {
81+
return node_values_[node_idx];
82+
}
83+
auto mid = node_begin + (node_end - node_begin) / 2;
84+
TValue result = initVal_;
85+
if (mid > q_begin && node_begin <= q_end) {
86+
result = reducer_(result, query(LCHILD(node_idx), node_begin, mid, q_begin, q_end));
87+
}
88+
if (node_end >= q_begin && mid + 1 <= q_end) {
89+
result = reducer_(result, query(RCHILD(node_idx), mid + 1, node_end, q_begin, q_end));
90+
}
91+
return result;
92+
}
93+
};
8094

8195
}// end of algon ns
8296
}// end of colinli ns

Algon/utility/Util.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,6 @@ namespace colinli {
1515
uint32_t sdigits10(int64_t v);
1616
int ll2string(char *dst, size_t dstlen, long long svalue);
1717

18-
19-
/// <summary>
20-
/// Identity functor
21-
/// </summary>
22-
struct identity {
23-
template<typename U>
24-
constexpr auto operator()(U&& v) const noexcept
25-
-> decltype(std::forward<U>(v))
26-
{
27-
return std::forward<U>(v);
28-
}
29-
};
3018

3119

3220
/// <summary>

tests/SegmentTreeTest.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include "stdafx.h"
2+
#include "CppUnitTest.h"
3+
4+
using namespace Microsoft::VisualStudio::CppUnitTestFramework;
5+
using namespace colinli::algon;
6+
using namespace std;
7+
namespace tests
8+
{
9+
TEST_CLASS(SegmentTreeTest)
10+
{
11+
public:
12+
13+
TEST_METHOD(TestQueryTree)
14+
{
15+
auto min_reducer = [](int x, int y) throw() { return x < y ? x : y; };
16+
auto identity = [](int x) {return x; };
17+
int initv = INT_MAX;
18+
vector<int> x{ 1, 2, 3, 4, 5, 6, 7 };
19+
SegmentTree<int, int> tree(x, min_reducer, identity, initv);
20+
Assert::AreEqual(3, tree.Query(2, 5));
21+
tree.Update(3, -1);
22+
Assert::AreEqual(-1, tree.Query(1, 6));
23+
}
24+
25+
};
26+
}

tests/stdafx.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
// Headers for CppUnitTest
1111
#include "CppUnitTest.h"
1212
#include <string>
13+
#include <vector>
1314
// TODO: reference additional headers your program requires here
1415
#include "../Algon/dstruct/SkipList.h"
1516
#include "../Algon/dstruct/DisjointSet.h"
16-
#include "../Algon/algorithm/Numeric.h"
17+
#include "../Algon/algorithm/Numeric.h"
18+
#include "../Algon/dstruct/SegmentTree.h"

tests/tests.vcxproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
<ItemGroup>
8989
<ClCompile Include="DisjointSetTest.cpp" />
9090
<ClCompile Include="NumericTest.cpp" />
91+
<ClCompile Include="SegmentTreeTest.cpp" />
9192
<ClCompile Include="stdafx.cpp">
9293
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">Create</PrecompiledHeader>
9394
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">Create</PrecompiledHeader>

tests/tests.vcxproj.filters

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,8 @@
3535
<ClCompile Include="NumericTest.cpp">
3636
<Filter>Source Files</Filter>
3737
</ClCompile>
38+
<ClCompile Include="SegmentTreeTest.cpp">
39+
<Filter>Source Files</Filter>
40+
</ClCompile>
3841
</ItemGroup>
3942
</Project>

0 commit comments

Comments
 (0)