/**
* Copyright 2008 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.sf.katta.lib.lucene;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import net.sf.katta.util.WritableType;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.log4j.Logger;
public class HitsMapWritable implements Writable {
private final static Logger LOG = Logger.getLogger(HitsMapWritable.class);
private String _nodeName;
private int _totalHits;
private WritableType[] _sortFieldTypes;
private List<Hit> _hits;
private Set<String> _shards;
public HitsMapWritable() {
// for serialization
}
public HitsMapWritable(final String nodeName) {
_nodeName = nodeName;
_hits = new ArrayList<Hit>();
_shards = new HashSet<String>();
}
public void readFields(final DataInput in) throws IOException {
long start = 0;
if (LOG.isDebugEnabled()) {
start = System.currentTimeMillis();
}
_nodeName = in.readUTF();
_totalHits = in.readInt();
byte sortFieldTypesLen = in.readByte();
if (sortFieldTypesLen > 0) {
_sortFieldTypes = new WritableType[sortFieldTypesLen];
for (int i = 0; i < sortFieldTypesLen; i++) {
_sortFieldTypes[i] = WritableType.values()[in.readByte()];
}
}
if (LOG.isDebugEnabled()) {
LOG.debug("HitsMap reading start at: " + start + " for server " + _nodeName);
}
final int shardCount = in.readInt();
HashMap<Byte, String> shardByShardIndex = new HashMap<Byte, String>(shardCount);
_shards = new HashSet<String>(shardCount);
for (int i = 0; i < shardCount; i++) {
String shardName = in.readUTF();
shardByShardIndex.put((byte) i, shardName);
_shards.add(shardName);
}
final int hitCount = in.readInt();
_hits = new ArrayList<Hit>(hitCount + 1);
for (int i = 0; i < hitCount; i++) {
final byte shardIndex = in.readByte();
final float score = in.readFloat();
final int docId = in.readInt();
final String shard = shardByShardIndex.get(shardIndex);
final Hit hit;
if (sortFieldTypesLen > 0) {
hit = new Hit(shard, _nodeName, score, docId, _sortFieldTypes);
} else {
hit = new Hit(shard, _nodeName, score, docId);
}
addHit(hit);
byte sortFieldsLen = in.readByte();
if (sortFieldsLen > 0) {
WritableComparable[] sortFields = new WritableComparable[sortFieldsLen];
for (int k = 0; k < sortFieldsLen; k++) {
sortFields[k] = _sortFieldTypes[k].newWritableComparable();
sortFields[k].readFields(in);
}
hit.setSortFields(sortFields);
}
}
if (LOG.isDebugEnabled()) {
final long end = System.currentTimeMillis();
LOG.debug("HitsMap reading of " + hitCount + " entries took " + (end - start) / 1000.0 + "sec.");
}
}
public void write(final DataOutput out) throws IOException {
long start = 0;
if (LOG.isDebugEnabled()) {
start = System.currentTimeMillis();
}
out.writeUTF(_nodeName);
out.writeInt(_totalHits);
if (_sortFieldTypes == null) {
out.writeByte(0);
} else {
out.writeByte(_sortFieldTypes.length);
for (WritableType writableType : _sortFieldTypes) {
out.writeByte(writableType.ordinal());
}
}
int shardCount = _shards.size();
out.writeInt(shardCount);
byte shardIndex = 0;
Map<String, Byte> shardIndexByShard = new HashMap<String, Byte>(shardCount);
for (String shard : _shards) {
out.writeUTF(shard);
shardIndexByShard.put(shard, shardIndex);
shardIndex++;
}
out.writeInt(_hits.size());
for (Hit hit : _hits) {
out.writeByte(shardIndexByShard.get(hit.getShard()));
out.writeFloat(hit.getScore());
out.writeInt(hit.getDocId());
WritableComparable[] sortFields = hit.getSortFields();
if (sortFields == null) {
out.writeByte(0);
} else {
out.writeByte(sortFields.length);
for (Writable writable : sortFields) {
writable.write(out);
}
}
}
if (LOG.isDebugEnabled()) {
final long end = System.currentTimeMillis();
LOG.debug("HitsMap writing took " + (end - start) / 1000.0 + "sec.");
LOG.debug("HitsMap writing ended at: " + end + " for server " + _nodeName);
}
}
public void addHit(final Hit hit) {
_hits.add(hit);
_shards.add(hit.getShard());
}
/**
* @deprecated use {@link #addHit(Hit)} instead
*/
public void addHitToShard(final String shard, final Hit hit) {
addHit(hit);
}
/**
* @deprecated use {@link #getNodeName()} instead
*/
public String getServerName() {
return getNodeName();
}
public String getNodeName() {
return _nodeName;
}
public List<Hit> getHitList() {
return _hits;
}
/**
* @deprecated use {@link #getHitList()} instead
*/
public Hits getHits() {
final Hits result = new Hits();
result.setTotalHits(_totalHits);
result.addHits(_hits);
return result;
}
public void addTotalHits(final int length) {
_totalHits += length;
}
public int getTotalHits() {
return _totalHits;
}
public WritableType[] getSortFieldTypes() {
return _sortFieldTypes;
}
public void setSortFieldTypes(WritableType[] sortFieldTypes) {
_sortFieldTypes = sortFieldTypes;
}
}