package tv.dyndns.kishibe.qmaclone.server; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import libsvm.svm; import libsvm.svm_model; import libsvm.svm_node; import libsvm.svm_parameter; import libsvm.svm_problem; import tv.dyndns.kishibe.qmaclone.client.constant.Constant; import tv.dyndns.kishibe.qmaclone.client.game.ProblemGenre; import tv.dyndns.kishibe.qmaclone.client.packet.NewAndOldProblems; import tv.dyndns.kishibe.qmaclone.client.packet.PacketProblemMinimum; import tv.dyndns.kishibe.qmaclone.client.packet.PacketThemeQuery; import tv.dyndns.kishibe.qmaclone.server.database.Database; import tv.dyndns.kishibe.qmaclone.server.database.DatabaseException; import tv.dyndns.kishibe.qmaclone.server.util.IntArray; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.primitives.Doubles; import com.google.inject.Inject; public class ThemeModeProblemManager extends ProblemManager { private static final Logger logger = Logger.getLogger(ThemeModeProblemManager.class.toString()); private final Database database; private final ThreadPool threadPool; /** * テーマと検索クエリのマップ。 */ private volatile Map<String, IntArray> themeToProblems; /** * ジャンル毎のテーマリスト。一時間毎に更新される。 */ private volatile List<List<String>> themes; @Inject public ThemeModeProblemManager(Database database, ThreadPool threadPool) { super(database); this.database = database; this.threadPool = threadPool; } private void initializeIfNotInitialized() { if (themeToProblems == null) { synchronized (this) { if (themeToProblems == null) { try { updateProblemTablesForThemeMode(); } catch (DatabaseException e) { logger.log(Level.WARNING, "テーマモードの読み込みに失敗しました", e); } threadPool.addHourTask(new Runnable() { public void run() { try { updateProblemTablesForThemeMode(); } catch (DatabaseException e) { logger.log(Level.WARNING, "テーマモードの読み込みに失敗しました", e); } } }); } } } } private void updateProblemTablesForThemeMode() throws DatabaseException { List<List<String>> themes = Lists.newArrayList(); for (int i = 0; i < ProblemGenre.values().length; ++i) { themes.add(new ArrayList<String>()); } Map<String, IntArray> themeToProblems = database.getThemeToProblems(getThemeModeQueries()); double[] min = new double[ProblemGenre.values().length]; double[] max = new double[ProblemGenre.values().length]; svm_model model = createSvmModel(themeToProblems, min, max); for (Entry<String, IntArray> entry : themeToProblems.entrySet()) { String theme = entry.getKey(); IntArray problemIds = entry.getValue(); // 問題数が少なすぎる場合はロビーに表示しない if (problemIds.size() < Constant.MIN_NUMBER_OF_THEME_MODE_PROBLEMS) { continue; } svm_node[] x = createNode(problemIds); scale(x, min, max); double[] prob = new double[ProblemGenre.values().length]; double y = svm.svm_predict_probability(model, x, prob); ProblemGenre themeBySvm = ProblemGenre.values()[(int) Math.rint(y)]; themes.get(themeBySvm.getIndex()).add(theme); } for (List<String> list : themes) { Collections.sort(list); } this.themeToProblems = themeToProblems; this.themes = themes; } private Map<String, List<String>> getThemeModeQueries() throws DatabaseException { Map<String, List<String>> themetoQueries = Maps.newHashMap(); for (PacketThemeQuery query : database.getThemeModeQueries()) { if (!themetoQueries.containsKey(query.getTheme())) { themetoQueries.put(query.getTheme(), new ArrayList<String>()); } themetoQueries.get(query.getTheme()).add(query.query); } return themetoQueries; } public PacketProblemMinimum selectProblem(String theme, int difficultSelect, int classLevel, Set<Integer> selectedProblemIds) throws Exception { initializeIfNotInitialized(); // 難易度調整 switch (difficultSelect) { case Constant.DIFFICULT_SELECT_DIFFICULT: classLevel = Constant.CLASS_LEVEL_DIFFICULT; break; case Constant.DIFFICULT_SELECT_LITTLE_DIFFICULT: classLevel = Constant.CLASS_LEVEL_LITTLE_DIFFICULT; break; case Constant.DIFFICULT_SELECT_LITTLE_EASY: classLevel = Constant.CLASS_LEVEL_LITTLE_EASY; break; case Constant.DIFFICULT_SELECT_EASY: classLevel = Constant.CLASS_LEVEL_EASY; break; case Constant.DIFFICULT_SELECT_NORMAL: classLevel = Constant.CLASS_LEVEL_NORMAL; break; } // 問題の選択 PacketProblemMinimum data = null; IntArray problemIds = themeToProblems.get(theme); for (int findLoop = 0; findLoop < MAX_FIND_LOOP && data == null; ++findLoop) { data = selectProblemFromList(problemIds, selectedProblemIds, classLevel, NewAndOldProblems.Both, false, new HashSet<Integer>(), new HashSet<Integer>(), false); } if (data == null && difficultSelect != Constant.DIFFICULT_SELECT_NORMAL) { // 問題が選択されなかった場合は全難易度から選択しなおす data = selectProblem(theme, Constant.DIFFICULT_SELECT_NORMAL, classLevel, selectedProblemIds); } if (data == null) { throw new Exception("問題が見つかりませんでした " + theme + " " + difficultSelect + " " + classLevel + " " + selectedProblemIds); } selectedProblemIds.add(data.id); return data; } /** * テーマモード一覧を返す。 * * @return [ジャンル][テーマ] */ public List<List<String>> getThemes() { initializeIfNotInitialized(); return themes; } /** * テーマモード問題一覧を返す * * @return テーマモード問題一覧 */ public Map<String, IntArray> getThemesAndProblems() { initializeIfNotInitialized(); return themeToProblems; } private static final Map<String, ProblemGenre> LEARNING_DATA = ImmutableMap .<String, ProblemGenre> builder().put("数字で答えなさい", ProblemGenre.Random) .put("「唯一」", ProblemGenre.Random).put("???", ProblemGenre.Random) .put("ガンダム", ProblemGenre.Anige).put("ドラえもん", ProblemGenre.Anige) .put("コナミ", ProblemGenre.Anige).put("名言・名台詞", ProblemGenre.Anige) .put("アイドル", ProblemGenre.Geinou).put("映画", ProblemGenre.Geinou) .put("お笑い", ProblemGenre.Geinou).put("プロ野球", ProblemGenre.Sports) .put("ワールドカップ", ProblemGenre.Sports).put("格闘技", ProblemGenre.Sports) .put("ファッション", ProblemGenre.Zatsugaku).put("漢字", ProblemGenre.Zatsugaku) .put("新聞", ProblemGenre.Zatsugaku).put("神話", ProblemGenre.Gakumon) .put("日本史", ProblemGenre.Gakumon).put("科学", ProblemGenre.Gakumon).build(); private svm_model createSvmModel(Map<String, IntArray> themeToProblems, double[] min, double[] max) throws DatabaseException { Preconditions.checkArgument(min.length == ProblemGenre.values().length); Preconditions.checkArgument(max.length == ProblemGenre.values().length); for (int i = 0; i < min.length; ++i) { min[i] = Double.POSITIVE_INFINITY; max[i] = Double.NEGATIVE_INFINITY; } svm_parameter param = new svm_parameter(); // default values param.svm_type = svm_parameter.C_SVC; param.kernel_type = svm_parameter.RBF; param.degree = 3; param.gamma = 1.0 / ProblemGenre.values().length; // 1/num_features param.coef0 = 0; param.nu = 0.5; param.cache_size = 100; param.C = 1; param.eps = 1e-3; param.p = 0.1; param.shrinking = 1; param.probability = 0; param.nr_weight = 0; param.weight_label = new int[0]; param.weight = new double[0]; List<Double> y = Lists.newArrayList(); List<svm_node[]> x = Lists.newArrayList(); for (Entry<String, ProblemGenre> entry : LEARNING_DATA.entrySet()) { String theme = entry.getKey(); IntArray problemIds = themeToProblems.get(theme); if (problemIds == null) { continue; } y.add((double) entry.getValue().getIndex()); svm_node[] node = createNode(problemIds); for (int i = 0; i < ProblemGenre.values().length; ++i) { min[i] = Math.min(min[i], node[i].value); max[i] = Math.max(max[i], node[i].value); } x.add(node); } for (svm_node[] node : x) { scale(node, min, max); } svm_problem problem = new svm_problem(); problem.l = y.size(); problem.y = Doubles.toArray(y); problem.x = x.toArray(new svm_node[0][]); return svm.svm_train(problem, param); } private svm_node[] createNode(IntArray problemIds) throws DatabaseException { svm_node[] node = new svm_node[ProblemGenre.values().length]; for (int i = 0; i < node.length; ++i) { node[i] = new svm_node(); node[i].index = i + 1; } for (int problemId : problemIds.data()) { ++node[database.getProblemMinimum(problemId).genre.getIndex()].value; } for (svm_node e : node) { e.value /= problemIds.data().length; } return node; } private void scale(svm_node[] x, double[] min, double[] max) { Preconditions.checkArgument(x.length == min.length); Preconditions.checkArgument(x.length == max.length); for (int i = 0; i < x.length; ++i) { if (max[i] == min[i]) { x[i].value = 0.0; } else { x[i].value = (x[i].value - min[i]) / (max[i] - min[i]) * 2.0 - 1.0; } } } }