package edu.cmu.minorthird.classify.experiments; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.io.LineNumberReader; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeSet; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.HasSubpopulationId; import edu.cmu.minorthird.classify.Splitter; import edu.cmu.minorthird.util.StringUtil; /** * A complicated splitter that stratifies samples according to an * arbitrary "profile" property, and restricts train/test splits to * not cross boundaries defined by "user" and "request" properties. * This will do a random split according to users, then a stratified * split according to requests (with the stratification done * according to profiles). * * <p>Constraints on splitting are defined by a file with * multiple lines of the form * <code><pre> * msgId userId requestId profileId * </pre></code>. * where each Id is a String. * * <p>The main purpose is of this is to split webmaster data, hence * the name. * * @author William Cohen */ public class WebmasterSplitter<T> implements Splitter<T>{ static private Logger log=Logger.getLogger(WebmasterSplitter.class); // number of cross-validation splits for the r-th population private int folds=3; // how to split the first r-1 populations private double fraction=0.7; // map subpopulationId -> user private Map<String,String> userMap=new HashMap<String,String>(); // map subpopulationId -> request private Map<String,String> requestMap=new HashMap<String,String>(); // map subpopulationId -> profile private Map<String,String> profileMap=new HashMap<String,String>(); // map request -> profile private Map<String,String> req2ProfileMap=new HashMap<String,String>(); // trainList[k] is training list for fold k private List<List<T>> trainList=null; // testList[k] is test list for fold k private List<List<T>> testList=null; public WebmasterSplitter(String constraintFileName,double fraction,int folds){ this.folds=folds; this.fraction=fraction; loadFile(constraintFileName); } private void loadFile(String constraintFileName){ try{ // read in constraints LineNumberReader in= new LineNumberReader(new FileReader(new File(constraintFileName))); String line=null; while((line=in.readLine())!=null){ if(!line.startsWith("#")){ String[] f=line.split(" "); if(f.length!=4) badInput(line,constraintFileName,in); //System.out.println("userMap: '"+f[0]+"' -> "+f[1]); requestMap.put(f[0],f[2]); profileMap.put(f[0],f[3]); String oldProfForRequest=req2ProfileMap.get(f[2]); if(oldProfForRequest!=null&&!oldProfForRequest.equals(f[3])){ log.error("request "+f[2]+" associated with two profiles: "+ oldProfForRequest+" and "+f[3]); badInput(line,constraintFileName,in); } req2ProfileMap.put(f[2],f[3]); } } in.close(); }catch(IOException ex){ throw new IllegalArgumentException("can't load from "+constraintFileName+ ": "+ex.toString()); } } private void badInput(String line,String fileName,LineNumberReader in){ throw new IllegalStateException("Bad input at "+fileName+" line "+ in.getLineNumber()+": "+line); } @Override public void split(Iterator<T> it){ // collect set of users, and also set of requests // maintaining list of all examples with each request List<T> inputList=asList(it); Set<String> users=new HashSet<String>(); Set<String> requests=new HashSet<String>(); for(Iterator<T> i=inputList.iterator();i.hasNext();){ T example=i.next(); if(!(example instanceof HasSubpopulationId)) badExample(example,"doesn't have a subpopulationId"); HasSubpopulationId hsi=(HasSubpopulationId)example; String subpop=hsi.getSubpopulationId(); String userId=userMap.get(subpop); if(userId==null){ badExample(example,"no userId for "+subpop+" in the constraint file"); } users.add(userId); String reqId=requestMap.get(subpop); requests.add(reqId); } //split users Splitter<String> userSplitter=new RandomSplitter<String>(fraction); userSplitter.split(users.iterator()); Set<String> testUsers=asSet(userSplitter.getTest(0)); if(log.isDebugEnabled()) log.debug("testUsers = "+testUsers); // do cross-val split of requests stratified by profile List<String> requestList=new ArrayList<String>(requests.size()); requestList.addAll(requests); Comparator<String> byProfile=new Comparator<String>(){ @Override public int compare(String s1,String s2){ //System.out.println("comparing "+o1+" and "+o2); String prof1=req2ProfileMap.get(s1); String prof2=req2ProfileMap.get(s2); return prof1.compareTo(prof2); } }; Collections.shuffle(requestList); Collections.sort(requestList,byProfile); List<Set<String>> partition=new ArrayList<Set<String>>(folds); for(int k=0;k<folds;k++){ partition.add(new HashSet<String>()); } for(int i=0;i<requestList.size();i++){ partition.get(i%folds).add(requestList.get(i)); } if(log.isDebugEnabled()){ for(int k=0;k<folds;k++){ Set<String> profilesForPartition=new TreeSet<String>(); for(Iterator<String> j=partition.get(k).iterator();j.hasNext();) profilesForPartition.add(req2ProfileMap.get(j.next())); log.debug("partition "+k+": "+partition.get(k)+" profiles: "+ profilesForPartition); } } // allocate the test and training lists trainList=new ArrayList<List<T>>(folds); testList=new ArrayList<List<T>>(folds); for(int k=0;k<folds;k++){ trainList.add(new ArrayList<T>()); testList.add(new ArrayList<T>()); } // populate them for(Iterator<T> i=inputList.iterator();i.hasNext();){ T item=i.next(); HasSubpopulationId hsi=(HasSubpopulationId)item; String subpop=hsi.getSubpopulationId(); String userId=userMap.get(subpop); String reqId=requestMap.get(subpop); int k=partitionContaining(partition,reqId); if(testUsers.contains(userId)){ testList.get(k).add(item); }else{ for(int j=0;j<folds;j++){ if(j!=k) trainList.get(j).add(item); } } } verifySplit(); } private void badExample(Object o,String msg){ throw new IllegalArgumentException(msg+" on input "+o); } private int partitionContaining(List<Set<String>> partition,String req){ for(int i=0;i<partition.size();i++){ if(partition.get(i).contains(req)) return i; } throw new IllegalStateException("request id "+req+ " not found in partition???"); } // check correctness of split private void verifySplit(){ for(int k=0;k<folds;k++){ for(int i=0;i<trainList.get(k).size();i++){ Object oi=trainList.get(k).get(i); for(int j=0;j<testList.get(k).size();j++){ Object oj=testList.get(k).get(j); if(similarTo(oi,oj)) throw new IllegalStateException("bad split for train/test "+oi+"/"+ oj); } } } for(int k1=0;k1<folds;k1++){ for(int k2=0;k2<folds;k2++){ if(k2!=k1){ for(int j1=0;j1<testList.get(k1).size();j1++){ for(int j2=0;j2<testList.get(k2).size();j2++){ if(testList.get(k1).get(j1)==testList.get(k2).get(j2)){ throw new IllegalStateException( "overlapping test cases for lists "+k1+" and "+k2); } } } } } } } private boolean similarTo(Object o1,Object o2){ String subpop1=((HasSubpopulationId)o1).getSubpopulationId(); String subpop2=((HasSubpopulationId)o2).getSubpopulationId(); if(userMap.get(subpop1).equals(userMap.get(subpop2))) return true; if(requestMap.get(subpop1).equals(requestMap.get(subpop2))) return true; return false; } @Override public int getNumPartitions(){ return folds; } @Override public Iterator<T> getTrain(int k){ return trainList.get(k).iterator(); } @Override public Iterator<T> getTest(int k){ return testList.get(k).iterator(); } @Override public String toString(){ return "[WebmasterSplitter "+folds+"]"; } static public void main(String args[]){ try{ String file=args[0]; WebmasterSplitter<HasSubpopulationId> splitter= new WebmasterSplitter<HasSubpopulationId>(file,0.7,StringUtil.atoi(args[1])); List<HasSubpopulationId> list=new ArrayList<HasSubpopulationId>(); LineNumberReader in= new LineNumberReader(new FileReader(new File(args[0]))); String line=null; while((line=in.readLine())!=null){ if(!line.startsWith("#")){ String[] f=line.split(" "); final String subpop=f[0]; list.add(new HasSubpopulationId(){ @Override public String toString(){ return "[Ex "+subpop+"]"; } @Override public String getSubpopulationId(){ return subpop; } }); } } splitter.split(list.iterator()); int totTestSize=0; int totTrainSize=0; for(int k=0;k<splitter.getNumPartitions();k++){ totTestSize+=splitter.asList(splitter.getTest(k)).size(); totTrainSize+=splitter.asList(splitter.getTrain(k)).size(); System.out.println("fold "+k+":"); System.out.println("test: "+splitter.testList.get(k)); System.out.println("train: "+splitter.trainList.get(k)); } System.out.println("data.size = "+list.size()); System.out.println("total test size="+totTestSize); System.out.println("total train size="+totTrainSize); }catch(Exception e){ e.printStackTrace(); System.out.println("usage: WebmasterSplitter constraint-file #folds"); } } private List<T> asList(Iterator<T> i){ List<T> accum=new ArrayList<T>(); while(i.hasNext()) accum.add(i.next()); return accum; } private Set<String> asSet(Iterator<String> i){ Set<String> accum=new HashSet<String>(); while(i.hasNext()) accum.add(i.next()); return accum; } }