package com.webpieces.hpack.impl;
import java.util.List;
import org.webpieces.data.api.DataWrapper;
import org.webpieces.data.api.DataWrapperGenerator;
import org.webpieces.data.api.DataWrapperGeneratorFactory;
import com.twitter.hpack.Decoder;
import com.twitter.hpack.Encoder;
import com.webpieces.hpack.api.HpackParser;
import com.webpieces.hpack.api.MarshalState;
import com.webpieces.hpack.api.UnmarshalState;
import com.webpieces.hpack.api.dto.Http2Headers;
import com.webpieces.hpack.api.dto.Http2Push;
import com.webpieces.http2parser.api.ConnectionException;
import com.webpieces.http2parser.api.Http2Memento;
import com.webpieces.http2parser.api.Http2Parser;
import com.webpieces.http2parser.api.ParseFailReason;
import com.webpieces.http2parser.api.dto.ContinuationFrame;
import com.webpieces.http2parser.api.dto.HeadersFrame;
import com.webpieces.http2parser.api.dto.PushPromiseFrame;
import com.webpieces.http2parser.api.dto.UnknownFrame;
import com.webpieces.http2parser.api.dto.lib.HasHeaderFragment;
import com.webpieces.http2parser.api.dto.lib.Http2Frame;
import com.webpieces.http2parser.api.dto.lib.Http2Header;
import com.webpieces.http2parser.api.dto.lib.Http2Msg;
import com.webpieces.http2parser.api.dto.lib.Http2Setting;
public class HpackParserImpl implements HpackParser {
//private static final Logger log = LoggerFactory.getLogger(HpackParserImpl.class);
private static final DataWrapperGenerator dataGen = DataWrapperGeneratorFactory.createDataWrapperGenerator();
private HeaderEncoding encoding = new HeaderEncoding();
private HeaderDecoding decoding = new HeaderDecoding();
private Http2Parser parser;
private boolean ignoreUnkownFrames;
public HpackParserImpl(Http2Parser parser, boolean ignoreUnkownFrames) {
this.parser = parser;
this.ignoreUnkownFrames = ignoreUnkownFrames;
}
@Override
public UnmarshalState prepareToUnmarshal(int maxHeaderSize, int maxHeaderTableSize, long localMaxFrameSize) {
Decoder decoder = new Decoder(maxHeaderSize, maxHeaderTableSize);
Http2Memento result = parser.prepareToParse(localMaxFrameSize);
return new UnmarshalStateImpl(result, decoding, decoder);
}
@Override
public UnmarshalState unmarshal(UnmarshalState memento, DataWrapper newData) {
UnmarshalStateImpl state = (UnmarshalStateImpl) memento;
state.clearParsedFrames(); //reset any parsed frames
Http2Memento result = parser.parse(state.getLowLevelState(), newData);
List<Http2Frame> parsedFrames = result.getParsedFrames();
for(Http2Frame frame: parsedFrames) {
processFrame(state, frame);
}
return state;
}
private void processFrame(UnmarshalStateImpl state, Http2Frame frame) {
List<HasHeaderFragment> headerFragList = state.getHeadersToCombine();
if(frame instanceof HasHeaderFragment) {
HasHeaderFragment headerFrame = (HasHeaderFragment) frame;
headerFragList.add(headerFrame);
validateHeader(state, headerFrame);
if(headerFrame.isEndHeaders())
combineAndSendHeadersToClient(state);
return;
} else if(headerFragList.size() > 0) {
throw new ConnectionException(ParseFailReason.HEADERS_MIXED_WITH_FRAMES, frame.getStreamId(),
"Parser in the middle of accepting headers(spec "
+ "doesn't allow frames between header fragments). frame="+frame+" list="+headerFragList);
}
if(frame instanceof UnknownFrame && ignoreUnkownFrames) {
//do nothing
} else if(frame instanceof Http2Msg) {
state.getParsedFrames().add((Http2Msg) frame);
} else {
throw new IllegalStateException("bug forgot support for frame="+frame);
}
}
private void validateHeader(UnmarshalStateImpl state, HasHeaderFragment lowLevelFrame) {
List<HasHeaderFragment> list = state.getHeadersToCombine();
HasHeaderFragment first = list.get(0);
int streamId = first.getStreamId();
if(first instanceof PushPromiseFrame) {
PushPromiseFrame f = (PushPromiseFrame) first;
streamId = f.getPromisedStreamId();
}
if(list.size() == 1) {
if(!(first instanceof HeadersFrame) && !(first instanceof PushPromiseFrame))
throw new ConnectionException(ParseFailReason.HEADERS_MIXED_WITH_FRAMES, lowLevelFrame.getStreamId(),
"First has header frame must be HeadersFrame or PushPromiseFrame first frame="+first);
} else if(streamId != lowLevelFrame.getStreamId()) {
throw new ConnectionException(ParseFailReason.HEADERS_MIXED_WITH_FRAMES, lowLevelFrame.getStreamId(),
"Headers/continuations from two different streams per spec cannot be"
+ " interleaved. frames="+list);
} else if(!(lowLevelFrame instanceof ContinuationFrame)) {
throw new ConnectionException(ParseFailReason.HEADERS_MIXED_WITH_FRAMES, lowLevelFrame.getStreamId(),
"Must be continuation frame and wasn't. frames="+list);
}
}
private void combineAndSendHeadersToClient(UnmarshalStateImpl state) {
List<HasHeaderFragment> hasHeaderFragmentList = state.getHeadersToCombine();
// Now we set the full header list on the first frame and just return that
HasHeaderFragment firstFrame = hasHeaderFragmentList.get(0);
DataWrapper allSerializedHeaders = dataGen.emptyWrapper();
for (HasHeaderFragment iterFrame : hasHeaderFragmentList) {
allSerializedHeaders = dataGen.chainDataWrappers(allSerializedHeaders, iterFrame.getHeaderFragment());
}
List<Http2Header> headers = decoding.decode(state.getDecoder(), allSerializedHeaders, firstFrame.getStreamId());
if(firstFrame instanceof HeadersFrame) {
HeadersFrame f = (HeadersFrame) firstFrame;
Http2Headers fullHeaders = new Http2Headers(headers);
fullHeaders.setStreamId(f.getStreamId());
fullHeaders.setPriorityDetails(f.getPriorityDetails());
fullHeaders.setEndOfStream(f.isEndOfStream());
state.getParsedFrames().add(fullHeaders);
} else if(firstFrame instanceof PushPromiseFrame) {
PushPromiseFrame f = (PushPromiseFrame) firstFrame;
Http2Push fullHeaders = new Http2Push(headers);
fullHeaders.setStreamId(f.getStreamId());
fullHeaders.setPromisedStreamId(f.getPromisedStreamId());
state.getParsedFrames().add(fullHeaders);
}
hasHeaderFragmentList.clear();
}
@Override
public MarshalState prepareToMarshal(int maxHeaderTableSize, long remoteMaxFrameSize) {
Encoder encoder = new Encoder(maxHeaderTableSize);
return new MarshalStateImpl(encoding, encoder, remoteMaxFrameSize);
}
@Override
public DataWrapper marshal(MarshalState memento, Http2Msg msg) {
if(memento == null || msg == null)
throw new IllegalArgumentException("no parameters can be null");
MarshalStateImpl state = (MarshalStateImpl) memento;
if(msg instanceof Http2Headers) {
Http2Headers h = (Http2Headers) msg;
return createHeadersData(state, h);
} else if(msg instanceof Http2Push) {
Http2Push p = (Http2Push) msg;
return createPushPromiseData(state, p);
} else if(msg instanceof Http2Frame) {
return parser.marshal((Http2Frame) msg);
} else
throw new IllegalStateException("bug, missing case for msg="+msg);
}
private DataWrapper createPushPromiseData(MarshalStateImpl state, Http2Push p) {
long maxFrameSize = state.getMaxRemoteFrameSize();
Encoder encoder = state.getEncoder();
List<Http2Frame> headerFrames = encoding.translateToFrames(maxFrameSize, encoder, p);
return translate(headerFrames);
}
private DataWrapper createHeadersData(MarshalStateImpl state, Http2Headers headers) {
long maxFrameSize = state.getMaxRemoteFrameSize();
Encoder encoder = state.getEncoder();
List<Http2Frame> headerFrames = encoding.translateToFrames(maxFrameSize, encoder, headers);
return translate(headerFrames);
}
private DataWrapper translate(List<Http2Frame> headerFrames) {
DataWrapper allData = DataWrapperGeneratorFactory.EMPTY;
for(Http2Frame f : headerFrames) {
DataWrapper frameData = parser.marshal(f);
allData = dataGen.chainDataWrappers(allData, frameData);
}
return allData;
}
@Override
public List<Http2Setting> unmarshalSettingsPayload(String base64SettingsPayload) {
return parser.unmarshalSettingsPayload(base64SettingsPayload);
}
@Override
public String marshalSettingsPayload(List<Http2Setting> settingsPayload) {
return parser.marshalSettingsPayload(settingsPayload);
}
}