// Copyright 2014 Thomas Müller
// This file is part of HMMLA, which is licensed under GPLv3.
package hmmla.hmm;
import hmmla.util.Arrays;
import hmmla.util.Numerics;
import hmmla.util.SymbolTable;
import java.io.Serializable;
public class Statistics implements Serializable {
private static final long serialVersionUID = 1L;
protected int num_tags;
protected int num_words;
private double[][] emissions;
public double[][] transistions;
public Statistics(int num_tags,int num_words){
this.num_tags = num_tags;
this.num_words = num_words;
emissions = new double[num_tags][num_words];
transistions = new double[num_tags][num_tags];
setZero();
}
public Statistics(SymbolTable<String> inputs,SymbolTable<String> outputs){
this(inputs.size(),outputs.size());
}
public Statistics(Statistics statistics) {
this(statistics.num_tags,statistics.num_words);
add(statistics);
}
public void set(double[][] tr, double[][] em) {
Arrays.multiArrayCopy(tr, transistions);
Arrays.multiArrayCopy(em, emissions);
}
public double getTransitions(int i, int j) {
return transistions[i][j];
}
public double getEmissions(int i, int o) {
return emissions[i][o];
}
public void setZero() {
for (int i=0;i<num_tags;i++){
for (int o=0;o<num_words;o++){
emissions[i][o] = 0.0;
}
for (int j=0;j<num_tags;j++){
transistions[i][j] = 0.0;
}
}
}
public void addEmissions(int toIndex, int output, double p) {
emissions[toIndex][output] += p;
}
public void addTransitions(int from, int to, double p) {
transistions[from][to] += p;
}
public void setTransitions(int from, int to, double p) {
transistions[from][to] = p;
}
public void setEmissions(int from, int o, double p) {
emissions[from][o] = p;
}
public void add(Statistics statistics) {
for (int i=0;i<num_tags;i++){
for (int o=0;o<num_words;o++){
addEmissions(i, o, statistics.getEmissions(i, o));
}
for (int j=0;j<num_tags;j++){
addTransitions(i, j, statistics.getTransitions(i, j));
}
}
}
public int getNumTags() {
return num_tags;
}
public int getNumOutputs() {
return num_words;
}
public void substract_onehalf(){
for (int i=0;i<num_tags;i++){
for (int o=0;o<num_words;o++){
double f;
try{
f = Numerics.exp_digamma(getEmissions(i, o));
}catch (IllegalArgumentException e){
f = 0.0;
}
setEmissions(i, o, f);
}
for (int j=0;j<num_tags;j++){
double f;
try{
f = Numerics.exp_digamma(getEmissions(i, j));
}catch (IllegalArgumentException e){
f = 0.0;
}
setTransitions(i, j, f);
}
}
}
public double totalEmission(){
double total = 0.0;
for (int i=0;i<num_tags;i++){
for (int o=0;o<num_words;o++){
total += getEmissions(i, o);
}
}
return total;
}
public double totalTransmission(){
double total = 0.0;
for (int i=0;i<num_tags;i++){
for (int j=0;j<num_tags;j++){
total += getTransitions(i, j);
}
}
return total;
}
}