Skip to content

Commit

Permalink
Added Steven's SimpleGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
Gerben authored and Gerben committed Mar 20, 2015
1 parent d03033b commit 15c2003
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public static void main(String[] args) {
ds.create();

System.out.println(ds.getRDFData().getInstances());

List<EvaluationFunction> evalFuncs = new ArrayList<EvaluationFunction>();
evalFuncs.add(new Accuracy());
evalFuncs.add(new F1());
Expand Down Expand Up @@ -137,7 +137,7 @@ public static void main(String[] args) {
RDFData data = ds.getRDFData();
List<Double> target = ds.getTarget();

computeGraphStatistics(tripleStore, ds, inference, depths);
//computeGraphStatistics(tripleStore, ds, inference, depths);


/*
Expand All @@ -160,7 +160,7 @@ public static void main(String[] args) {
}
//*/

///* The baseline experiment, BoW (or BoL if you prefer)
/* The baseline experiment, BoW (or BoL if you prefer)
for (boolean inf : inference) {
resTable.newRow("Baseline BoL: " + inf);
for (int d : depths) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.data2semantics.mustard.learners.SparseVector;
import org.data2semantics.mustard.simplegraph.SimpleGraph;
import org.data2semantics.mustard.weisfeilerlehman.StringLabel;
import org.data2semantics.mustard.weisfeilerlehman.WLUtils;
import org.data2semantics.mustard.weisfeilerlehman.WeisfeilerLehmanApproxDTGraphIterator;
import org.data2semantics.mustard.weisfeilerlehman.WeisfeilerLehmanApproxIterator;
import org.data2semantics.mustard.weisfeilerlehman.WeisfeilerLehmanDTGraphIterator;
import org.data2semantics.mustard.weisfeilerlehman.WeisfeilerLehmanIterator;
import org.data2semantics.mustard.weisfeilerlehman.WeisfeilerLehmanSimpleGraphIterator;
import org.nodes.DTGraph;
import org.nodes.DTLink;
import org.nodes.DTNode;
Expand All @@ -37,14 +39,14 @@
*/
public class DTGraphWLSubTreeGeoProbKernel implements GraphKernel<SingleDTGraph>, FeatureVectorKernel<SingleDTGraph>, ComputationTimeTracker, FeatureInspector {

private Map<DTNode<StringLabel,StringLabel>, Map<DTNode<StringLabel,StringLabel>, Integer>> instanceVertexIndexMap;
private Map<DTNode<StringLabel,StringLabel>, Map<DTLink<StringLabel,StringLabel>, Integer>> instanceEdgeIndexMap;
private Map<SimpleGraph<StringLabel,StringLabel>.Node, Map<SimpleGraph<StringLabel,StringLabel>.Node, Integer>> instanceVertexIndexMap;
private Map<SimpleGraph<StringLabel,StringLabel>.Node, Map<SimpleGraph<StringLabel,StringLabel>.Link, Integer>> instanceEdgeIndexMap;

private Map<DTNode<StringLabel,StringLabel>, Map<DTNode<StringLabel,StringLabel>, Boolean>> instanceVertexIgnoreMap;
private Map<DTNode<StringLabel,StringLabel>, Map<DTLink<StringLabel,StringLabel>, Boolean>> instanceEdgeIgnoreMap;
//private Map<DTNode<StringLabel,StringLabel>, Map<DTNode<StringLabel,StringLabel>, Boolean>> instanceVertexIgnoreMap;
//private Map<DTNode<StringLabel,StringLabel>, Map<DTLink<StringLabel,StringLabel>, Boolean>> instanceEdgeIgnoreMap;

private DTGraph<StringLabel,StringLabel> rdfGraph;
private List<DTNode<StringLabel,StringLabel>> instanceVertices;
private SimpleGraph<StringLabel,StringLabel> rdfGraph;
private List<SimpleGraph<StringLabel,StringLabel>.Node> instanceVertices;

private int depth;
private int iterations;
Expand All @@ -54,7 +56,7 @@ public class DTGraphWLSubTreeGeoProbKernel implements GraphKernel<SingleDTGraph>

private long compTime;
private Map<String,String> dict;

private double p;
private double mean;
private Map<Integer, Double> probs;
Expand Down Expand Up @@ -89,51 +91,55 @@ public SparseVector[] computeFeatureVectors(SingleDTGraph data) {
for (int i = 0; i < featureVectors.length; i++) {
featureVectors[i] = new SparseVector();
}

probs = new HashMap<Integer, Double>();
p = 1.0 / (mean + 1.0); // mean is (1-p)/p


System.out.println("Depth threshold info");

for (int i = 0; i < 20; i++) {
System.out.print(i + ": " + getCumProb(i) + ", ");
}
System.out.println("");

long tic2 = System.currentTimeMillis();

init(data.getGraph(), data.getInstances());

System.out.println("DTGraph init (ms): " + (System.currentTimeMillis() - tic2));

WeisfeilerLehmanIterator<DTGraph<StringLabel,StringLabel>> wl = new WeisfeilerLehmanDTGraphIterator(reverse, true);

List<DTGraph<StringLabel,StringLabel>> gList = new ArrayList<DTGraph<StringLabel,StringLabel>>();
WeisfeilerLehmanIterator<SimpleGraph<StringLabel,StringLabel>> wl = new WeisfeilerLehmanSimpleGraphIterator(reverse, true);

List<SimpleGraph<StringLabel,StringLabel>> gList = new ArrayList<SimpleGraph<StringLabel,StringLabel>>();
gList.add(rdfGraph);

long tic = System.currentTimeMillis();

wl.wlInitialize(gList);
compTime = System.currentTimeMillis() - tic;

double weight = 1.0;
if (iterationWeighting) {
weight = Math.sqrt(1.0 / (iterations + 1));
}


computeFVs(rdfGraph, instanceVertices, weight, featureVectors, wl.getLabelDict().size()-1, 0);

for (int i = 0; i < iterations; i++) {
if (iterationWeighting) {
weight = Math.sqrt((2.0 + i) / (iterations + 1));
}

tic = System.currentTimeMillis();
wl.wlIterate(gList);
compTime += System.currentTimeMillis() - tic;

computeFVs(rdfGraph, instanceVertices, weight, featureVectors, wl.getLabelDict().size()-1, i + 1);
}

compTime = System.currentTimeMillis() - tic;
//compTime = System.currentTimeMillis() - tic;

System.out.println("DTGraph WL (ms): " + compTime);

// Set the reverse label dict, to reverse engineer the features
Expand Down Expand Up @@ -162,33 +168,33 @@ public double[][] compute(SingleDTGraph data) {


private void init(DTGraph<String,String> graph, List<DTNode<String,String>> instances) {
DTNode<StringLabel,StringLabel> startV;
SimpleGraph<StringLabel,StringLabel>.Node startV;
List<DTNode<String,String>> frontV, newFrontV;
Map<DTNode<StringLabel,StringLabel>, Integer> vertexIndexMap;
Map<DTLink<StringLabel,StringLabel>, Integer> edgeIndexMap;
Map<DTNode<StringLabel,StringLabel>, Boolean> vertexIgnoreMap;
Map<DTLink<StringLabel,StringLabel>, Boolean> edgeIgnoreMap;
Map<DTNode<String,String>, DTNode<StringLabel,StringLabel>> vOldNewMap = new HashMap<DTNode<String,String>,DTNode<StringLabel,StringLabel>>();
Map<DTLink<String,String>, DTLink<StringLabel,StringLabel>> eOldNewMap = new HashMap<DTLink<String,String>,DTLink<StringLabel,StringLabel>>();

rdfGraph = new LightDTGraph<StringLabel,StringLabel>();
instanceVertices = new ArrayList<DTNode<StringLabel,StringLabel>>();
instanceVertexIndexMap = new HashMap<DTNode<StringLabel,StringLabel>, Map<DTNode<StringLabel,StringLabel>, Integer>>();
instanceEdgeIndexMap = new HashMap<DTNode<StringLabel,StringLabel>, Map<DTLink<StringLabel,StringLabel>, Integer>>();
instanceVertexIgnoreMap = new HashMap<DTNode<StringLabel,StringLabel>, Map<DTNode<StringLabel,StringLabel>, Boolean>>();
instanceEdgeIgnoreMap = new HashMap<DTNode<StringLabel,StringLabel>, Map<DTLink<StringLabel,StringLabel>, Boolean>>();
Map<SimpleGraph<StringLabel,StringLabel>.Node, Integer> vertexIndexMap;
Map<SimpleGraph<StringLabel,StringLabel>.Link, Integer> edgeIndexMap;
//Map<DTNode<StringLabel,StringLabel>, Boolean> vertexIgnoreMap;
//Map<DTLink<StringLabel,StringLabel>, Boolean> edgeIgnoreMap;
Map<DTNode<String,String>, SimpleGraph<StringLabel,StringLabel>.Node> vOldNewMap = new HashMap<DTNode<String,String>,SimpleGraph<StringLabel,StringLabel>.Node>();
Map<DTLink<String,String>, SimpleGraph<StringLabel,StringLabel>.Link> eOldNewMap = new HashMap<DTLink<String,String>,SimpleGraph<StringLabel,StringLabel>.Link>();

rdfGraph = new SimpleGraph<StringLabel,StringLabel>();
instanceVertices = new ArrayList<SimpleGraph<StringLabel,StringLabel>.Node>();
instanceVertexIndexMap = new HashMap<SimpleGraph<StringLabel,StringLabel>.Node, Map<SimpleGraph<StringLabel,StringLabel>.Node, Integer>>();
instanceEdgeIndexMap = new HashMap<SimpleGraph<StringLabel,StringLabel>.Node, Map<SimpleGraph<StringLabel,StringLabel>.Link, Integer>>();
//instanceVertexIgnoreMap = new HashMap<DTNode<StringLabel,StringLabel>, Map<DTNode<StringLabel,StringLabel>, Boolean>>();
//instanceEdgeIgnoreMap = new HashMap<DTNode<StringLabel,StringLabel>, Map<DTLink<StringLabel,StringLabel>, Boolean>>();

for (DTNode<String,String> oldStartV : instances) {
vertexIndexMap = new HashMap<DTNode<StringLabel,StringLabel>, Integer>();
edgeIndexMap = new HashMap<DTLink<StringLabel,StringLabel>, Integer>();
vertexIgnoreMap = new HashMap<DTNode<StringLabel,StringLabel>, Boolean>();
edgeIgnoreMap = new HashMap<DTLink<StringLabel,StringLabel>, Boolean>();
vertexIndexMap = new HashMap<SimpleGraph<StringLabel,StringLabel>.Node, Integer>();
edgeIndexMap = new HashMap<SimpleGraph<StringLabel,StringLabel>.Link, Integer>();
//vertexIgnoreMap = new HashMap<DTNode<StringLabel,StringLabel>, Boolean>();
//edgeIgnoreMap = new HashMap<DTLink<StringLabel,StringLabel>, Boolean>();

// Get the start node
if (vOldNewMap.containsKey(oldStartV)) {
startV = vOldNewMap.get(oldStartV);
} else {
startV = rdfGraph.add(new StringLabel());
startV = rdfGraph.new Node(new StringLabel());
vOldNewMap.put(oldStartV, startV);
}
startV.label().clear();
Expand All @@ -198,51 +204,53 @@ private void init(DTGraph<String,String> graph, List<DTNode<String,String>> inst

instanceVertexIndexMap.put(startV, vertexIndexMap);
instanceEdgeIndexMap.put(startV, edgeIndexMap);
instanceVertexIgnoreMap.put(startV, vertexIgnoreMap);
instanceEdgeIgnoreMap.put(startV, edgeIgnoreMap);
//instanceVertexIgnoreMap.put(startV, vertexIgnoreMap);
//instanceEdgeIgnoreMap.put(startV, edgeIgnoreMap);

frontV = new ArrayList<DTNode<String,String>>();
frontV.add(oldStartV);

// Process the start node
vertexIndexMap.put(startV, depth);
vertexIgnoreMap.put(startV, false);
//vertexIgnoreMap.put(startV, false);

for (int j = depth - 1; j >= 0; j--) {
newFrontV = new ArrayList<DTNode<String,String>>();
for (DTNode<String,String> qV : frontV) {
for (DTLink<String,String> edge : qV.linksOut()) {
if (vOldNewMap.containsKey(edge.to())) { // This vertex has been added to rdfGraph
if (!vertexIndexMap.containsKey(vOldNewMap.get(edge.to())) || !reverse) { // we have not seen it for this instance or labels travel to the fringe vertices, in which case we want to have the lowest depth encounter
if (!vertexIndexMap.containsKey(vOldNewMap.get(edge.to())) || !reverse) { // we have not seen it for this instance or labels travel to the fringe vertices, in which case we want to have the lowest depth encounter
if (vOldNewMap.containsKey(edge.to())) { // This vertex has been added to rdfGraph
vertexIndexMap.put(vOldNewMap.get(edge.to()), j);
vertexIgnoreMap.put(vOldNewMap.get(edge.to()), false);
//vertexIgnoreMap.put(vOldNewMap.get(edge.to()), false);
}
//vOldNewMap.get(edge.to()).label().clear();
//vOldNewMap.get(edge.to()).label().append(edge.to().label());
else {
SimpleGraph<StringLabel,StringLabel>.Node newN = rdfGraph.new Node(new StringLabel());
newN.label().clear();
newN.label().append(edge.to().label());
vOldNewMap.put(edge.to(), newN);
vertexIndexMap.put(newN, j);
//vertexIgnoreMap.put(newN, false);
}
vOldNewMap.get(edge.to()).label().clear();
vOldNewMap.get(edge.to()).label().append(edge.to().label()); // However, we should always include it in the graph at depth j
} else {
DTNode<StringLabel,StringLabel> newN = rdfGraph.add(new StringLabel());
newN.label().clear();
newN.label().append(edge.to().label());
vOldNewMap.put(edge.to(), newN);
vertexIndexMap.put(newN, j);
vertexIgnoreMap.put(newN, false);
}

if (eOldNewMap.containsKey(edge)) {
// Process the edge, if we haven't seen it before
if (!edgeIndexMap.containsKey(eOldNewMap.get(edge)) || !reverse) { // see comment for vertices
if (!edgeIndexMap.containsKey(eOldNewMap.get(edge)) || !reverse) { // see comment for vertices
if (eOldNewMap.containsKey(edge)) {
// Process the edge, if we haven't seen it before
edgeIndexMap.put(eOldNewMap.get(edge), j);
edgeIgnoreMap.put(eOldNewMap.get(edge), false);
//edgeIgnoreMap.put(eOldNewMap.get(edge), false);
}
//eOldNewMap.get(edge).tag().clear();
//eOldNewMap.get(edge).tag().append(edge.tag());
else {
SimpleGraph<StringLabel,StringLabel>.Link newE = rdfGraph.new Link(vOldNewMap.get(qV), vOldNewMap.get(edge.to()), new StringLabel());
newE.tag().clear();
newE.tag().append(edge.tag());
eOldNewMap.put(edge, newE);
edgeIndexMap.put(newE, j);
//edgeIgnoreMap.put(newE, false);
}
eOldNewMap.get(edge).tag().clear();
eOldNewMap.get(edge).tag().append(edge.tag());
} else {
DTLink<StringLabel,StringLabel> newE = vOldNewMap.get(qV).connect(vOldNewMap.get(edge.to()), new StringLabel());
newE.tag().clear();
newE.tag().append(edge.tag());
eOldNewMap.put(edge, newE);
edgeIndexMap.put(newE, j);
edgeIgnoreMap.put(newE, false);
}

// Add the vertex to the new front, if we go into a new round
Expand All @@ -267,18 +275,18 @@ private void init(DTGraph<String,String> graph, List<DTNode<String,String>> inst
* @param weight
* @param featureVectors
*/
private void computeFVs(DTGraph<StringLabel,StringLabel> graph, List<DTNode<StringLabel,StringLabel>> instances, double weight, SparseVector[] featureVectors, int lastIndex, int currentIt) {
private void computeFVs(SimpleGraph<StringLabel,StringLabel> graph, List<SimpleGraph<StringLabel,StringLabel>.Node> instances, double weight, SparseVector[] featureVectors, int lastIndex, int currentIt) {
int index, depth;
Map<DTNode<StringLabel,StringLabel>, Integer> vertexIndexMap;
Map<DTLink<StringLabel,StringLabel>, Integer> edgeIndexMap;
Map<SimpleGraph<StringLabel,StringLabel>.Node, Integer> vertexIndexMap;
Map<SimpleGraph<StringLabel,StringLabel>.Link, Integer> edgeIndexMap;

for (int i = 0; i < instances.size(); i++) {
featureVectors[i].setLastIndex(lastIndex);

vertexIndexMap = instanceVertexIndexMap.get(instances.get(i));
edgeIndexMap = instanceEdgeIndexMap.get(instances.get(i));
for (DTNode<StringLabel,StringLabel> vertex : vertexIndexMap.keySet()) {

for (SimpleGraph<StringLabel,StringLabel>.Node vertex : vertexIndexMap.keySet()) {
depth = vertexIndexMap.get(vertex);

if (!vertex.label().isSameAsPrev() && (depth * 2) >= currentIt) {
Expand All @@ -287,7 +295,7 @@ private void computeFVs(DTGraph<StringLabel,StringLabel> graph, List<DTNode<Stri
}
}

for (DTLink<StringLabel,StringLabel> edge : edgeIndexMap.keySet()) {
for (SimpleGraph<StringLabel,StringLabel>.Link edge : edgeIndexMap.keySet()) {
depth = edgeIndexMap.get(edge);

if (!edge.tag().isSameAsPrev() && ((depth * 2)+1) >= currentIt) {
Expand All @@ -312,7 +320,7 @@ public List<String> getFeatureDescriptions(List<Integer> indices) {
}
}


/**
* from wikipedia on geometric dist.
*
Expand All @@ -330,6 +338,6 @@ private double getProb(int depth) {
private double getCumProb(int depth) {
return 1-Math.pow(1-p, depth+1);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class Node {
private List<Link> _in_links = new ArrayList<Link>();
private List<Link> _out_links = new ArrayList<Link>();

private Node(V label) {
public Node(V label) {
_nodes.add(this);
this._label = label;
}
Expand All @@ -42,7 +42,7 @@ public class Link {
private Node _from;
private Node _to;

private Link(Node from, Node to, W tag) {
public Link(Node from, Node to, W tag) {
_from = from;
_to = to;
_tag = tag;
Expand Down
Loading

0 comments on commit 15c2003

Please sign in to comment.