package org.apache.lucene.classification.utils; /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.TextField; import org.apache.lucene.index.AtomicReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.StorableField; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.util.Version; import java.io.IOException; /** * Utility class for creating training / test / cross validation indexes from the original index. */ public class DatasetSplitter { private final double crossValidationRatio; private final double testRatio; /** * Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes * * @param testRatio the ratio of the original index to be used for the test IDX as a <code>double</code> between 0.0 and 1.0 * @param crossValidationRatio the ratio of the original index to be used for the c.v. IDX as a <code>double</code> between 0.0 and 1.0 */ public DatasetSplitter(double testRatio, double crossValidationRatio) { this.crossValidationRatio = crossValidationRatio; this.testRatio = testRatio; } /** * Split a given index into 3 indexes for training, test and cross validation tasks respectively * * @param originalIndex an {@link AtomicReader} on the source index * @param trainingIndex a {@link Directory} used to write the training index * @param testIndex a {@link Directory} used to write the test index * @param crossValidationIndex a {@link Directory} used to write the cross validation index * @param analyzer {@link Analyzer} used to create the new docs * @param fieldNames names of fields that need to be put in the new indexes or <code>null</code> if all should be used * @throws IOException if any writing operation fails on any of the indexes */ public void split(AtomicReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, Analyzer analyzer, String... fieldNames) throws IOException { // create IWs for train / test / cv IDXs // :Post-Release-Update-Version.LUCENE_XY: IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer)); IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer)); IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(Version.LUCENE_50, analyzer)); try { int size = originalIndex.maxDoc(); IndexSearcher indexSearcher = new IndexSearcher(originalIndex); TopDocs topDocs = indexSearcher.search(new MatchAllDocsQuery(), Integer.MAX_VALUE); // set the type to be indexed, stored, with term vectors FieldType ft = new FieldType(TextField.TYPE_STORED); ft.setStoreTermVectors(true); ft.setStoreTermVectorOffsets(true); ft.setStoreTermVectorPositions(true); int b = 0; // iterate over existing documents for (ScoreDoc scoreDoc : topDocs.scoreDocs) { // create a new document for indexing Document doc = new Document(); if (fieldNames != null && fieldNames.length > 0) { for (String fieldName : fieldNames) { doc.add(new Field(fieldName, originalIndex.document(scoreDoc.doc).getField(fieldName).stringValue(), ft)); } } else { for (StorableField storableField : originalIndex.document(scoreDoc.doc).getFields()) { if (storableField.readerValue() != null) { doc.add(new Field(storableField.name(), storableField.readerValue(), ft)); } else if (storableField.binaryValue() != null) { doc.add(new Field(storableField.name(), storableField.binaryValue(), ft)); } else if (storableField.stringValue() != null) { doc.add(new Field(storableField.name(), storableField.stringValue(), ft)); } else if (storableField.numericValue() != null) { doc.add(new Field(storableField.name(), storableField.numericValue().toString(), ft)); } } } // add it to one of the IDXs if (b % 2 == 0 && testWriter.maxDoc() < size * testRatio) { testWriter.addDocument(doc); } else if (cvWriter.maxDoc() < size * crossValidationRatio) { cvWriter.addDocument(doc); } else { trainingWriter.addDocument(doc); } b++; } } catch (Exception e) { throw new IOException(e); } finally { testWriter.commit(); cvWriter.commit(); trainingWriter.commit(); // close IWs testWriter.close(); cvWriter.close(); trainingWriter.close(); } } }