package com.compomics.util.experiment.identification.protein_inference.fm_index; import com.compomics.util.waiting.WaitingHandler; import java.util.ArrayList; import java.util.Collections; /** * Wavelet tree. * * @author Dominik Kopczynski */ public class WaveletTree { /** * Instance of a rank. */ private Rank rank; /** * Stored alphabet in a 128 bitfield. */ private long[] alphabetDirections = new long[2]; // 1 equals left child /** * First character in alphabet. */ private int firstChar; /** * Last character in alphabet. */ private int lastChar; /** * Text length. */ private int lenText; /** * Continue range query for left child. */ private boolean continueLeftRangeQuery = false; /** * Continue range query for right child. */ private boolean continueRightRangeQuery = false; /** * Left child of the wavelet tree. */ private WaveletTree leftChild; /** * Right child of the wavelet tree. */ private WaveletTree rightChild; /** * Shift number for fast bitwise divisions. */ private final int shift = 6; /** * Mask for fast bitwise modulo operations. */ private final int mask = 63; /** * number of masses */ private int numMasses; /** * left right mask */ private int leftRightMask; private int[] less; /** * Class for huffman nodes. */ public class HuffmanNode implements Comparable<HuffmanNode> { long[] alphabet = new long[]{0, 0}; int counts = 0; int depth = 0; int innernodes = 0; HuffmanNode leftChild = null; HuffmanNode rightChild = null; ArrayList<Byte> charAlphabet = new ArrayList<Byte>(); public HuffmanNode(int counts, int character) { this.counts = counts; alphabet[character >>> shift] |= 1L << (character & mask); charAlphabet.add((byte) character); } public HuffmanNode(HuffmanNode first, HuffmanNode second) { this.counts = first.counts + second.counts; alphabet[0] |= first.alphabet[0]; alphabet[0] |= second.alphabet[0]; alphabet[1] |= first.alphabet[1]; alphabet[1] |= second.alphabet[1]; leftChild = first; rightChild = second; charAlphabet.addAll(first.charAlphabet); charAlphabet.addAll(second.charAlphabet); depth = Math.max(first.depth, second.depth) + 1; innernodes = first.innernodes + second.innernodes + 1; } @Override public int compareTo(HuffmanNode argument) { if (counts < argument.counts) { return -1; } if (counts > argument.counts) { return 1; } return 0; } } /** * Constructor. * * @param text the text * @param aAlphabet the alphabet * @param waitingHandler the waiting handler * @param numMasses number of masses plus modifications * @param hasPTMatTerminus indicates how to handle / sign */ public WaveletTree(byte[] text, long[] aAlphabet, WaitingHandler waitingHandler, int numMasses, boolean hasPTMatTerminus) { prepareWaveletTree(text, aAlphabet, waitingHandler, numMasses, hasPTMatTerminus); } /** * Constructor. * * @param text the text * @param aAlphabet the alphabet * @param waitingHandler the waiting handler * @param numMasses number of masses plus modifications */ public WaveletTree(byte[] text, long[] aAlphabet, WaitingHandler waitingHandler, int numMasses) { prepareWaveletTree(text, aAlphabet, waitingHandler, numMasses, false); } /** * Constructor forward function. * * @param text the text * @param aAlphabet the alphabet * @param waitingHandler the waiting handler * @param numMasses number of masses plus modifications * @param hasPTMatTerminus indicates how to handle / sign */ private void prepareWaveletTree(byte[] text, long[] aAlphabet, WaitingHandler waitingHandler, int numMasses, boolean hasPTMatTerminus) { int[] counts = new int[128]; for (byte c : text) { ++counts[c]; } ArrayList<HuffmanNode> huffmanNodes = new ArrayList<HuffmanNode>(); for (int i = 0; i < 128; ++i) { if (((aAlphabet[i >>> shift] >>> (i & mask)) & 1L) == 1) { huffmanNodes.add(new HuffmanNode(counts[i], i)); } } while (huffmanNodes.size() > 1) { Collections.sort(huffmanNodes); HuffmanNode first = huffmanNodes.remove(0); HuffmanNode second = huffmanNodes.remove(0); huffmanNodes.add(new HuffmanNode(first, second)); } createWaveletTreeHuffman(text, waitingHandler, huffmanNodes.get(0), numMasses, hasPTMatTerminus); less = new int[128]; long[] alphabet = new long[2]; alphabet[0] = huffmanNodes.get(0).alphabet[0]; alphabet[1] = huffmanNodes.get(0).alphabet[1]; int cumulativeSum = 0; for (int i = 0; i < 128; ++i) { less[i] = cumulativeSum; if (((alphabet[i >>> shift] >>> (i & mask)) & 1L) != 0) { cumulativeSum += getRank(lenText - 1, i); } } } /** * Constructor. * * @param text the text * @param waitingHandler the waiting handler * @param root the root * @param numMasses number of masses plus modifications * @param hasPTMatTerminus if there is a PTM at the terminus */ public WaveletTree(byte[] text, WaitingHandler waitingHandler, HuffmanNode root, int numMasses, boolean hasPTMatTerminus) { this.numMasses = numMasses; createWaveletTreeHuffman(text, waitingHandler, root, numMasses, hasPTMatTerminus); } /** * Create wavelet tree huffman. * * @param text the text * @param waitingHandler the waiting handler * @param root the root * @param numMasses number of masses plus modifications * @param hasPTMatTerminus if there is a PTM at the terminus */ public void createWaveletTreeHuffman(byte[] text, WaitingHandler waitingHandler, HuffmanNode root, int numMasses, boolean hasPTMatTerminus) { this.numMasses = numMasses; long[] alphabet = new long[2]; alphabet[0] = root.alphabet[0]; alphabet[1] = root.alphabet[1]; long[] alphabetExcluded = new long[2]; alphabetExcluded[0] = 1L << '$'; if (!hasPTMatTerminus) alphabetExcluded[0] |= 1L << '/'; alphabetExcluded[1] = 1L << ('B' & 63); alphabetExcluded[1] |= 1L << ('X' & 63); alphabetExcluded[1] |= 1L << ('Z' & 63); long[] alphabet_left = new long[2]; long[] alphabet_right = new long[2]; alphabetDirections[0] = alphabet_left[0] = root.leftChild.alphabet[0]; alphabetDirections[1] = alphabet_left[1] = root.leftChild.alphabet[1]; alphabet_right[0] = root.rightChild.alphabet[0]; alphabet_right[1] = root.rightChild.alphabet[1]; continueLeftRangeQuery = (((alphabet_left[0] & (~alphabetExcluded[0])) + (alphabet_left[1] & (~alphabetExcluded[1]))) > 0); continueRightRangeQuery = (((alphabet_right[0] & (~alphabetExcluded[0])) + (alphabet_right[1] & (~alphabetExcluded[1]))) > 0); lenText = text.length; rank = new Rank(text, alphabet_right); leftChild = null; rightChild = null; int lenAlphabet = Long.bitCount(alphabet[0]) + Long.bitCount(alphabet[1]); byte[] charAlphabetField = new byte[lenAlphabet]; for (int i = 0; i < root.charAlphabet.size(); ++i) { charAlphabetField[i] = root.charAlphabet.get(i); } firstChar = charAlphabetField[0]; lastChar = charAlphabetField[lenAlphabet - 1]; int len_alphabet_left = Long.bitCount(alphabet_left[0]) + Long.bitCount(alphabet_left[1]); int len_alphabet_right = Long.bitCount(alphabet_right[0]) + Long.bitCount(alphabet_right[1]); if (len_alphabet_left > 1) { int len_text_left = 0; for (int i = 0; i < text.length; ++i) { int cell = text[i] >>> shift; int pos = text[i] & mask; len_text_left += (int) ((alphabet_left[cell] >>> pos) & 1L); } if (len_text_left > 0) { byte[] text_left = new byte[len_text_left]; int j = 0; for (int i = 0; i < text.length; ++i) { int cell = text[i] >>> shift; int pos = text[i] & mask; long bit = (alphabet_left[cell] >>> pos) & 1L; if (bit > 0) { text_left[j++] = text[i]; } } leftChild = new WaveletTree(text_left, waitingHandler, root.leftChild, numMasses, hasPTMatTerminus); } } if (waitingHandler != null && waitingHandler.isRunCanceled()) { return; } if (len_alphabet_right > 1) { int len_text_right = 0; for (int i = 0; i < text.length; ++i) { int cell = text[i] >>> shift; int pos = text[i] & mask; len_text_right += (int) ((alphabet_right[cell] >>> pos) & 1L); } if (len_text_right > 0) { byte[] text_right = new byte[len_text_right]; int j = 0; for (int i = 0; i < text.length; ++i) { int cell = text[i] >>> shift; int pos = text[i] & mask; long bit = (alphabet_right[cell] >>> pos) & 1L; if (bit > 0) { text_right[j++] = text[i]; } } rightChild = new WaveletTree(text_right, waitingHandler, root.rightChild, numMasses, hasPTMatTerminus); } } if (leftChild != null) leftRightMask = 4; if (rightChild != null) leftRightMask |= 2; } /** * Create the less table. * * @return the less table */ public int[] createLessTable() { return less; } /** * Returns the number of occurrences of a given character until position * index. * * @param index the index * @param character the character * @return the rank */ public int getRank(int index, int character) { if (index < lenText) { return getRankRecursive(index, character); } throw new ArrayIndexOutOfBoundsException(); } /** * Returns the number of occurrences of a given character until position * index. * * @param index the index * @param character the character * @return the rank */ public int getRankRecursive(int index, int character) { if (index >= 0) { int cell = character >>> shift; int pos = character & mask; boolean left = ((alphabetDirections[cell] >>> pos) & 1) == 1; int result = rank.getRank(index, left); if (left && leftChild != null) { return leftChild.getRankRecursive(result - 1, character); } else if (!left && rightChild != null) { return rightChild.getRankRecursive(result - 1, character); } return result; } return 0; } /** * Returns the character and rank at a given index. * * @param index the index * @return the character and rank */ public int[] getCharacterInfo(int index) { if (index < lenText) { boolean left = !rank.isOne(index); int result = rank.getRank(index, left); if (result == 0) { return new int[]{firstChar, 0}; } result -= 1; if (left) { if (leftChild == null) { return new int[]{firstChar, result}; } else { return leftChild.getCharacterInfo(result); } } else if (rightChild == null) { return new int[]{lastChar, result}; } else { return rightChild.getCharacterInfo(result); } } throw new ArrayIndexOutOfBoundsException(); } /** * Returns the number of bytes for the allocated arrays. * * @return number of allocated bytes */ public int getAllocatedBytes() { int bytes = rank.getAllocatedBytes(); if (leftChild != null) { bytes += leftChild.getAllocatedBytes(); } if (rightChild != null) { bytes += rightChild.getAllocatedBytes(); } return bytes; } /** * Returns a list of character and new left/right index for a given range. * * @param leftIndex left index boundary * @param rightIndex right index boundary * @return list of counted characters */ public int[][] rangeQuery(int leftIndex, int rightIndex) { int[][] query = new int[numMasses + 1][]; query[numMasses] = new int[]{0}; if (leftIndex + 1 < rightIndex) rangeQuery(leftIndex, rightIndex, query); else rangeQueryOneValue(rightIndex, query); return query; } /** * Fills a list of character and new left/right index for a given range. * * @param leftIndex left index boundary * @param rightIndex right index boundary * @param setCharacter list of counted characters */ public void rangeQuery(int leftIndex, int rightIndex, int[][] setCharacter) { int newLeftIndex = (leftIndex >= 0) ? rank.getRankOne(leftIndex) : 0; int newRightIndex = (rightIndex >= 0) ? rank.getRankOne(rightIndex) : 0; if (continueRightRangeQuery && newRightIndex - newLeftIndex > 0) { if (rightChild != null) { rightChild.rangeQuery(newLeftIndex - 1, newRightIndex - 1, setCharacter); } else { setCharacter[setCharacter[numMasses][0]++] = new int[]{lastChar, newLeftIndex, newRightIndex, lastChar}; } } newLeftIndex = leftIndex - newLeftIndex; newRightIndex = rightIndex - newRightIndex; if (continueLeftRangeQuery && newRightIndex - newLeftIndex > 0) { if (leftChild != null) { leftChild.rangeQuery(newLeftIndex, newRightIndex, setCharacter); } else { setCharacter[setCharacter[numMasses][0]++] = new int[]{firstChar, newLeftIndex + 1, newRightIndex + 1, firstChar}; } } } /** * Fills a list of character and new left/right index for a given range. * * @param index index boundary * @param setCharacter list of counted characters */ public void rangeQueryOneValue(int index, int[][] setCharacter) { int switchOption = rank.isOneInt(index); switchOption += leftRightMask & (4 >> switchOption); switch(switchOption){ case 3: // go right and right child avaliable int newIndex3 = rank.getRankOne(index); rightChild.rangeQueryOneValue(newIndex3 - 1, setCharacter); break; case 4: // go left and left child avaliable int newIndex4 = rank.getRankZero(index); leftChild.rangeQueryOneValue(newIndex4 - 1, setCharacter); break; case 0: // go left and no left child avaliable int newIndex0 = rank.getRankZero(index); setCharacter[setCharacter[numMasses][0]++] = new int[]{firstChar, newIndex0 - 1, newIndex0, firstChar}; break; case 1: // go right and no right child avaliable int newIndex1 = rank.getRankOne(index); setCharacter[setCharacter[numMasses][0]++] = new int[]{lastChar, newIndex1 - 1, newIndex1, lastChar}; break; } } /** * Returns a list of character and new left/right index for a given range * recursively. * * @param leftIndex left index boundary * @param rightIndex right index boundary * @param character character to check * @return a list of character and new left/right index for a given range * recursively */ public int[] singleRangeQuery(int leftIndex, int rightIndex, int character) { boolean left = ((alphabetDirections[character >>> shift] >>> (character & mask)) & 1) == 1; if (left) { int newLeftIndex = (leftIndex >= 0) ? rank.getRankZero(leftIndex) : 0; int newRightIndex = (rightIndex >= 0) ? rank.getRankZero(rightIndex) : 0; if (leftChild != null) { return leftChild.singleRangeQuery(newLeftIndex - 1, newRightIndex - 1, character); } else { return new int[]{newLeftIndex, newRightIndex}; } } else { int newLeftIndex = (leftIndex >= 0) ? rank.getRankOne(leftIndex) : 0; int newRightIndex = (rightIndex >= 0) ? rank.getRankOne(rightIndex) : 0; if (rightChild != null) { return rightChild.singleRangeQuery(newLeftIndex - 1, newRightIndex - 1, character); } else { return new int[]{newLeftIndex, newRightIndex}; } } } }