/**
* Copyright 2015, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.experiment;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import edu.emory.clir.clearnlp.collection.pair.Pair;
import edu.emory.clir.clearnlp.util.CharUtils;
import edu.emory.clir.clearnlp.util.IOUtils;
import edu.emory.clir.clearnlp.util.MathUtils;
/**
* @since 3.0.3
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
public class WordEmbeddingExtract
{
public Map<String,Set<String>> getWordEmbeddingsNorm(InputStream in, int size, int norm) throws Exception
{
List<Pair<String,float[]>> embeddings = readEmbeddings(in, size);
Pair<float[],float[]> maxMin = getMaxMin(embeddings, size);
Map<String,Set<String>> map = new HashMap<>();
int i, j, n, len = embeddings.size();
float[] max = maxMin.o1;
float[] min = maxMin.o2;
Pair<String,float[]> p;
Set<String> set;
float[] d;
for (i=0; i<len; i++)
{
p = embeddings.get(i);
set = new HashSet<>();
map.put(p.o1, set);
d = p.o2;
for (j=0; j<size; j++)
{
n = (int)Math.round((d[j] - min[j]) * norm / (max[j] - min[j]));
set.add(j+":"+n);
}
}
return map;
}
public Map<String,Set<String>> getWordEmbeddingsStdev(InputStream in, int size, int norm) throws Exception
{
List<Pair<String,float[]>> embeddings = readEmbeddings(in, size);
Pair<double[],double[]> meanStdev = getMeanStdev(embeddings, size);
Map<String,Set<String>> map = new HashMap<>();
int i, j, n, len = embeddings.size();
double[] mean = meanStdev.o1;
double[] stdev = meanStdev.o2;
Pair<String,float[]> p;
Set<String> set;
float[] d;
for (i=0; i<len; i++)
{
p = embeddings.get(i);
set = new HashSet<>();
map.put(p.o1, set);
d = p.o2;
for (j=0; j<size; j++)
{
n = (int)Math.round((d[j] - mean[j]) * norm / stdev[j]);
set.add(j+":"+n);
}
}
return map;
}
public List<Pair<String,float[]>> readEmbeddings(InputStream in, int size) throws Exception
{
List<Pair<String,float[]>> embeddings = new ArrayList<>();
BufferedReader reader = IOUtils.createBufferedReader(in);
Pair<String,float[]> p;
while ((p = readEmbedding(reader, size)) != null)
embeddings.add(p);
return embeddings;
}
private Pair<String,float[]> readEmbedding(BufferedReader reader, int size) throws Exception
{
float[] vector = new float[size];
int[] buffer = new int[128];
String s, word = null;
int i, b, ch;
for (i=-1; i<size; i++)
{
b = 0;
while (true)
{
ch = reader.read();
if (ch == -1) return null;
if (CharUtils.isWhiteSpace((char)ch)) break;
else buffer[b++] = ch;
}
s = new String(buffer, 0, b).trim();
if (i < 0) word = s;
else vector[i] = (float)Double.parseDouble(s);
}
return new Pair<String,float[]>(word, vector);
}
private Pair<float[],float[]> getMaxMin(List<Pair<String,float[]>> embeddings, int size)
{
float[] max = Arrays.copyOf(embeddings.get(0).o2, size);
float[] min = Arrays.copyOf(max, size);
int i, j, len = embeddings.size();
float[] d;
for (i=1; i<len; i++)
{
d = embeddings.get(i).o2;
for (j=0; j<size; j++)
{
max[j] = Math.max(max[j], d[j]);
min[j] = Math.min(min[j], d[j]);
}
}
return new Pair<>(max, min);
}
private Pair<double[],double[]> getMeanStdev(List<Pair<String,float[]>> embeddings, int size)
{
int i, j, len = embeddings.size(), den = len * size;
double[] mean = new double[size];
float[] d;
for (i=0; i<len; i++)
{
d = embeddings.get(i).o2;
for (j=0; j<size; j++)
mean[j] += d[j];
}
for (j=0; j<size; j++)
mean[j] /= den;
double[] stdev = new double[size];
for (i=0; i<len; i++)
{
d = embeddings.get(i).o2;
for (j=0; j<size; j++)
stdev[j] += MathUtils.sq(d[j] - mean[j]);
}
for (j=0; j<size; j++)
stdev[j] = Math.sqrt(stdev[j] / den);
return new Pair<>(mean, stdev);
}
static public void main(String[] args)
{
String filename = args[0];
int size = Integer.parseInt(args[1]);
int norm = 5;
try
{
WordEmbeddingExtract emb = new WordEmbeddingExtract();
Map<String,Set<String>> tree = emb.getWordEmbeddingsStdev(new FileInputStream(filename), size, norm);
ObjectOutputStream out = IOUtils.createObjectXZBufferedOutputStream(filename+".xz"+"."+norm);
out.writeObject(tree);
out.close();
}
catch (Exception e) {e.printStackTrace();}
}
}