package com.sissi.io.read.sax; import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicLong; import javax.xml.parsers.ParserConfigurationException; import javax.xml.parsers.SAXParser; import javax.xml.parsers.SAXParserFactory; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.xml.sax.SAXException; import com.sissi.commons.Trace; import com.sissi.io.read.Counter; import com.sissi.io.read.Mapping; import com.sissi.io.read.Reader; import com.sissi.resource.ResourceCounter; /** * @author Kim.shen 2013-10-16 */ public class SAXReader implements Reader { private final NoneCounter nothing = new NoneCounter(); private final Log log = LogFactory.getLog(this.getClass()); private final String resource = ParseRunnable.class.getSimpleName(); private final ResourceCounter resourceCounter; private final SAXParserFactory factory; private final Executor executor; private final Mapping mapping; private final long limitXml; private final byte limitQueue; /** * @param limitXml XML长度限制 * @param limitQueue Future长度限制 * @param executor * @param resourceCounter * @throws Exception */ public SAXReader(long limitXml, byte limitQueue, Executor executor, ResourceCounter resourceCounter) throws Exception { this(limitXml, limitQueue, new XMLMapping(), executor, resourceCounter); } public SAXReader(long limitXml, byte limitQueue, Mapping mapping, Executor executor, ResourceCounter resourceCounter) throws Exception { super(); this.mapping = mapping; this.executor = executor; this.limitXml = limitXml; this.limitQueue = limitQueue; this.resourceCounter = resourceCounter; this.factory = SAXParserFactory.newInstance(); this.factory.setNamespaceAware(true); this.factory.setFeature("http://apache.org/xml/features/continue-after-fatal-error", true); } public Future<Object> future(InputStream stream) throws IOException { try { SAXFuture future = new SAXFuture(this.limitQueue); return this.limitXml == 0 ? this.noneffectiveCount(stream, future) : this.effectiveCount(stream, future); } catch (Exception e) { this.log.error(e); Trace.trace(log, e); throw new RuntimeException(e); } } /** * XML DDOS校验 * * @param stream * @param future * @return * @throws ParserConfigurationException * @throws SAXException */ private SAXFuture effectiveCount(InputStream stream, SAXFuture future) throws ParserConfigurationException, SAXException { SAXSecurityInputStream input = new SAXSecurityInputStream(stream, this.limitXml); this.executor.execute(new ParseRunnable(input, this.factory.newSAXParser(), new SAXHandler(this.mapping, future, input))); return future; } private SAXFuture noneffectiveCount(InputStream stream, SAXFuture future) throws ParserConfigurationException, SAXException { this.executor.execute(new ParseRunnable(stream, this.factory.newSAXParser(), new SAXHandler(this.mapping, future, this.nothing))); return future; } private class ParseRunnable implements Runnable { private final SAXParser parser; private final SAXHandler handler; private final InputStream stream; public ParseRunnable(InputStream stream, SAXParser parser, SAXHandler handler) { super(); this.parser = parser; this.stream = stream; this.handler = handler; } public void run() { try { SAXReader.this.resourceCounter.increment(SAXReader.this.resource); this.parser.parse(this.stream, this.handler); } catch (Exception e) { SAXReader.this.log.debug(e.toString()); Trace.trace(SAXReader.this.log, e); } finally { SAXReader.this.resourceCounter.decrement(SAXReader.this.resource); } } } private class NoneCounter implements Counter { @Override public Counter recount() { return this; } } private class SAXSecurityInputStream extends FilterInputStream implements Counter { private final AtomicLong counter = new AtomicLong(); private final long limit; private SAXSecurityInputStream(InputStream proxy, long limit) { super(proxy); this.limit = limit; } private SAXSecurityInputStream incr(long incr) throws IOException { if (this.counter.addAndGet(incr) > this.limit) { IOException exception = new IOException("Leak: " + this.counter.get() + " / " + this.limit); SAXReader.this.log.error(exception.getMessage()); throw exception; } return this; } @Override public int read() throws IOException { this.incr(1); return super.read(); } public int read(byte b[]) throws IOException { this.incr(b.length); return super.read(b, 0, b.length); } public int read(byte b[], int off, int len) throws IOException { this.incr(len); return super.read(b, off, len); } public SAXSecurityInputStream recount() { this.counter.set(0); return this; } } }