/******************************************************************************* * Copyright 2016 Observational Health Data Sciences and Informatics * * This file is part of WhiteRabbit * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ package org.ohdsi.utilities.files; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; import org.ohdsi.utilities.files.MultiRowIterator.MultiRowSet; /** * Allows iteration over multiple tables (as Iterator<Row>) simultaneously, synchronized by the value of the [linkingColumn]. Assumes all tables are sorted by * the [linkingColumn]. * * @author MSCHUEMI */ public class MultiRowIterator implements Iterator<MultiRowSet> { private Iterator<Row>[] iterators; private String[] tableNames; private Row[] nextRows; private MultiRowSet buffer; private String linkingColumn; private boolean sortedNumerically; @SafeVarargs public MultiRowIterator(String linkingColumn, String[] tableNames, Iterator<Row>... tableIterators) { this(linkingColumn, false, tableNames, tableIterators); } public MultiRowIterator(String linkingColumn, boolean sortedNumerically, String[] tableNames, Iterator<Row>[] tableIterators) { this.iterators = tableIterators; this.linkingColumn = linkingColumn; this.tableNames = tableNames; this.sortedNumerically = sortedNumerically; startRead(); } private void startRead() { nextRows = new Row[iterators.length]; for (int i = 0; i < iterators.length; i++) if (iterators[i].hasNext()) nextRows[i] = iterators[i].next(); else nextRows[i] = null; readNext(); } @Override public boolean hasNext() { return (buffer != null); } @Override public MultiRowSet next() { MultiRowSet result = buffer; readNext(); return result; } private void readNext() { String lowestLinkingColumn = findLowestLinkingColumn(nextRows); if (lowestLinkingColumn == null) { buffer = null; return; } buffer = new MultiRowSet(tableNames); buffer.linkingId = lowestLinkingColumn; for (int i = 0; i < iterators.length; i++) { Iterator<Row> iterator = iterators[i]; while (nextRows[i] != null && nextRows[i].get(linkingColumn).equals(lowestLinkingColumn)) { buffer.get(tableNames[i]).add(nextRows[i]); if (iterator.hasNext()) nextRows[i] = iterator.next(); else nextRows[i] = null; } } } private String findLowestLinkingColumn(Row[] rows) { String linkingId = null; for (Row row : rows) if (row != null && (linkingId == null || compare(row.get(linkingColumn), linkingId) < 0)) linkingId = row.get(linkingColumn); return linkingId; } private int compare(String value1, String value2) { if (sortedNumerically) return efficientLongCompare(value1, value2); else return value1.compareTo(value2); } private int efficientLongCompare(String value1, String value2) { if (value1.length() > value2.length()) return 1; else if (value1.length() < value2.length()) return -1; else return value1.compareTo(value2); } @Override public void remove() { System.err.println("Calling unimplemented remove method in class " + this.getClass().getName()); } public static class MultiRowSet extends HashMap<String, List<Row>> { private static final long serialVersionUID = 1164317535150664720L; public String linkingId; public MultiRowSet(String[] tableNames) { for (String tableName : tableNames) { put(tableName, new ArrayList<Row>()); } } public List<String> getNonEmptyTableNames() { List<String> result = new ArrayList<String>(); for (String tableName : keySet()) if (get(tableName).size() != 0) result.add(tableName); return result; } /** * returns the total number of rows (summed across the tables) * * @return */ public int totalSize() { int size = 0; for (List<Row> rows : values()) size += rows.size(); return size; } } }