/** Class NormalizeAttribute.java
*
* @author AJB
* @version 1
* @since 14/4/09
*
* Class normalizes attributes, basic version.
1. Assumes no missing values.
2. Assumes all attributes real values
3. Assumes class index same in all data (vague checks made) but can be none set (classIndex==-1)
4. Batch process, by default it calculates the ranges from the instances in trainData, then uses
this to process the instances passed. Note that this may produce values outside the
interval range, since the min or max of the test data may be separate. If you want to avoid
this, the only way at the moment is to first merge train and test, then pass the merged set.
Easy to hack round this if I have to.
* Normalise onto [0,1] if norm==NormType.INTERVAL,
* Normalise onto Normal(0,1) if norm==NormType.STD_NORMAL,
*
* Useage:
* Instances train = //Get Train
* Instances test = //Get Train
*
* NormalizeAttributes na = new NormalizeAttributes(train);
*
* na.setNormMethod(NormalizeAttribute.NormType.INTERVAL); //Defaults to interval anyway
try{
//Both processed with the stats from train.
Instances newTrain=na.process(train);
Instances newTest=na.process(test);
*/
package weka.filters;
import weka.core.Instances;
public class NormalizeAttribute extends SimpleBatchFilter{
enum NormType {INTERVAL,STD_NORMAL};
Instances trainData;
double[] min;
double[] max;
double[] mean;
double[] stdev;
int classIndex;
NormType norm=NormType.INTERVAL;
/*
*
*/
public NormalizeAttribute(Instances data){
trainData=data;
classIndex=data.classIndex();
//Finds all the stats, doesnt cost much more really
findStats(data);
}
protected void findStats(Instances r){
//Find min and max
// assert(classIndex==r.classIndex());
max=new double[r.numAttributes()];
min=new double[r.numAttributes()];
for(int j=0;j<r.numAttributes();j++)
{
max[j]=Double.MIN_VALUE;
min[j]=Double.MAX_VALUE;
for(int i=0;i<r.numInstances();i++){
double x=r.instance(i).value(j);
if(x>max[j])
max[j]=x;
if(x<min[j])
min[j]=x;
}
}
//Find mean and stdev
mean=new double[r.numAttributes()];
stdev=new double[r.numAttributes()];
double sum,sumSq,x,y;
for(int j=0;j<r.numAttributes();j++)
{
sum=0;
sumSq=0;
for(int i=0;i<r.numInstances();i++){
x=r.instance(i).value(j);
sum+=x;
sumSq+=x*x;
}
stdev[j]=sumSq/r.numInstances()-sum*sum;
mean[j]=sum/r.numInstances();
stdev[j]=Math.sqrt(stdev[j]);
}
}
public double[] getRanges(){
double[] r= new double[max.length];
for(int i=0;i<r.length;i++)
r[i]=max[i]-min[i];
return r;
}
//This should probably be connected to trainData?
protected Instances determineOutputFormat(Instances inputFormat){
return new Instances(inputFormat, 0);
}
public void setTrainData(Instances data){ //Same as the constructor
trainData=data;
classIndex=data.classIndex();
//Finds all the stats, doesnt cost much more really
findStats(data);
}
public void setNormMethod(NormType n){
norm=n;
}
public Instances process(Instances inst) throws Exception {
//Clones the data. Presupposes find stats has been called!
if(classIndex!=inst.classIndex())
throw new Exception("Wrong class index ="+inst.classIndex()+" expecting ="+classIndex);
Instances result = new Instances(inst);
switch(norm){
case INTERVAL:
intervalNorm(result);
break;
case STD_NORMAL:
standardNorm(result);
break;
default:
System.out.println(" Unknown norm!"+norm);
throw new Exception("in process");
}
return result;
}
/* Wont normalise the class value*/
public void intervalNorm(Instances r){
for(int i=0;i<r.numInstances();i++){
for(int j=0;j<r.numAttributes();j++){
if(j!=classIndex){
double x=r.instance(i).value(j);
r.instance(i).setValue(j,(x-min[j])/(max[j]-min[j]));
// System.out.println("instance ="+i+" Attribute ="+j+" Value = "+x+" Min ="+min[j]+" max = "+max[j]);
}
}
}
}
public void standardNorm(Instances r){
for(int j=0;j<r.numAttributes();j++){
if(j!=classIndex){
for(int i=0;i<r.numInstances();i++){
double x=r.instance(i).value(j);
r.instance(i).setValue(i,(x-mean[j])/(stdev[j]));
}
}
}
}
public String globalInfo() {
// TODO Auto-generated method stub
return null;
}
public String getRevision() {
// TODO Auto-generated method stub
return null;
}
/* Test Harness.
*
public static void main(String[] args){
Instances test=weka.classifiers.evaluation.ClassifierTools.loadData("C:\\Research\\Data\\WekaTest\\NormalizeTest");
Instances train=weka.classifiers.evaluation.ClassifierTools.loadData("C:\\Research\\Data\\WekaTest\\NormalizeTrain");
test.setClassIndex(test.numAttributes()-1);
train.setClassIndex(test.numAttributes()-1);
NormalizeAttribute na=new NormalizeAttribute(test);
try{
na.setNormMethod(NormalizeAttribute.NormType.INTERVAL); //Defaults to interval anyway
Instances newTrain=na.process(train);
Instances newTest=na.process(test);
System.out.println(" Fixed interval train ="+newTrain);
System.out.println(" Fixed interval test ="+newTest);
na.setNormMethod(NormalizeAttribute.NormType.STD_NORMAL); //Defaults to interval anyway
na.setTrainData(train);
newTrain=na.process(train);
newTest=na.process(test);
System.out.println(" Std Normal train ="+newTrain);
System.out.println(" Std Normal test ="+newTest);
}catch(Exception e){
System.out.println(" Exception thrown somewhere, caught main ="+e);
}
}
*/
}