package processing.hashtag.solr; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.FileReader; import java.io.FileWriter; import java.io.InputStream; import java.io.InputStreamReader; import java.io.ObjectInput; import java.io.ObjectInputStream; import java.io.ObjectOutput; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.util.ArrayList; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; import com.google.common.primitives.Ints; import common.Bookmark; import common.SolrConnector; import file.BookmarkReader; import file.PredictionFileWriter; import file.ResultSerializer; public class SolrHashtagCalculator { private final static int LIMIT = 10; // Statics ---------------------------------------------------------------------------------------------------------------------- public static String predictSample(String sampleDir, String solrCore, String solrUrl) { List<Set<String>> predictionValues = new ArrayList<Set<String>>(); List<Set<String>> realValues = new ArrayList<Set<String>>(); SolrConnector trainConnector = new SolrConnector(solrUrl, solrCore + "_train"); SolrConnector testConnector = new SolrConnector(solrUrl, solrCore + "_test"); Map<String, Set<String>> tweets = testConnector.getTweets(); for (Map.Entry<String, Set<String>> tweet : tweets.entrySet()) { if (tweet.getValue().size() > 0) { Map<String, Double> map = trainConnector.getTopHashtagsForTweetText(tweet.getKey(), LIMIT); predictionValues.add(map.keySet()); realValues.add(tweet.getValue()); if (predictionValues.size() % 100 == 0) { System.out.println(predictionValues.size() + " users done. Left ones: " + (tweets.size() - predictionValues.size())); } } } String suffix = "solrht"; PredictionFileWriter.writeSimplePredictions(predictionValues, realValues, null, sampleDir + "/" + solrCore + "_" + suffix); return suffix; } public static Map<Integer, Map<Integer, Double>> getNormalizedHashtagPredictions(String sampleDir, String solrCore, String solrUrl, BookmarkReader reader, Integer trainHours) { SolrConnector trainConnector = new SolrConnector(solrUrl, solrCore + "_train"); SolrConnector testConnector = new SolrConnector(solrUrl, solrCore + "_test"); Map<Integer, Map<Integer, Double>> hashtagMaps = new LinkedHashMap<Integer, Map<Integer, Double>>(); List<Set<String>> realValues = new ArrayList<Set<String>>(); List<Set<String>> predictionValues = new ArrayList<Set<String>>(); List<String> tweetIDs = new ArrayList<String>(); List<Tweet> tweets = null; if (trainHours == null) { tweets = testConnector.getTweetObjects(true); } else { tweets = testConnector.getTrainTweetObjects(trainConnector, trainHours.intValue()); } for (Tweet tweet : tweets) { if (tweet.getHashtags().size() > 0) { Set<String> predTagIDs = new LinkedHashSet<String>(); Map<Integer, Double> normalizedIntResult = new LinkedHashMap<Integer, Double>(); Map<String, Double> stringResult = trainConnector.getTopHashtagsForTweetText(tweet.getText(), 50); double denom = 0.0; for (Map.Entry<String, Double> e : stringResult.entrySet()) { Integer tID = reader.getTagMap().get(e.getKey().toLowerCase()); if (tID != null) { normalizedIntResult.put(tID.intValue(), e.getValue()); predTagIDs.add(tID.toString()); denom += Math.exp(e.getValue()); } } for (Map.Entry<Integer, Double> e : normalizedIntResult.entrySet()) { e.setValue(Math.exp(e.getValue()) / denom); } Integer uID = reader.getUserMap().get(tweet.getUserid()); if (uID != null) { hashtagMaps.put(uID.intValue(), normalizedIntResult); predictionValues.add(predTagIDs); if (hashtagMaps.size() % 100 == 0) { System.out.println(hashtagMaps.size() + " users done. Left ones: " + (tweets.size() - hashtagMaps.size())); } Set<String> tagIDs = new LinkedHashSet<String>(); for (String t : tweet.getHashtags()) { Integer tID = reader.getTagMap().get(t.toLowerCase()); if (tID != null) { tagIDs.add(tID.toString()); } } realValues.add(tagIDs); tweetIDs.add(uID + "-" + reader.getResourceMap().get(tweet.getId())); } } } //printHashtagPrediction(hashtagMaps, "./data/results/" + sampleDir + "/" + solrCore + "_cbpredictions.txt"); ResultSerializer.serializePredictions(hashtagMaps, "./data/results/" + sampleDir + "/" + solrCore + "_cbpredictions.ser"); PredictionFileWriter.writeSimplePredictions(predictionValues, realValues, tweetIDs, sampleDir + "/" + solrCore + "_solrht_normalized"); return hashtagMaps; } public static String predictTrainSample(String sampleDir, String solrCore, String solrUrl, boolean hours, Integer recentTweetThreshold) { List<Set<String>> predictionValues = new ArrayList<Set<String>>(); List<Set<String>> realValues = new ArrayList<Set<String>>(); SolrConnector trainConnector = new SolrConnector(solrUrl, solrCore + "_train"); SolrConnector testConnector = new SolrConnector(solrUrl, solrCore + "_test"); String suffix = ""; Map<String, Set<String>> userIDs = testConnector.getUserIDs(); for (Map.Entry<String, Set<String>> user : userIDs.entrySet()) { if (user.getValue().size() > 0) { Map<String, Double> map = null; if (recentTweetThreshold == null) { String id = trainConnector.getMostRecentTweetOfUser(user.getKey()); map = trainConnector.getTopHashtagsForTweetID(id, LIMIT); suffix = "solrht_train"; } else { String text = null; if (hours) { text = trainConnector.getTweetTextOfLastHours(user.getKey(), recentTweetThreshold.intValue()); suffix = "solrht_train_" + recentTweetThreshold.intValue() + "hours"; } else { text = trainConnector.getTweetTextOfRecentTweets(user.getKey(), recentTweetThreshold.intValue()); suffix = "solrht_train_" + recentTweetThreshold.intValue(); } map = trainConnector.getTopHashtagsForTweetText(text, LIMIT); } predictionValues.add(map.keySet()); realValues.add(user.getValue()); if (predictionValues.size() % 100 == 0) { System.out.println(predictionValues.size() + " users done. Left ones: " + (userIDs.size() - predictionValues.size())); } } } PredictionFileWriter.writeSimplePredictions(predictionValues, realValues, null, sampleDir + "/" + solrCore + "_" + suffix); return suffix; } public static void printHashtagPrediction(Map<Integer, Map<Integer, Double>> predictions, String filePath) { try { FileWriter writer = new FileWriter(new File(filePath)); BufferedWriter bw = new BufferedWriter(writer); for (Map.Entry<Integer, Map<Integer, Double>> predEntry : predictions.entrySet()) { bw.write(predEntry.getKey() + "|"); int i = 1; for (Map.Entry<Integer, Double> mapEntry : predEntry.getValue().entrySet()) { bw.write(mapEntry.getKey() + ":" + mapEntry.getValue()); if (i++ < predEntry.getValue().size()) { bw.write(";"); } } bw.write("\n"); } bw.close(); } catch (Exception e) { e.printStackTrace(); } } public static Map<Integer, Map<Integer, Double>> deSerializeHashtagPrediction(String filePath) { InputStream file = null; Map<Integer, Map<Integer, Double>> predictions = null; try { file = new FileInputStream(filePath); InputStream buffer = new BufferedInputStream(file); ObjectInput input = new ObjectInputStream(buffer); predictions = (Map<Integer, Map<Integer, Double>>) input.readObject(); input.close(); } catch (Exception e) { e.printStackTrace(); } return predictions; } public static Map<Integer, Map<Integer, Double>> readHashtagPrediction(String filePath) { Map<Integer, Map<Integer, Double>> hashtagMaps = new LinkedHashMap<Integer, Map<Integer, Double>>(); try { //FileReader reader = new FileReader(new File(filePath)); InputStreamReader reader = new InputStreamReader(new FileInputStream(new File(filePath)), "UTF8"); BufferedReader br = new BufferedReader(reader); String line = null; while((line = br.readLine()) != null) { Map<Integer, Double> tagMap = new LinkedHashMap<Integer, Double>(); String[] parts = line.split("\\|"); int userID = Integer.parseInt(parts[0]); if (parts.length > 1) { String[] tags = parts[1].split(";"); for (String t : tags) { String[] tParts = t.split(":"); if (tParts.length > 1 && !tParts[0].equals("null")) { try { tagMap.put(Integer.parseInt(tParts[0]), Double.parseDouble(tParts[1])); } catch (Exception e) { System.out.println("Parse Exception: " + tParts[0] + " " + tParts[1]); } } } } hashtagMaps.put(userID, tagMap); } br.close(); } catch (Exception e) { e.printStackTrace(); } return hashtagMaps; } }