/**
* Copyright 2013, Landz and its contributors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package z.znr.invoke.linux.x64;
import com.kenai.jffi.*;
import jnr.udis86.X86Disassembler;
import jnr.x86asm.Assembler;
import jnr.x86asm.REG;
import jnr.x86asm.Register;
import z.znr.InlineAssembler;
import z.znr.invoke.types.ParameterType;
import z.znr.invoke.types.ResultType;
import java.io.PrintStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import java.util.logging.Level;
import java.util.logging.Logger;
import static jnr.x86asm.Asm.*;
import static z.znr.invoke.linux.x64.CodegenUtils.sig;
/**
* Compilers method trampoline stubs for x86_64
*/
final class X64StubCompiler extends StubCompiler {
public final static boolean DEBUG = Boolean.getBoolean("jnr.invoke.compile.dump");
@Override
// There is only one calling convention; SYSV, so abort if someone tries to use stdcall
boolean canCompile(ResultType returnType, ParameterType[] parameterTypes) {
switch (returnType.nativeType()) {
case VOID:
case SCHAR:
case UCHAR:
case SSHORT:
case USHORT:
case SINT:
case UINT:
case SLONG:
case ULONG:
case SLONG_LONG:
case ULONG_LONG:
case FLOAT:
case DOUBLE:
case POINTER:
break;
default:
return false;
}
int fCount = 0;
int iCount = 0;
for (ParameterType t : parameterTypes) {
switch (t.nativeType()) {
case SCHAR:
case UCHAR:
case SSHORT:
case USHORT:
case SINT:
case UINT:
case SLONG:
case ULONG:
case SLONG_LONG:
case ULONG_LONG:
case POINTER:
++iCount;
break;
case FLOAT:
case DOUBLE:
++fCount;
break;
default:
// Fail on anything else
return false;
}
}
// We can only safely compile methods with up to 6 integer and 8 floating point parameters
return iCount <= 6 && fCount <= 8;
}
static final Register[] srcRegisters8 = { dl, cl, r8b, r9b };
static final Register[] srcRegisters16 = { dx, cx, r8w, r9w };
static final Register[] srcRegisters32 = { edx, ecx, Register.gpr(REG.REG_R8D), Register.gpr(REG.REG_R9D) };
static final Register[] srcRegisters64 = { rdx, rcx, r8, r9 };
static final Register[] dstRegisters32 = { edi, esi, edx, ecx, Register.gpr(REG.REG_R8D), Register.gpr(REG.REG_R9D) };
static final Register[] dstRegisters64 = { rdi, rsi, rdx, rcx, r8, r9 };
@Override
final void compile(InlineAssembler inlineAssembler, String name, ResultType resultType, ParameterType[] parameterTypes,
Class resultClass, Class[] parameterClasses, boolean saveErrno) {
Assembler a = new Assembler(X86_64);
int iCount = iCount(parameterTypes);
int fCount = fCount(parameterTypes);
boolean canJumpToTarget = !saveErrno & iCount <= 6 & fCount <= 8;
switch (resultType.nativeType()) {
case SINT:
case UINT:
canJumpToTarget &= int.class == resultClass;
break;
case SLONG_LONG:
case ULONG_LONG:
canJumpToTarget &= long.class == resultClass;
break;
case FLOAT:
canJumpToTarget &= float.class == resultClass;
break;
case DOUBLE:
canJumpToTarget &= double.class == resultClass;
break;
case VOID:
break;
default:
canJumpToTarget = false;
break;
}
// JNI functions all look like:
// foo(JNIEnv* env, jobject self, arg...)
// on AMD64, those sit in %rdi, %rsi, %rdx, %rcx, %r8 and %r9
// So we need to shuffle all the integer args up to over-write the
// env and self arguments
//
// for (int i = 0; i < Math.min(iCount, 4); i++) {
// switch (parameterTypes[i].nativeType()) {
// case SCHAR:
// a.movsx(dstRegisters64[i], srcRegisters8[i]);
// break;
//
// case UCHAR:
// a.movzx(dstRegisters64[i], srcRegisters8[i]);
// break;
//
// case SSHORT:
// a.movsx(dstRegisters64[i], srcRegisters16[i]);
// break;
//
// case USHORT:
// a.movzx(dstRegisters64[i], srcRegisters16[i]);
// break;
//
// case SINT:
// a.movsxd(dstRegisters64[i], srcRegisters32[i]);
// break;
//
// case UINT:
// // mov with a 32bit dst reg zero extends to 64bit
// a.mov(dstRegisters32[i], srcRegisters32[i]);
// break;
//
// default:
// a.mov(dstRegisters64[i], srcRegisters64[i]);
// break;
// }
// }
if (iCount > 6) {
throw new IllegalArgumentException("integer argument count > 6");
}
// For args 5 & 6 of the function, they would have been pushed on the stack
// for (int i = 4; i < iCount; i++) {
// int disp = 8 + ((4 - i) * 8);
// switch (parameterTypes[i].nativeType()) {
// case SCHAR:
// a.movsx(dstRegisters64[i], byte_ptr(rsp, disp));
// break;
//
// case UCHAR:
// a.movzx(dstRegisters64[i], byte_ptr(rsp, disp));
// break;
//
// case SSHORT:
// a.movsx(dstRegisters64[i], word_ptr(rsp, disp));
// break;
//
// case USHORT:
// a.movzx(dstRegisters64[i], word_ptr(rsp, disp));
// break;
//
// case SINT:
// a.movsxd(dstRegisters64[i], dword_ptr(rsp, disp));
// break;
//
// case UINT:
// // mov with a 32bit dst reg zero extends to 64bit
// a.mov(dstRegisters32[i], dword_ptr(rsp, disp));
// break;
//
// default:
// a.mov(dstRegisters64[i], qword_ptr(rsp, disp));
// break;
// }
// }
// All the integer registers are loaded; there nothing to do for the floating
// registers, as the first 8 args are already in xmm0..xmm7, so just sanity check
if (fCount > 8) {
throw new IllegalArgumentException("float argument count > 8");
}
if (canJumpToTarget) {
inlineAssembler.assemble(a);
stubs.add(new Stub(name, sig(resultClass, parameterClasses), a));
return;
}
// Need to align the stack to 16 bytes for function call.
// It already has 8 bytes pushed (the return address), so making space
// to save the return value from the function neatly aligns it to 16 bytes
int space = resultClass == float.class || resultClass == double.class
? 24 : 8;
a.sub(rsp, imm(space));
// Clear %rax, since it is used by varargs functions to determine the number of float registers to be saved
a.xor_(eax, eax);
// Call to the actual native function
long function = Util.inlineAssemblerToCodeAddress(inlineAssembler).address();
a.call(imm(function));
if (saveErrno) {
// Save the return on the stack
switch (resultType.nativeType()) {
case VOID:
// No need to save/reload return value registers
break;
case FLOAT:
a.movss(dword_ptr(rsp, 0), xmm0);
break;
case DOUBLE:
a.movsd(qword_ptr(rsp, 0), xmm0);
break;
default:
a.mov(qword_ptr(rsp, 0), rax);
}
// Save the errno in a thread-local variable
a.call(imm(errnoFunctionAddress));
// Retrieve return value and put it back in the appropriate return register
switch (resultType.nativeType()) {
case VOID:
// No need to save/reload return value registers
break;
case SCHAR:
a.movsx(rax, byte_ptr(rsp, 0));
break;
case UCHAR:
a.movzx(rax, byte_ptr(rsp, 0));
break;
case SSHORT:
a.movsx(rax, word_ptr(rsp, 0));
break;
case USHORT:
a.movzx(rax, word_ptr(rsp, 0));
break;
case SINT:
a.movsxd(rax, dword_ptr(rsp, 0));
break;
case UINT:
// storing a value in eax zeroes out the upper 32 bits of rax
a.mov(eax, dword_ptr(rsp, 0));
break;
case FLOAT:
a.movss(xmm0, dword_ptr(rsp, 0));
break;
case DOUBLE:
a.movsd(xmm0, qword_ptr(rsp, 0));
break;
default:
a.mov(rax, qword_ptr(rsp, 0));
break;
}
} else {
// sign/zero extend the result
switch (resultType.nativeType()) {
case SCHAR:
a.movsx(rax, al);
break;
case UCHAR:
a.movzx(rax, al);
break;
case SSHORT:
a.movsx(rax, ax);
break;
case USHORT:
a.movzx(rax, ax);
break;
case SINT:
if (long.class == resultClass) a.movsxd(rax, eax);
break;
case UINT:
if (long.class == resultClass) a.mov(eax, eax);
break;
}
}
// Restore rsp to original position
a.add(rsp, imm(space));
a.ret();
stubs.add(new Stub(name, sig(resultClass, parameterClasses), a));
}
static int fCount(ParameterType[] parameterTypes) {
int fCount = 0;
for (ParameterType t : parameterTypes) {
switch (t.nativeType()) {
case FLOAT:
case DOUBLE:
++fCount;
break;
}
}
return fCount;
}
static int iCount(ParameterType[] parameterTypes) {
int iCount = 0;
for (ParameterType t : parameterTypes) {
switch (t.nativeType()) {
case SCHAR:
case UCHAR:
case SSHORT:
case USHORT:
case SINT:
case UINT:
case SLONG:
case ULONG:
case SLONG_LONG:
case ULONG_LONG:
case POINTER:
++iCount;
break;
}
}
return iCount;
}
private static final class StaticDataHolder {
// Keep a reference from the loaded class to the pages holding the code for that class.
static final Map<Class, PageHolder> PAGES
= Collections.synchronizedMap(new WeakHashMap<Class, PageHolder>());
}
final List<Stub> stubs = new LinkedList<Stub>();
static final class Stub {
final String name;
final String signature;
final Assembler assembler;
public Stub(String name, String signature, Assembler assembler) {
this.name = name;
this.signature = signature;
this.assembler = assembler;
}
}
static final class PageHolder {
final PageManager pm;
final long memory;
final long pageCount;
public PageHolder(PageManager pm, long memory, long pageCount) {
this.pm = pm;
this.memory = memory;
this.pageCount = pageCount;
}
@Override
protected void finalize() throws Throwable {
try {
pm.freePages(memory, (int) pageCount);
} catch (Throwable t) {
Logger.getLogger(getClass().getName()).log(Level.WARNING,
"Exception when freeing native pages: %s", t.getLocalizedMessage());
} finally {
super.finalize();
}
}
}
@Override
Object attach(Class clazz) {
if (stubs.isEmpty()) {
return new Object();
}
long codeSize = 0;
for (Stub stub : stubs) {
// add 8 bytes for alignment
codeSize += stub.assembler.codeSize() + 8;
}
PageManager pm = PageManager.getInstance();
long npages = (codeSize + pm.pageSize() - 1) / pm.pageSize();
// Allocate some native memory for it
long code = pm.allocatePages((int) npages, PageManager.PROT_READ | PageManager.PROT_WRITE);
if (code == 0) {
throw new OutOfMemoryError("allocatePages failed for codeSize=" + codeSize);
}
PageHolder page = new PageHolder(pm, code, npages);
// Now relocate/copy all the assembler stubs into the real code area
List<NativeMethod> methods = new ArrayList<NativeMethod>(stubs.size());
long fn = code;
PrintStream dbg = System.err;
System.out.flush(); System.err.flush();
for (Stub stub : stubs) {
Assembler asm = stub.assembler;
// align the start of all functions on a 8 byte boundary
fn = align(fn, 8);
ByteBuffer buf = ByteBuffer.allocate(asm.codeSize()).order(ByteOrder.LITTLE_ENDIAN);
stub.assembler.relocCode(buf, fn);
buf.flip();
MemoryIO.getInstance().putByteArray(fn, buf.array(), buf.arrayOffset(), buf.limit());
if (DEBUG && X86Disassembler.isAvailable()) {
dbg.println(clazz.getName() + "." + stub.name + " " + stub.signature);
X86Disassembler disassembler = X86Disassembler.create();
disassembler.setMode(Platform.getPlatform().getCPU() == Platform.CPU.I386
? X86Disassembler.Mode.I386 : X86Disassembler.Mode.X86_64);
disassembler.setInputBuffer(fn, asm.offset());
while (disassembler.disassemble()) {
dbg.printf("%8x: %s\n", disassembler.offset(), disassembler.insn());
}
if (buf.remaining() > asm.offset()) {
// libudis86 for some reason cannot understand the code asmjit emits for the trampolines
dbg.printf("%8x: <indirect call trampolines>\n", asm.offset());
}
dbg.println();
}
methods.add(new NativeMethod(fn, stub.name, stub.signature));
fn += asm.codeSize();
}
pm.protectPages(code, (int) npages, PageManager.PROT_READ | PageManager.PROT_EXEC);
NativeMethods.register(clazz, methods);
StaticDataHolder.PAGES.put(clazz, page);
return page;
}
private static int align(int offset, int align) {
return (offset + align - 1) & ~(align - 1);
}
private static long align(long offset, long align) {
return (offset + align - 1) & ~(align - 1);
}
}