/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package sim.app.socialsystemsanalog;
import sim.util.MutableDouble3D;
/**
*
* @author epokh
* This class only compute the output generated by the distal error
*/
public class IcoLearner {
private boolean isLearning=true;
private double output;
public double range;
public double saturation;
private MutableDouble3D derivReflex;
private MutableDouble3D reflex;
private MutableDouble3D Wdistal;
private MutableDouble3D Wreflex;
private MutableDouble3D Wtemp;
private double Wdistal0=0.1;
//this define the learning rate
private double mu=0.1;
private MutableDouble3D deltaWdistal;
//two main inhibition methods
private boolean feedback_inhibition=false;
private MutableDouble3D fb_gain;
private boolean input_inhibition=false;
private MutableDouble3D in_gain;
//private MutableDouble3D Wreflex;
//initialize with initial weights
public IcoLearner(boolean flag, double Wx,double Wy,double Wz)
{
Wdistal=new MutableDouble3D(Wx, Wy, Wz);
Wreflex=new MutableDouble3D(1.0, 1.0, 1.0);
if(Wx<0.01) Wreflex.setX(0.0);
if(Wy<0.01) Wreflex.setY(0.0);
if(Wz<0.01) Wreflex.setZ(0.0);
Wtemp=new MutableDouble3D();
reflex=new MutableDouble3D();
derivReflex=new MutableDouble3D();
deltaWdistal=new MutableDouble3D();
fb_gain=new MutableDouble3D(1.0,1.0,1.0);
in_gain=new MutableDouble3D(1.0,1.0,1.0);
derivReflex.zero();
reflex.zero();
Wtemp.zero();
isLearning=flag;
range=4;
saturation=0.01;
}
public IcoLearner(boolean flag)
{
Wdistal=new MutableDouble3D(Wdistal0, Wdistal0, Wdistal0);
Wreflex=new MutableDouble3D(1.0, 1.0, 1.0);
Wtemp=new MutableDouble3D();
reflex=new MutableDouble3D();
derivReflex=new MutableDouble3D();
deltaWdistal=new MutableDouble3D();
fb_gain=new MutableDouble3D(1.0,1.0,1.0);
in_gain=new MutableDouble3D(1.0,1.0,1.0);
derivReflex.zero();
reflex.zero();
Wtemp.zero();
isLearning=flag;
range=4;
saturation=0.01;
}
public void setLearning(boolean flag)
{
isLearning=flag;
}
public boolean getLearning()
{
return isLearning;
}
//input is the distal error
public void calculate(MutableDouble3D reflex_double,MutableDouble3D distal_double)
{
//if(input_inhibition)
// distal_double.multiplyDot(distal_double,in_gain);
// if there's no learning the agent is only reactive
if(!isLearning)
{
/*! compute the next output only from the reflex*/
// if(feedback_inhibition)
// output=Wreflex.dot(reflex_double,fb_gain);
//else
output=Wreflex.dot(reflex_double);
/*! if the agent has learnt something let's use the new weight */
if(Wdistal.x!= Wdistal0 || Wdistal.y!= Wdistal0 || Wdistal.z!= Wdistal0)
output +=Wdistal.dot(distal_double);
//find the maximum output motor value and normalize to a unitary vector
//fb_gain.maxnorm(Wdistal,distal_double);
}
else
{
//compute output firs for the reflex
//if(feedback_inhibition)
// {
// output=Wreflex.dot(reflex_double,fb_gain);
// output +=Wdistal.dot(distal_double,fb_gain);
// }
// else
// {
output=Wreflex.dot(reflex_double);
output +=Wdistal.dot(distal_double);
// }
//find the maximum output motor value and normalize to a unitary vector
//fb_gain.maxnorm(Wdistal,distal_double);
//compute derivative of reflex dr=r(t)-r(t-1)
derivReflex.subtract(reflex_double, reflex);
//update r(t-1) WARNING CHECK THIS
reflex.setTo(reflex_double);
// Learn and update the distal synaptic weights
Wtemp.zero();
Wtemp.multiply(derivReflex, mu);
Wtemp.multiplyDot(Wtemp, distal_double);
Wdistal.addIn(Wtemp);
}
//return IcoLearner.getSigmaValue(output,range,saturation);
}
public double getOutput() throws Exception
{
return IcoLearner.getSigmaValue(output,range,saturation);
}
public int getDiscreteOutput() throws Exception
{
return (int)Math.round(IcoLearner.getSigmaValue(output,range,saturation));
}
public double getWeight(int type) {
switch(type)
{
case 0:
return Wdistal.x;
case 1:
return Wdistal.y;
case 2:
return Wdistal.z;
default:
return 0;
}
}
public static double getSigmaValue(double val,double range,double saturation) throws Exception
{
double temp=range/(1+Math.exp(-val/(saturation/2)))-range/2;
if(Math.abs(temp)>range/2) throw new Exception("Sigmoid out of saturation range");
else return temp;
}
public void setWeight(int type,double value) {
switch(type)
{
case 0:
Wdistal.x=value;
break;
case 1:
Wdistal.y=value;
break;
case 2:
Wdistal.z=value;
break;
default:
}
}
public synchronized boolean isSeeker()
{
deltaWdistal.zero();
deltaWdistal.percentage(Wdistal,Wdistal0);
if(Math.abs(Wdistal.y)>Math.abs(Wdistal.z))
return true;
else return false;
}
public void setLearningRate(double rate) {
this.mu=rate;
}
public void setFeedbackInhibition(boolean flag)
{
feedback_inhibition=flag;
}
public void setInputInhibition(boolean flag)
{
input_inhibition=flag;
}
}