/* * Copyright 2012, 2015 Takao Nakaguchi * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.trie4j.doublearray; import java.io.Externalizable; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInput; import java.io.ObjectInputStream; import java.io.ObjectOutput; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.io.Writer; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.Deque; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; import java.util.TreeSet; import org.trie4j.AbstractTermIdTrie; import org.trie4j.Node; import org.trie4j.TermIdNode; import org.trie4j.TermIdTrie; import org.trie4j.Trie; import org.trie4j.bv.BytesRank1OnlySuccinctBitVector; import org.trie4j.bv.SuccinctBitVector; import org.trie4j.util.BitSet; import org.trie4j.util.FastBitSet; import org.trie4j.util.Pair; import sun.misc.Unsafe; @Deprecated @SuppressWarnings("restriction") public class UnsafeDoubleArray extends AbstractTermIdTrie implements Externalizable, TermIdTrie{ public static interface TermNodeListener{ void listen(Node node, int nodeIndex); } public UnsafeDoubleArray() { } public UnsafeDoubleArray(Trie trie){ this(trie, trie.size() * 2); } public UnsafeDoubleArray(Trie trie, int arraySize){ this(trie, arraySize, new TermNodeListener(){ @Override public void listen(Node node, int nodeIndex) { } }); } public UnsafeDoubleArray(Trie trie, int arraySize, TermNodeListener listener){ if(arraySize <= 1) arraySize = 2; size = trie.size(); nodeSize = trie.nodeSize(); base = new int[arraySize]; Arrays.fill(base, BASE_EMPTY); check = new int[arraySize]; Arrays.fill(check, -1 * Unsafe.ARRAY_INT_INDEX_SCALE); FastBitSet bs = new FastBitSet(arraySize); build(trie.getRoot(), arrayIndexToOffset(0), bs, listener); term = new BytesRank1OnlySuccinctBitVector(bs.getBytes(), bs.size()); base = Arrays.copyOf(base, offsetToIntArrayIndex(last) + chars.size()); check = Arrays.copyOf(check, offsetToIntArrayIndex(last) + chars.size()); } @Override public int nodeSize() { return nodeSize; } @Override public int size() { return size; } @Override public TermIdNode getRoot() { return newDoubleArrayNode(0); } public int[] getBase(){ return base; } public int[] getCheck(){ return check; } public BitSet getTerm() { return term; } protected class DoubleArrayNode implements TermIdNode{ public DoubleArrayNode(int nodeId){ this.nodeId = nodeId; } public DoubleArrayNode(int nodeId, char firstChar){ this.nodeId = nodeId; this.firstChar = firstChar; } @Override public boolean isTerminate() { int nid = nodeId; while(true){ CharSequence children = listupChildChars(nid); int n = children.length(); if(n == 0) return term.get(nid); int b = offsetToIntArrayIndex(base[nid]); char firstChar = children.charAt(0); if(n > 1){ return term.get(nid); } else{ int firstNid = b + charToScaledCode[firstChar] / Unsafe.ARRAY_INT_INDEX_SCALE; if(term.get(firstNid)) return true; nid = firstNid; } } } @Override public char[] getLetters() { StringBuilder ret = new StringBuilder(); if(firstChar != 0) ret.append(firstChar); return ret.toString().toCharArray(); } @Override public DoubleArrayNode[] getChildren() { int nid = nodeId; while(true){ CharSequence children = listupChildChars(nid); int n = children.length(); if(n == 0) return emptyNodes; int b = offsetToIntArrayIndex(base[nid]); if(n > 1 || term.get(nid)){ return listupChildNodes(b, children); } nid = b + charToScaledCode[children.charAt(0)] / Unsafe.ARRAY_INT_INDEX_SCALE; } } @Override public DoubleArrayNode getChild(char c) { int code = charToScaledCode[c]; if(code == -1) return null; int nid = offsetToIntArrayIndex(base[nodeId] + code); if(nid >= 0 && nid < check.length && offsetToIntArrayIndex(check[nid]) == nodeId) return new DoubleArrayNode(nid, c); return null; } public int getNodeId() { return nodeId; } @Override public int getTermId(){ if(!term.get(nodeId)){ return -1; } return term.rank1(nodeId) - 1; } private CharSequence listupChildChars(int nodeId){ StringBuilder b = new StringBuilder(); long bs = base[nodeId]; for(char c : chars){ long nodeOffset = bs + charToScaledCode[c]; if(nodeOffset >= 0 && nodeOffset < arrayIndexToOffset(check.length) && unsafe.getInt(check, nodeOffset) == nodeId){ b.append(c); } } return b; } private DoubleArrayNode[] listupChildNodes(int baseIndex, CharSequence chars){ int n = chars.length(); DoubleArrayNode[] ret = new DoubleArrayNode[n]; for(int i = 0; i < n; i++){ char c = chars.charAt(i); int code = charToScaledCode[c] / Unsafe.ARRAY_INT_INDEX_SCALE; ret[i] = newDoubleArrayNode(baseIndex + code, c); } return ret; } private char firstChar = 0; private int nodeId; } @Override //* public boolean contains(String text){ long nodeOffset = Unsafe.ARRAY_INT_BASE_OFFSET; // root int n = text.length(); for(int i = 0; i < n; i++){ int cid = charToScaledCode[text.charAt(i)]; // int cid = unsafe.getInt(charToScaledCode, (long)Unsafe.ARRAY_INT_BASE_OFFSET + (text.charAt(i) << 2)); if(cid == 0) return false; long next = unsafe.getInt(base, nodeOffset) + cid; if(next < 0 || unsafe.getInt(check, next) != nodeOffset) return false; nodeOffset = next; } return term.get((int)((nodeOffset - Unsafe.ARRAY_INT_BASE_OFFSET) >> 2)); } /*/ public boolean contains(String text){ int nodeIndex = 0; // root int n = text.length(); for(int i = 0; i < n; i++){ char cid = unsafe.getChar(charToCode, (long)Unsafe.ARRAY_CHAR_BASE_OFFSET + (text.charAt(i) << 1)); if(cid == 0) return false; int next = unsafe.getInt(base, (long)Unsafe.ARRAY_INT_BASE_OFFSET + (nodeIndex << 2)) + cid; if(next < 0 || unsafe.getInt(check, (long)Unsafe.ARRAY_INT_BASE_OFFSET + (next << 2)) != nodeIndex) return false; nodeIndex = next; } return term.get(nodeIndex); } //*/ private static final sun.misc.Unsafe unsafe; static { try { Field field = sun.misc.Unsafe.class.getDeclaredField("theUnsafe"); field.setAccessible(true); unsafe = (sun.misc.Unsafe) field.get(null); } catch (Exception e) { throw new RuntimeException(e); } } public int getNodeId(String text) { long nodeOffset = Unsafe.ARRAY_INT_BASE_OFFSET; // root int n = text.length(); for(int i = 0; i < n; i++){ int code = charToScaledCode[text.charAt(i)]; if(code == 0) return -1; long next = unsafe.getInt(base, nodeOffset) + code; if(next < 0 || unsafe.getInt(check, next) != nodeOffset) return -1; nodeOffset = next; } return offsetToIntArrayIndex(nodeOffset); } @Override public int getTermId(String text) { int nid = getNodeId(text); return term.get(nid) ? term.rank1(nid) - 1 : -1; } @Override public Iterable<String> commonPrefixSearch(String query) { List<String> ret = new ArrayList<String>(); char[] chars = query.toCharArray(); int charsLen = chars.length; long checkEnd = arrayIndexToOffset(check.length); long nodeOffset = Unsafe.ARRAY_INT_BASE_OFFSET; for(int i = 0; i < charsLen; i++){ int cid = findCharScaledCode(chars[i]); if(cid == -1) return ret; int b = unsafe.getInt(base, nodeOffset); if(b == BASE_EMPTY) return ret; long next = b + cid; if(next >= checkEnd || unsafe.getInt(check, next) != nodeOffset) return ret; nodeOffset = next; if(term.get(offsetToIntArrayIndex(nodeOffset))) ret.add(new String(chars, 0, i + 1)); } return ret; } @Override public Iterable<Pair<String, Integer>> commonPrefixSearchWithTermId( String query) { List<Pair<String, Integer>> ret = new ArrayList<Pair<String, Integer>>(); char[] chars = query.toCharArray(); int charsLen = chars.length; long checkEnd = arrayIndexToOffset(check.length); long nodeOffset = Unsafe.ARRAY_INT_BASE_OFFSET; for(int i = 0; i < charsLen; i++){ int cid = findCharScaledCode(chars[i]); if(cid == -1) return ret; int b = unsafe.getInt(base, nodeOffset); if(b == BASE_EMPTY) return ret; long next = b + cid; if(next >= checkEnd || unsafe.getInt(check, next) != nodeOffset) return ret; nodeOffset = next; int nodeIndex = offsetToIntArrayIndex(nodeOffset); if(term.get(nodeIndex)){ ret.add(Pair.create( new String(chars, 0, i + 1), term.rank1(nodeIndex) - 1 )); } } return ret; } @Override public int findWord(CharSequence chars, int start, int end, StringBuilder word) { for(int i = start; i < end; i++){ long nodeOffset = Unsafe.ARRAY_INT_BASE_OFFSET; try{ for(int j = i; j < end; j++){ int cid = findCharScaledCode(chars.charAt(j)); if(cid == -1) break; int b = unsafe.getInt(base, nodeOffset); if(b == BASE_EMPTY) break; long next = b + cid; if(nodeOffset != unsafe.getInt(check, next)) break; nodeOffset = next; if(term.get(offsetToIntArrayIndex(nodeOffset))){ if(word != null) word.append(chars, i, j + 1); return i; } } } catch(ArrayIndexOutOfBoundsException e){ break; } } return -1; } @Override public Iterable<String> predictiveSearch(String prefix) { List<String> ret = new ArrayList<String>(); char[] chars = prefix.toCharArray(); int charsLen = chars.length; long checkEnd = arrayIndexToOffset(check.length); long nodeOffset = Unsafe.ARRAY_INT_BASE_OFFSET; for(int i = 0; i < charsLen; i++){ int cid = findCharScaledCode(chars[i]); if(cid == -1) return ret; long next = unsafe.getInt(base, nodeOffset) + cid; if(next < 0 || next >= checkEnd || unsafe.getInt(check, next) != nodeOffset) return ret; nodeOffset = next; } int nodeIndex = offsetToIntArrayIndex(nodeOffset); if(term.get(nodeIndex)){ ret.add(prefix); } Deque<Pair<Long, String>> q = new LinkedList<Pair<Long, String>>(); q.add(Pair.create(nodeOffset, prefix)); while(!q.isEmpty()){ Pair<Long, String> p = q.pop(); long no = p.getFirst(); long b = unsafe.getInt(base, no); if(b == BASE_EMPTY) continue; String c = p.getSecond(); for(char v : this.chars){ long next = b + charToScaledCode[v]; if(next < 0 || next >= checkEnd) continue; if(unsafe.getInt(check, next) == no){ String n = new StringBuilder(c).append(v).toString(); if(term.get(offsetToIntArrayIndex(next))){ ret.add(n); } q.push(Pair.create(next, n)); } } } return ret; } @Override public Iterable<Pair<String, Integer>> predictiveSearchWithTermId( String prefix) { List<Pair<String, Integer>> ret = new ArrayList<Pair<String, Integer>>(); char[] chars = prefix.toCharArray(); int charsLen = chars.length; if(charsLen == 0) return ret; if(this.nodeSize == 0) return ret; long checkEnd = arrayIndexToOffset(check.length); long nodeOffset = Unsafe.ARRAY_INT_BASE_OFFSET; for(int i = 0; i < charsLen; i++){ int cid = findCharScaledCode(chars[i]); if(cid == -1) return ret; long next = unsafe.getInt(base, nodeOffset) + cid; if(next < 0 || next >= checkEnd || unsafe.getInt(check, next) != nodeOffset) return ret; nodeOffset = next; } int nodeIndex = offsetToIntArrayIndex(nodeOffset); if(term.get(nodeIndex)){ ret.add(Pair.create(prefix, term.rank1(nodeIndex) - 1)); } Deque<Pair<Long, String>> q = new LinkedList<Pair<Long, String>>(); q.add(Pair.create(nodeOffset, prefix)); while(!q.isEmpty()){ Pair<Long, String> p = q.pop(); long no = p.getFirst(); int b = unsafe.getInt(base, no); if(b == BASE_EMPTY) continue; String c = p.getSecond(); for(char v : this.chars){ long next = b + charToScaledCode[v]; if(next < 0 || next >= checkEnd) continue; if(unsafe.getInt(check, next) == no){ String n = new StringBuilder(c).append(v).toString(); int nextIndex = offsetToIntArrayIndex(next); if(term.get(nextIndex)){ ret.add(Pair.create( n, term.rank1(nextIndex) - 1 )); } q.push(Pair.create(next, n)); } } } return ret; } @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeInt(size); out.writeInt(nodeSize); out.writeInt(base.length); for(int v : base){ out.writeInt(v); } for(int v : check){ out.writeInt(v); } out.writeObject(term); out.writeInt((int)firstEmptyCheckOffset); out.writeInt(chars.size()); for(char c : chars){ out.writeChar(c); out.writeInt(charToScaledCode[c]); } } public void save(OutputStream os) throws IOException{ ObjectOutputStream out = new ObjectOutputStream(os); try{ writeExternal(out); } finally{ out.flush(); } } @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { size = in.readInt(); nodeSize = in.readInt(); int len = in.readInt(); base = new int[len]; for(int i = 0; i < len; i++){ base[i] = in.readInt(); } check = new int[len]; for(int i = 0; i < len; i++){ check[i] = in.readInt(); } try{ term = (SuccinctBitVector)in.readObject(); } catch(ClassNotFoundException e){ throw new IOException(e); } firstEmptyCheckOffset = in.readInt(); int n = in.readInt(); for(int i = 0; i < n; i++){ char c = in.readChar(); int v = in.readInt(); chars.add(c); charToScaledCode[c] = v; } } public void load(InputStream is) throws IOException{ try{ readExternal(new ObjectInputStream(is)); } catch(ClassNotFoundException e){ throw new IOException(e); } } @Override public void trimToSize(){ int sz = (int)last + 1 + 0xFFFF; base = Arrays.copyOf(base, sz); check = Arrays.copyOf(check, sz); } @Override public void dump(Writer w){ PrintWriter writer = new PrintWriter(w); try{ writer.println("array size: " + base.length); writer.print(" |"); for(int i = 0; i < 16; i++){ writer.print(String.format("%3d|", i)); } writer.println(); writer.print("|base |"); for(int i = 0; i < 16; i++){ if(base[i] == BASE_EMPTY){ writer.print("N/A|"); } else{ writer.print(String.format("%3d|", base[i])); } } writer.println(); writer.print("|check|"); for(int i = 0; i < 16; i++){ if(check[i] < 0){ writer.print("N/A|"); } else{ writer.print(String.format("%3d|", check[i])); } } writer.println(); writer.print("|term |"); for(int i = 0; i < 16; i++){ writer.print(String.format("%3d|", term.get(i) ? 1 : 0)); } writer.println(); writer.print("chars: "); int c = 0; for(char e : chars){ writer.print(String.format("%c:%d,", e, (int)charToScaledCode[e])); c++; if(c > 16) break; } writer.println(); writer.println("chars count: " + chars.size()); writer.println(); } finally{ writer.flush(); } } private void build(Node node, long nodeOffset, FastBitSet bs, TermNodeListener listener){ // letters char[] letters = node.getLetters(); int lettersLen = letters.length; for(int i = 1; i < lettersLen; i++){ bs.unsetIfLE(offsetToIntArrayIndex(nodeOffset)); int cid = getCharScaledCode(letters[i]); long emptyOffset = findFirstEmptyCheckOffset(); setCheck(emptyOffset, nodeOffset); unsafe.putInt(base, nodeOffset, (int)(emptyOffset - cid)); nodeOffset = emptyOffset; } int nodeIndex = offsetToIntArrayIndex(nodeOffset); if(node.isTerminate()){ bs.set(nodeIndex); listener.listen(node, nodeIndex); } else{ bs.unsetIfLE(nodeIndex); } // children Node[] children = node.getChildren(); int childrenLen = children.length; if(childrenLen == 0) return; int[] heads = new int[childrenLen]; int maxHead = 0; int minHead = Integer.MAX_VALUE; for(int i = 0; i < childrenLen; i++){ heads[i] = getCharScaledCode(children[i].getLetters()[0]); maxHead = Math.max(maxHead, heads[i]); minHead = Math.min(minHead, heads[i]); } long offset = findInsertOffset(heads, minHead, maxHead); unsafe.putInt(base, nodeOffset, (int)offset); for(int cid : heads){ setCheck(offset + cid, nodeOffset); } /* for(int i = 0; i < children.length; i++){ build(children[i], offset + heads[i]); } /*/ // sort children by children's children count. Map<Integer, List<Pair<Node, Integer>>> nodes = new TreeMap<Integer, List<Pair<Node, Integer>>>(new Comparator<Integer>() { @Override public int compare(Integer arg0, Integer arg1) { return arg1 - arg0; } }); for(int i = 0; i < children.length; i++){ Node[] c = children[i].getChildren(); int n = 0; if(c != null){ n = c.length; } List<Pair<Node, Integer>> p = nodes.get(n); if(p == null){ p = new ArrayList<Pair<Node, Integer>>(); nodes.put(n, p); } p.add(Pair.create(children[i], heads[i])); } for(Map.Entry<Integer, List<Pair<Node, Integer>>> e : nodes.entrySet()){ for(Pair<Node, Integer> e2 : e.getValue()){ build(e2.getFirst(), e2.getSecond() + offset, bs, listener); } } //*/ } private DoubleArrayNode newDoubleArrayNode(int id){ return new DoubleArrayNode(id); } private DoubleArrayNode newDoubleArrayNode(int id, char s){ return new DoubleArrayNode(id, s); } private int findCharScaledCode(char c){ int v = charToScaledCode[c]; if(v != 0) return v; return -1; } private int offsetToIntArrayIndex(long offset){ return (int)((offset - Unsafe.ARRAY_INT_BASE_OFFSET) / Unsafe.ARRAY_INT_INDEX_SCALE); } private long arrayIndexToOffset(int index){ return Unsafe.ARRAY_INT_BASE_OFFSET + index * Unsafe.ARRAY_INT_INDEX_SCALE; } private long findInsertOffset(int[] headOffsets, int minHeadOffset, int maxHeadOffset){ for(long empty = findFirstEmptyCheckOffset(); ; empty = findNextEmptyCheckOffset(empty)){ long offset = empty - minHeadOffset; if((offset + maxHeadOffset) >= arrayIndexToOffset(check.length)){ extend(offsetToIntArrayIndex(offset + maxHeadOffset)); } // find space boolean found = true; for(int ho : headOffsets){ if(unsafe.getInt(check, offset + ho) >= 0){ found = false; break; } } if(found) return offset; } } private int getCharScaledCode(char c){ int v = charToScaledCode[c]; if(v != 0) return v; chars.add(c); v = (chars.size() + 1) * Unsafe.ARRAY_INT_INDEX_SCALE; charToScaledCode[c] = v; return v; } private void extend(int i){ int sz = base.length; int nsz = Math.max(i + 0xFFFF, (int)(sz * 1.5)); // System.out.println("extend to " + nsz); base = Arrays.copyOf(base, nsz); Arrays.fill(base, sz, nsz, BASE_EMPTY); check = Arrays.copyOf(check, nsz); Arrays.fill(check, sz, nsz, -1 * Unsafe.ARRAY_INT_INDEX_SCALE); } private long findFirstEmptyCheckOffset(){ long i = firstEmptyCheckOffset; while(unsafe.getInt(check, i) >= 0 || unsafe.getInt(base, i) != BASE_EMPTY){ i += Unsafe.ARRAY_INT_INDEX_SCALE; } firstEmptyCheckOffset = i; return i; } private long findNextEmptyCheckOffset(long offset){ /* for(i++; i < check.length; i++){ if(check[i] < 0) return i; } extend(i); return i; /*/ int d = unsafe.getInt(check, offset) * -1; if(d <= 0){ throw new RuntimeException(); } long prev = offset; offset += d; long endOffset = arrayIndexToOffset(check.length); if(endOffset <= offset){ extend(offsetToIntArrayIndex(offset)); return offset; } if(unsafe.getInt(check, offset) < 0){ return offset; } for(offset += Unsafe.ARRAY_INT_INDEX_SCALE; offset < endOffset; offset += Unsafe.ARRAY_INT_INDEX_SCALE){ if(unsafe.getInt(check, offset) < 0){ unsafe.putInt(check, prev, (int)(prev - offset)); return offset; } } extend(offsetToIntArrayIndex(offset)); unsafe.putInt(check, prev, (int)(prev - offset)); return offset; //*/ } private void setCheck(long offset, long nodeOffset){ if(firstEmptyCheckOffset == offset){ firstEmptyCheckOffset = findNextEmptyCheckOffset(firstEmptyCheckOffset); } unsafe.putInt(check, offset, (int)nodeOffset); last = Math.max(last, offset); } private int size; private int nodeSize; private int[] base; private int[] check; private long firstEmptyCheckOffset = Unsafe.ARRAY_INT_BASE_OFFSET + Unsafe.ARRAY_INT_INDEX_SCALE; private long last = Unsafe.ARRAY_INT_BASE_OFFSET; private SuccinctBitVector term; private Set<Character> chars = new TreeSet<Character>(); private int[] charToScaledCode = new int[Character.MAX_VALUE]; private static final int BASE_EMPTY = Integer.MAX_VALUE; private static final DoubleArrayNode[] emptyNodes = {}; }