/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.drivers.input;
import org.apache.commons.lang3.text.WordUtils;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.client.program.ProgramParametrizationException;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.asm.translate.TranslateFunction;
import org.apache.flink.graph.asm.translate.TranslateGraphIds;
import org.apache.flink.graph.asm.translate.translators.LongValueToStringValue;
import org.apache.flink.graph.asm.translate.translators.LongValueToUnsignedIntValue;
import org.apache.flink.graph.drivers.parameter.ChoiceParameter;
import org.apache.flink.graph.drivers.parameter.ParameterizedBase;
import org.apache.flink.types.ByteValue;
import org.apache.flink.types.CharValue;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.NullValue;
import org.apache.flink.types.ShortValue;
/**
* Base class for generated graphs.
*
* @param <K> graph ID type
*/
public abstract class GeneratedGraph<K>
extends ParameterizedBase
implements Input<K, NullValue, NullValue> {
private static final String BYTE = "byte";
private static final String NATIVE_BYTE = "nativeByte";
private static final String SHORT = "short";
private static final String NATIVE_SHORT = "nativeShort";
private static final String CHAR = "char";
private static final String NATIVE_CHAR = "nativeChar";
private static final String INTEGER = "integer";
private static final String NATIVE_INTEGER = "nativeInteger";
private static final String LONG = "long";
private static final String NATIVE_LONG = "nativeLong";
private static final String STRING = "string";
private static final String NATIVE_STRING = "nativeString";
private ChoiceParameter type = new ChoiceParameter(this, "type")
.setDefaultValue(INTEGER)
.addChoices(LONG, STRING)
.addHiddenChoices(BYTE, NATIVE_BYTE, SHORT, NATIVE_SHORT, CHAR, NATIVE_CHAR, NATIVE_INTEGER, NATIVE_LONG, NATIVE_STRING);
/**
* The vertex count is verified to be no greater than the capacity of the
* selected data type. All vertices must be counted even if skipped or
* unused when generating graph edges.
*
* @return number of vertices configured for the graph
*/
protected abstract long vertexCount();
/**
* Generate the graph as configured.
*
* @param env Flink execution environment
* @return generated graph
* @throws Exception on error
*/
protected abstract Graph<K, NullValue, NullValue> generate(ExecutionEnvironment env) throws Exception;
/**
* Get the name of the type.
*
* @return name of the type
*/
protected String getTypeName() {
return WordUtils.capitalize(type.getValue());
}
@Override
public Graph<K, NullValue, NullValue> create(ExecutionEnvironment env)
throws Exception {
long maxVertexCount = Long.MAX_VALUE;
TranslateFunction translator = null;
switch (type.getValue()) {
case BYTE:
maxVertexCount = LongValueToUnsignedByteValue.MAX_VERTEX_COUNT;
translator = new LongValueToUnsignedByteValue();
break;
case NATIVE_BYTE:
maxVertexCount = LongValueToUnsignedByte.MAX_VERTEX_COUNT;
translator = new LongValueToUnsignedByte();
break;
case SHORT:
maxVertexCount = LongValueToUnsignedShortValue.MAX_VERTEX_COUNT;
translator = new LongValueToUnsignedShortValue();
break;
case NATIVE_SHORT:
maxVertexCount = LongValueToUnsignedShort.MAX_VERTEX_COUNT;
translator = new LongValueToUnsignedShort();
break;
case CHAR:
maxVertexCount = LongValueToCharValue.MAX_VERTEX_COUNT;
translator = new LongValueToCharValue();
break;
case NATIVE_CHAR:
maxVertexCount = LongValueToChar.MAX_VERTEX_COUNT;
translator = new LongValueToChar();
break;
case INTEGER:
maxVertexCount = LongValueToUnsignedIntValue.MAX_VERTEX_COUNT;
translator = new LongValueToUnsignedIntValue();
break;
case NATIVE_INTEGER:
maxVertexCount = LongValueToUnsignedInt.MAX_VERTEX_COUNT;
translator = new LongValueToUnsignedInt();
break;
case LONG:
break;
case NATIVE_LONG:
translator = new LongValueToLong();
break;
case STRING:
translator = new LongValueToStringValue();
break;
case NATIVE_STRING:
translator = new LongValueToString();
break;
default:
throw new ProgramParametrizationException("Unknown type '" + type.getValue() + "'");
}
long vertexCount = vertexCount();
if (vertexCount > maxVertexCount) {
throw new ProgramParametrizationException("Vertex count '" + vertexCount +
"' must be no greater than " + maxVertexCount + " for type '" + type.getValue() + "'.");
}
Graph<K, NullValue, NullValue> graph = generate(env);
if (translator != null) {
graph = (Graph<K, NullValue, NullValue>) graph.run(new TranslateGraphIds(translator));
}
return graph;
}
/**
* Translate {@link LongValue} to {@link ByteValue}.
*
* Throws {@link RuntimeException} for byte overflow.
*/
static class LongValueToUnsignedByteValue
implements TranslateFunction<LongValue, ByteValue> {
public static final long MAX_VERTEX_COUNT = 1L << 8;
@Override
public ByteValue translate(LongValue value, ByteValue reuse)
throws Exception {
if (reuse == null) {
reuse = new ByteValue();
}
long l = value.getValue();
if (l < 0 || l >= MAX_VERTEX_COUNT) {
throw new IllegalArgumentException("Cannot cast long value " + value + " to byte.");
} else {
reuse.setValue((byte) (l & (MAX_VERTEX_COUNT - 1)));
}
return reuse;
}
}
/**
* Translate {@link LongValue} to {@link Byte}.
*
* Throws {@link RuntimeException} for byte overflow.
*/
static class LongValueToUnsignedByte
implements TranslateFunction<LongValue, Byte> {
public static final long MAX_VERTEX_COUNT = 1L << 8;
@Override
public Byte translate(LongValue value, Byte reuse)
throws Exception {
long l = value.getValue();
if (l < 0 || l >= MAX_VERTEX_COUNT) {
throw new IllegalArgumentException("Cannot cast long value " + value + " to byte.");
}
return (byte) (l & (MAX_VERTEX_COUNT - 1));
}
}
/**
* Translate {@link LongValue} to {@link ShortValue}.
*
* Throws {@link RuntimeException} for short overflow.
*/
static class LongValueToUnsignedShortValue
implements TranslateFunction<LongValue, ShortValue> {
public static final long MAX_VERTEX_COUNT = 1L << 16;
@Override
public ShortValue translate(LongValue value, ShortValue reuse)
throws Exception {
if (reuse == null) {
reuse = new ShortValue();
}
long l = value.getValue();
if (l < 0 || l >= MAX_VERTEX_COUNT) {
throw new IllegalArgumentException("Cannot cast long value " + value + " to short.");
} else {
reuse.setValue((short) (l & (MAX_VERTEX_COUNT - 1)));
}
return reuse;
}
}
/**
* Translate {@link LongValue} to {@link Short}.
*
* Throws {@link RuntimeException} for short overflow.
*/
static class LongValueToUnsignedShort
implements TranslateFunction<LongValue, Short> {
public static final long MAX_VERTEX_COUNT = 1L << 16;
@Override
public Short translate(LongValue value, Short reuse)
throws Exception {
long l = value.getValue();
if (l < 0 || l >= MAX_VERTEX_COUNT) {
throw new IllegalArgumentException("Cannot cast long value " + value + " to short.");
}
return (short) (l & (MAX_VERTEX_COUNT - 1));
}
}
/**
* Translate {@link LongValue} to {@link CharValue}.
*
* Throws {@link RuntimeException} for char overflow.
*/
static class LongValueToCharValue
implements TranslateFunction<LongValue, CharValue> {
public static final long MAX_VERTEX_COUNT = 1L << 16;
@Override
public CharValue translate(LongValue value, CharValue reuse)
throws Exception {
if (reuse == null) {
reuse = new CharValue();
}
long l = value.getValue();
if (l < 0 || l >= MAX_VERTEX_COUNT) {
throw new IllegalArgumentException("Cannot cast long value " + value + " to char.");
} else {
reuse.setValue((char) (l & (MAX_VERTEX_COUNT - 1)));
}
return reuse;
}
}
/**
* Translate {@link LongValue} to {@code Character}.
*
* Throws {@link RuntimeException} for char overflow.
*/
static class LongValueToChar
implements TranslateFunction<LongValue, Character> {
public static final long MAX_VERTEX_COUNT = 1L << 16;
@Override
public Character translate(LongValue value, Character reuse)
throws Exception {
long l = value.getValue();
if (l < 0 || l >= MAX_VERTEX_COUNT) {
throw new IllegalArgumentException("Cannot cast long value " + value + " to char.");
}
return (char) (l & (MAX_VERTEX_COUNT - 1));
}
}
/**
* Translate {@link LongValue} to {@link Integer}.
*
* Throws {@link RuntimeException} for integer overflow.
*/
static class LongValueToUnsignedInt
implements TranslateFunction<LongValue, Integer> {
public static final long MAX_VERTEX_COUNT = 1L << 32;
@Override
public Integer translate(LongValue value, Integer reuse)
throws Exception {
long l = value.getValue();
if (l < 0 || l >= MAX_VERTEX_COUNT) {
throw new IllegalArgumentException("Cannot cast long value " + value + " to integer.");
}
return (int) (l & (MAX_VERTEX_COUNT - 1));
}
}
/**
* Translate {@link LongValue} to {@link Long}.
*/
static class LongValueToLong
implements TranslateFunction<LongValue, Long> {
@Override
public Long translate(LongValue value, Long reuse)
throws Exception {
return value.getValue();
}
}
/**
* Translate {@link LongValue} to {@link String}.
*/
static class LongValueToString
implements TranslateFunction<LongValue, String> {
@Override
public String translate(LongValue value, String reuse)
throws Exception {
return Long.toString(value.getValue());
}
}
}