/* * Copyright 2012 Takao Nakaguchi * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.trie4j.louds; import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.io.Writer; import java.util.ArrayList; import java.util.Arrays; import java.util.Deque; import java.util.LinkedList; import java.util.List; import org.trie4j.AbstractTermIdTrie; import org.trie4j.Node; import org.trie4j.TermIdNode; import org.trie4j.TermIdTrie; import org.trie4j.Trie; import org.trie4j.bv.BytesRank1OnlySuccinctBitVector; import org.trie4j.bv.SuccinctBitVector; import org.trie4j.louds.bvtree.BvTree; import org.trie4j.louds.bvtree.LOUDSBvTree; import org.trie4j.patricia.PatriciaTrie; import org.trie4j.tail.ConcatTailArrayBuilder; import org.trie4j.tail.TailArray; import org.trie4j.tail.TailArrayBuilder; import org.trie4j.tail.TailCharIterator; import org.trie4j.util.FastBitSet; import org.trie4j.util.Pair; import org.trie4j.util.Range; public class TailLOUDSTrie extends AbstractTermIdTrie implements Externalizable, TermIdTrie{ protected static interface NodeListener{ void listen(Node node, int id); } public TailLOUDSTrie(){ this(new PatriciaTrie()); } public TailLOUDSTrie(Trie orig){ this(orig, new LOUDSBvTree(orig.nodeSize())); } public TailLOUDSTrie(Trie orig, BvTree bvTree){ this(orig, bvTree, new ConcatTailArrayBuilder(orig.size() * 4), new NodeListener(){ @Override public void listen(Node node, int id) { } }); } public TailLOUDSTrie(Trie orig, BvTree bvTree, TailArrayBuilder tailArrayBuilder){ this(orig, bvTree, tailArrayBuilder, new NodeListener(){ @Override public void listen(Node node, int id) { } }); } public TailLOUDSTrie(Trie orig, TailArrayBuilder tailArrayBuilder){ this(orig, tailArrayBuilder, new NodeListener(){ @Override public void listen(Node node, int id) { } }); } public TailLOUDSTrie(Trie orig, TailArrayBuilder tailArrayBuilder, NodeListener listener){ this(orig, new LOUDSBvTree(orig.size()), tailArrayBuilder, listener); } public TailLOUDSTrie(Trie orig, BvTree bvTree, TailArrayBuilder tailArrayBuilder, NodeListener listener){ FastBitSet bs = new FastBitSet(orig.size()); build(orig, bvTree, tailArrayBuilder, bs, listener); this.term = new BytesRank1OnlySuccinctBitVector(bs.getBytes(), bs.size()); this.tailArray = tailArrayBuilder.build(); this.bvtree.trimToSize(); } public TailLOUDSTrie(int size, int nodeSize, BvTree bvTree, char[] labels, TailArray tailArray, SuccinctBitVector term){ this.size = size; this.nodeSize = nodeSize; this.bvtree = bvTree; this.labels = labels; this.tailArray = tailArray; this.term = term; } @Override public int size() { return size; } public int nodeSize(){ return nodeSize; } public void setNodeSize(int nodeSize) { this.nodeSize = nodeSize; } @Override public boolean contains(String text) { int nodeId = 0; // root Range r = new Range(); TailCharIterator it = tailArray.newIterator(); int n = text.length(); for(int i = 0; i < n; i++){ nodeId = getChildNode(nodeId, text.charAt(i), r); if(nodeId == -1) return false; it.setOffset(tailArray.getIteratorOffset(nodeId)); while(it.hasNext()){ i++; if(i == n) return false; if(text.charAt(i) != it.next()) return false; } } return term.get(nodeId); } public int getNodeId(String text){ int nodeId = 0; // root Range r = new Range(); TailCharIterator it = tailArray.newIterator(); int n = text.length(); for(int i = 0; i < n; i++){ nodeId = getChildNode(nodeId, text.charAt(i), r); if(nodeId == -1) return -1; it.setOffset(tailArray.getIteratorOffset(nodeId)); while(it.hasNext()){ i++; if(i == n) return -1; if(text.charAt(i) != it.next()) return -1; } } return nodeId; } @Override public int getTermId(String text){ int nodeId = getNodeId(text); if(nodeId == -1) return -1; return term.get(nodeId) ? term.rank1(nodeId) - 1 : -1; } private void build(Trie orig, BvTree bvtree, TailArrayBuilder tailArrayBuilder, FastBitSet termBs, NodeListener listener){ this.bvtree = bvtree; this.size = orig.size(); this.labels = new char[size]; LinkedList<Node> queue = new LinkedList<Node>(); int count = 0; if(orig.getRoot() != null) queue.add(orig.getRoot()); while(!queue.isEmpty()){ Node node = queue.pollFirst(); int index = count++; if(index >= labels.length){ extend(); } listener.listen(node, index); if(node.isTerminate()){ termBs.set(index); } else if(termBs.size() <= index){ termBs.ensureCapacity(index); } for(Node c : node.getChildren()){ bvtree.appendChild(); queue.offerLast(c); } bvtree.appendSelf(); char[] letters = node.getLetters(); if(letters.length == 0){ labels[index] = 0xffff; tailArrayBuilder.appendEmpty(index); } else{ labels[index] = letters[0]; if(letters.length >= 2){ tailArrayBuilder.append(index, letters, 1, letters.length - 1); } else{ tailArrayBuilder.appendEmpty(index); } } } this.nodeSize = count; } public BvTree getBvTree() { return bvtree; } public void setBvtree(BvTree bvtree) { this.bvtree = bvtree; } public char[] getLabels(){ return labels; } public TailArray getTailArray(){ return tailArray; } public SuccinctBitVector getTerm(){ return term; } @Override public TermIdNode getRoot(){ return new LOUDSNode(0); } @Override public void dump(Writer writer) throws IOException{ super.dump(writer); writer.write(bvtree.toString()); writer.write("\nlabels: "); int count = 0; for(char c : labels){ writer.write(c); if(count++ == 99) break; } writer.write("\n"); } @Override public Iterable<Pair<String, Integer>> commonPrefixSearchWithTermId(String query) { List<Pair<String, Integer>> ret = new ArrayList<Pair<String, Integer>>(); char[] chars = query.toCharArray(); int charsLen = chars.length; int nodeId = 0; // root TailCharIterator tci = tailArray.newIterator(); Range r = new Range(); for(int charsIndex = 0; charsIndex < charsLen; charsIndex++){ int child = getChildNode(nodeId, chars[charsIndex], r); if(child == -1) return ret; tci.setOffset(tailArray.getIteratorOffset(child)); while(tci.hasNext()){ charsIndex++; if(charsLen <= charsIndex) return ret; if(chars[charsIndex] != tci.next()) return ret; } if(term.get(child)){ ret.add(Pair.create( new String(chars, 0, charsIndex + 1), term.rank1(child) - 1)); } nodeId = child; } return ret; } @Override public Iterable<Pair<String, Integer>> predictiveSearchWithTermId(String query) { List<Pair<String, Integer>> ret = new ArrayList<Pair<String, Integer>>(); char[] chars = query.toCharArray(); int charsLen = chars.length; int nodeId = 0; // root Range r = new Range(); TailCharIterator tci = tailArray.newIterator(); String pfx = null; int charsIndexBack = 0; for(int charsIndex = 0; charsIndex < charsLen; charsIndex++){ charsIndexBack = charsIndex; int child = getChildNode(nodeId, chars[charsIndex], r); if(child == -1) return ret; tci.setOffset(tailArray.getIteratorOffset(child)); while(tci.hasNext()){ charsIndex++; if(charsIndex >= charsLen) break; if(chars[charsIndex] != tci.next()) return ret; } nodeId = child; } pfx = new String(chars, 0, charsIndexBack); Deque<Pair<Integer, String>> queue = new LinkedList<Pair<Integer,String>>(); queue.offerLast(Pair.create(nodeId, pfx)); while(queue.size() > 0){ Pair<Integer, String> element = queue.pollFirst(); int nid = element.getFirst(); StringBuilder b = new StringBuilder(element.getSecond()); if(nid > 0){ b.append(labels[nid]); } tci.setOffset(tailArray.getIteratorOffset(nid)); while(tci.hasNext()) b.append(tci.next()); String letter = b.toString(); if(term.get(nid)){ ret.add(Pair.create(letter, term.rank1(nid) - 1)); } bvtree.getChildNodeIds(nid, r); for(int i = (r.getEnd() - 1); i >= r.getStart(); i--){ queue.offerFirst(Pair.create(i, letter)); } } return ret; } public class LOUDSNode implements TermIdNode{ public LOUDSNode(int nodeId) { this.nodeId = nodeId; } public int getNodeId() { return nodeId; } @Override public int getTermId() { if(!term.get(nodeId)){ return -1; } else{ return term.rank1(nodeId) - 1; } } @Override public char[] getLetters() { StringBuilder b = new StringBuilder(); char h = labels[nodeId]; if(h != 0xffff){ b.append(h); } int ti = tailArray.getIteratorOffset(nodeId); if(ti != -1){ TailCharIterator it = tailArray.newIterator(ti); it.setOffset(ti); while(it.hasNext()) b.append(it.next()); } return b.toString().toCharArray(); } @Override public boolean isTerminate() { return term.get(nodeId); } @Override public LOUDSNode getChild(char c) { int nid = getChildNode(nodeId, c, new Range()); if(nid == -1) return null; else return new LOUDSNode(nid); } @Override public LOUDSNode[] getChildren() { Range r = new Range(); bvtree.getChildNodeIds(nodeId, r); LOUDSNode[] children = new LOUDSNode[r.getLength()]; for(int i = r.getStart(); i < r.getEnd(); i++){ children[i - r.getStart()] = new LOUDSNode(i); } return children; } private int nodeId; } public void trimToSize(){ if(labels.length > nodeSize){ labels = Arrays.copyOf(labels, nodeSize); } bvtree.trimToSize(); } @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeInt(size); out.writeInt(nodeSize); trimToSize(); out.writeObject(bvtree); out.writeObject(labels); out.writeObject(tailArray); out.writeObject(term); } @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { size = in.readInt(); nodeSize = in.readInt(); bvtree = (BvTree)in.readObject(); labels = (char[])in.readObject(); tailArray = (TailArray)in.readObject(); term = (SuccinctBitVector)in.readObject(); } private int getChildNode(int nodeId, char c, Range r){ bvtree.getChildNodeIds(nodeId, r); int start = r.getStart(); int end = r.getEnd(); if(end == -1) return -1; if((end - start) <= 16){ for(int i = start; i < end; i++){ if(c == labels[i]) return i; } return -1; } else{ do{ int i = (start + end) / 2; int d = c - labels[i]; if(d < 0){ end = i; } else if(d > 0){ if(start == i) return -1; else start = i; } else{ return i; } } while(start != end); return -1; } } private void extend(){ int nsz = (int)(labels.length * 1.2); if(nsz <= labels.length) nsz = labels.length * 2 + 1; labels = Arrays.copyOf(labels, nsz); } private BvTree bvtree; private int size; private char[] labels; private TailArray tailArray; private SuccinctBitVector term; private int nodeSize; }