package sharpen.xobotos.api.interop; import org.eclipse.jdt.core.dom.ITypeBinding; import sharpen.core.csharp.ast.*; import sharpen.xobotos.api.interop.NativeMethodBuilder.ElementInfo; import sharpen.xobotos.api.interop.glue.*; import java.util.List; public class ArrayHelperClass extends HelperClassBuilder { private final ElementInfo _elementInfo; private Member _lengthMember; private ElementMember _ptrMember; public ArrayHelperClass(ITypeBinding type, String name, NativeBuilder builder, ElementInfo elementInfo, boolean isShared) { super(type, name, builder, new CSArrayTypeReference(elementInfo.getManagedType(), 1), computeArrayType(elementInfo), isShared, isShared); this._elementInfo = elementInfo; } private static AbstractTypeReference computeArrayType(ElementInfo info) { if (!info.isClass()) return new TemplateTypeReference("NativeArray", info.getNativeType()); AbstractNativeTypeBuilder builder = info.getTypeBuilder(); if (builder instanceof HelperClassBuilder) { HelperClassBuilder helper = (HelperClassBuilder) builder; if (helper.isByRef()) return new TemplateTypeReference("NativePtrArray", helper.getRealNativeType()); } return new TemplateTypeReference("NativeArray", builder.getNativeType()); } @Override protected boolean isBlittable() { return HelperClassBuilder.isBlittable(_elementInfo.getType()); } @Override protected boolean isByRef() { return true; } @Override public List<String> getIncludes() { return null; } @Override public boolean resolve(IMarshalContext context) { return true; } @Override protected void buildWrap(Method wrap, NativeVariable src) { if (!isBlittable()) { super.buildWrap(wrap, src); return; } ConstructorInvocation ctorCall = new ConstructorInvocation(getRealNativeType()); ctorCall.addArgument(_lengthMember.getReference(src)); ctorCall.addArgument(_ptrMember.getReference(src)); wrap.getBody().addStatement(new ReturnStatement(ctorCall)); } @Override protected void buildMembers() { _lengthMember = addMember(new LengthMember()); _ptrMember = new PtrMember(); addMember(_ptrMember); } private CSTypeReferenceExpression getManagedElementType() { return _elementInfo.getManagedType(); } @Override protected CSExpression computeNativeSize(CSExpression expr) { CSExpression baseSize = NATIVE_SIZE.expr(); final CSExpression elementSize; if (_ptrMember.isByRef()) { CSExpression sizeof = new CSReferenceExpression("Marshal.SizeOf"); elementSize = new CSMethodInvocationExpression(sizeof, new CSTypeofExpression( new CSTypeReference("System.IntPtr"))); } else { elementSize = _ptrMember.getNativeElementSizeExpr(); } CSExpression length = new CSMemberReferenceExpression(expr, "Length"); return new CSInfixExpression("+", baseSize, new CSInfixExpression("*", elementSize, length)); } @Override protected CSExpression createInstance(ManagedVariable obj) { return new CSArrayCreationExpression(getManagedElementType(), _lengthMember.getReference(obj)); } private class LengthMember extends Member { public LengthMember() { super(new TypeReference("uint32_t"), new CSTypeReference("int"), "length", "length", Flags.FUNCTION, Flags.PASS_TO_CTOR); } @Override protected Statement unwrap(NativeVariable src, NativeVariable dest) { return new AssignmentStatement(getReference(dest), getReference(src)); } @Override protected void createPinnedPtr(CSStruct struct, CSMethod free) { ; } @Override protected CSStatement getPinnedPtr(ManagedVariable arg, ManagedVariable obj, ManagedVariable pinned) { final CSExpression length = new CSMemberReferenceExpression(arg.getReference(), "Length"); return new CSExpressionStatement(-1, new CSInfixExpression("=", getReference(obj), length)); } @Override protected CSStatement marshalIn(ManagedVariable arg, ManagedVariable obj, ManagedVariable ptr) { final CSExpression length = new CSMemberReferenceExpression(arg.getReference(), "Length"); return new CSExpressionStatement(-1, new CSInfixExpression("=", getReference(obj), length)); } @Override protected CSStatement marshalOut(ManagedVariable arg, ManagedVariable obj, ManagedVariable ptr) { return createManagedAssert(new CSInfixExpression("!=", getReference(obj), new CSMemberReferenceExpression(arg.getReference(), "Length"))); } @Override protected CSStatement nativeToManaged(ManagedVariable arg, ManagedVariable obj, ManagedVariable ptr) { return createManagedAssert(new CSInfixExpression("!=", getReference(obj), new CSMemberReferenceExpression(arg.getReference(), "Length"))); } @Override public CSStatement free(ManagedVariable obj) { return null; } @Override public Statement freeMembers(NativeVariable obj) { return null; } @Override protected Statement wrap(NativeVariable src, NativeVariable dest) { return createAssert(new BinaryOperator("==", getReference(src), getReference(dest))); } @Override protected Statement marshalOut(NativeVariable src, NativeVariable dest) { return createAssert(new BinaryOperator("==", getReference(src), getReference(dest))); } } private class PtrMember extends ElementMember { private final String _handleName; public PtrMember() { super(_elementInfo, "ptr", "ptr", Flags.FUNCTION, Flags.POINTER); this._handleName = "handle_array_ptr"; } @Override protected void createPinnedPtr(CSStruct struct, CSMethod free) { CSField handle = new CSField(_handleName, new CSTypeReference("GCHandle"), CSVisibility.Public); struct.addMember(handle); free.body().addStatement( new CSMethodInvocationExpression(new CSMemberReferenceExpression( new CSReferenceExpression(_handleName), "Free"))); } @Override protected Statement unwrap(NativeVariable src, NativeVariable dest) { if (isBlittable()) { return new AssignmentStatement(getReference(dest), (new MethodInvocation( src.getMemberAccess("takeOwnership")))); } Block block = new Block(); Expression alloc; if (TRACK_ALLOCATIONS) { alloc = new MethodInvocation(new TemplateFunctionReference("MarshalHelper::allocArray", getNativeType()), _lengthMember.getReference(src)); } else { alloc = new ArrayCreationExpression(getNativeType(), _lengthMember.getReference(src)); } block.addStatement(new AssignmentStatement(getReference(dest), alloc)); LocalVariable i = new LocalVariable(new TypeReference("uint32_t"), "i"); Expression init = new NumberLiteralExpression(0); Expression iRef = new VariableReference(i); Expression check = new BinaryOperator("<", iRef, _lengthMember.getReference(src)); Expression update = new PostfixIncrement(i); ForStatement forStm = new ForStatement(i, init, check, update); Expression targetIdx = new IndexedExpression(getReference(dest), iRef); Expression srcIdx = getIndex(src, iRef); if (!isClass()) { forStm.getBody().addStatement(new AssignmentStatement(targetIdx, srcIdx)); } else if (isByRef()) { forStm.getBody().addStatement( new AssignmentStatement(targetIdx, new MethodInvocation( getHelper().UNWRAP.expr(), srcIdx))); } else { Expression targetAddr = new AddressOfExpression(targetIdx); forStm.getBody().addStatement( new MethodInvocation(getHelper().DEEP_UNWRAP.expr(), srcIdx, targetAddr)); } block.addStatement(forStm); return block; } @Override protected CSStatement getPinnedPtr(ManagedVariable arg, ManagedVariable obj, ManagedVariable pinned) { CSExpression handleRef = new CSMemberReferenceExpression(pinned.getReference(), _handleName); CSExpression createHandle = new CSMethodInvocationExpression( new CSReferenceExpression("GCHandle.Alloc"), arg.getReference(), new CSReferenceExpression("GCHandleType.Pinned")); CSExpression getAddr = new CSMethodInvocationExpression(new CSMemberReferenceExpression( handleRef, "AddrOfPinnedObject")); CSBlock block = new CSBlock(); block.addStatement(new CSInfixExpression("=", handleRef, createHandle)); block.addStatement(new CSInfixExpression("=", getReference(obj), getAddr)); return block; } @Override protected CSStatement marshalIn(ManagedVariable arg, ManagedVariable obj, ManagedVariable ptr) { CSBlock block = new CSBlock(); final CSExpression lengthRef = new CSMemberReferenceExpression(arg.getReference(), "Length"); final CSExpression zero = new CSNumberLiteralExpression("0"); final CSExpression baseSize = NATIVE_SIZE.expr(); final CSExpression elementSize = getNativeElementSizeExpr(); block.addStatement(new CSInfixExpression("=", getReference(obj), new CSInfixExpression("+", ptr.getReference(), baseSize))); if (isBlittable()) { CSExpression copyRef = new CSReferenceExpression("Marshal.Copy"); block.addStatement(new CSMethodInvocationExpression(copyRef, arg.getReference(), zero, getReference(obj), lengthRef)); return block; } ManagedVariable i = new ManagedVariable("i", new CSTypeReference("int")); i.getDeclaration().initializer(zero); CSExpression check = new CSInfixExpression("<", i.getReference(), lengthRef); CSForStatement forStm = new CSForStatement(-1, check); forStm.addInitializer(new CSDeclarationExpression(i.getDeclaration())); forStm.addUpdater(new CSPostfixExpression("++", i.getReference())); final CSExpression expr = new CSIndexedExpression(arg.getReference(), i.getReference()); final CSTypeReference intPtrType = new CSTypeReference("System.IntPtr"); if (isByRef()) { ManagedVariable vector = new ManagedVariable("vector", new CSArrayTypeReference( intPtrType, 1), new CSArrayCreationExpression(intPtrType, lengthRef)); block.addStatement(vector.getDeclarationStatement()); block.addStatement(forStm); CSExpression idx = new CSIndexedExpression(vector.getReference(), i.getReference()); forStm.body().addStatement( new CSInfixExpression("=", idx, new CSMethodInvocationExpression( getHelper().MANAGED_TO_NATIVE.expr(), expr))); CSExpression copyRef = new CSReferenceExpression("Marshal.Copy"); block.addStatement(new CSMethodInvocationExpression(copyRef, vector.getReference(), zero, getReference(obj), lengthRef)); } else { ManagedVariable addr = new ManagedVariable("addr", intPtrType, getReference(obj)); block.addStatement(addr.getDeclarationStatement()); block.addStatement(forStm); forStm.addUpdater(new CSInfixExpression("+=", addr.getReference(), elementSize)); forStm.body().addStatement(marshalIn(addr.getReference(), expr)); } return block; } @Override protected CSStatement marshalOut(ManagedVariable arg, ManagedVariable obj, ManagedVariable ptr) { final CSExpression lengthRef = _lengthMember.getReference(obj); final CSExpression zero = new CSNumberLiteralExpression("0"); final CSExpression elementSize = getNativeElementSizeExpr(); if (isBlittable()) { return createManagedAssert(null); } else { CSBlock block = new CSBlock(); ManagedVariable addr = new ManagedVariable("addr", new CSTypeReference("System.IntPtr")); addr.getDeclaration().initializer(getReference(obj)); block.addStatement(new CSDeclarationStatement(-1, addr.getDeclaration())); ManagedVariable i = new ManagedVariable("i", new CSTypeReference("int")); i.getDeclaration().initializer(zero); CSExpression check = new CSInfixExpression("<", i.getReference(), lengthRef); CSForStatement forStm = new CSForStatement(-1, check); forStm.addInitializer(new CSDeclarationExpression(i.getDeclaration())); forStm.addUpdater(new CSPostfixExpression("++", i.getReference())); forStm.addUpdater(new CSInfixExpression("+=", addr.getReference(), elementSize)); final CSExpression expr = new CSIndexedExpression(arg.getReference(), i.getReference()); forStm.body().addStatement( marshalOut(addr.getReference(), expr, _elementInfo.getManagedType())); block.addStatement(forStm); return block; } } @Override protected CSStatement nativeToManaged(ManagedVariable arg, ManagedVariable obj, ManagedVariable ptr) { final CSExpression lengthRef = _lengthMember.getReference(obj); final CSExpression zero = new CSNumberLiteralExpression("0"); final CSExpression elementSize = getNativeElementSizeExpr(); if (isBlittable()) { CSExpression copyRef = new CSReferenceExpression("Marshal.Copy"); return new CSExpressionStatement(-1, new CSMethodInvocationExpression(copyRef, getReference(obj), arg.getReference(), zero, lengthRef)); } CSBlock block = new CSBlock(); ManagedVariable i = new ManagedVariable("i", new CSTypeReference("int"), zero); final CSExpression expr = new CSIndexedExpression(arg.getReference(), i.getReference()); CSExpression check = new CSInfixExpression("<", i.getReference(), lengthRef); CSForStatement forStm = new CSForStatement(-1, check); forStm.addInitializer(new CSDeclarationExpression(i.getDeclaration())); forStm.addUpdater(new CSPostfixExpression("++", i.getReference())); final CSTypeReference intPtrType = new CSTypeReference("System.IntPtr"); if (isByRef()) { ManagedVariable vector = new ManagedVariable("vector", new CSArrayTypeReference( intPtrType, 1), new CSArrayCreationExpression(intPtrType, lengthRef)); block.addStatement(vector.getDeclarationStatement()); CSExpression copyRef = new CSReferenceExpression("Marshal.Copy"); block.addStatement(new CSMethodInvocationExpression(copyRef, getReference(obj), vector .getReference(), zero, lengthRef)); CSExpression idx = new CSIndexedExpression(vector.getReference(), i.getReference()); forStm.body().addStatement(nativeToManaged(idx, expr, _elementInfo.getManagedType())); } else { ManagedVariable addr = new ManagedVariable("addr", new CSTypeReference("System.IntPtr")); addr.getDeclaration().initializer(getReference(obj)); block.addStatement(new CSDeclarationStatement(-1, addr.getDeclaration())); forStm.addUpdater(new CSInfixExpression("+=", addr.getReference(), elementSize)); forStm.body().addStatement( nativeToManaged(addr.getReference(), expr, _elementInfo.getManagedType())); } block.addStatement(forStm); return block; } @Override public CSStatement free(ManagedVariable obj) { if (isBlittable() || !isClass()) return null; final CSExpression lengthRef = _lengthMember.getReference(obj); final CSExpression zero = new CSNumberLiteralExpression("0"); final CSExpression elementSize = getNativeElementSizeExpr(); CSBlock block = new CSBlock(); ManagedVariable i = new ManagedVariable("i", new CSTypeReference("int"), zero); CSExpression check = new CSInfixExpression("<", i.getReference(), lengthRef); CSForStatement forStm = new CSForStatement(-1, check); forStm.addInitializer(new CSDeclarationExpression(i.getDeclaration())); forStm.addUpdater(new CSPostfixExpression("++", i.getReference())); final CSTypeReference intPtrType = new CSTypeReference("System.IntPtr"); if (isByRef()) { ManagedVariable vector = new ManagedVariable("vector", new CSArrayTypeReference( intPtrType, 1), new CSArrayCreationExpression(intPtrType, lengthRef)); block.addStatement(vector.getDeclarationStatement()); CSExpression copyRef = new CSReferenceExpression("Marshal.Copy"); block.addStatement(new CSMethodInvocationExpression(copyRef, getReference(obj), vector .getReference(), zero, lengthRef)); CSExpression idx = new CSIndexedExpression(vector.getReference(), i.getReference()); forStm.body().addStatement( new CSMethodInvocationExpression(getHelper().FREE_MANAGED_PTR.expr(), idx)); } else { ManagedVariable addr = new ManagedVariable("addr", new CSTypeReference("System.IntPtr"), getReference(obj)); block.addStatement(addr.getDeclarationStatement()); forStm.addUpdater(new CSInfixExpression("+=", addr.getReference(), elementSize)); forStm.body().addStatement( new CSMethodInvocationExpression(getHelper().DEEP_FREE_MANAGED_PTR .expr(), addr.getReference())); } block.addStatement(forStm); return block; } @Override public Statement freeMembers(NativeVariable obj) { final Statement freeArray; if (TRACK_ALLOCATIONS) { freeArray = new ExpressionStatement(new MethodInvocation(new TemplateFunctionReference( "MarshalHelper::freeArray", getNativeType()), getReference(obj), _lengthMember.getReference(obj))); } else { freeArray = new ArrayDestructorInvocation(getNativeType()); } if (isBlittable() || !isClass()) return freeArray; Block block = new Block(); LocalVariable i = new LocalVariable(new TypeReference("uint32_t"), "i"); Expression init = new NumberLiteralExpression(0); Expression iRef = new VariableReference(i); Expression check = new BinaryOperator("<", iRef, _lengthMember.getReference(obj)); Expression update = new PostfixIncrement(i); ForStatement forStm = new ForStatement(i, init, check, update); Expression srcIdx = new IndexedExpression(getReference(obj), iRef); if (isByRef()) forStm.getBody().addStatement( new MethodInvocation(getHelper().DESTRUCTOR.expr(), srcIdx)); else forStm.getBody().addStatement( new MethodInvocation(getHelper().FREE_MEMBERS.expr(), srcIdx)); block.addStatement(forStm); block.addStatement(freeArray); return block; } @Override protected Statement wrap(NativeVariable src, NativeVariable dest) { LocalVariable i = new LocalVariable(new TypeReference("uint32_t"), "i"); Expression init = new NumberLiteralExpression(0); Expression iRef = new VariableReference(i); Expression check = new BinaryOperator("<", iRef, _lengthMember.getReference(src)); Expression update = new PostfixIncrement(i); ForStatement forStm = new ForStatement(i, init, check, update); Expression srcIdx = new IndexedExpression(getReference(src), iRef); Expression targetIdx = getIndex(dest, iRef); if (!isClass()) { forStm.getBody().addStatement(new AssignmentStatement(targetIdx, srcIdx)); } else if (isByRef()) { forStm.getBody().addStatement( new AssignmentStatement(targetIdx, new MethodInvocation( getHelper().WRAP.expr(), srcIdx))); } else { Expression destAddr = new AddressOfExpression(targetIdx); forStm.getBody().addStatement( new MethodInvocation(getHelper().WRAP.expr(), srcIdx, destAddr)); } return forStm; } @Override protected Statement marshalOut(NativeVariable src, NativeVariable dest) { if (isByRef() || isBlittable()) return createAssert(null); LocalVariable i = new LocalVariable(new TypeReference("uint32_t"), "i"); Expression init = new NumberLiteralExpression(0); Expression iRef = new VariableReference(i); Expression check = new BinaryOperator("<", iRef, _lengthMember.getReference(src)); Expression update = new PostfixIncrement(i); ForStatement forStm = new ForStatement(i, init, check, update); Expression targetIdx = new IndexedExpression(getReference(dest), iRef); Expression srcIdx = getIndex(src, iRef); Expression srcExpr = isByRef() ? new DereferenceExpression(srcIdx) : srcIdx; if (isClass()) { Expression targetAddr = new AddressOfExpression(targetIdx); forStm.getBody().addStatement( new MethodInvocation(getHelper().MARSHAL_OUT.expr(), srcExpr, targetAddr)); } else { forStm.getBody().addStatement(new AssignmentStatement(targetIdx, srcIdx)); } return forStm; } } }