/**
* Copyright 2015, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.experiment;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.List;
import org.kohsuke.args4j.Option;
import edu.emory.clir.clearnlp.util.BinUtils;
import edu.emory.clir.clearnlp.util.FileUtils;
import edu.emory.clir.clearnlp.util.IOUtils;
import edu.emory.clir.clearnlp.util.constant.StringConst;
/**
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
public class CVCreate
{
@Option(name="-t", usage="path to training files (required)", required=true, metaVar="<filepath>")
protected String s_trainPath;
@Option(name="-te", usage="training file extension (default: *)", required=false, metaVar="<regex>")
protected String s_trainExt = "*";
@Option(name="-o", usage="path to output files (required)", required=true, metaVar="<filepath>")
protected String s_outputPath;
@Option(name="-cv", usage="number of cross-validation sets (default: 10)", required=false, metaVar="<integer>")
protected int n_cv = 10;
public CVCreate() {}
public CVCreate(String[] args)
{
BinUtils.initArgs(args, this);
try
{
List<String> trainFiles = FileUtils.getFileList(s_trainPath, s_trainExt, false);
create(trainFiles, s_outputPath, n_cv);
}
catch (IOException e) {e.printStackTrace();}
}
public void create(List<String> trainFiles, String outputPath, final int N) throws IOException
{
PrintStream[] fout = createPrintStreams(outputPath, N);
BufferedReader fin;
String line;
int cv = 0;
for (String trainFile : trainFiles)
{
fin = IOUtils.createBufferedReader(trainFile);
while ((line = fin.readLine()) != null)
{
fout[cv].println(line);
if (line.trim().equals(StringConst.EMPTY))
cv = (cv + 1) % N;
}
fin.close();
}
for (PrintStream f : fout) f.close();
}
private PrintStream[] createPrintStreams(String outputPath, final int N)
{
PrintStream[] fout = new PrintStream[N];
for (int i=0; i<N; i++)
fout[i] = IOUtils.createBufferedPrintStream(outputPath+"/"+i+".cv");
return fout;
}
static public void main(String[] args)
{
new CVCreate(args);
}
}