/**
* Implements operations for managing the model table of the aidr_predict DB
*
* @author Koushik
*/
package qa.qcri.aidr.dbmanager.ejb.remote.facade.imp;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import org.apache.log4j.Logger;
import org.hibernate.Criteria;
import org.hibernate.Hibernate;
import org.hibernate.HibernateException;
import org.hibernate.Query;
import org.hibernate.Session;
import org.hibernate.criterion.Order;
import org.hibernate.criterion.Restrictions;
import qa.qcri.aidr.common.exception.PropertyNotSetException;
import qa.qcri.aidr.dbmanager.dto.ModelDTO;
import qa.qcri.aidr.dbmanager.dto.taggerapi.ModelHistoryWrapper;
import qa.qcri.aidr.dbmanager.dto.taggerapi.ModelWrapper;
import qa.qcri.aidr.dbmanager.dto.taggerapi.TrainingDataDTO;
import qa.qcri.aidr.dbmanager.ejb.local.facade.impl.CoreDBServiceFacadeImp;
import qa.qcri.aidr.dbmanager.ejb.remote.facade.CollectionResourceFacade;
import qa.qcri.aidr.dbmanager.ejb.remote.facade.MiscResourceFacade;
import qa.qcri.aidr.dbmanager.ejb.remote.facade.ModelResourceFacade;
import qa.qcri.aidr.dbmanager.entities.misc.Collection;
import qa.qcri.aidr.dbmanager.entities.model.Model;
import qa.qcri.aidr.dbmanager.entities.model.ModelFamily;
import qa.qcri.aidr.dbmanager.entities.model.ModelNominalLabel;
import qa.qcri.aidr.util.NativeQueryUtil;
@Stateless(name="ModelResourceFacadeImp")
public class ModelResourceFacadeImp extends CoreDBServiceFacadeImp<Model, Long> implements ModelResourceFacade {
private static Logger logger = Logger.getLogger("db-manager-log");
@EJB
private CollectionResourceFacade collectionResourceFacade;
@EJB
private MiscResourceFacade miscResourceFacade;
public ModelResourceFacadeImp() {
super(Model.class);
}
@Override
public List<ModelDTO> getAllModels() throws PropertyNotSetException {
List<ModelDTO> modelDTOList = new ArrayList<ModelDTO>();
List<Model> modelList = getAll();
logger.info("Fetched models list size: " + modelList.size());
for (Model model : modelList) {
modelDTOList.add(new ModelDTO(model));
}
return modelDTOList;
}
@Override
public ModelDTO getModelByID(Long id) throws PropertyNotSetException {
return new ModelDTO(getById(id));
}
/*
Use this method to get number of models associated with a model family.
*/
@Override
@SuppressWarnings("unchecked")
public Integer getModelCountByModelFamilyID(Long modelFamilyID) throws PropertyNotSetException {
Criteria criteria = getCurrentSession().createCriteria(Model.class);
criteria.add(Restrictions.eq("modelFamily.modelFamilyId", modelFamilyID));
List<Model> modelList = criteria.list();
return modelList != null ? Integer.valueOf(modelList.size()) : 0;
}
@Override
@SuppressWarnings("unchecked")
public List<ModelHistoryWrapper> getModelByModelFamilyID(Long modelFamilyID, Integer start, Integer limit) throws PropertyNotSetException {
return getModelByModelFamilyID(modelFamilyID, start, limit, "trainingTime", "DESC");
}
@Override
@SuppressWarnings("unchecked")
public List<ModelHistoryWrapper> getModelByModelFamilyID(Long modelFamilyID, Integer start, Integer limit,
String sortColumn, String sortDirection) throws PropertyNotSetException {
List<ModelDTO> modelDTOList = new ArrayList<ModelDTO>();
List<ModelHistoryWrapper> wrapperList = new ArrayList<ModelHistoryWrapper>();
Criteria criteria = getCurrentSession().createCriteria(Model.class);
criteria.add(Restrictions.eq("modelFamily.modelFamilyId", modelFamilyID));
if(sortColumn.isEmpty()){
sortColumn = "trainingTime";
}
if(sortDirection.isEmpty()){
sortDirection = "DESC";
criteria.addOrder(Order.desc(sortColumn));
}
else{
if(sortDirection.equalsIgnoreCase("ASC")){
criteria.addOrder(Order.asc(sortColumn));
}
else if (sortDirection.equalsIgnoreCase("DESC")){
criteria.addOrder(Order.desc(sortColumn));
}
}
criteria.setFirstResult(start);
criteria.setMaxResults(limit);
List<Model> modelList = criteria.list();
for (Model model : modelList) {
modelDTOList.add(new ModelDTO(model));
ModelHistoryWrapper w = new ModelHistoryWrapper();
w.setModelID(model.getModelId());
w.setAvgAuc(model.getAvgAuc());
w.setAvgPrecision(model.getAvgPrecision());
w.setAvgRecall(model.getAvgRecall());
w.setTrainingCount(model.getTrainingCount());
w.setTrainingTime(model.getTrainingTime());
wrapperList.add(w);
}
return wrapperList;
}
@Override
public List<ModelWrapper> getModelByCrisisID(Long crisisID) throws PropertyNotSetException{
List<ModelWrapper> modelWrapperList = new ArrayList<ModelWrapper>();
Collection collection = getEntityManager().find(Collection.class, crisisID);
getEntityManager().find(Collection.class, crisisID);
Hibernate.initialize(collection.getModelFamilies());
List<ModelFamily> modelFamilyList = collection.getModelFamilies();
// for each modelFamily get all the models and take avg
for (ModelFamily modelFamily : modelFamilyList) {
Hibernate.initialize(modelFamily.getModels());
List<Model> modelList = modelFamily.getModels();
ModelWrapper modelWrapper = new ModelWrapper();
modelWrapper.setModelFamilyID(modelFamily.getModelFamilyId());
long classifiedElements = 0;
double auc = 0.0;
Long modelID = 0l;
long trainingExamples = 0;
// if size 0 we will get NaN for aucAverage
if (modelList!=null && modelList.size() > 0) {
for (Model model : modelList) {
if (model.isIsCurrentModel()) {
auc = model.getAvgAuc();
modelID = model.getModelId();
//for each model get all the labels and sum over classifiedDocumentCount
Hibernate.initialize(model.getModelNominalLabels());
long totalClassifiedDocuments = 0;
for (ModelNominalLabel label : model.getModelNominalLabels()) {
totalClassifiedDocuments += label.getClassifiedDocumentCount();
}
classifiedElements = totalClassifiedDocuments;
}
}
}
modelWrapper.setTrainingExamples(trainingExamples);
modelWrapper.setAttribute(modelFamily.getNominalAttribute().getName());
modelWrapper.setAttributeID(modelFamily.getNominalAttribute().getNominalAttributeId());
modelWrapper.setAuc(auc);
modelWrapper.setClassifiedDocuments(classifiedElements);
String status = "";
if (modelFamily.isIsActive()) {
status = "Active";
} else {
status = "Inactive";
}
modelWrapper.setStatus(status);
modelWrapper.setModelID(modelID);
modelWrapperList.add(modelWrapper);
}
return modelWrapperList;
}
@Override
public boolean deleteModel(Long modelID) {
Model model = getEntityManager().find(Model.class, modelID);
if (model != null) {
try {
getEntityManager().remove(model);
} catch (HibernateException he) {
logger.error("Hibernate exception on deleting Model using ID=" + model + he.getStackTrace());
return false; // hibernate delete operation failed. Details in the logs.
}
return true; // successfully deleted.
} else {
return false; // delete operation failed becuase no modelfamily is found against given ID.
}
}
}