/* * Copyright (C) 2015 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package com.android.switchaccess; import android.content.Context; import android.util.Log; import com.android.utils.LogUtils; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.Map; import java.util.Set; /** * Builds a trie to be used for prediction by partial matching. A PPM model maintains the * frequencies for actions that have been seen before in all context that have occurred, up to * some maximum order. * * Necessary information to understand the PPM model and how to calculate the probability * distribution according to this model was obtained from the paper * "Implementing the PPM Data Compression Scheme" by Alistair Moffat, found at the following link: * http://cs1.cs.nyu.edu/~roweis/csc310-2006/extras/implementing_ppm.pdf */ public class PPMTrie { private final TrieNode mRoot; private final int mTrieDepth; private TrieNode mStartInsertionNode; public PPMTrie(int depth) { mTrieDepth = depth; mRoot = new TrieNode('\0'); mRoot.setVineNode(null); mStartInsertionNode = mRoot; } /** * Uses the text in a training file to form a ppm model and store it in a trie. The file is a * .txt file that contains plain unicode text. * * @param fileResource The file to be used for training the ppm model and constructing the * trie. */ public void initializeTrie(Context context, int fileResource) { TrieNode startInsertionNode = mRoot; InputStream stream = context.getResources().openRawResource(fileResource); BufferedReader reader = null; try { reader = new BufferedReader(new InputStreamReader(stream, "UTF-8")); String input; while ((input = reader.readLine()) != null) { for (int i = 0; i < input.length(); i++) { startInsertionNode = insertSymbol(startInsertionNode, input.charAt(i)); } } } catch (IOException e) { LogUtils.log(this, Log.ERROR, "Unable to read PPMTrie input file: %1$s", e.toString()); } finally { try { if (reader != null) { reader.close(); } } catch(IOException e) { LogUtils.log(this, Log.ERROR, "Unable to close input file: %1$s", e.toString()); } } } /** * Updates the trie to include the specified symbol. When the symbol is inserted into the trie * the node that was inserted at the highest level is tracked. This enables us to track the * N most recent symbols inserted into the trie, which hence give us the user context. * * @param symbol The symbol to be updated/inserted into the trie. */ public void learnSymbol(int symbol) { mStartInsertionNode = insertSymbol(mStartInsertionNode, symbol); } public void clearUserContext() { mStartInsertionNode = mRoot; } /** * Given the context, computes the probability for all the symbols in the set of * {@code symbols}. If a symbol doesn't appear in the context we are given, we escape to a * lower order context and attempt to find the symbol in the lower order context. If we escape * from every context and can't find the symbol, we assign a default probability, which is a * uniform distribution based on the number of symbols in the {@code symbols} set. * * To compute the escape probability the well known PPM method C is used, in which at any * context, the escape is counted as having occurred a number of times equal to the number of * unique symbols encountered in the context, with the total context count inflated by the same * amount. Finally, the principle of exclusion is applied when calculating the probabilities: * when switching to a lower order context, the count of characters that occurred in the * higher context is excluded and only those characters that did not occur in a higher-order * context is considered. * * For further details on how and why the escape count in calculated in the following way * refer to the paper found at this link: * http://cs1.cs.nyu.edu/~roweis/csc310-2006/extras/implementing_ppm.pdf * * @param userContext The actions that the user has taken so far. * @param symbols The set of symbol whose probability we're interested in. * @return A map associating each symbol with a probability value. */ public Map<Integer, Double> getProbabilityDistribution(String userContext, Set<Integer> symbols) { Map<Integer, Double> probabilityDistribution = new HashMap<>(symbols.size()); if (symbols.size() == 0) { return probabilityDistribution; } TrieNode node = lookupTrieNode(mRoot, userContext, 0); if (node != null) { double escapeProbability = 1.0; Set<Integer> seenSymbols = new HashSet<>(); int currentOrder = getNodeDepth(node); while (currentOrder >= 0) { LinkedList<TrieNode> children = node.getChildren(); /* the escape character is counted as having occurred a number of times equal to * the number of unique symbols encountered in the context, hence adding the * children.size() to the node count. */ int parentCount = node.getCount() + children.size(); int exclusionCount = parentCount; for (TrieNode child : children) { if (seenSymbols.contains(child.getContent())) { /* symbols that have been seen in higher order contexts can be excluded * from consideration in lower contexts so that increased probabilities can * be allocated to the remaining symbols. */ exclusionCount -= child.getCount(); } seenSymbols.add(child.getContent()); if (symbols.contains(child.getContent()) && !probabilityDistribution.containsKey(child.getContent())) { Double childProbability = (escapeProbability * child.getCount()) / parentCount; probabilityDistribution.put(child.getContent(), childProbability); } } escapeProbability = escapeProbability * children.size() / exclusionCount ; node = node.getVineNode(); currentOrder--; } } assignDefaultProbability(symbols, probabilityDistribution); return probabilityDistribution; } /** * The {@code startInsertionNode} points to the TrieNode with symbol X where insertion should * begin. If this TrieNode is at a depth less then the max trie depth, the symbol is inserted * as a child of that node. Then the vine pointer from TrieNode with symbol X on depth n is * followed to a node with the same symbol X on level n - 1. A node is then inserted as a child * of this TrieNode at level n - 1. This process is repeated until a node is inserted as a * child of the root node. The vine pointers of all the nodes at depth 1, point to the root of * the trie. * * @param startInsertionNode The TrieNode where insertion should begin. * @param symbol The symbol to be inserted into the trie. * @return The TrieNode inserted at the highest context. Returning this node helps implicitly * keep track of the user context. */ private TrieNode insertSymbol (TrieNode startInsertionNode, int symbol) { int currentLevel = getNodeDepth(startInsertionNode); TrieNode currentNode = startInsertionNode; TrieNode prevModifiedNode = null; TrieNode nodeAtHighestDepth = null; while (currentLevel >= 0) { if (currentLevel < mTrieDepth) { TrieNode child = currentNode.addChild(symbol); if (child.getCount() == Integer.MAX_VALUE) { scaleCount(mRoot); } if (prevModifiedNode == null) { // keep track of the node inserted at the greatest depth. nodeAtHighestDepth = child; } else if (prevModifiedNode.getVineNode() == null) { /* if the vineNode reference is null that means the node was recently added and * doesn't point to a node on a lower context. Hence update this reference. */ prevModifiedNode.setVineNode(child); } if (currentNode == mRoot && child.getVineNode() == null) { /* For a node that has been inserted as a child of the root node and doesn't * have a vineNode reference, update the vineNode reference to be the root * node */ child.setVineNode(mRoot); } prevModifiedNode = child; } currentNode = currentNode.getVineNode(); currentLevel -= 1; } mRoot.setCount(mRoot.getCount() + 1); return nodeAtHighestDepth; } /* TODO Figure out if there's a more efficient way of scaling without having to * scale the entire trie, but rather only certain branches of the trie */ private void scaleCount(TrieNode rootNode) { LinkedList<TrieNode> children = rootNode.getChildren(); rootNode.setCount(rootNode.getCount() / 2); for (TrieNode child : children) { scaleCount(child); } } /** * Given the context, tries to find the context of greatest length within the trie. The maximum * length will naturally be the max depth of the trie. Null is returned only if even a context * of length 1 can't be found. * * @param rootNode The trie node from where the search should begin * @param userContext The overall context we are searching * @param index The position in the context from where to begin searching. * @return The TrieNode found that matches the the max possible components of the context. If * even a context of length 1 can't be found, {@code null} is returned. */ private TrieNode lookupTrieNode(TrieNode rootNode, String userContext, int index) { if (index >= userContext.length()) { rootNode = (rootNode == mRoot) ? null : rootNode; return rootNode; } int curContent = (int) userContext.charAt(index); if (rootNode.hasChild(curContent)) { return lookupTrieNode(rootNode.getChild(curContent), userContext, index + 1); } else if (rootNode == mRoot) { /* could not find context, trying to find a context starting at the next element in * the userContext */ return lookupTrieNode(rootNode, userContext, index + 1); } else { return lookupTrieNode(rootNode.getVineNode(), userContext, index); } } /** * Given a set of symbols whose probability we're interested in and a map which associates a * subset of these symbols to probability value, finds the symbols in the set that are not in * the map and assigns them a default probability. It is possible that all or none of the * symbols in the set are included in the map as well. The default probability is a uniform * distribution based on the number of symbols in the {@code symbols} set. * * @param symbols The set of symbols, whose probability value we're interested in. * @param probabilityDistribution The map that associates a probability values to a subset of * symbols in the {@code symbols} set. If a symbol is in the set but not in the map, a * default probability is assigned to the symbol. */ private static void assignDefaultProbability(Set<Integer> symbols, Map<Integer, Double> probabilityDistribution) { int unassignedSymbolsSize = symbols.size() - probabilityDistribution.size(); if (unassignedSymbolsSize > 0) { Double totalProbability = 0.0; for (Double value : probabilityDistribution.values()) { totalProbability += value; } Double missingProbabilityMass = 1.0 - totalProbability; Double defaultProbability = missingProbabilityMass / unassignedSymbolsSize; for (Integer symbol : symbols) { if (probabilityDistribution.get(symbol) == null) { probabilityDistribution.put(symbol, defaultProbability); } } } } /** * Given a TrieNode, finds the node depth by counting the number of vine pointers that have to * be followed to reach the root of the trie. * * @param node The TrieNode whose depth we've interested in. * @return The depth of the node */ private int getNodeDepth(TrieNode node) { if (node == mRoot) { return 0; } return getNodeDepth(node.getVineNode()) + 1; } /** * Prints the trie. This method is intended for debugging. * * @param node The TriNode from which printing should begin * @param prefix The trie prefix * @param index The index of the next free spot in the prefix array * @param debugPrefix Any prefix that should be prepended to each line. */ @SuppressWarnings("unused") public void printTrie(TrieNode node, char[] prefix, int index, String debugPrefix) { LinkedList<TrieNode> children = node.getChildren(); LogUtils.log(this, Log.DEBUG, "%1$s: current Prefix %2$s", debugPrefix, new String(prefix)); LogUtils.log(this, Log.DEBUG, "%1$s: children size %2$d", debugPrefix, children.size()); for (TrieNode child : children) { LogUtils.log(this, Log.INFO, "%1$s: Prefix children %2$c : %3$d", debugPrefix, (char) child.getContent(), child.getCount()); } for (TrieNode child : children) { char content = (char) child.getContent(); prefix[index] = content; printTrie(child, prefix, index + 1, debugPrefix + "-"); } if (index > 0) { prefix[index - 1] = ' '; } } /** * The trie nodes */ private class TrieNode { /* The content is an int representation for an AccessibilityNodeInfoCompat. For * AccessibilityNodeInfoCompats that represent each of the symbols in the keyboard, * the unicode for the first character of the content description of these * AccessibilityNodeInfoCompat is obtained. For other views a hashing function is probably * needed to enable this int representation. */ private final int mContent; /* TODO Consider using a sparse array */ private final LinkedList<TrieNode> mChildren; /* The number of times we can seen the content */ private int mCount; /* The vine node is a reference to a node with the same content on a depth level one less * than the depth of the current node. */ private TrieNode mVineNode; public TrieNode(int content) { mContent = content; mCount = 0; mChildren = new LinkedList<>(); mVineNode = null; } public int getContent() { return mContent; } public int getCount() { return mCount; } public void setCount(int updatedValue) { mCount = updatedValue; } public TrieNode getVineNode() { return mVineNode; } public void setVineNode(TrieNode trieNode) { mVineNode = trieNode; } public LinkedList<TrieNode> getChildren() { return mChildren; } public TrieNode getChild(int content) { for (TrieNode child : mChildren) { if (child.getContent() == content) { return child; } } return null; } public TrieNode addChild(int content) { TrieNode child = getChild(content); if (child == null) { child = new TrieNode(content); mChildren.add(child); } child.mCount += 1; return child; } public boolean hasChild(int content) { TrieNode child = getChild(content); return child != null; } } }