Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/src/gandiva/function_registry_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
kResultNullIfNull, "gdv_mask_last_n_utf8_int32",
NativeFunction::kNeedsContext),

NativeFunction("find_in_set", {}, DataTypeVector{utf8(), utf8()}, int32(),
kResultNullIfNull, "find_in_set_utf8_utf8",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("instr", {}, DataTypeVector{utf8(), utf8()}, int32(),
kResultNullIfNull, "instr_utf8"),

Expand Down
41 changes: 41 additions & 0 deletions cpp/src/gandiva/precompiled/string_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3034,4 +3034,45 @@ int32_t instr_utf8(const char* string, int32_t string_len, const char* substring
}
return 0;
}

FORCE_INLINE
int32_t find_in_set_utf8_utf8(int64_t context, const char* to_find, int32_t to_find_len,
const char* string_list, int32_t string_list_len) {
// Return 0 if to search entry have commas
if (is_substr_utf8_utf8(to_find, to_find_len, reinterpret_cast<const char*>(","), 1)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are looking for a single unicode codepoint below 128, you can probably do this faster using memchr.

return 0;
}

int32_t cur_pos_in_array = 0;
int32_t cur_length = 0;
bool matching = true;

for (int i = 0; i < string_list_len; i++) {
if (string_list[i] == ',') {
cur_pos_in_array++;
if (matching && cur_length == to_find_len) {
return cur_pos_in_array;
} else {
matching = true;
cur_length = 0;
}
} else {
if (cur_length + 1 <= string_list_len) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this condition? In which situation is it false?

if (!matching || (memcmp(string_list + i, to_find + cur_length, 1))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why call memcmp if you are only comparing a single byte?

matching = false;
}
} else {
matching = false;
}
cur_length++;
}
}

if (matching && cur_length == to_find_len) {
cur_pos_in_array++;
return cur_pos_in_array;
} else {
return 0;
}
}
} // extern "C"
26 changes: 26 additions & 0 deletions cpp/src/gandiva/precompiled/string_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2702,4 +2702,30 @@ TEST(TestStringOps, TestInstr) {
result = instr_utf8(s1.c_str(), s1_len, s2.c_str(), s2_len);
EXPECT_EQ(result, 8);
}

TEST(TestStringOps, TestFindInSet) {
gandiva::ExecutionContext ctx;
auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);
int32_t result;
result = find_in_set_utf8_utf8(ctx_ptr, "EE", 2, ",A,B,C,D,EE,F", 13);
EXPECT_EQ(result, 6);
result = find_in_set_utf8_utf8(ctx_ptr, "A", 1, "A,B,C,D,EE,F", 12);
EXPECT_EQ(result, 1);
result = find_in_set_utf8_utf8(ctx_ptr, "AAAB", 4, "A,B,C,D,EE,F", 12);
EXPECT_EQ(result, 0);
result = find_in_set_utf8_utf8(ctx_ptr, "E,E", 3, "A,B,C,D,EE,F", 12);
EXPECT_EQ(result, 0);
result = find_in_set_utf8_utf8(ctx_ptr, "C", 1, "A,B,,,,,,,C,,,,,", 16);
EXPECT_EQ(result, 9);
result = find_in_set_utf8_utf8(ctx_ptr, "", 0, "", 0);
EXPECT_EQ(result, 1);
result = find_in_set_utf8_utf8(ctx_ptr, "", 0, " ", 1);
EXPECT_EQ(result, 0);
result = find_in_set_utf8_utf8(ctx_ptr, " ", 1, "", 0);
EXPECT_EQ(result, 0);
result = find_in_set_utf8_utf8(ctx_ptr, "", 0, "a,b,,c,d", 8);
EXPECT_EQ(result, 3);
result = find_in_set_utf8_utf8(ctx_ptr, "", 0, ",", 1);
EXPECT_EQ(result, 1);
}
} // namespace gandiva
3 changes: 3 additions & 0 deletions cpp/src/gandiva/precompiled/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -829,4 +829,7 @@ const char* elt_int32_utf8_utf8_utf8_utf8_utf8(
int32_t instr_utf8(const char* string, int32_t string_len, const char* substring,
int32_t substring_len);

int32_t find_in_set_utf8_utf8(int64_t context, const char* to_find, int32_t to_find_len,
const char* string_list, int32_t string_list_len);

} // extern "C"
42 changes: 42 additions & 0 deletions cpp/src/gandiva/tests/projector_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2824,6 +2824,48 @@ TEST_F(TestProjector, TestInstr) {
// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0));
}
TEST_F(TestProjector, TestFindInSet) {
// schema for input fields
auto field0 = field("f0", arrow::utf8());
auto field1 = field("f1", arrow::utf8());
auto schema = arrow::schema({field0, field1});

// output fields
auto output_find_in_set = field("find_in_set_output", int32());

// Build expression
auto find_in_set_expr = TreeExprBuilder::MakeExpression("find_in_set", {field0, field1},
output_find_in_set);

std::shared_ptr<Projector> projector;
auto status =
Projector::Make(schema, {find_in_set_expr}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());

// Create a row-batch with some sample data
int num_records = 8;
auto array0 =
MakeArrowArrayUtf8({"ABC", "...", "!C", "MORE", "学路", "b大", "路", "学路"},
{true, true, true, true, true, true, true, true});
auto array1 = MakeArrowArrayUtf8(
{"ZXL,KMY,DDD,ABC", "!!!,@@@,###,...,,,", ",A,,,,,,,,!C,,,,,", "MORE",
"学路,学路,学路,123", "大b,,,b大", "大b,,学路,学,b大", "学路"},
{true, true, true, true, true, true, true, true});
// expected output
auto exp_res = MakeArrowArrayInt32({4, 4, 10, 1, 1, 4, 0, 1},
{true, true, true, true, true, true, true, true});

// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});

// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());

// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp_res, outputs.at(0));
}

TEST_F(TestProjector, TestNextDay) {
// schema for input fields
Expand Down