/*
* Copyright 2012 Takao Nakaguchi
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.trie4j.doublearray;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Comparator;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.trie4j.AbstractTrie;
import org.trie4j.Node;
import org.trie4j.Trie;
import org.trie4j.tail.TailCharIterator;
import org.trie4j.tail.builder.SuffixTrieTailBuilder;
import org.trie4j.tail.builder.TailBuilder;
import org.trie4j.util.Pair;
public class OptimizedTailDoubleArray extends AbstractTrie implements Trie{
private static final int BASE_EMPTY = Integer.MAX_VALUE;
public OptimizedTailDoubleArray(){
}
public OptimizedTailDoubleArray(Trie trie){
this(trie, 65536);
}
public OptimizedTailDoubleArray(Trie trie, int arraySize){
this(trie, arraySize, new SuffixTrieTailBuilder());
}
public OptimizedTailDoubleArray(Trie trie, int arraySize, TailBuilder tb){
size = trie.size();
nodeSize = trie.nodeSize();
base = new int[arraySize];
Arrays.fill(base, BASE_EMPTY);
check = new short[arraySize];
Arrays.fill(check, (short)-1);
tail = new int[arraySize];
Arrays.fill(tail, -1);
term = new BitSet(65536);
int nodeIndex = 0;
base[0] = nodeIndex;
Node root = trie.getRoot();
if(root == null) return;
if(root.getLetters() != null){
if(root.getLetters().length == 0){
if(root.isTerminate()) term.set(0);
} else{
int c = getCharId(root.getLetters()[0]);
check[c] = (short)c;
nodeIndex = c;
}
}
build(root, nodeIndex, tb);
tails = tb.getTails();
}
@Override
public int size() {
return size;
}
@Override
public int nodeSize() {
return nodeSize;
}
@Override
public Node getRoot() {
throw new UnsupportedOperationException();
}
public boolean contains(String text){
char[] chars = text.toCharArray();
int charsIndex = 0;
int nodeIndex = 0;
TailCharIterator it = new TailCharIterator(tails, -1);
while(charsIndex < chars.length){
int tailIndex = tail[nodeIndex];
if(tailIndex != -1){
it.setIndex(tailIndex);
while(it.hasNext()){
if(chars.length <= charsIndex) return false;
if(chars[charsIndex] != it.next()) return false;
charsIndex++;
}
if(chars.length == charsIndex){
if(!it.hasNext()) return term.get(nodeIndex);
else return false;
}
}
int cid = findCharId(chars[charsIndex]);
if(cid == -1) return false;
int i = cid + base[nodeIndex];
if(i < 0 || check.length <= i || (i + check[i]) != nodeIndex) return false;
charsIndex++;
nodeIndex = i;
}
return term.get(nodeIndex);
}
@Override
public Iterable<String> commonPrefixSearch(String query) {
List<String> ret = new ArrayList<String>();
char[] chars = query.toCharArray();
int ci = 0;
int ni = 0;
if(tail[0] != -1){
TailCharIterator it = new TailCharIterator(tails, tail[0]);
while(it.hasNext()){
ci++;
if(ci >= chars.length) return ret;
if(it.next() != chars[ci]) return ret;
}
if(term.get(0)) ret.add(new String(chars, 0, ci + 1));
}
TailCharIterator it = new TailCharIterator(tails, -1);
for(; ci < chars.length; ci++){
int cid = findCharId(chars[ci]);
if(cid == -1) return ret;
int b = base[ni];
if(b == BASE_EMPTY) return ret;
if(b == (BASE_EMPTY - 1)) return ret;
int next = b + cid;
if(check.length <= next || (next + check[next]) != ni) return ret;
ni = next;
if(tail[ni] != -1){
it.setIndex(tail[ni]);
while(it.hasNext()){
ci++;
if(ci >= chars.length) return ret;
if(it.next() != chars[ci]) return ret;
}
}
if(term.get(ni)) ret.add(new String(chars, 0, ci + 1));
}
return ret;
}
@Override
public Iterable<String> predictiveSearch(String prefix) {
List<String> ret = new ArrayList<String>();
StringBuilder current = new StringBuilder();
char[] chars = prefix.toCharArray();
int nodeIndex = 0;
TailCharIterator it = new TailCharIterator(tails, -1);
for(int i = 0; i < chars.length; i++){
int ti = tail[nodeIndex];
if(ti != -1){
int first = i;
it.setIndex(ti);
do{
if(!it.hasNext()) break;
if(it.next() != chars[i]) return ret;
i++;
} while(i < chars.length);
if(i >= chars.length) break;
current.append(chars, first, i - first);
}
int cid = findCharId(chars[i]);
if(cid == -1) return ret;
int next = base[nodeIndex] + cid;
if(next < 0 || check.length <= next || (next + check[next]) != nodeIndex) return ret;
nodeIndex = next;
current.append(chars[i]);
}
Deque<Pair<Integer, char[]>> q = new LinkedList<Pair<Integer,char[]>>();
q.add(Pair.create(nodeIndex, current.toString().toCharArray()));
while(!q.isEmpty()){
Pair<Integer, char[]> p = q.pop();
int ni = p.getFirst();
StringBuilder buff = new StringBuilder().append(p.getSecond());
int ti = tail[ni];
if(ti != -1){
it.setIndex(ti);
while(it.hasNext()){
buff.append(it.next());
}
}
if(term.get(ni)) ret.add(buff.toString());
for(Map.Entry<Character, Integer> e : charCodes.entrySet()){
int b = base[ni];
if(b == BASE_EMPTY) continue;
if(b == (BASE_EMPTY - 1)) continue;
int next = b + e.getValue();
if(check.length <= next) continue;
if(next + check[next] == ni){
StringBuilder bu = new StringBuilder(buff);
bu.append(e.getKey());
q.push(Pair.create(next, bu.toString().toCharArray()));
}
}
}
return ret;
/*/
List<String> ret = new ArrayList<String>();
StringBuilder current = new StringBuilder();
char[] chars = prefix.toCharArray();
int nodeIndex = 0;
for(int i = 0; i < chars.length; i++){
int ti = tail[nodeIndex];
if(ti != -1){
int first = i;
TailCharIterator it = new TailCharIterator(tails, ti);
do{
if(!it.hasNext()) break;
if(it.next() != chars[i]) return ret;
i++;
} while(i < chars.length);
if(i >= chars.length) break;
current.append(chars, i, i - first);
}
int cid = findCharId(chars[i]);
if(cid == -1) return ret;
int next = base[nodeIndex] + cid;
if(next < 0 || check.length <= next || check[next] != cid) return ret;
nodeIndex = next;
current.append(chars[i]);
}
Deque<Pair<Integer, char[]>> q = new LinkedList<Pair<Integer,char[]>>();
q.add(Pair.create(nodeIndex, current.toString().toCharArray()));
while(!q.isEmpty()){
Pair<Integer, char[]> p = q.pop();
int ni = p.getFirst();
StringBuilder buff = new StringBuilder().append(p.getSecond());
int ti = tail[ni];
if(ti != -1){
TailCharIterator it = new TailCharIterator(tails, ti);
while(it.hasNext()){
buff.append(it.next());
}
}
if(term.get(ni)) ret.add(buff.toString());
for(Map.Entry<Character, Integer> e : charCodes.entrySet()){
int b = base[ni];
if(b == BASE_EMPTY) continue;
int next = b + e.getValue();
if(check.length <= next) continue;
if(check[next] == e.getValue()){
StringBuilder bu = new StringBuilder(buff);
bu.append(e.getKey());
q.push(Pair.create(next, bu.toString().toCharArray()));
}
}
}
return ret;
//*/ }
/**
* Double Array currently not support dynamic construction.
*/
@Override
public void insert(String word) {
throw new UnsupportedOperationException();
}
public void save(OutputStream os) throws IOException{
BufferedOutputStream bos = new BufferedOutputStream(os);
DataOutputStream dos = new DataOutputStream(bos);
dos.writeInt(size);
dos.writeInt(nodeSize);
dos.writeInt(base.length);
for(int v : base){
dos.writeInt(v);
}
for(int v : check){
dos.writeShort(v);
}
for(int v : tail){
dos.writeInt(v);
}
dos.flush();
ObjectOutputStream oos = new ObjectOutputStream(bos);
oos.writeObject(term);
oos.flush();
dos.writeInt(tails.length());
dos.writeChars(tails.toString());
dos.writeInt(charCodes.size());
for(Map.Entry<Character, Integer> e : charCodes.entrySet()){
dos.writeChar(e.getKey());
dos.writeInt(e.getValue());
}
dos.flush();
bos.flush();
}
public void load(InputStream is) throws IOException{
BufferedInputStream bis = new BufferedInputStream(is);
DataInputStream dis = new DataInputStream(bis);
size = dis.readInt();
nodeSize = dis.readInt();
int len = dis.readInt();
base = new int[len];
for(int i = 0; i < len; i++){
base[i] = dis.readInt();
}
check = new short[len];
for(int i = 0; i < len; i++){
check[i] = dis.readShort();
}
tail = new int[len];
for(int i = 0; i < len; i++){
tail[i] = dis.readInt();
}
ObjectInputStream ois = new ObjectInputStream(bis);
try{
term = (BitSet)ois.readObject();
} catch(ClassNotFoundException e){
throw new IOException(e);
}
int n = dis.readInt();
StringBuilder b = new StringBuilder(n);
for(int i = 0; i < n; i++){
b.append(dis.readChar());
}
tails = b;
n = dis.readInt();
for(int i = 0; i < n; i++){
char c = dis.readChar();
int v = dis.readInt();
charCodes.put(c, v);
}
}
public void dump(){
System.out.println("array size: " + base.length);
System.out.println("last index of valid element: " + last);
int vc = 0;
for(int i = 0; i < base.length; i++){
if(base[i] != BASE_EMPTY || check[i] >= 0) vc++;
}
System.out.println("valid elements: " + vc);
System.out.print(" |");
for(int i = 0; i < 16; i++){
System.out.print(String.format("%3d|", i));
}
System.out.println();
System.out.print("|base |");
for(int i = 0; i < 16; i++){
if(base[i] == BASE_EMPTY){
System.out.print("N/A|");
} else{
System.out.print(String.format("%3d|", base[i]));
}
}
System.out.println();
System.out.print("|check|");
for(int i = 0; i < 16; i++){
System.out.print(String.format("%3d|", check[i]));
}
System.out.println();
System.out.print("|tail |");
for(int i = 0; i < 16; i++){
if(tail[i] < 0){
System.out.print("N/A|");
} else{
System.out.print(String.format("%3d|", tail[i]));
}
}
System.out.println();
System.out.print("|term |");
for(int i = 0; i < 16; i++){
System.out.print(String.format("%3d|", term.get(i) ? 1 : 0));
}
System.out.println();
int count = 0;
for(int i : tail){
if(i != -1) count++;
}
System.out.println("tail count: " + count);
System.out.println();
System.out.print("tails: [");
char[] tailChars = tails.subSequence(0, Math.min(tails.length(), 64)).toString().toCharArray();
for(int i = 0; i < tailChars.length; i++){
char c = tailChars[i];
if(c == '\0'){
System.out.print("\\0");
continue;
}
if(c == '\1'){
int index = tailChars[i + 1] + (tailChars[i + 2] << 16);
i += 2;
System.out.print(String.format("\\1(%d)", index));
continue;
}
System.out.print(c);
}
System.out.println("]");
System.out.println("tailBuf size: " + tails.length());
{
System.out.print("chars: ");
int c = 0;
for(Map.Entry<Character, Integer> e : charCodes.entrySet()){
System.out.print(String.format("%c:%d,", e.getKey(), e.getValue()));
c++;
if(c > 16) break;
}
System.out.println();
System.out.println("chars count: " + charCodes.size());
}
{
System.out.println("calculating max and min base.");
int min = Integer.MAX_VALUE;
int max = Integer.MIN_VALUE;
int maxDelta = Integer.MIN_VALUE;
for(int i = 0; i < base.length; i++){
int b = base[i];
if(b == BASE_EMPTY) continue;
min = Math.min(min, b);
max = Math.max(max, b);
maxDelta = Math.max(maxDelta, Math.abs(i - b));
}
System.out.println("maxDelta: " + maxDelta);
System.out.println("max: " + max);
System.out.println("min: " + min);
}
{
System.out.println("calculating min check.");
int min = Integer.MAX_VALUE;
for(int i = 0; i < base.length; i++){
int b = check[i];
if(b == BASE_EMPTY) continue;
min = Math.min(min, b);
}
System.out.println("min: " + min);
}
System.out.println();
}
public void trimToSize(){
int sz = last + 1;
int[] nb = new int[sz];
System.arraycopy(base, 0, nb, 0, sz);
base = nb;
short[] nc = new short[sz];
System.arraycopy(check, 0, nc, 0, sz);
check = nc;
int[] nt = new int[sz];
System.arraycopy(tail, 0, nt, 0, sz);
tail = nt;
if(tails instanceof StringBuilder){
((StringBuilder)tails).trimToSize();
}
}
private void build(Node node, int nodeIndex, TailBuilder tb){
// letters
char[] letters = node.getLetters();
if(letters != null){
if(letters.length > 1){
int tailIndex = tb.insert(letters, 1, letters.length - 1);
tail[nodeIndex] = tailIndex;
}
if(node.isTerminate()){
term.set(nodeIndex);
}
}
// children
Node[] children = node.getChildren();
if(children == null || children.length == 0) return;
int[] heads = new int[children.length];
int maxHead = 0;
int minHead = Integer.MAX_VALUE;
for(int i = 0; i < children.length; i++){
heads[i] = getCharId(children[i].getLetters()[0]);
maxHead = Math.max(maxHead, heads[i]);
minHead = Math.min(minHead, heads[i]);
}
int empty = findFirstEmptyCheck(nodeIndex);
int offset = empty - minHead;
while(true){
if(check.length <= (offset + maxHead)){
extend(offset + maxHead);
}
// find space
boolean found = true;
for(int cid : heads){
if(check[offset + cid] >= 0){
found = false;
break;
}
}
if(found) break;
empty = findNextEmptyCheck(nodeIndex, empty);
offset = empty - minHead;
}
base[nodeIndex] = offset;
for(int cid : heads){
if(cid > Short.MAX_VALUE){
throw new RuntimeException("check value overflow");
}
setCheck(offset + cid, (short)(nodeIndex - (offset + cid)));
}
/*
for(int i = 0; i < children.length; i++){
build(children[i], offset + heads[i]);
}
/*/
// sort children by children's children count.
Map<Integer, List<Pair<Node, Integer>>> nodes = new TreeMap<Integer, List<Pair<Node, Integer>>>(new Comparator<Integer>() {
@Override
public int compare(Integer arg0, Integer arg1) {
return arg1 - arg0;
}
});
for(int i = 0; i < children.length; i++){
Node[] c = children[i].getChildren();
int n = 0;
if(c != null){
n = c.length;
}
List<Pair<Node, Integer>> p = nodes.get(n);
if(p == null){
p = new ArrayList<Pair<Node, Integer>>();
nodes.put(n, p);
}
p.add(Pair.create(children[i], heads[i]));
}
for(Map.Entry<Integer, List<Pair<Node, Integer>>> e : nodes.entrySet()){
for(Pair<Node, Integer> e2 : e.getValue()){
build(e2.getFirst(), e2.getSecond() + offset, tb);
}
}
//*/
}
private int getCharId(char c){
Integer cid = charCodes.get(c);
if(cid == null){
cid = charCodes.size() + 1;
if(cid > Short.MAX_VALUE){
throw new RuntimeException("too many kinds of character(max: 32767).");
}
charCodes.put(c, cid);
}
return cid;
}
private int findCharId(char c){
Integer cid = charCodes.get(c);
if(cid == null){
return -1;
}
return cid;
}
private void extend(int i){
int sz = base.length;
int nsz = Math.max(i, (int)(sz * 1.5));
// System.out.println("extend to " + nsz);
int[] nb = new int[nsz];
System.arraycopy(base, 0, nb, 0, sz);
Arrays.fill(nb, sz, nsz, BASE_EMPTY);
base = nb;
short[] nc = new short[nsz];
System.arraycopy(check, 0, nc, 0, sz);
Arrays.fill(nc, sz, nsz, (short)-1);
check = nc;
int[] nt = new int[nsz];
System.arraycopy(tail, 0, nt, 0, sz);
Arrays.fill(nt, sz, nsz, -1);
tail = nt;
}
private int findFirstEmptyCheck(int baseNodeIndex){
int i = Math.max(baseNodeIndex - Short.MAX_VALUE, 0);
while(check[i] >= 0 || base[i] != BASE_EMPTY){
i++;
}
return i;
}
private int findNextEmptyCheck(int baseNodeIndex, int i){
/*
for(i++; i < check.length; i++){
if(check[i] < 0) return i;
}
extend(i);
return i;
/*/
int d = check[i] * -1;
if(d <= 0){
throw new RuntimeException();
}
int prev = i;
i += d;
if(check.length <= i){
extend(i);
return i;
}
if(check[i] < 0){
return i;
}
for(i++; i < check.length; i++){
if(check[i] < 0 && base[i] == BASE_EMPTY){
int v = baseNodeIndex - i;
if(v < Short.MIN_VALUE){
throw new RuntimeException("check value overflow");
}
check[prev] = (short)v;
return i;
}
}
extend(i);
int v = prev - i;
if(v < Short.MIN_VALUE){
throw new RuntimeException("check value overflow");
}
check[prev] = (short)v;
return i;
//*/
}
private void setCheck(int index, short value){
check[index] = value;
last = Math.max(last, index);
if(base[index] == BASE_EMPTY) base[index]--;
}
private int size;
private int nodeSize;
private int[] base;
private short[] check;
private int[] tail;
private int last;
private BitSet term;
private CharSequence tails;
private Map<Character, Integer> charCodes = new TreeMap<Character, Integer>(new Comparator<Character>(){
@Override
public int compare(Character arg0, Character arg1) {
return arg1 - arg0;
}
});
}