package com.antbrains.crf;
import java.awt.BorderLayout;
import java.awt.Font;
import java.awt.GridBagLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JTextArea;
import javax.swing.JTextField;
import com.mxgraph.model.mxCell;
import com.mxgraph.swing.mxGraphComponent;
import com.mxgraph.view.mxGraph;
@SuppressWarnings("serial")
public class CRFExplainer extends JFrame {
private JTextField txtInput = new JTextField(60);
private JButton btnExplain = new JButton("分词然后分析");
private JButton btnCalcPath = new JButton("计算当前路径");
private JTextArea txtResult = new JTextArea(15, 80);
private Explanation explanation;
private mxGraph graph = new mxGraph();
private mxGraphComponent graphComponent = new mxGraphComponent(graph);
private mxCell startCell = null;
private CrfModel model;
private TagConvertor tc = new BESB1B2MTagConvertor();
private void clickButton() {
try {
String s = txtInput.getText();
if (s != null && !s.trim().equals("")) {
explanation = SgdCrf.explain(s, model);
int[] tags = explanation.bestTagIds;
String[] txtTags = SgdCrf.tagId2Text(tags, model);
List<String> tks = this.tc.tags2TokenList(txtTags, s);
StringBuilder sb = new StringBuilder();
for (String tk : tks) {
sb.append(tk).append(" ");
}
try {
txtResult.setText(sb.toString().trim());
} catch (Exception ex) {
}
draw();
}
} catch (Exception e) {
}
}
public CRFExplainer(String crfModelPath) throws Exception {
super("CRFs分词解析器");
this.model = SgdCrf.loadModel(crfModelPath);
try {
txtInput.setText("今天的天气不错");
} catch (Exception e) {
}
btnExplain.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent actionevent) {
clickButton();
}
});
btnCalcPath.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent actionevent) {
calcCurrentPath();
}
});
this.setLayout(new BorderLayout());
JPanel cmdFrame = new JPanel();
cmdFrame.add(txtInput);
cmdFrame.add(btnExplain);
cmdFrame.add(btnCalcPath);
JPanel resultFrame = new JPanel();
Font font = txtResult.getFont();
Font newFont = new Font(font.getName(), font.getSize(), 14);
txtResult.setFont(newFont);
JScrollPane scrollPanel = new JScrollPane(txtResult);
resultFrame.add(scrollPanel);
this.draw();
graphComponent.getGraphControl().addMouseListener(new MouseAdapter() {
public void mouseReleased(MouseEvent e) {
Object cell = graphComponent.getCellAt(e.getX(), e.getY());
if (cell instanceof mxCell) {
mxCell mCell = (mxCell) cell;
String id = mCell.getId();
if (id != null) {
if (mCell.isEdge()) {// edge
if (e.getButton() == MouseEvent.BUTTON3) {
graph.removeCells(new Object[] { cell });
} else {
mxCell source = (mxCell) mCell.getSource();
mxCell target = (mxCell) mCell.getTarget();
String sId = source.getId();
String tId = target.getId();
String s = "";
if (sId.equals("vStart")) {
if (tId.equals("vEnd")) {
s = "Error Edge";
graph.removeCells(new Object[] { cell });
} else {
int idx = tId.indexOf(",");
int itemId = Integer.valueOf(tId.substring(1, idx));
int stateId = Integer.valueOf(tId.substring(idx + 1));
if (itemId != 0) {
s = "Error Edge";
graph.removeCells(new Object[] { cell });
} else {
double transitionScore = explanation.bosTransitionWeights[stateId];
mCell.setValue(df.format(transitionScore));
graph.removeCells(new Object[] { cell });
Object parent = graph.getDefaultParent();
graph.insertEdge(parent, null, df.format(transitionScore), source, target);
}
}
} else if (tId.equals("vEnd")) {
if (sId.equals("vStart")) {
s = "Error Edge";
graph.removeCells(new Object[] { cell });
} else {
int idx = sId.indexOf(",");
int itemId = Integer.valueOf(sId.substring(1, idx));
int stateId = Integer.valueOf(sId.substring(idx + 1));
if (itemId != explanation.details.length - 1) {
s = "Error Edge";
graph.removeCells(new Object[] { cell });
} else {
double transitionScore = explanation.eosTransitionWeights[stateId];
mCell.setValue(df.format(transitionScore));
graph.removeCells(new Object[] { cell });
Object parent = graph.getDefaultParent();
graph.insertEdge(parent, null, df.format(transitionScore), source, target);
}
}
} else {
int idx = sId.indexOf(",");
int itemId = Integer.valueOf(sId.substring(1, idx));
int stateId = Integer.valueOf(sId.substring(idx + 1));
int idx2 = tId.indexOf(",");
int itemId2 = Integer.valueOf(tId.substring(1, idx2));
int stateId2 = Integer.valueOf(tId.substring(idx2 + 1));
if (itemId + 1 != itemId2) {
s = "Error Edge";
graph.removeCells(new Object[] { cell });
} else {
double transitionScore = explanation.transitionWeights[stateId
* explanation.details[0].length + stateId2];
mCell.setValue(df.format(transitionScore));
graph.removeCells(new Object[] { cell });
Object parent = graph.getDefaultParent();
graph.insertEdge(parent, null, df.format(transitionScore), source, target);
}
}
}
} else if (mCell.isVertex()) {
if (id.equals("vStart") || id.equals("vEnd")) {
try {
txtResult.setText("");
} catch (Exception ex) {
ex.printStackTrace();
}
} else {
int idx = id.indexOf(",");
int i = Integer.valueOf(id.substring(1, idx));
int j = Integer.valueOf(id.substring(idx + 1));
FeatureWeightScore fws = explanation.details[i][j];
StringBuilder sb = new StringBuilder();
sb.append("total score: " + fws.score + "\n\n");
for (int k = 0; k < fws.features.size(); k++) {
String feature = fws.features.get(k);
double weight = fws.weights.get(k);
sb.append(feature + " " + df.format(weight) + "\n");
}
try {
txtResult.setText(sb.toString());
} catch (Exception ex) {
ex.printStackTrace();
}
}
}
}
graph.repaint();
}
}
});
this.getContentPane().add(cmdFrame, BorderLayout.NORTH);
this.getContentPane().add(graphComponent, BorderLayout.CENTER);
this.getContentPane().add(resultFrame, BorderLayout.SOUTH);
this.clickButton();
}
private int nodeWidth = 60;
private int nodeHeight = 25;
private int nodeXInterval = 20;
private int nodeYInterval = 40;
private static final int WIDTH = 800;
private static final int HEIGHT = 600;
private String nodeStyle = "ROUNDED;fillColor=white;fontColor=blue";
private String edgeStyle = "ROUNDED;fillColor=white;fontColor=blue";
private DecimalFormat df = new DecimalFormat("##.0");
private void draw() {
graph.selectAll();
graph.removeCells();
if (this.explanation == null)
return;
int charNum = explanation.details.length;
int labelNum = explanation.details[0].length;
int totalWidth = charNum * nodeWidth + (charNum - 1) * nodeXInterval + nodeWidth * 2
+ nodeXInterval * 2;
int totalHeight = labelNum * nodeHeight + (labelNum - 1) * nodeYInterval;
int xStart = (WIDTH - totalWidth) / 2;
xStart = Math.max(0, xStart);
int yStart = 30;
// yStart=Math.max(0, yStart);
Object parent = graph.getDefaultParent();
graph.getModel().beginUpdate();
try {
// Object v1 = graph.insertVertex(parent, null, "Hello", 20, 20, 80,
// 30);
// Object v2 = graph.insertVertex(parent, null, "World!",
// 240, 150, 80, 30);
// graph.insertEdge(parent, null, "Edge", v1, v2);
Object v1 = graph.insertVertex(parent, "vStart", "Start", xStart, yStart + totalHeight / 2,
nodeWidth, nodeHeight, nodeStyle);
startCell = (mxCell) v1;
int xoffset = xStart + nodeXInterval + nodeWidth + 2 * nodeWidth;
Object lastNode = v1;
String lastLabel = null;
int lastId = 0;
for (int i = 0; i < charNum; i++) {
String curToken = explanation.tokens.get(i);
int bestTag = explanation.bestTagIds[i];
int yoffset = yStart;
for (int j = 0; j < labelNum; j++) {
String label = explanation.labelTexts[j];
String scoreStr = df.format(explanation.details[i][j].score);
Object v2 = graph.insertVertex(parent, "v" + i + "," + j, curToken + "/" + label + "("
+ scoreStr + ")", xoffset, yoffset, nodeWidth, nodeHeight, nodeStyle);
yoffset += nodeYInterval + nodeHeight;
if (j == bestTag) {
String s = "";
if (i == 0) {
s = df.format(explanation.bosTransitionWeights[j]);
} else {
s = df.format(explanation.transitionWeights[lastId * labelNum + j]);
}
Object edge = graph.insertEdge(parent, null, s, lastNode, v2);
lastNode = v2;
lastLabel = label;
lastId = j;
}
}
xoffset += nodeXInterval + nodeWidth;
}
xoffset += nodeWidth * 2;
Object v3 = graph.insertVertex(parent, "vEnd", "End", xoffset, yStart + totalHeight / 2,
nodeWidth, nodeHeight, nodeStyle);
String s = df.format(explanation.eosTransitionWeights[lastId]);
graph.insertEdge(parent, null, s, lastNode, v3);
} finally {
graph.getModel().endUpdate();
}
graphComponent = new mxGraphComponent(graph);
}
private int[] getIndex(String s) {
int idx = s.indexOf(",");
return new int[] { Integer.valueOf(s.substring(1, idx)), Integer.valueOf(s.substring(idx + 1)) };
}
private mxCell getNextCell(mxCell curCell) {
if (curCell == null)
return null;
List<mxCell> lst = new ArrayList<mxCell>();
for (int i = 0; i < curCell.getEdgeCount(); i++) {
mxCell child = (mxCell) curCell.getEdgeAt(i);
if (child.getSource() != curCell)
continue;
lst.add(child);
}
if (lst.size() == 1)
return (mxCell) lst.get(0).getTarget();
return null;
}
private void calcCurrentPath() {
if (this.explanation == null)
return;
int itemNum = this.explanation.details.length;
int labelNum = this.explanation.labelTexts.length;
if (this.startCell == null)
return;
mxCell curCell = this.startCell;
double score = 0;
int[] lastIndex = null;
while (true) {
mxCell child = this.getNextCell(curCell);
if (child == null) {
try {
txtResult.setText("错误的路径!");
} catch (Exception e) {
}
return;
}
String childId = child.getId();
curCell = child;
if (lastIndex == null) {
int[] curIndex = this.getIndex(childId);
if (curIndex[0] != 0) {
try {
txtResult.setText("错误的路径!");
} catch (Exception e) {
}
return;
}
score += (explanation.bosTransitionWeights[curIndex[1]]);
score += (explanation.details[0][curIndex[1]].score);
lastIndex = curIndex;
continue;
}
if (childId.equals("vEnd")) {
double transitionScore = explanation.eosTransitionWeights[lastIndex[1]];
score += transitionScore;
break;
} else {
int[] curIndex = this.getIndex(childId);
if (lastIndex[0] + 1 != curIndex[0]) {
txtResult.setText("错误的路径!");
return;
}
score += (explanation.transitionWeights[lastIndex[1] * labelNum + curIndex[1]]);
score += (explanation.details[curIndex[0]][curIndex[1]].score);
lastIndex = curIndex;
}
}
txtResult.setText("当前路径得分: " + score);
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
if (args.length != 1) {
System.out.println("Usage: CRFExplainer <model_path>");
System.exit(-1);
}
CRFExplainer frame = new CRFExplainer(args[0]);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.setSize(WIDTH, HEIGHT);
frame.pack();
frame.setVisible(true);
}
}