/* * Copyright 2012 Takao Nakaguchi * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.trie4j.doublearray; import java.io.Externalizable; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInput; import java.io.ObjectInputStream; import java.io.ObjectOutput; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.io.Writer; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.Deque; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; import java.util.TreeSet; import org.trie4j.AbstractTermIdTrie; import org.trie4j.Node; import org.trie4j.TermIdNode; import org.trie4j.TermIdTrie; import org.trie4j.Trie; import org.trie4j.bv.BytesRank1OnlySuccinctBitVector; import org.trie4j.bv.SuccinctBitVector; import org.trie4j.util.BitSet; import org.trie4j.util.FastBitSet; import org.trie4j.util.Pair; public class DoubleArray extends AbstractTermIdTrie implements Externalizable, TermIdTrie{ public static interface TermNodeListener{ void listen(Node node, int nodeIndex); } public DoubleArray() { } public DoubleArray(Trie trie){ this(trie, trie.size() * 2); } public DoubleArray(Trie trie, int arraySize){ this(trie, arraySize, new TermNodeListener(){ @Override public void listen(Node node, int nodeIndex) { } }); } public DoubleArray(Trie trie, int arraySize, TermNodeListener listener){ if(arraySize <= 1) arraySize = 2; size = trie.size(); base = new int[arraySize]; Arrays.fill(base, BASE_EMPTY); check = new int[arraySize]; Arrays.fill(check, -1); FastBitSet bs = new FastBitSet(arraySize); nodeSize = 1; // for root node because it has no letter; build(trie.getRoot(), 0, bs, listener); term = new BytesRank1OnlySuccinctBitVector(bs.getBytes(), bs.size()); base = Arrays.copyOf(base, last + chars.size()); check = Arrays.copyOf(check, last + chars.size()); } @Override public int nodeSize() { return nodeSize; } @Override public int size() { return size; } @Override public TermIdNode getRoot() { return newDoubleArrayNode(0); } public int[] getBase(){ return base; } public int[] getCheck(){ return check; } public BitSet getTerm() { return term; } protected class DoubleArrayNode implements TermIdNode{ public DoubleArrayNode(int nodeId){ this.nodeId = nodeId; } public DoubleArrayNode(int nodeId, char firstChar){ this.nodeId = nodeId; this.firstChar = firstChar; } @Override public boolean isTerminate() { return term.get(nodeId); } @Override public char[] getLetters() { StringBuilder ret = new StringBuilder(); if(firstChar != 0) ret.append(firstChar); return ret.toString().toCharArray(); } @Override public DoubleArrayNode[] getChildren() { CharSequence children = listupChildChars(nodeId); if(children.length() == 0) return emptyNodes; return listupChildNodes(base[nodeId], children); } @Override public DoubleArrayNode getChild(char c) { int code = charToCode[c]; if(code == -1) return null; int nid = base[nodeId] + code; if(nid >= 0 && nid < check.length && check[nid] == nodeId) return new DoubleArrayNode(nid, c); return null; } public int getNodeId() { return nodeId; } @Override public int getTermId(){ if(!term.get(nodeId)){ return -1; } return term.rank1(nodeId) - 1; } private CharSequence listupChildChars(int nodeId){ StringBuilder b = new StringBuilder(); int bs = base[nodeId]; for(char c : chars){ int nid = bs + charToCode[c]; if(nid >= 0 && nid < check.length && check[nid] == nodeId){ b.append(c); } } return b; } private DoubleArrayNode[] listupChildNodes(int base, CharSequence chars){ int n = chars.length(); DoubleArrayNode[] ret = new DoubleArrayNode[n]; for(int i = 0; i < n; i++){ char c = chars.charAt(i); char code = charToCode[c]; ret[i] = newDoubleArrayNode(base + code, c); } return ret; } private char firstChar = 0; private int nodeId; } @Override public boolean contains(String text){ int nodeIndex = 0; // root int n = text.length(); for(int i = 0; i < n; i++){ char cid = charToCode[text.charAt(i)]; if(cid == 0) return false; int next = base[nodeIndex] + cid; if(next < 0 || check[next] != nodeIndex) return false; nodeIndex = next; } return term.get(nodeIndex); } public int getNodeId(String text) { int nodeIndex = 0; // root int n = text.length(); for(int i = 0; i < n; i++){ char cid = charToCode[text.charAt(i)]; if(cid == 0) return -1; int next = base[nodeIndex] + cid; if(next < 0 || check[next] != nodeIndex) return -1; nodeIndex = next; } return nodeIndex; } @Override public int getTermId(String text) { int nid = getNodeId(text); if(nid == -1) return -1; return term.get(nid) ? term.rank1(nid) - 1 : -1; } @Override public Iterable<String> commonPrefixSearch(String query) { List<String> ret = new ArrayList<String>(); char[] chars = query.toCharArray(); int charsLen = chars.length; int checkLen = check.length; int nodeIndex = 0; for(int i = 0; i < charsLen; i++){ int cid = findCharId(chars[i]); if(cid == -1) return ret; int b = base[nodeIndex]; if(b == BASE_EMPTY) return ret; int next = b + cid; if(next >= checkLen || check[next] != nodeIndex) return ret; nodeIndex = next; if(term.get(nodeIndex)) ret.add(new String(chars, 0, i + 1)); } return ret; } @Override public Iterable<Pair<String, Integer>> commonPrefixSearchWithTermId( String query) { List<Pair<String, Integer>> ret = new ArrayList<Pair<String, Integer>>(); char[] chars = query.toCharArray(); int charsLen = chars.length; int checkLen = check.length; int nodeIndex = 0; for(int i = 0; i < charsLen; i++){ int cid = findCharId(chars[i]); if(cid == -1) return ret; int b = base[nodeIndex]; if(b == BASE_EMPTY) return ret; int next = b + cid; if(next >= checkLen || check[next] != nodeIndex) return ret; nodeIndex = next; if(term.get(nodeIndex)){ ret.add(Pair.create( new String(chars, 0, i + 1), term.rank1(nodeIndex) - 1 )); } } return ret; } @Override public int findWord(CharSequence chars, int start, int end, StringBuilder word) { for(int i = start; i < end; i++){ int nodeIndex = 0; try{ for(int j = i; j < end; j++){ int cid = findCharId(chars.charAt(j)); if(cid == -1) break; int b = base[nodeIndex]; if(b == BASE_EMPTY) break; int next = b + cid; if(nodeIndex != check[next]) break; nodeIndex = next; if(term.get(nodeIndex)){ if(word != null) word.append(chars, i, j + 1); return i; } } } catch(ArrayIndexOutOfBoundsException e){ break; } } return -1; } @Override public Iterable<String> predictiveSearch(String prefix) { List<String> ret = new ArrayList<String>(); char[] chars = prefix.toCharArray(); int charsLen = chars.length; int checkLen = check.length; int nodeIndex = 0; for(int i = 0; i < charsLen; i++){ int cid = findCharId(chars[i]); if(cid == -1) return ret; int next = base[nodeIndex] + cid; if(next < 0 || next >= checkLen || check[next] != nodeIndex) return ret; nodeIndex = next; } if(term.get(nodeIndex)){ ret.add(prefix); } Deque<Pair<Integer, String>> q = new LinkedList<Pair<Integer, String>>(); q.add(Pair.create(nodeIndex, prefix)); while(!q.isEmpty()){ Pair<Integer, String> p = q.pop(); int ni = p.getFirst(); int b = base[ni]; if(b == BASE_EMPTY) continue; String c = p.getSecond(); for(char v : this.chars){ int next = b + charToCode[v]; if(next < 0 || next >= checkLen) continue; if(check[next] == ni){ String n = new StringBuilder(c).append(v).toString(); if(term.get(next)){ ret.add(n); } q.push(Pair.create(next, n)); } } } return ret; } @Override public Iterable<Pair<String, Integer>> predictiveSearchWithTermId( String prefix) { List<Pair<String, Integer>> ret = new ArrayList<Pair<String, Integer>>(); char[] chars = prefix.toCharArray(); int charsLen = chars.length; if(charsLen == 0) return ret; if(this.nodeSize == 0) return ret; int checkLen = check.length; int nodeIndex = 0; for(int i = 0; i < charsLen; i++){ int cid = findCharId(chars[i]); if(cid == -1) return ret; int next = base[nodeIndex] + cid; if(next < 0 || next >= checkLen || check[next] != nodeIndex) return ret; nodeIndex = next; } if(term.get(nodeIndex)){ ret.add(Pair.create(prefix, term.rank1(nodeIndex) - 1)); } Deque<Pair<Integer, String>> q = new LinkedList<Pair<Integer, String>>(); q.add(Pair.create(nodeIndex, prefix)); while(!q.isEmpty()){ Pair<Integer, String> p = q.pop(); int ni = p.getFirst(); int b = base[ni]; if(b == BASE_EMPTY) continue; String c = p.getSecond(); for(char v : this.chars){ int next = b + charToCode[v]; if(next < 0 || next >= checkLen) continue; if(check[next] == ni){ String n = new StringBuilder(c).append(v).toString(); if(term.get(next)){ ret.add(Pair.create( n, term.rank1(next) - 1 )); } q.push(Pair.create(next, n)); } } } return ret; } @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeInt(size); out.writeInt(nodeSize); out.writeInt(base.length); for(int v : base){ out.writeInt(v); } for(int v : check){ out.writeInt(v); } out.writeObject(term); out.writeInt(firstEmptyCheck); out.writeInt(chars.size()); for(char c : chars){ out.writeChar(c); out.writeChar(charToCode[c]); } } public void save(OutputStream os) throws IOException{ ObjectOutputStream out = new ObjectOutputStream(os); try{ writeExternal(out); } finally{ out.flush(); } } @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { size = in.readInt(); nodeSize = in.readInt(); int len = in.readInt(); base = new int[len]; for(int i = 0; i < len; i++){ base[i] = in.readInt(); } check = new int[len]; for(int i = 0; i < len; i++){ check[i] = in.readInt(); } try{ term = (SuccinctBitVector)in.readObject(); } catch(ClassNotFoundException e){ throw new IOException(e); } firstEmptyCheck = in.readInt(); int n = in.readInt(); for(int i = 0; i < n; i++){ char c = in.readChar(); char v = in.readChar(); chars.add(c); charToCode[c] = v; } } public void load(InputStream is) throws IOException{ try{ readExternal(new ObjectInputStream(is)); } catch(ClassNotFoundException e){ throw new IOException(e); } } @Override public void trimToSize(){ int sz = last + 1 + 0xFFFF; base = Arrays.copyOf(base, sz); check = Arrays.copyOf(check, sz); } @Override public void dump(Writer w){ PrintWriter writer = new PrintWriter(w); try{ int n = Math.min(16, base.length); writer.println("array size: " + base.length); writer.print(" |"); for(int i = 0; i < n; i++){ writer.print(String.format("%3d|", i)); } writer.println(); writer.print("|base |"); for(int i = 0; i < n; i++){ if(base[i] == BASE_EMPTY){ writer.print("N/A|"); } else{ writer.print(String.format("%3d|", base[i])); } } writer.println(); writer.print("|check|"); for(int i = 0; i < n; i++){ if(check[i] < 0){ writer.print("N/A|"); } else{ writer.print(String.format("%3d|", check[i])); } } writer.println(); writer.print("|term |"); for(int i = 0; i < n; i++){ writer.print(String.format("%3d|", term.get(i) ? 1 : 0)); } writer.println(); writer.print("chars: "); int c = 0; for(char e : chars){ writer.print(String.format("%c:%d,", e, (int)charToCode[e])); c++; if(c > 16) break; } writer.println(); writer.println("chars count: " + chars.size()); writer.println(); } finally{ writer.flush(); } } private void build(Node node, int nodeIndex, FastBitSet bs, TermNodeListener listener){ // letters char[] letters = node.getLetters(); int lettersLen = letters.length; if(lettersLen > 0) nodeSize++; // for first letter for(int i = 1; i < lettersLen; i++){ bs.unsetIfLE(nodeIndex); int cid = getCharId(letters[i]); int empty = findFirstEmptyCheck(); setCheck(empty, nodeIndex); base[nodeIndex] = empty - cid; nodeSize++; nodeIndex = empty; } if(node.isTerminate()){ bs.set(nodeIndex); listener.listen(node, nodeIndex); } else{ bs.unsetIfLE(nodeIndex); } // children Node[] children = node.getChildren(); int childrenLen = children.length; if(childrenLen == 0) return; int[] heads = new int[childrenLen]; int maxHead = 0; int minHead = Integer.MAX_VALUE; for(int i = 0; i < childrenLen; i++){ heads[i] = getCharId(children[i].getLetters()[0]); maxHead = Math.max(maxHead, heads[i]); minHead = Math.min(minHead, heads[i]); } int offset = findInsertOffset(heads, minHead, maxHead); base[nodeIndex] = offset; for(int cid : heads){ setCheck(offset + cid, nodeIndex); } /* for(int i = 0; i < children.length; i++){ build(children[i], offset + heads[i]); } /*/ // sort children by children's children count. Map<Integer, List<Pair<Node, Integer>>> nodes = new TreeMap<Integer, List<Pair<Node, Integer>>>(new Comparator<Integer>() { @Override public int compare(Integer arg0, Integer arg1) { return arg1 - arg0; } }); for(int i = 0; i < children.length; i++){ Node[] c = children[i].getChildren(); int n = 0; if(c != null){ n = c.length; } List<Pair<Node, Integer>> p = nodes.get(n); if(p == null){ p = new ArrayList<Pair<Node, Integer>>(); nodes.put(n, p); } p.add(Pair.create(children[i], heads[i])); } for(Map.Entry<Integer, List<Pair<Node, Integer>>> e : nodes.entrySet()){ for(Pair<Node, Integer> e2 : e.getValue()){ build(e2.getFirst(), e2.getSecond() + offset, bs, listener); } } //*/ } private DoubleArrayNode newDoubleArrayNode(int id){ return new DoubleArrayNode(id); } private DoubleArrayNode newDoubleArrayNode(int id, char s){ return new DoubleArrayNode(id, s); } private int findCharId(char c){ char v = charToCode[c]; if(v != 0) return v; return -1; } private int findInsertOffset(int[] heads, int minHead, int maxHead){ for(int empty = findFirstEmptyCheck(); ; empty = findNextEmptyCheck(empty)){ int offset = empty - minHead; if((offset + maxHead) >= check.length){ extend(offset + maxHead); } // find space boolean found = true; for(int cid : heads){ if(check[offset + cid] >= 0){ found = false; break; } } if(found) return offset; } } private int getCharId(char c){ char v = charToCode[c]; if(v != 0) return v; v = (char)(chars.size() + 1); chars.add(c); charToCode[c] = v; return v; } private void extend(int i){ int sz = base.length; int nsz = Math.max(i + 0xFFFF, (int)(sz * 1.5)); // System.out.println("extend to " + nsz); base = Arrays.copyOf(base, nsz); Arrays.fill(base, sz, nsz, BASE_EMPTY); check = Arrays.copyOf(check, nsz); Arrays.fill(check, sz, nsz, -1); } private int findFirstEmptyCheck(){ int i = firstEmptyCheck; while(check[i] >= 0 || base[i] != BASE_EMPTY){ i++; } firstEmptyCheck = i; return i; } private int findNextEmptyCheck(int i){ /* for(i++; i < check.length; i++){ if(check[i] < 0) return i; } extend(i); return i; /*/ int d = check[i] * -1; if(d <= 0){ throw new RuntimeException(); } int prev = i; i += d; if(check.length <= i){ extend(i); return i; } if(check[i] < 0){ return i; } for(i++; i < check.length; i++){ if(check[i] < 0){ check[prev] = prev - i; return i; } } extend(i); check[prev] = prev - i; return i; //*/ } private void setCheck(int index, int id){ if(firstEmptyCheck == index){ firstEmptyCheck = findNextEmptyCheck(firstEmptyCheck); } check[index] = id; last = Math.max(last, index); } private int size; private int nodeSize; private int[] base; private int[] check; private int firstEmptyCheck = 1; private int last; private SuccinctBitVector term; private Set<Character> chars = new TreeSet<Character>(); private char[] charToCode = new char[Character.MAX_VALUE]; private static final int BASE_EMPTY = Integer.MAX_VALUE; private static final DoubleArrayNode[] emptyNodes = {}; }