package com.jinoh.ruby.marshal; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Array; import java.lang.reflect.Field; import java.math.BigInteger; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; public class Unmarshaler { protected String mEncoding = "UTF-8"; protected InputStream mStream; protected ArrayList<Object> mCache = new ArrayList<Object>(); protected ArrayList<Object> mSymCache = new ArrayList<Object>(); public List<Object> getCache () { return mCache; } public List<Object> getSymCache () { return mSymCache; } public Unmarshaler(InputStream is) { mStream = is; } public InputStream getInputStream () { return mStream; } public void setEncoding(String encoding) { mEncoding = encoding; } public String getEncoding() { return mEncoding; } public Object unmarshalAuto() throws IOException { return unmarshalAuto(mStream.read(), null); } public Object unmarshalAuto(int c) throws IOException { return unmarshalAuto(c, null); } public Object unmarshalAuto(Class<?> type) throws IOException { return unmarshalAuto(mStream.read(), type); } @SuppressWarnings("unchecked") public Object unmarshalAuto(int c, Class<?> type) throws IOException { switch (c) { case '0': // nil return null; case 'T': // true return true; case 'F': // false return false; case 'i': return readInt(); case 'l': return readBigInteger(); case '"': return readString(); case ':': return readSymbol(); case '[': if (type == null) return readArray(); return readArray(type.getComponentType()); case '{': try { return readHash(type.asSubclass(Map.class)); } catch (Exception e) { // Either NullPointerException or ClassCastException return readHash(); } case 'o': return readMarshalable(); case 'u': return readUserdef(); case 'f': return readRubyFloat(); case '@': int i = readInt(); if (i >= mCache.size()) throw new IOException ("Invalid cache index"); return mCache.get(i); case ';': i = readInt(); if (i >= mSymCache.size()) throw new IOException ("Invalid symbol cache index"); return mSymCache.get(i); case 'I': return readIvar(); } throw new RuntimeException("WTF?"); } public Symbol forceUnmarshalSymbol() throws IOException { int c = mStream.read(); if (c == ':' || c == ';') return (Symbol) unmarshalAuto(c); throw new IOException("Expected symbol, got " + ((c == -1) ? "EOF" : (char) c)); } public Symbol readSymbol() throws IOException { Symbol o = new Symbol(new String(readBytesAsString(), "UTF-8")); mSymCache.add(o); return o; } public Object[] readArray() throws IOException { return (Object[]) readArray(Object.class); } public Object readArray(Class<?> clazz) throws IOException { int size = readInt(); Object array = Array.newInstance(clazz, size); mCache.add(array); for (int i = 0; i < size; i++) { Array.set(array, i, unmarshalAuto()); } return array; } public Map<?, ?> readHash() throws IOException { int size = readInt(); Map<Object, Object> hash = new LinkedHashMap<Object, Object>(); mCache.add(hash); for (int i = 0; i < size; i++) { Object key = unmarshalAuto(); Object value = unmarshalAuto(); hash.put(key, value); } return hash; } @SuppressWarnings("unchecked") public <K, V, T extends Map<K, V>> T readHash(Class<T> clazz) throws IOException { int size = readInt(); T hash; try { hash = clazz.getConstructor(int.class).newInstance(size); } catch (Exception e) { try { hash = clazz.getConstructor().newInstance(); } catch (Exception e1) { throw new RuntimeException(e1); } } mCache.add(hash); for (int i = 0; i < size; i++) { K key = (K) unmarshalAuto(); V value = (V) unmarshalAuto(); hash.put(key, value); } return hash; } public Marshallable readMarshalable () throws IOException { Symbol objName = forceUnmarshalSymbol(); Class<?> clazz = Marshal.sSymbolToClass.get(objName); if (clazz == null) throw new IOException ("Symbol " + objName + " not registered. Counts : " + Marshal.sSymbolToClass.size()); try { Marshallable inst = (Marshallable) clazz.newInstance(); mCache.add(inst); int fieldCount = readInt(); String name; Field f; Class<?> cl; for (int i = 0; i < fieldCount; i++) { name = forceUnmarshalSymbol().toString(); if (name.charAt(0) == '@') { name = name.substring(1); cl = clazz; Object val = null; boolean bRead = false; int oos = 0; do { try { f = cl.getDeclaredField(name); f.setAccessible(true); val = unmarshalAuto(f.getType()); bRead = true; f.set(inst, val); break; } catch (NoSuchFieldException e) { cl = cl.getSuperclass(); } catch (IllegalAccessException e) { break; } if (oos++ > 40) throw new RuntimeException("40+ deep / " + clazz + " > " + cl); } while (cl != null); if (!bRead) val = unmarshalAuto(); } } return inst; } catch (RuntimeException e1) { throw e1; } catch (IOException e1) { throw e1; } catch (InstantiationException e1) { throw new RuntimeException(e1); } catch (IllegalAccessException e1) { throw new RuntimeException(e1); } } public CustomMarshallable readUserdef() throws IOException { Symbol objName = forceUnmarshalSymbol(); Class<?> clazz = Marshal.sSymbolToClass.get(objName); if (clazz == null) throw new IOException ("Symbol " + objName + " not registered."); try { CustomMarshallable inst = (CustomMarshallable) clazz.newInstance(); mCache.add(inst); inst.load(this, readBytesAsString()); return inst; } catch (RuntimeException e1) { throw e1; } catch (IOException e1) { throw e1; } catch (InstantiationException e) { throw new RuntimeException(e); } catch (IllegalAccessException e) { throw new RuntimeException(e); } } public byte readByte() throws IOException { int c = mStream.read(); if (c == -1) throw new IOException("End of stream"); return (byte) c; } public int readUnsignedByte() throws IOException { return ((int) readByte()) & 0xFF; } public String readString(String encoding) throws IOException { String s = new String(readBytesAsString(), encoding); mCache.add(s); return s; } public byte[] readBytesAsString() throws IOException { byte[] buf = new byte[readInt()]; if (mStream.read(buf) < buf.length) throw new IOException("End of stream"); return buf; } public String readString() throws IOException { return readString(mEncoding); } public Object readIvar() throws IOException { int c = mStream.read(); Map<Symbol, Object> map = new HashMap<Symbol, Object>(); RubyIVar ivar = new RubyIVar(null, map); mCache.add(ivar); int count; Symbol name; Object o; String str; if (c == '"') { // it is a raw string, so encode is necessary byte[] data = readBytesAsString(); int idx; synchronized (mCache) { // space for string idx = mCache.size(); mCache.add(null); } count = readInt(); String encoding = null; while (count-- > 0) { str = (name = forceUnmarshalSymbol()).toString(); map.put(name, o = unmarshalAuto()); if (str.equals("E")) { if (o instanceof Boolean) { encoding = ((Boolean) o).booleanValue() ? "UTF-8" : "US-ASCII"; } } else if (str.equals("encoding")) { if (o instanceof String) { encoding = o.toString(); } } } if (encoding != null) { str = new String(data, encoding); } else str = new String(data, mEncoding); ivar.setValue(str); mCache.set(idx, str); return ivar; } ivar.setValue(unmarshalAuto()); count = readInt(); while (count-- > 0) { str = (name = forceUnmarshalSymbol()).toString(); map.put(name, o = unmarshalAuto()); } return ivar; } public int readInt() throws IOException { int c = readUnsignedByte(); if (c == 0) { return 0; } else if (5 < c && c < 128) { return c - 5; } else if (-129 < c && c < -5) { return c + 5; } int result; if (c > 0) { c <<= 3; result = 0; for (int i = 0; i < c; i += 8) result |= readUnsignedByte() << i; } else { c = (byte) ((-c) << 3); result = -1; for (int i = 0; i < c; i += 8) result = (result & ~(0xff << i)) | (readUnsignedByte() << i); } return result; } public Object readBigInteger() throws IOException { // what a convoluted way to serialize a big integer (gotta love ruby) boolean positive = readByte() == '+'; int shortLength = readInt(), i; Object o; if (shortLength > 8) { // BigInteger required a sign byte in incoming array byte[] digits = new byte[(shortLength << 1) + 1]; digits[0] = positive ? 0 : ((byte) -1); for (i = digits.length - 1; i > 0; i--) { digits[i] = readByte(); } o = new BigInteger(digits); } else { long value = 0; shortLength <<= 4; for (i = 0; i < shortLength; i += 8) { value |= readByte() << i; } if (!positive) o = -value; o = value; } mCache.add(o); return o; } public double readRubyFloat () throws IOException { double val = Double.parseDouble(new String(readBytesAsString())); mCache.add(val); return val; } }