/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.LiteralParameters;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.type.Constraint;
import com.facebook.presto.type.JoniRegexpType;
import io.airlift.joni.Matcher;
import io.airlift.joni.Option;
import io.airlift.joni.Regex;
import io.airlift.joni.Region;
import io.airlift.joni.exception.ValueException;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import java.nio.charset.StandardCharsets;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
public final class JoniRegexpFunctions
{
private JoniRegexpFunctions()
{
}
@Description("returns whether the pattern is contained within the string")
@ScalarFunction
@LiteralParameters("x")
@SqlType(StandardTypes.BOOLEAN)
public static boolean regexpLike(@SqlType("varchar(x)") Slice source, @SqlType(JoniRegexpType.NAME) Regex pattern)
{
Matcher m = pattern.matcher(source.getBytes());
int offset = m.search(0, source.length(), Option.DEFAULT);
return offset != -1;
}
@Description("removes substrings matching a regular expression")
@ScalarFunction
@LiteralParameters("x")
@SqlType("varchar(x)")
public static Slice regexpReplace(@SqlType("varchar(x)") Slice source, @SqlType(JoniRegexpType.NAME) Regex pattern)
{
return regexpReplace(source, pattern, Slices.EMPTY_SLICE);
}
@Description("replaces substrings matching a regular expression by given string")
@ScalarFunction
@LiteralParameters({"x", "y", "z"})
// Longest possible output is when the pattern is empty, than the replacement will be placed in between
// any two letters of source (x + 1) times. As the replacement may be wildcard and the wildcard input that takes two letters
// can produce (x) length output it max length is (x * y / 2) however for (x < 2), (y) itself (without wildcards)
// may be longer, so we choose max of (x * y / 2) and (y). We than add the length we've added to basic length of source (x)
// to get the formula: x + max(x * y / 2, y) * (x + 1)
@Constraint(variable = "z", expression = "min(2147483647, x + max(x * y / 2, y) * (x + 1))")
@SqlType("varchar(z)")
public static Slice regexpReplace(@SqlType("varchar(x)") Slice source, @SqlType(JoniRegexpType.NAME) Regex pattern, @SqlType("varchar(y)") Slice replacement)
{
Matcher matcher = pattern.matcher(source.getBytes());
SliceOutput sliceOutput = new DynamicSliceOutput(source.length() + replacement.length() * 5);
int lastEnd = 0;
int nextStart = 0; // nextStart is the same as lastEnd, unless the last match was zero-width. In such case, nextStart is lastEnd + 1.
while (true) {
int offset = matcher.search(nextStart, source.length(), Option.DEFAULT);
if (offset == -1) {
break;
}
if (matcher.getEnd() == matcher.getBegin()) {
nextStart = matcher.getEnd() + 1;
}
else {
nextStart = matcher.getEnd();
}
Slice sliceBetweenReplacements = source.slice(lastEnd, matcher.getBegin() - lastEnd);
lastEnd = matcher.getEnd();
sliceOutput.appendBytes(sliceBetweenReplacements);
appendReplacement(sliceOutput, source, pattern, matcher.getEagerRegion(), replacement);
}
sliceOutput.appendBytes(source.slice(lastEnd, source.length() - lastEnd));
return sliceOutput.slice();
}
private static void appendReplacement(SliceOutput result, Slice source, Regex pattern, Region region, Slice replacement)
{
// Handle the following items:
// 1. ${name};
// 2. $0, $1, $123 (group 123, if exists; or group 12, if exists; or group 1);
// 3. \\, \$, \t (literal 't').
// 4. Anything that doesn't starts with \ or $ is considered regular bytes
int idx = 0;
while (idx < replacement.length()) {
byte nextByte = replacement.getByte(idx);
if (nextByte == '$') {
idx++;
if (idx == replacement.length()) { // not using checkArgument because `.toStringUtf8` is expensive
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: " + replacement.toStringUtf8());
}
nextByte = replacement.getByte(idx);
int backref;
if (nextByte == '{') { // case 1 in the above comment
idx++;
int startCursor = idx;
while (idx < replacement.length()) {
nextByte = replacement.getByte(idx);
if (nextByte == '}') {
break;
}
idx++;
}
byte[] groupName = replacement.getBytes(startCursor, idx - startCursor);
try {
backref = pattern.nameToBackrefNumber(groupName, 0, groupName.length, region);
}
catch (ValueException e) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: unknown group { " + new String(groupName, StandardCharsets.UTF_8) + " }");
}
idx++;
}
else { // case 2 in the above comment
backref = nextByte - '0';
if (backref < 0 || backref > 9) { // not using checkArgument because `.toStringUtf8` is expensive
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: " + replacement.toStringUtf8());
}
if (region.numRegs <= backref) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: unknown group " + backref);
}
idx++;
while (idx < replacement.length()) { // Adaptive group number: find largest group num that is not greater than actual number of groups
int nextDigit = replacement.getByte(idx) - '0';
if (nextDigit < 0 || nextDigit > 9) {
break;
}
int newBackref = (backref * 10) + nextDigit;
if (region.numRegs <= newBackref) {
break;
}
backref = newBackref;
idx++;
}
}
int beg = region.beg[backref];
int end = region.end[backref];
if (beg != -1 && end != -1) { // the specific group doesn't exist in the current match, skip
result.appendBytes(source.slice(beg, end - beg));
}
}
else { // case 3 and 4 in the above comment
if (nextByte == '\\') {
idx++;
if (idx == replacement.length()) { // not using checkArgument because `.toStringUtf8` is expensive
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: " + replacement.toStringUtf8());
}
nextByte = replacement.getByte(idx);
}
result.appendByte(nextByte);
idx++;
}
}
}
@Description("string(s) extracted using the given pattern")
@ScalarFunction
@LiteralParameters("x")
@SqlType("array(varchar(x))")
public static Block regexpExtractAll(@SqlType("varchar(x)") Slice source, @SqlType(JoniRegexpType.NAME) Regex pattern)
{
return regexpExtractAll(source, pattern, 0);
}
@Description("group(s) extracted using the given pattern")
@ScalarFunction
@LiteralParameters("x")
@SqlType("array(varchar(x))")
public static Block regexpExtractAll(@SqlType("varchar(x)") Slice source, @SqlType(JoniRegexpType.NAME) Regex pattern, @SqlType(StandardTypes.BIGINT) long groupIndex)
{
Matcher matcher = pattern.matcher(source.getBytes());
validateGroup(groupIndex, matcher.getEagerRegion());
BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(new BlockBuilderStatus(), 32);
int group = toIntExact(groupIndex);
int nextStart = 0;
while (true) {
int offset = matcher.search(nextStart, source.length(), Option.DEFAULT);
if (offset == -1) {
break;
}
if (matcher.getEnd() == matcher.getBegin()) {
nextStart = matcher.getEnd() + 1;
}
else {
nextStart = matcher.getEnd();
}
Region region = matcher.getEagerRegion();
int beg = region.beg[group];
int end = region.end[group];
if (beg == -1 || end == -1) {
blockBuilder.appendNull();
}
else {
Slice slice = source.slice(beg, end - beg);
VARCHAR.writeSlice(blockBuilder, slice);
}
}
return blockBuilder.build();
}
@SqlNullable
@Description("string extracted using the given pattern")
@ScalarFunction
@LiteralParameters("x")
@SqlType("varchar(x)")
public static Slice regexpExtract(@SqlType("varchar(x)") Slice source, @SqlType(JoniRegexpType.NAME) Regex pattern)
{
return regexpExtract(source, pattern, 0);
}
@SqlNullable
@Description("returns regex group of extracted string with a pattern")
@ScalarFunction
@LiteralParameters("x")
@SqlType("varchar(x)")
public static Slice regexpExtract(@SqlType("varchar(x)") Slice source, @SqlType(JoniRegexpType.NAME) Regex pattern, @SqlType(StandardTypes.BIGINT) long groupIndex)
{
Matcher matcher = pattern.matcher(source.getBytes());
validateGroup(groupIndex, matcher.getEagerRegion());
int group = toIntExact(groupIndex);
int offset = matcher.search(0, source.length(), Option.DEFAULT);
if (offset == -1) {
return null;
}
Region region = matcher.getEagerRegion();
int beg = region.beg[group];
int end = region.end[group];
if (beg == -1) {
// end == -1 must be true
return null;
}
Slice slice = source.slice(beg, end - beg);
return slice;
}
@ScalarFunction
@LiteralParameters("x")
@Description("returns array of strings split by pattern")
@SqlType("array(varchar(x))")
public static Block regexpSplit(@SqlType("varchar(x)") Slice source, @SqlType(JoniRegexpType.NAME) Regex pattern)
{
Matcher matcher = pattern.matcher(source.getBytes());
BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(new BlockBuilderStatus(), 32);
int lastEnd = 0;
int nextStart = 0;
while (true) {
int offset = matcher.search(nextStart, source.length(), Option.DEFAULT);
if (offset == -1) {
break;
}
if (matcher.getEnd() == matcher.getBegin()) {
nextStart = matcher.getEnd() + 1;
}
else {
nextStart = matcher.getEnd();
}
Slice slice = source.slice(lastEnd, matcher.getBegin() - lastEnd);
lastEnd = matcher.getEnd();
VARCHAR.writeSlice(blockBuilder, slice);
}
VARCHAR.writeSlice(blockBuilder, source.slice(lastEnd, source.length() - lastEnd));
return blockBuilder.build();
}
private static void validateGroup(long group, Region region)
{
if (group < 0) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Group cannot be negative");
}
if (group > region.numRegs - 1) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Pattern has %d groups. Cannot access group %d", region.numRegs - 1, group));
}
}
}