/*
* Copyright 2011, Mysema Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mysema.codegen;
import static com.mysema.codegen.Symbols.ASSIGN;
import static com.mysema.codegen.Symbols.COMMA;
import static com.mysema.codegen.Symbols.DOT;
import static com.mysema.codegen.Symbols.QUOTE;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import com.google.common.base.Function;
import com.mysema.codegen.model.Parameter;
import com.mysema.codegen.model.Type;
import com.mysema.codegen.model.Types;
import com.mysema.codegen.support.ScalaSyntaxUtils;
/**
* @author tiwe
*
*/
public class ScalaWriter extends AbstractCodeWriter<ScalaWriter> {
private static final Set<String> PRIMITIVE_TYPES = new HashSet<String>(Arrays.asList("boolean",
"byte", "char", "int", "long", "short", "double", "float"));
private static final String DEF = "def ";
private static final String OVERRIDE_DEF = "override " + DEF;
private static final String EXTENDS = " extends ";
private static final String WITH = " with ";
private static final String IMPORT = "import ";
private static final String IMPORT_STATIC = "import ";
private static final String PACKAGE = "package ";
private static final String PRIVATE = "private ";
private static final String PRIVATE_VAL = "private val ";
private static final String PROTECTED = "protected ";
private static final String PROTECTED_VAL = "protected val ";
private static final String PUBLIC = "public ";
private static final String PUBLIC_CLASS = "class ";
private static final String PUBLIC_OBJECT = "object ";
private static final String CASE_CLASS = "case class ";
private static final String VAR = "var ";
private static final String VAL = "val ";
private static final String THIS = "this";
private static final String TRAIT = "trait ";
private final Set<String> classes = new HashSet<String>();
private final Set<String> packages = new HashSet<String>();
private Type type;
private final boolean compact;
public ScalaWriter(Appendable appendable) {
this(appendable, false);
}
public ScalaWriter(Appendable appendable, boolean compact) {
super(appendable, 2);
this.classes.add("java.lang.String");
this.classes.add("java.lang.Object");
this.classes.add("java.lang.Integer");
this.classes.add("java.lang.Comparable");
this.compact = compact;
}
@Override
public ScalaWriter annotation(Annotation annotation) throws IOException {
beginLine().append("@").appendType(annotation.annotationType());
Method[] methods = annotation.annotationType().getDeclaredMethods();
if (methods.length == 1 && methods[0].getName().equals("value")) {
try {
Object value = methods[0].invoke(annotation);
append("(");
annotationConstant(value);
append(")");
} catch (IllegalArgumentException e) {
throw new CodegenException(e);
} catch (IllegalAccessException e) {
throw new CodegenException(e);
} catch (InvocationTargetException e) {
throw new CodegenException(e);
}
} else {
boolean first = true;
for (Method method : methods) {
try {
Object value = method.invoke(annotation);
if (value == null
|| value.equals(method.getDefaultValue())
|| (value.getClass().isArray() && Arrays.equals((Object[]) value,
(Object[]) method.getDefaultValue()))) {
continue;
} else if (!first) {
append(COMMA);
} else {
append("(");
}
append(escape(method.getName())).append("=");
annotationConstant(value);
} catch (IllegalArgumentException e) {
throw new CodegenException(e);
} catch (IllegalAccessException e) {
throw new CodegenException(e);
} catch (InvocationTargetException e) {
throw new CodegenException(e);
}
first = false;
}
if (!first) {
append(")");
}
}
return nl();
}
@Override
public ScalaWriter annotation(Class<? extends Annotation> annotation) throws IOException {
return beginLine().append("@").appendType(annotation).nl();
}
@SuppressWarnings("unchecked")
private void annotationConstant(Object value) throws IOException {
if (value.getClass().isArray()) {
append("Array(");
boolean first = true;
for (Object o : (Object[]) value) {
if (!first) {
append(", ");
}
annotationConstant(o);
first = false;
}
append(")");
} else if (value instanceof Class) {
append("classOf[");
appendType((Class) value);
append("]");
} else if (value instanceof Number || value instanceof Boolean) {
append(value.toString());
} else if (value instanceof Enum) {
Enum<?> enumValue = (Enum<?>) value;
if (classes.contains(enumValue.getClass().getName())
|| packages.contains(enumValue.getClass().getPackage().getName())) {
append(enumValue.name());
} else {
append(enumValue.getDeclaringClass().getName()).append(DOT).append(enumValue.name());
}
} else if (value instanceof String) {
append(QUOTE).append(StringUtils.escapeJava(value.toString())).append(QUOTE);
} else {
throw new IllegalArgumentException("Unsupported annotation value : " + value);
}
}
private ScalaWriter appendType(Class<?> type) throws IOException {
if (type.isPrimitive()) {
append(StringUtils.capitalize(type.getName()));
} else if (type.getPackage() == null || classes.contains(type.getName())
|| packages.contains(type.getPackage().getName())) {
append(type.getSimpleName());
} else {
append(type.getName());
}
return this;
}
public ScalaWriter beginObject(String header) throws IOException {
line(PUBLIC_OBJECT, header, " {");
goIn();
return this;
}
public ScalaWriter beginClass(String header) throws IOException {
line(PUBLIC_CLASS, header, " {");
goIn();
return this;
}
@Override
public ScalaWriter beginClass(Type type) throws IOException {
return beginClass(type, null);
}
@Override
public ScalaWriter beginClass(Type type, Type superClass, Type... interfaces)
throws IOException {
packages.add(type.getPackageName());
beginLine(PUBLIC_CLASS, getGenericName(false, type));
if (superClass != null) {
append(EXTENDS).append(getGenericName(false, superClass));
}
if (interfaces.length > 0) {
if (superClass == null) {
append(EXTENDS);
append(getGenericName(false, interfaces[0]));
append(WITH);
for (int i = 1; i < interfaces.length; i++) {
if (i > 1) {
append(COMMA);
}
append(getGenericName(false, interfaces[i]));
}
} else {
append(WITH);
for (int i = 0; i < interfaces.length; i++) {
if (i > 0) {
append(COMMA);
}
append(getGenericName(false, interfaces[i]));
}
}
}
append(" {").nl().nl();
goIn();
this.type = type;
return this;
}
@Override
public <T> ScalaWriter beginConstructor(Collection<T> parameters,
Function<T, Parameter> transformer) throws IOException {
beginLine(DEF, THIS).params(parameters, transformer).append(" {").nl();
return goIn();
}
@Override
public ScalaWriter beginConstructor(Parameter... params) throws IOException {
beginLine(DEF, THIS).params(params).append(" {").nl();
return goIn();
}
@Override
public ScalaWriter beginInterface(Type type, Type... interfaces) throws IOException {
packages.add(type.getPackageName());
beginLine(TRAIT, getGenericName(false, type));
if (interfaces.length > 0) {
append(EXTENDS);
append(getGenericName(false, interfaces[0]));
if (interfaces.length > 1) {
append(WITH);
for (int i = 1; i < interfaces.length; i++) {
if (i > 1) {
append(COMMA);
}
append(getGenericName(false, interfaces[i]));
}
}
}
append(" {").nl().nl();
goIn();
this.type = type;
return this;
}
private ScalaWriter beginMethod(String modifiers, Type returnType, String methodName,
Parameter... args) throws IOException {
if (returnType.equals(Types.VOID)) {
beginLine(modifiers, escape(methodName)).params(args).append(" {").nl();
} else {
beginLine(modifiers, escape(methodName)).params(args)
.append(": ").append(getGenericName(true, returnType)).append(" = {").nl();
}
return goIn();
}
@Override
public <T> ScalaWriter beginPublicMethod(Type returnType, String methodName,
Collection<T> parameters, Function<T, Parameter> transformer) throws IOException {
return beginMethod(DEF, returnType, methodName, transform(parameters, transformer));
}
@Override
public ScalaWriter beginPublicMethod(Type returnType, String methodName, Parameter... args)
throws IOException {
return beginMethod(DEF, returnType, methodName, args);
}
public <T> ScalaWriter beginOverridePublicMethod(Type returnType, String methodName,
Collection<T> parameters, Function<T, Parameter> transformer)
throws IOException {
return beginMethod(OVERRIDE_DEF, returnType, methodName, transform(parameters, transformer));
}
public ScalaWriter beginOverridePublicMethod(Type returnType, String methodName, Parameter... args)
throws IOException {
return beginMethod(OVERRIDE_DEF, returnType, methodName, args);
}
@Override
public <T> ScalaWriter beginStaticMethod(Type returnType, String methodName,
Collection<T> parameters, Function<T, Parameter> transformer) throws IOException {
return beginMethod(DEF, returnType, methodName, transform(parameters, transformer));
}
@Override
public ScalaWriter beginStaticMethod(Type returnType, String methodName, Parameter... args)
throws IOException {
return beginMethod(DEF, returnType, methodName, args);
}
public ScalaWriter caseClass(String header, Parameter... parameters) throws IOException {
beginLine(CASE_CLASS, header).params(parameters).nl();
return this;
}
@Override
public ScalaWriter end() throws IOException {
goOut();
return line("}").nl();
}
public ScalaWriter field(Type type, String name) throws IOException {
line(VAR, escape(name), ": ", getGenericName(true, type));
return compact ? this : nl();
}
private ScalaWriter field(String modifier, Type type, String name) throws IOException {
line(modifier, escape(name), ": ", getGenericName(true, type));
return compact ? this : nl();
}
private ScalaWriter field(String modifier, Type type, String name, String value)
throws IOException {
line(modifier, escape(name), ": ", getGenericName(true, type), ASSIGN, value);
return compact ? this : nl();
}
@Override
public String getClassConstant(String className) {
return "classOf[" + className + "]";
}
@Override
public String getGenericName(boolean asArgType, Type type) {
if (type.getParameters().isEmpty()) {
return getRawName(type);
} else {
StringBuilder builder = new StringBuilder();
builder.append(getRawName(type));
builder.append("[");
boolean first = true;
String fullName = type.getFullName();
for (Type parameter : type.getParameters()) {
if (!first) {
builder.append(", ");
}
if (parameter == null || parameter.getFullName().equals(fullName)) {
builder.append("_");
} else {
builder.append(getGenericName(false, parameter));
}
first = false;
}
builder.append("]");
return builder.toString();
}
}
@Override
public String getRawName(Type type) {
String fullName = type.getFullName();
if (PRIMITIVE_TYPES.contains(fullName)) {
fullName = StringUtils.capitalize(fullName);
}
String packageName = type.getPackageName();
if (packageName != null && packageName.length() > 0) {
fullName = packageName + "." + fullName.substring(packageName.length()+1).replace('.', '$');
} else {
fullName = fullName.replace('.', '$');
}
String rv = fullName;
if (type.isPrimitive() && packageName.isEmpty()) {
rv = Character.toUpperCase(rv.charAt(0)) + rv.substring(1);
}
if (packages.contains(packageName) || classes.contains(fullName)) {
if (packageName.length() > 0) {
rv = fullName.substring(packageName.length() + 1);
}
}
if (rv.endsWith("[]")) {
rv = rv.substring(0, rv.length() - 2);
if (PRIMITIVE_TYPES.contains(rv)) {
rv = StringUtils.capitalize(rv);
} else if (classes.contains(rv)) {
rv = rv.substring(packageName.length() + 1);
}
return "Array[" + rv + "]";
} else {
return rv;
}
}
@Override
public ScalaWriter imports(Class<?>... imports) throws IOException {
for (Class<?> cl : imports) {
classes.add(cl.getName());
line(IMPORT, cl.getName());
}
nl();
return this;
}
@Override
public ScalaWriter imports(Package... imports) throws IOException {
for (Package p : imports) {
packages.add(p.getName());
line(IMPORT, p.getName(), "._");
}
nl();
return this;
}
@Override
public ScalaWriter importClasses(String... imports) throws IOException {
for (String cl : imports) {
classes.add(cl);
line(IMPORT, cl);
}
nl();
return this;
}
@Override
public ScalaWriter importPackages(String... imports) throws IOException {
for (String p : imports) {
packages.add(p);
line(IMPORT, p, "._");
}
nl();
return this;
}
@Override
public ScalaWriter javadoc(String... lines) throws IOException {
line("/**");
for (String line : lines) {
line(" * ", line);
}
return line(" */");
}
@Override
public ScalaWriter packageDecl(String packageName) throws IOException {
packages.add(packageName);
return line(PACKAGE, packageName).nl();
}
private <T> ScalaWriter params(Collection<T> parameters, Function<T, Parameter> transformer)
throws IOException {
append("(");
boolean first = true;
for (T param : parameters) {
if (!first) {
append(COMMA);
}
param(transformer.apply(param));
first = false;
}
append(")");
return this;
}
private ScalaWriter params(Parameter... params) throws IOException {
append("(");
for (int i = 0; i < params.length; i++) {
if (i > 0) {
append(COMMA);
}
param(params[i]);
}
append(")");
return this;
}
private ScalaWriter param(Parameter parameter) throws IOException {
append(escape(parameter.getName()));
append(": ");
append(getGenericName(true, parameter.getType()));
return this;
}
@Override
public ScalaWriter privateField(Type type, String name) throws IOException {
return field(PRIVATE, type, name);
}
@Override
public ScalaWriter privateFinal(Type type, String name) throws IOException {
return field(PRIVATE_VAL, type, name);
}
@Override
public ScalaWriter privateFinal(Type type, String name, String value) throws IOException {
return field(PRIVATE_VAL, type, name, value);
}
@Override
public ScalaWriter privateStaticFinal(Type type, String name, String value) throws IOException {
return field(PRIVATE_VAL, type, name, value);
}
@Override
public ScalaWriter protectedField(Type type, String name) throws IOException {
return field(PROTECTED, type, name);
}
@Override
public ScalaWriter protectedFinal(Type type, String name) throws IOException {
return field(PROTECTED_VAL, type, name);
}
@Override
public ScalaWriter protectedFinal(Type type, String name, String value) throws IOException {
return field(PROTECTED_VAL, type, name, value);
}
@Override
public ScalaWriter publicField(Type type, String name) throws IOException {
return field(VAR, type, name);
}
@Override
public ScalaWriter publicField(Type type, String name, String value) throws IOException {
return field(VAR, type, name, value);
}
@Override
public ScalaWriter publicFinal(Type type, String name) throws IOException {
return field(VAL, type, name);
}
@Override
public ScalaWriter publicFinal(Type type, String name, String value) throws IOException {
return field(VAL, type, name, value);
}
@Override
public ScalaWriter publicStaticFinal(Type type, String name, String value) throws IOException {
return field(VAL, type, name, value);
}
@Override
public ScalaWriter staticimports(Class<?>... imports) throws IOException {
for (Class<?> cl : imports) {
line(IMPORT_STATIC, cl.getName(), "._;");
}
return this;
}
@Override
public ScalaWriter suppressWarnings(String type) throws IOException {
return line("@SuppressWarnings(\"", type, "\")");
}
@Override
public CodeWriter suppressWarnings(String... types) throws IOException {
return annotation(new MultiSuppressWarnings(types));
}
private <T> Parameter[] transform(Collection<T> parameters,
Function<T, Parameter> transformer) {
Parameter[] rv = new Parameter[parameters.size()];
int i = 0;
for (T value : parameters) {
rv[i++] = transformer.apply(value);
}
return rv;
}
private String escape(String token) {
if (ScalaSyntaxUtils.isReserved(token)) {
return "`" + token + "`";
} else {
return token;
}
}
}