/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.experiment.mc; import rapaio.core.RandomSource; import rapaio.math.linear.RM; import rapaio.math.linear.RV; import rapaio.math.linear.dense.SolidRM; import rapaio.math.linear.dense.SolidRV; import rapaio.printer.Printable; import rapaio.sys.WS; import java.util.*; import java.util.function.Predicate; import java.util.stream.Collectors; /** * @author <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a> */ @Deprecated public class MarkovChain implements Printable { private List<String> states; private Map<String, Integer> revert; private RV p; private RM m; private ChainAdapter adapter = new NGram(2); // private double smoothEps = 1e-30; public MarkovChain() { } private static String clean(String text) { text = text.replace(')', ' '); text = text.replace('(', ' '); text = text.replace(';', ' '); text = text.replace('\'', ' '); text = text.replace(':', ' '); text = text.replace('/', ' '); text = text.replace('1', ' '); text = text.replace('2', ' '); text = text.replace('3', ' '); text = text.replace('4', ' '); text = text.replace('5', ' '); text = text.replace('6', ' '); text = text.replace('7', ' '); text = text.replace('8', ' '); text = text.replace('9', ' '); text = text.replace('0', ' '); text = text.replace('!', '.'); text = text.replace('?', '.'); return text.toLowerCase().trim() + "."; } public MarkovChain withAdapter(ChainAdapter adapter) { this.adapter = adapter; return this; } public void train(List<String> rowChains) { this.states = new ArrayList<>(rowChains.stream().flatMap(chain -> adapter.tokenize(chain).stream()).collect(Collectors.toSet())); this.revert = new HashMap<>(); for (int i = 0; i < states.size(); i++) { revert.put(states.get(i), i); } // clean this.p = SolidRV.fill(states.size(), smoothEps); this.m = SolidRM.fill(states.size(), states.size(), smoothEps); List<List<String>> chains = rowChains.stream() .map(chain -> adapter.tokenize(chain)) .filter(chain -> !chain.isEmpty()) .collect(Collectors.toList()); for (List<String> chain : chains) { if (chain.isEmpty()) continue; p.increment(revert.get(chain.get(0)), 1); for (int i = 1; i < chain.size(); i++) { m.increment(revert.get(chain.get(i - 1)), revert.get(chain.get(i)), 1.0); } } // normalization p.normalize(1); for (int i = 0; i < m.rowCount(); i++) { m.mapRow(i).normalize(1); } } public List<String> generateChain(Predicate<List<String>> tokenCondition) { List<String> result = new ArrayList<>(); double c = RandomSource.nextDouble(); int last = -1; for (int i = 0; i < p.count(); i++) { c -= p.get(i); if (c <= 0) { result.add(states.get(i)); last = i; break; } } if (tokenCondition.test(result)) { return result; } Map<Integer, double[]> cache = new HashMap<>(); while (true) { if (!cache.containsKey(last)) { RV ref = m.mapRow(last); double[] index = new double[ref.count()]; for (int i = 0; i < ref.count(); i++) { index[i] = ref.get(i); if (i > 0) { index[i] += index[i - 1]; } } cache.put(last, index); } double[] row = cache.get(last); c = RandomSource.nextDouble(); int i = Arrays.binarySearch(row, c); if (i < 0) { i = -i - 1; } if (i == states.size()) i--; result.add(states.get(i)); last = i; if (tokenCondition.test(result)) { return result; } } } public String generateSentence(Predicate<List<String>> endCondition) { List<String> list = generateChain(endCondition); return adapter.restore(list); } @Override public String summary() { RandomSource.setSeed(1); StringBuilder sb = new StringBuilder(); sb.append("MarkovChain model\n"); sb.append("=================\n"); sb.append("States: \n"); sb.append("count: ").append(states.size()).append("\n"); sb.append("values: \n"); String buff = ""; for (String state : states) { if (buff.length() + state.length() + 3 >= WS.getPrinter().textWidth()) { sb.append(buff).append("\n"); buff = ""; } buff = buff + "'" + state + "',"; } if (!buff.isEmpty()) sb.append(buff).append("\n"); sb.append("Priors: \n"); sb.append(p.summary()); sb.append("Matrix: \n"); sb.append(m.summary()); return sb.toString(); } }