/** * Copyright (C) 2012 cogroo <cogroo@cogroo.org> * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.cogroo.tools.featurizer; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.util.Collections; import java.util.HashSet; import java.util.Map; import java.util.Set; import opennlp.tools.ml.model.AbstractModel; import opennlp.tools.util.BaseToolFactory; import opennlp.tools.util.InvalidFormatException; import opennlp.tools.util.SequenceValidator; import opennlp.tools.util.ext.ExtensionLoader; import opennlp.tools.util.model.ArtifactSerializer; import opennlp.tools.util.model.UncloseableInputStream; import org.cogroo.dictionary.FeatureDictionary; import org.cogroo.tools.chunker2.TokenTag; public abstract class FeaturizerFactory extends BaseToolFactory { // private static final String POISONED_TAGS_ENTRY_NAME = "poisonedtags.serialized_set"; private static final String CG_FLAGS_PROPERTY = "cgFlags"; protected FeatureDictionary featureDictionary; private Set<String> poisonedDictionaryTags = null; private String cgFlags; /** * Creates a {@link FeaturizerFactory} that provides the default * implementation of the resources. */ public FeaturizerFactory() { } /** * Creates a {@link FeaturizerFactory}. Use this constructor to * programmatically create a factory. * */ public FeaturizerFactory(FeatureDictionary featureDictionary, String cgFlags) { this.init(featureDictionary, cgFlags); } protected void init(FeatureDictionary featureDictionary, String cgFlags) { this.featureDictionary = featureDictionary; this.cgFlags = cgFlags; } @Override @SuppressWarnings("rawtypes") public Map<String, ArtifactSerializer> createArtifactSerializersMap() { Map<String, ArtifactSerializer> serializers = super .createArtifactSerializersMap(); SetSerializer.register(serializers); return serializers; } @Override public Map<String, String> createManifestEntries() { Map<String, String> manifestEntries = super.createManifestEntries(); // EOS characters are optional if (getCGFlags() != null) manifestEntries.put(CG_FLAGS_PROPERTY, getCGFlags()); return manifestEntries; } @Override public Map<String, Object> createArtifactMap() { Map<String, Object> artifactMap = super.createArtifactMap(); // add a empty set that will be populated latter // artifactMap.put(POISONED_TAGS_ENTRY_NAME, new HashSet<String>()); return artifactMap; } public String getCGFlags() { if (this.cgFlags == null) { if (artifactProvider != null) { String prop = this.artifactProvider .getManifestProperty(CG_FLAGS_PROPERTY); if (prop != null) { this.cgFlags = prop; } } if (this.cgFlags == null) { this.cgFlags = "wshnc"; } } return this.cgFlags; } public FeatureDictionary getFeatureDictionary() { if (this.featureDictionary == null) this.featureDictionary = loadFeatureDictionary(); return this.featureDictionary; } protected abstract FeatureDictionary loadFeatureDictionary(); public Set<String> getDictionaryPoisonedTags() { // if (this.poisonedDictionaryTags == null && artifactProvider != null) // this.poisonedDictionaryTags = artifactProvider // .getArtifact(POISONED_TAGS_ENTRY_NAME); return this.poisonedDictionaryTags; } public FeaturizerContextGenerator getFeaturizerContextGenerator() { return new DefaultFeaturizerContextGenerator(getCGFlags()); } public SequenceValidator<TokenTag> getSequenceValidator() { return new DefaultFeaturizerSequenceValidator(getFeatureDictionary(), this.getDictionaryPoisonedTags()); } // call this method to find the poisoned tags. Call only during training // because the poisoned tags are persisted... protected void validateFeatureDictionary() { FeatureDictionary dict = getFeatureDictionary(); if (dict != null) { if (dict instanceof Iterable<?>) { FeatureDictionary posDict = (FeatureDictionary) dict; Set<String> dictTags = new HashSet<String>(); Set<String> poisoned = new HashSet<String>(); for (WordTag wt : (Iterable<WordTag>) posDict) { dictTags.add(wt.getPostag()); } Set<String> modelTags = new HashSet<String>(); AbstractModel posModel = this.artifactProvider .getArtifact(FeaturizerModel.FEATURIZER_MODEL_ENTRY_NAME); for (int i = 0; i < posModel.getNumOutcomes(); i++) { modelTags.add(posModel.getOutcome(i)); } for (String d : dictTags) { if (!modelTags.contains(d)) { poisoned.add(d); } } this.poisonedDictionaryTags = Collections.unmodifiableSet(poisoned); // if (poisonedDictionaryTags.size() > 0) { // System.err // .println("WARNING: Feature dictioinary contains tags which are unkown by the model! " // + this.poisonedDictionaryTags.toString()); // } } } } @Override public void validateArtifactMap() throws InvalidFormatException { // Ensure that the tag dictionary is compatible with the model // Object poisonedTags = this.artifactProvider // .getArtifact(POISONED_TAGS_ENTRY_NAME); // if (poisonedTags != null && !(poisonedTags instanceof Set<?>)) { // throw new InvalidFormatException("Invalid serialized poisoned tags!"); // } validateFeatureDictionary(); } public static FeaturizerFactory create(String subclassName, FeatureDictionary posDictionary, String cgFlags) throws InvalidFormatException { if (subclassName == null) { // will create the default factory return new DefaultFeaturizerFactory(posDictionary, cgFlags); } FeaturizerFactory theFactory = ExtensionLoader.instantiateExtension(FeaturizerFactory.class, subclassName); theFactory.init(posDictionary, cgFlags); return theFactory; } } class SetSerializer implements ArtifactSerializer<Set<String>> { @SuppressWarnings("unchecked") public Set<String> create(InputStream in) throws IOException, InvalidFormatException { ObjectInputStream oin = null; Set<String> set = null; oin = new ObjectInputStream(new UncloseableInputStream(in)); try { set = (Set<String>) oin.readObject(); } catch (ClassNotFoundException e) { System.err.println("could not restore serialied object"); e.printStackTrace(); } return Collections.unmodifiableSet(set); } public void serialize(Set<String> artifact, OutputStream out) throws IOException { ObjectOutputStream objOut = null; objOut = new ObjectOutputStream(out); objOut.writeObject(artifact); } static void register( @SuppressWarnings("rawtypes") Map<String, ArtifactSerializer> factories) { factories.put("serialized_set", new SetSerializer()); } }