package is2.data; import is2.util.DB; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; final public class ParametersFloat { public float[] parameters; public float[] total; public ParametersFloat(int size) { parameters = new float[size]; total = new float[size]; for (int i = 0; i < parameters.length; i++) { parameters[i] = 0F; total[i] = 0F; } } /** * @param parameters2 */ public ParametersFloat(float[] p) { parameters = p; } public void average(double avVal) { for (int j = 0; j < total.length; j++) { parameters[j] = total[j] / ((float) avVal); } total = null; } public ParametersFloat average2(double avVal) { float[] px = new float[this.parameters.length]; for (int j = 0; j < total.length; j++) { px[j] = total[j] / ((float) avVal); } ParametersFloat pf = new ParametersFloat(px); return pf; } public void update(FV pred, FV act, float upd, float err) { float lam_dist = act.getScore(parameters, false) - pred.getScore(parameters, false); float loss = (float) err - lam_dist; FV dist = act.getDistVector(pred); float alpha; float A = dist.dotProduct(dist); if (A <= 0.0000000000000001) { alpha = 0.0f; } else { alpha = loss / A; } // alpha = Math.min(alpha, 0.00578125F); dist.update(parameters, total, alpha, upd, false); } public void update(FV pred, FV act, float upd, float err, float C) { float lam_dist = act.getScore(parameters, false) - pred.getScore(parameters, false); float loss = (float) err - lam_dist; FV dist = act.getDistVector(pred); float alpha; float A = dist.dotProduct(dist); if (A <= 0.0000000000000001) { alpha = 0.0f; } else { alpha = loss / A; } alpha = Math.min(alpha, C); dist.update(parameters, total, alpha, upd, false); } public double update(FV a, double b) { double A = a.dotProduct(a); if (A <= 0.0000000000000000001) { return 0.0; } return b / A; } public double getScore(FV fv) { if (fv == null) { return 0.0F; } return fv.getScore(parameters, false); } final public void write(DataOutputStream dos) throws IOException { dos.writeInt(parameters.length); for (float d : parameters) { dos.writeFloat(d); } } public void read(DataInputStream dis) throws IOException { parameters = new float[dis.readInt()]; int notZero = 0; for (int i = 0; i < parameters.length; i++) { parameters[i] = dis.readFloat(); if (parameters[i] != 0.0F) { notZero++; } } DB.println("read parameters " + parameters.length + " not zero " + notZero); } public int countNZ() { int notZero = 0; for (int i = 0; i < parameters.length; i++) { if (parameters[i] != 0.0F) { notZero++; } } return notZero; } public F2SF getFV() { return new F2SF(parameters); } public int size() { return parameters.length; } public void update(FVR act, FVR pred, Instances isd, int instc, Parse dx, double upd, double e, float lam_dist) { e++; float b = (float) e - lam_dist; FVR dist = act.getDistVector(pred); dist.update(parameters, total, hildreth(dist, b), upd, false); } public void update(FVR pred, FVR act, float upd, float e) { e++; float lam_dist = act.getScore(parameters, false) - pred.getScore(parameters, false); float b = (float) e - lam_dist; FVR dist = act.getDistVector(pred); dist.update(parameters, total, hildreth(dist, b), upd, false); } protected double hildreth(FVR a, double b) { double A = a.dotProduct(a); if (A <= 0.0000000000000000001) { return 0.0; } return b / A; } public float getScore(FVR fv) { //xx if (fv == null) { return 0.0F; } return fv.getScore(parameters, false); } }