Skip to content

Commit 7e8cbec

Browse files
committed
Fix replace SparkSQL function (4922)
1 parent 8c9af76 commit 7e8cbec

File tree

5 files changed

+62
-8
lines changed

5 files changed

+62
-8
lines changed

velox/functions/lib/string/StringCore.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ inline int64_t findNthInstanceByteIndexFromEnd(
299299
/// each charecter. When inputString is empty results is empty.
300300
/// replace("", "", "x") = ""
301301
/// replace("aa", "", "x") = "xaxax"
302+
template <bool ignoreEmptyReplaced = false>
302303
inline static size_t replace(
303304
char* outputString,
304305
const std::string_view& inputString,
@@ -309,6 +310,13 @@ inline static size_t replace(
309310
return 0;
310311
}
311312

313+
if (ignoreEmptyReplaced && replaced.size() == 0) {
314+
if (!inPlace) {
315+
std::memcpy(outputString, inputString.data(), inputString.size());
316+
}
317+
return inputString.size();
318+
}
319+
312320
size_t readPosition = 0;
313321
size_t writePosition = 0;
314322
// Copy needed in out of place replace, and when replaced and replacement are

velox/functions/lib/string/StringImpl.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ stringPosition(const T& string, const T& subString, int64_t instance = 0) {
209209

210210
/// Replace replaced with replacement in inputString and write results to
211211
/// outputString.
212-
template <typename TOutString, typename TInString>
212+
template <
213+
typename TOutString,
214+
typename TInString,
215+
bool ignoreEmptyReplaced = false>
213216
FOLLY_ALWAYS_INLINE void replace(
214217
TOutString& outputString,
215218
const TInString& inputString,
@@ -226,7 +229,7 @@ FOLLY_ALWAYS_INLINE void replace(
226229
(inputString.size() / replaced.size()) * replacement.size());
227230
}
228231

229-
auto outputSize = stringCore::replace(
232+
auto outputSize = stringCore::replace<ignoreEmptyReplaced>(
230233
outputString.data(),
231234
std::string_view(inputString.data(), inputString.size()),
232235
std::string_view(replaced.data(), replaced.size()),
@@ -237,14 +240,17 @@ FOLLY_ALWAYS_INLINE void replace(
237240
}
238241

239242
/// Replace replaced with replacement in place in string.
240-
template <typename TInOutString, typename TInString>
243+
template <
244+
typename TInOutString,
245+
typename TInString,
246+
bool ignoreEmptyReplaced = false>
241247
FOLLY_ALWAYS_INLINE void replaceInPlace(
242248
TInOutString& string,
243249
const TInString& replaced,
244250
const TInString& replacement) {
245251
assert(replacement.size() <= replaced.size() && "invalid inplace replace");
246252

247-
auto outputSize = stringCore::replace(
253+
auto outputSize = stringCore::replace<ignoreEmptyReplaced>(
248254
string.data(),
249255
std::string_view(string.data(), string.size()),
250256
std::string_view(replaced.data(), replaced.size()),

velox/functions/prestosql/StringFunctions.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ class ConcatFunction : public exec::VectorFunction {
284284
* If search is an empty string, inserts replace in front of every character
285285
*and at the end of the string.
286286
**/
287-
class Replace : public exec::VectorFunction {
287+
template <bool ignoreEmptyReplaced>
288+
class ReplaceBase : public exec::VectorFunction {
288289
private:
289290
template <
290291
typename StringReader,
@@ -298,7 +299,10 @@ class Replace : public exec::VectorFunction {
298299
FlatVector<StringView>* results) const {
299300
rows.applyToSelected([&](int row) {
300301
auto proxy = exec::StringWriter<>(results, row);
301-
stringImpl::replace(
302+
stringImpl::replace<
303+
decltype(proxy),
304+
decltype(searchReader(row)),
305+
ignoreEmptyReplaced>(
302306
proxy, stringReader(row), searchReader(row), replaceReader(row));
303307
proxy.finalize();
304308
});
@@ -317,7 +321,10 @@ class Replace : public exec::VectorFunction {
317321
rows.applyToSelected([&](int row) {
318322
auto proxy = exec::StringWriter<true /*reuseInput*/>(
319323
results, row, stringReader(row) /*reusedInput*/, true /*inPlace*/);
320-
stringImpl::replaceInPlace(proxy, searchReader(row), replaceReader(row));
324+
stringImpl::replaceInPlace<
325+
decltype(proxy),
326+
decltype(searchReader(row)),
327+
ignoreEmptyReplaced>(proxy, searchReader(row), replaceReader(row));
321328
proxy.finalize();
322329
});
323330
}
@@ -429,6 +436,11 @@ class Replace : public exec::VectorFunction {
429436
return {{0, 2}};
430437
}
431438
};
439+
440+
class Replace : public ReplaceBase<false /*ignoreEmptyReplaced*/> {};
441+
442+
class ReplaceIgnoreEmptyReplaced
443+
: public ReplaceBase<true /*ignoreEmptyReplaced*/> {};
432444
} // namespace
433445

434446
VELOX_DECLARE_VECTOR_FUNCTION(
@@ -456,4 +468,9 @@ VELOX_DECLARE_VECTOR_FUNCTION(
456468
Replace::signatures(),
457469
std::make_unique<Replace>());
458470

471+
VELOX_DECLARE_VECTOR_FUNCTION(
472+
udf_replace_ignore_empty_replaced,
473+
ReplaceIgnoreEmptyReplaced::signatures(),
474+
std::make_unique<ReplaceIgnoreEmptyReplaced>());
475+
459476
} // namespace facebook::velox::functions

velox/functions/sparksql/Register.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ static void workAroundRegistrationMacro(const std::string& prefix) {
6363
// String functions.
6464
VELOX_REGISTER_VECTOR_FUNCTION(udf_concat, prefix + "concat");
6565
VELOX_REGISTER_VECTOR_FUNCTION(udf_lower, prefix + "lower");
66-
VELOX_REGISTER_VECTOR_FUNCTION(udf_replace, prefix + "replace");
66+
VELOX_REGISTER_VECTOR_FUNCTION(
67+
udf_replace_ignore_empty_replaced, prefix + "replace");
6768
VELOX_REGISTER_VECTOR_FUNCTION(udf_upper, prefix + "upper");
6869
// Logical.
6970
VELOX_REGISTER_VECTOR_FUNCTION(udf_not, prefix + "not");

velox/functions/sparksql/tests/StringTest.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,14 @@ class StringTest : public SparkFunctionBaseTest {
190190
std::optional<int32_t> size) {
191191
return evaluateOnce<std::string>("lpad(c0, c1)", string, size);
192192
}
193+
194+
std::optional<std::string> replace(
195+
std::optional<std::string> str,
196+
std::optional<std::string> replaced,
197+
std::optional<std::string> replacement) {
198+
return evaluateOnce<std::string>(
199+
"replace(c0, c1, c2)", str, replaced, replacement);
200+
}
193201
};
194202

195203
TEST_F(StringTest, Ascii) {
@@ -682,5 +690,19 @@ TEST_F(StringTest, translateNonconstantMatch) {
682690
expected = makeFlatVector<std::string>({"åbaæçè", "åæcèaç"});
683691
testTranslate({input, match, replace}, expected);
684692
}
693+
694+
TEST_F(StringTest, replace) {
695+
EXPECT_EQ(replace("aaabaac", "a", "z"), "zzzbzzc");
696+
EXPECT_EQ(replace("aaabaac", "", "z"), "aaabaac");
697+
EXPECT_EQ(replace("aaabaac", "a", ""), "bc");
698+
EXPECT_EQ(replace("aaabaac", "x", "z"), "aaabaac");
699+
EXPECT_EQ(replace("aaabaac", "ab", "z"), "aazaac");
700+
EXPECT_EQ(replace("aaabaac", "aa", "z"), "zabzc");
701+
EXPECT_EQ(replace("aaabaac", "aa", "xyz"), "xyzabxyzc");
702+
EXPECT_EQ(replace("aaabaac", "aaabaac", "z"), "z");
703+
EXPECT_EQ(
704+
replace("123\u6570\u6570\u636E", "\u6570\u636E", "data"),
705+
"123\u6570data");
706+
}
685707
} // namespace
686708
} // namespace facebook::velox::functions::sparksql::test

0 commit comments

Comments
 (0)