/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.hadoop.stochasticsvd.qr;
import java.io.Closeable;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import org.apache.commons.lang.Validate;
import org.apache.mahout.common.iterator.CopyConstructorIterator;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.stochasticsvd.DenseBlockWritable;
import org.apache.mahout.math.hadoop.stochasticsvd.UpperTriangular;
import com.google.common.collect.Lists;
/**
* Second/last step of QR iterations. Takes input of qtHats and rHats and
* provides iterator to pull ready rows of final Q.
*
*/
public class QRLastStep implements Closeable, Iterator<Vector> {
private final Iterator<DenseBlockWritable> qHatInput;
private final List<UpperTriangular> mRs = Lists.newArrayList();
private final int blockNum;
private double[][] mQt;
private int cnt;
private int r;
private int kp;
private Vector qRow;
/**
*
* @param qHatInput
* the Q-Hat input that was output in the first step
* @param rHatInput
* all RHat outputs int the group in order of groups
* @param blockNum
* our RHat number in the group
* @throws IOException
*/
public QRLastStep(Iterator<DenseBlockWritable> qHatInput,
Iterator<VectorWritable> rHatInput,
int blockNum) {
this.blockNum = blockNum;
this.qHatInput = qHatInput;
/*
* in this implementation we actually preload all Rs into memory to make R
* sequence modifications more efficient.
*/
int block = 0;
while (rHatInput.hasNext()) {
Vector value = rHatInput.next().get();
if (block < blockNum && block > 0) {
GivensThinSolver.mergeR(mRs.get(0), new UpperTriangular(value));
} else {
mRs.add(new UpperTriangular(value));
}
block++;
}
}
private boolean loadNextQt() {
boolean more = qHatInput.hasNext();
if (!more) {
return false;
}
DenseBlockWritable v = qHatInput.next();
mQt =
GivensThinSolver
.computeQtHat(v.getBlock(),
blockNum == 0 ? 0 : 1,
new CopyConstructorIterator<UpperTriangular>(mRs
.iterator()));
r = mQt[0].length;
kp = mQt.length;
if (qRow == null) {
qRow = new DenseVector(kp);
}
return true;
}
@Override
public boolean hasNext() {
if (mQt != null && cnt == r) {
mQt = null;
}
boolean result = true;
if (mQt == null) {
result = loadNextQt();
cnt = 0;
}
return result;
}
@Override
public Vector next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
Validate.isTrue(hasNext(), "Q input overrun");
/*
* because Q blocks are initially stored in inverse order
*/
int qRowIndex = r - cnt - 1;
for (int j = 0; j < kp; j++) {
qRow.setQuick(j, mQt[j][qRowIndex]);
}
cnt++;
return qRow;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
@Override
public void close() throws IOException {
mQt = null;
mRs.clear();
}
}