package charts.graphics; import java.awt.BasicStroke; import java.awt.Color; import java.awt.Font; import java.awt.Graphics2D; import java.awt.Paint; import java.awt.Shape; import java.awt.Stroke; import java.awt.geom.Line2D; import java.awt.geom.Rectangle2D; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; import org.apache.commons.lang3.StringUtils; import org.jfree.chart.axis.AxisSpace; import org.jfree.chart.axis.AxisState; import org.jfree.chart.axis.CategoryAxis; import org.jfree.chart.axis.CategoryLabelPositions; import org.jfree.chart.plot.Plot; import org.jfree.chart.plot.PlotRenderingInfo; import org.jfree.data.category.CategoryDataset; import org.jfree.data.category.DefaultCategoryDataset; import org.jfree.text.TextBlock; import org.jfree.text.TextBlockAnchor; import org.jfree.text.TextLine; import org.jfree.ui.RectangleEdge; import com.google.common.collect.Lists; import com.google.common.collect.Maps; public class AutoSubCategoryAxis extends CategoryAxis { public static final String DEFAULT_SEPARATOR = "_||_"; private static final Double DEFAULT_CATEGORY_MARGIN = 0.0; public static enum Border { NONE, BETWEEN, ALL } public static class CategoryLabelConfig { final private CategoryLabelPositions position; final private Font font; final private Paint fontPaint; final private double marginTop; final private double marginBottom; final private Border border; final private Paint borderPaint; final private Stroke borderStroke; final private Map<String, String> abbrDict; final private boolean wrap; public CategoryLabelConfig(CategoryLabelPositions position, Font font, Paint fontPaint, double marginTop, double marginBottom, Border border, Paint borderPaint, Stroke borderStroke, Map<String, String> abbrDict, boolean wrap) { super(); this.position = position; this.font = font; this.fontPaint = fontPaint; this.marginTop = marginTop; this.marginBottom = marginBottom; this.border = border; this.borderPaint = borderPaint; this.borderStroke = borderStroke; this.abbrDict = abbrDict; this.wrap = wrap; } public CategoryLabelConfig(CategoryLabelPositions position, Font font, Paint fontPaint, double marginTop, double marginBottom) { this(position, font, fontPaint, marginTop, marginBottom, Border.NONE, null, null, null, false); } public CategoryLabelConfig(Font font, double marginTop, double marginBottom) { this(CategoryLabelPositions.STANDARD, font, Color.black, marginTop, marginBottom, Border.NONE, null, null, null, false); } public CategoryLabelConfig(Font font, double marginTop, double marginBottom, Border border, Paint borderPaint) { this(CategoryLabelPositions.STANDARD, font, Color.black, marginTop, marginBottom, border, borderPaint, new BasicStroke(1), null, false); } public CategoryLabelConfig(Font font, double marginTop, double marginBottom, Border border, Paint borderPaint, Map<String, String> abbrDict, boolean wrap) { this(CategoryLabelPositions.STANDARD, font, Color.black, marginTop, marginBottom, border, borderPaint, new BasicStroke(1), abbrDict, wrap); } public CategoryLabelPositions getPosition() { return position; } public Font getFont() { return font; } public double getMarginTop() { return marginTop; } public double getMarginBottom() { return marginBottom; } public Paint getFontPaint() { return fontPaint; } public Border getBorder() { return border; } public Paint getBorderPaint() { return borderPaint; } public Stroke getBorderStroke() { return borderStroke; } public boolean wrap() { return wrap; } } private class SubCategory { private String name; private SubCategory parent; private List<SubCategory> subcategories = Lists.newArrayList(); public SubCategory() {} public SubCategory(String name, SubCategory parent) { this.name = name; this.parent = parent; } public void addSubCategory(String path) { String[] split = StringUtils.split(path, AutoSubCategoryAxis.this.separator); String name = split[0]; SubCategory sub = getSubCategory(name); if(sub == null) { sub = new SubCategory(name, this); subcategories.add(sub); } if(split.length > 1) { sub.addSubCategory(StringUtils.join(split, AutoSubCategoryAxis.this.separator, 1, split.length)); } } public SubCategory getSubCategory(String name) { for(SubCategory sub : subcategories) { if(sub.name.equals(name)) { return sub; } } return null; } public String getName() { return this.name; } public List<SubCategory> getLeafs() { return getLeafs(new ArrayList<SubCategory>()); } private List<SubCategory> getLeafs(List<SubCategory> leafs) { if(subcategories.isEmpty()) { leafs.add(this); } else { for(SubCategory c : subcategories) { c.getLeafs(leafs); } } return leafs; } public int getDepth() { return getDepth(0); } private int getDepth(int d) { if(parent == null) { return d; } else { return parent.getDepth(d+1); } } public boolean isSibling(SubCategory other) { if(parent == null) { return false; } for(SubCategory s : parent.subcategories) { if(s == other) { return true; } } return false; } public int findSiblingDepth(SubCategory other) { if(isSibling(other)) { return getDepth(); } else if(parent == null) { throw new RuntimeException("root"); } else { return parent.findSiblingDepth(other.parent); } } public int depth() { return depth(0); } private int depth(int depth) { if(subcategories.isEmpty()) { return depth; } else { int d = depth; for(SubCategory c : subcategories) { d = Math.max(d, c.depth(depth+1)); } return d; } } public List<SubCategory> getAllOfDepth(int depth) { return getAllOfDepth(new ArrayList<SubCategory>(), depth, 0); } private List<SubCategory> getAllOfDepth(List<SubCategory> l, int depth, int current) { if(depth == current) { l.add(this); } else { for(SubCategory c : subcategories) { c.getAllOfDepth(l, depth, current+1); } } return l; } public String getPath() { List<String> path = Lists.newArrayList(); SubCategory current = this; while(current.name != null) { path.add(current.name); current = current.parent; } return StringUtils.join(Lists.reverse(path), AutoSubCategoryAxis.this.separator); } public String toString(String indent) { String s = indent + (name!=null?name:"root") + "\n"; for(SubCategory sub : subcategories) { s += sub.toString(indent+" "); } return s; } @Override public String toString() { return toString(""); } } private CategoryDataset dataset; private SubCategory root = new SubCategory(); private Map<Integer, Double> categoryMargins = Maps.newHashMap(); private Map<Integer, CategoryLabelConfig> labelConfigs = Maps.newHashMap(); private String separator = DEFAULT_SEPARATOR; public AutoSubCategoryAxis(CategoryDataset dataset) { this.dataset = dataset; for(int col=0;col<dataset.getColumnCount();col++) { root.addSubCategory(dataset.getColumnKey(col).toString()); } } public String getSeparator() { return separator; } public void setSeparator(String separator) { this.separator = separator; } public void setItemMargin(int depth, double margin) { categoryMargins.put(depth, margin); } public double getItemMargin(int depth) { Double m = categoryMargins.get(depth); if(m == null) { return DEFAULT_CATEGORY_MARGIN; } else { return m; } } public void setCategoryLabelConfig(int depth, CategoryLabelConfig config) { labelConfigs.put(depth, config); } public CategoryLabelConfig getCategoryLabelConfig(int depth) { CategoryLabelConfig config = labelConfigs.get(depth); if(config == null) { config = new CategoryLabelConfig(new Font(Font.DIALOG, Font.PLAIN, 10), 0,0); } return config; } @Override public double getCategoryMargin() { return 0.0; } /** * Returns the starting coordinate for the specified category. * * @param category the category. * @param categoryCount the number of categories. * @param area the data area. * @param edge the axis location. * * @return The coordinate. * * @see #getCategoryMiddle(int, int, Rectangle2D, RectangleEdge) * @see #getCategoryEnd(int, int, Rectangle2D, RectangleEdge) */ @Override public double getCategoryStart(int category, int categoryCount, Rectangle2D area, RectangleEdge edge) { double result = 0.0; if ((edge == RectangleEdge.TOP) || (edge == RectangleEdge.BOTTOM)) { result = area.getX() + area.getWidth() * getLowerMargin(); } else if ((edge == RectangleEdge.LEFT) || (edge == RectangleEdge.RIGHT)) { result = area.getMinY() + area.getHeight() * getLowerMargin(); } double categorySize = calculateCategorySize(categoryCount, area, edge); double margin = 0.0; for(int i = 0; i < category;i++) { margin += categorySize * getMarginAsPercentage(i); } return result += category * categorySize + margin; } public double calculateCategorySize(Rectangle2D area) { return calculateCategorySize(dataset.getColumnCount(), area, RectangleEdge.BOTTOM); } @Override protected double calculateCategorySize(int categoryCount, Rectangle2D area, RectangleEdge edge) { double result = 0.0; double available = 0.0; if ((edge == RectangleEdge.TOP) || (edge == RectangleEdge.BOTTOM)) { available = area.getWidth(); } else if ((edge == RectangleEdge.LEFT) || (edge == RectangleEdge.RIGHT)) { available = area.getHeight(); } List<SubCategory> leafs = root.getLeafs(); if(leafs.size() != categoryCount) { throw new RuntimeException(String.format("leaf size %s does not match category count %s", leafs.size(), categoryCount)); } double margin = 0.0; for(int i = 0;i<leafs.size()-1;i++) { margin += getMarginAsPercentage(i); } result = available * (1 - getLowerMargin() - getUpperMargin()) / (leafs.size() + margin); return result; } /** * Returns the Margin as a percentage of the category size between category1 and category1+1 */ private double getMarginAsPercentage(int category1) { List<SubCategory> leafs = root.getLeafs(); SubCategory c1 = leafs.get(category1); SubCategory c2 = leafs.get(category1+1); return getItemMargin(c1.findSiblingDepth(c2)-1); } @Override protected AxisState drawCategoryLabels(Graphics2D g2, Rectangle2D plotArea, Rectangle2D dataArea, RectangleEdge edge, AxisState state, PlotRenderingInfo plotState) { return drawCategoryLabels(g2, plotArea, dataArea, edge, state, plotState, true); } private String[] labelRows(String name, CategoryLabelConfig config) { // TODO make number of rows configurable return config.wrap?StringUtils.split(name, " ", 2):new String[] {name}; } private TextBlock makeLabelBlock(String[] rows, CategoryLabelConfig config) { Font f = config.getFont(); Paint p = config.getFontPaint(); TextBlock block = new TextBlock(); for(String s : rows) { if(StringUtils.isNotBlank(s)) { TextLine line = new TextLine(s, f, p); block.addLine(line); } } return block; } private String[] shortenLabel(String[] rows, CategoryLabelConfig config) { if(config.abbrDict == null) { return rows; } String[] result = new String[rows.length]; Map<String, String> abbrDict = config.abbrDict; for(int i=0;i<rows.length;i++) { if(StringUtils.isNotBlank(rows[i])) { result[i] = rows[i]; for(Map.Entry<String, String> me : abbrDict.entrySet()) { result[i] = StringUtils.replace(result[i], me.getKey(), me.getValue()); } } } return result; } private String[] shortestLables(String[] rows) { String[] result = new String[rows.length]; for(int i=0;i<rows.length;i++) { if(StringUtils.isNotBlank(rows[i])) { result[i] = rows[i].substring(0, 1); } } return result; } private List<TextBlock> labelAlternatives(String name, CategoryLabelConfig config) { if(config.wrap) { String[] rows = labelRows(name, config); TextBlock b0 = makeLabelBlock(rows, config); TextBlock b1 = makeLabelBlock(shortenLabel(rows, config), config); TextBlock b2 = makeLabelBlock(shortestLables(rows), config); return Lists.newArrayList(b0, b1, b2); } else { return Collections.singletonList(makeLabelBlock(labelRows(name, config), config)); } } private AxisState drawCategoryLabels(Graphics2D g2, Rectangle2D plotArea, Rectangle2D dataArea, RectangleEdge edge, AxisState state, PlotRenderingInfo plotState, boolean render) { if(!isTickLabelsVisible()) { return state; } double y0 = state.getCursor(); for(int depth = root.depth();depth>0;depth--) { List<SubCategory> l = root.getAllOfDepth(depth); CategoryLabelConfig config = getCategoryLabelConfig(depth-1); double yb0 = state.getCursor(); state.cursorDown(config.getMarginTop()); float maxHeight = 0.0f; for(SubCategory c : l) { double angle = 0; if(config.position.equals(CategoryLabelPositions.UP_90)) { angle = -Math.PI/2; } int categoryStartIndex = getCategoryStartIndex(c); int categoryEndIndex = getCategoryEndIndex(c); double start = getCategoryStart(categoryStartIndex, dataset.getColumnCount(), dataArea, edge); double end = getCategoryEnd(categoryEndIndex, dataset.getColumnCount(), dataArea, edge); double middle = start + (end-start)/2; double maxWidth = end-start; g2.setFont(config.getFont()); g2.setPaint(config.getFontPaint()); List<TextBlock> blocks = labelAlternatives(c.getName(), config); Iterator<TextBlock> iter = blocks.iterator(); while(iter.hasNext()) { TextBlock block = iter.next(); Shape b = block.calculateBounds(g2, (float)middle, (float)state.getCursor(), TextBlockAnchor.CENTER, (float)middle, (float)state.getCursor(), angle); Rectangle2D bounds = b.getBounds2D(); if(bounds.getWidth() > maxWidth && config.wrap() && iter.hasNext()) { continue; } float height = (float)(bounds.getMaxY() - bounds.getMinY()); maxHeight = Math.max(height, maxHeight); if(render) { block.draw(g2, (float)middle, (float)state.getCursor()+height/2, TextBlockAnchor.CENTER, (float)middle, (float)state.getCursor()+height/2, angle); } break; } } state.cursorDown(maxHeight); state.cursorDown(config.getMarginBottom()); if(render) { for(SubCategory c : l) { drawBorder(g2, y0, yb0, state.getCursor(), config, dataArea, c); } } } return state; } private void drawBorder(Graphics2D g2, double y0, double yb0, double y1, CategoryLabelConfig config, Rectangle2D area, SubCategory c) { if(config.getBorder() == Border.NONE) { return; } int start = getCategoryStartIndex(c); int end = getCategoryEndIndex(c); g2.setStroke(config.getBorderStroke()); g2.setPaint(config.getBorderPaint()); if(config.getBorder() == Border.BETWEEN) { if(end+1 >= dataset.getColumnCount()) { return; } double m1 = getCategoryMiddle(end, dataset.getColumnCount(), area, RectangleEdge.BOTTOM); double m2 = getCategoryMiddle(end+1, dataset.getColumnCount(), area, RectangleEdge.BOTTOM); double x = (m1+m2)/2; line(g2,x,y0,x,y1); } else if(config.getBorder() == Border.ALL) { double x0; double x1; if(start == end) { x0 = getCategoryStart(start, dataset.getColumnCount(), area, RectangleEdge.BOTTOM); x1 = getCategoryEnd(start, dataset.getColumnCount(), area, RectangleEdge.BOTTOM); } else { x0 = getMiddleX(area, start-1, start); x1 = getMiddleX(area, end, end+1); } line(g2, x0, yb0, x1, yb0); line(g2, x0, y1, x1, y1); line(g2, x0, y0, x0, y1); line(g2, x1, y0, x1, y1); } } private void line(Graphics2D g2, double x1, double y1, double x2, double y2) { g2.draw(new Line2D.Double(x1, y1, x2, y2)); } private double getMiddleX(Rectangle2D area, int index1, int index2) { double x; if(index1 <= 0) { x = area.getMinX(); } else if(index2 >= dataset.getColumnCount()){ x = area.getMaxX(); } else { double m1 = getCategoryMiddle(index1, dataset.getColumnCount(), area, RectangleEdge.BOTTOM); double m2 = getCategoryMiddle(index2, dataset.getColumnCount(), area, RectangleEdge.BOTTOM); x = (m1+m2)/2; } return x; } private int getCategoryStartIndex(SubCategory c) { List<SubCategory> leafs = root.getLeafs(); SubCategory l1 = c.getLeafs().get(0); return leafs.indexOf(l1); } private int getCategoryEndIndex(SubCategory c) { List<SubCategory> leafs = root.getLeafs(); List<SubCategory> tmp = c.getLeafs(); SubCategory lx = tmp.get(tmp.size()-1); return leafs.indexOf(lx); } @Override public AxisSpace reserveSpace(Graphics2D g2, Plot plot, Rectangle2D plotArea, RectangleEdge edge, AxisSpace space) { Rectangle2D labelEnclosure = getLabelEnclosure(g2, edge); space.add(labelEnclosure.getHeight() + this.getCategoryLabelPositionOffset(), edge); AxisState axisState = drawCategoryLabels(g2, plotArea, space.shrink(plotArea, null), edge, new AxisState(), null, false); space.add(axisState.getCursor(), RectangleEdge.BOTTOM); return space; } // rearrange dataset so that the order of leafs // matches the order of columns public CategoryDataset getFixedDataset() { DefaultCategoryDataset fixed = new DefaultCategoryDataset(); List<SubCategory> leafs = root.getLeafs(); for(SubCategory leaf : leafs) { String path = leaf.getPath(); int col = dataset.getColumnIndex(path); if(col == -1) { throw new RuntimeException(String.format("column '%s' not found in dataset", path)); } for(int row=0;row<dataset.getRowCount();row++) { Comparable<?> rc = dataset.getRowKey(row); if(rc != null) { fixed.addValue(dataset.getValue(row, col), rc, path); } } } return fixed; } }