package com.spbsu.bernulli; /** * User: Noxoomo * Date: 20.03.15 * Time: 16:38 */ public abstract class EM<Result> { protected abstract void expectation(); protected abstract void maximization(); protected abstract boolean stop(); public abstract Result model(); protected abstract double likelihood(); public final FittedModel<Result> fit() { return fit(false); } public final FittedModel<Result> fit(boolean correctnessTest) { if (!correctnessTest) { while (!stop()) { expectation(); maximization(); } } else { double ll = Double.NEGATIVE_INFINITY; while (!stop()) { expectation(); maximization(); double currentLL = likelihood(); if (currentLL + 1e-2 < ll) { throw new RuntimeException("EM always increase likelihood"); } } } return new FittedModel<>(likelihood(), model()); } }