/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.codegen;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import javax.tools.Diagnostic;
import javax.tools.Diagnostic.Kind;
import javax.tools.DiagnosticCollector;
import javax.tools.JavaCompiler;
import javax.tools.JavaCompiler.CompilationTask;
import javax.tools.JavaFileObject;
import javax.tools.StandardJavaFileManager;
import javax.tools.ToolProvider;
import org.apache.commons.io.IOUtils;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.codegen.SpoofCompiler;
import org.apache.sysml.hops.codegen.SpoofCompiler.CompilerType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.io.IOUtilFunctions;
import org.apache.sysml.runtime.util.LocalFileUtils;
import org.apache.sysml.utils.Statistics;
import org.codehaus.janino.SimpleCompiler;
public class CodegenUtils
{
//cache to reuse compiled and loaded classes
private static ConcurrentHashMap<String, Class<?>> _cache = new ConcurrentHashMap<String,Class<?>>();
//janino-specific map of source code transfer/recompile on-demand
private static ConcurrentHashMap<String, String> _src = new ConcurrentHashMap<String,String>();
//javac-specific working directory for src/class files
private static String _workingDir = null;
public static Class<?> compileClass(String name, String src)
throws DMLRuntimeException
{
//reuse existing compiled class
Class<?> ret = _cache.get(name);
if( ret != null )
return ret;
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
//compile java source w/ specific compiler
if( SpoofCompiler.JAVA_COMPILER == CompilerType.JANINO )
ret = compileClassJanino(name, src);
else
ret = compileClassJavac(name, src);
//keep compiled class for reuse
_cache.put(name, ret);
if( DMLScript.STATISTICS ) {
Statistics.incrementCodegenClassCompile();
Statistics.incrementCodegenClassCompileTime(System.nanoTime()-t0);
}
return ret;
}
public static Class<?> getClass(String name) throws DMLRuntimeException {
return getClass(name, null);
}
public static Class<?> getClass(String name, byte[] classBytes)
throws DMLRuntimeException
{
//reuse existing compiled class
Class<?> ret = _cache.get(name);
if( ret != null )
return ret;
//get class in a compiler-specific manner
if( SpoofCompiler.JAVA_COMPILER == CompilerType.JANINO )
ret = compileClassJanino(name, new String(classBytes));
else
ret = loadFromClassFile(name, classBytes);
//keep loaded class for reuse
_cache.put(name, ret);
return ret;
}
public static byte[] getClassData(String name)
throws DMLRuntimeException
{
//get class in a compiler-specific manner
if( SpoofCompiler.JAVA_COMPILER == CompilerType.JANINO )
return _src.get(name).getBytes();
else
return getClassAsByteArray(name);
}
public static void clearClassCache() {
_cache.clear();
_src.clear();
}
public static void clearClassCache(Class<?> cla) {
//one-pass, in-place filtering of class cache
Iterator<Entry<String,Class<?>>> iter = _cache.entrySet().iterator();
while( iter.hasNext() )
if( iter.next().getValue()==cla )
iter.remove();
}
public static SpoofOperator createInstance(Class<?> cla)
throws DMLRuntimeException
{
SpoofOperator ret = null;
try {
ret = (SpoofOperator) cla.newInstance();
}
catch( Exception ex ) {
throw new DMLRuntimeException(ex);
}
return ret;
}
////////////////////////////
//JANINO-specific methods (used for spark environments)
private static Class<?> compileClassJanino(String name, String src)
throws DMLRuntimeException
{
try {
//compile source code
SimpleCompiler compiler = new SimpleCompiler();
compiler.cook(src);
//keep source code for later re-construction
_src.put(name, src);
//load compile class
return compiler.getClassLoader()
.loadClass(name);
}
catch(Exception ex) {
throw new DMLRuntimeException("Failed to compile class "+name+".", ex);
}
}
////////////////////////////
//JAVAC-specific methods (used for hadoop environments)
private static Class<?> compileClassJavac(String name, String src)
throws DMLRuntimeException
{
try
{
//create working dir on demand
if( _workingDir == null )
createWorkingDir();
//write input file (for debugging / classpath handling)
File ftmp = new File(_workingDir+"/"+name.replace(".", "/")+".java");
if( !ftmp.getParentFile().exists() )
ftmp.getParentFile().mkdirs();
LocalFileUtils.writeTextFile(ftmp, src);
//get system java compiler
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
if( compiler == null )
throw new RuntimeException("Unable to obtain system java compiler.");
//prepare file manager
DiagnosticCollector<JavaFileObject> diagnostics = new DiagnosticCollector<JavaFileObject>();
StandardJavaFileManager fileManager = compiler.getStandardFileManager(diagnostics, null, null);
//prepare input source code
Iterable<? extends JavaFileObject> sources = fileManager
.getJavaFileObjectsFromFiles(Arrays.asList(ftmp));
//prepare class path
URL runDir = CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
String classpath = System.getProperty("java.class.path") +
File.pathSeparator + runDir.getPath();
List<String> options = Arrays.asList("-classpath",classpath);
//compile source code
CompilationTask task = compiler.getTask(null, fileManager, diagnostics, options, null, sources);
Boolean success = task.call();
//output diagnostics and error handling
for(Diagnostic<? extends JavaFileObject> tmp : diagnostics.getDiagnostics())
if( tmp.getKind()==Kind.ERROR )
System.err.println("ERROR: "+tmp.toString());
if( success == null || !success )
throw new RuntimeException("Failed to compile class "+name);
//dynamically load compiled class
URLClassLoader classLoader = null;
try {
classLoader = new URLClassLoader(
new URL[]{new File(_workingDir).toURI().toURL(), runDir},
CodegenUtils.class.getClassLoader());
return classLoader.loadClass(name);
}
finally {
IOUtilFunctions.closeSilently(classLoader);
}
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
}
private static Class<?> loadFromClassFile(String name, byte[] classBytes)
throws DMLRuntimeException
{
if(classBytes != null) {
//load from byte representation of class file
try(ByteClassLoader byteLoader = new ByteClassLoader(new URL[]{},
CodegenUtils.class.getClassLoader(), classBytes))
{
return byteLoader.findClass(name);
}
catch (Exception e) {
throw new DMLRuntimeException(e);
}
}
else {
//load compiled class file
URL runDir = CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
try(URLClassLoader classLoader = new URLClassLoader(new URL[]{new File(_workingDir)
.toURI().toURL(), runDir}, CodegenUtils.class.getClassLoader()))
{
return classLoader.loadClass(name);
}
catch (Exception e) {
throw new DMLRuntimeException(e);
}
}
}
private static byte[] getClassAsByteArray(String name)
throws DMLRuntimeException
{
String classAsPath = name.replace('.', '/') + ".class";
URLClassLoader classLoader = null;
InputStream stream = null;
try {
//dynamically load compiled class
URL runDir = CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
classLoader = new URLClassLoader(
new URL[]{new File(_workingDir).toURI().toURL(), runDir},
CodegenUtils.class.getClassLoader());
stream = classLoader.getResourceAsStream(classAsPath);
return IOUtils.toByteArray(stream);
}
catch (IOException e) {
throw new DMLRuntimeException(e);
}
finally {
IOUtilFunctions.closeSilently(classLoader);
IOUtilFunctions.closeSilently(stream);
}
}
private static void createWorkingDir() throws DMLRuntimeException {
if( _workingDir != null )
return;
String tmp = LocalFileUtils.getWorkingDir(LocalFileUtils.CATEGORY_CODEGEN);
LocalFileUtils.createLocalFileIfNotExist(tmp);
_workingDir = tmp;
}
}