Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions Makefile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Makefile for SMORe-Go

# Build output directory
BINDIR := bin

# Go build flags
GOFLAGS := -ldflags="-s -w"

.PHONY: all clean deepwalk node2vec fastrp transe rotate complex sne metapath2vec han ctdne jodie line bpr hpe textgcn skewopt sasrec gsasrec recdenoiser cpr tpr test

all: deepwalk node2vec fastrp transe rotate complex sne metapath2vec han ctdne jodie line bpr hpe textgcn skewopt sasrec gsasrec recdenoiser cpr tpr

# Create bin directory
$(BINDIR):
mkdir -p $(BINDIR)

# Build DeepWalk
deepwalk: $(BINDIR)
@echo "Building deepwalk..."
go build $(GOFLAGS) -o $(BINDIR)/deepwalk ./cmd/deepwalk

# Build Node2Vec
node2vec: $(BINDIR)
@echo "Building node2vec..."
go build $(GOFLAGS) -o $(BINDIR)/node2vec ./cmd/node2vec

# Build FastRP
fastrp: $(BINDIR)
@echo "Building fastrp..."
go build $(GOFLAGS) -o $(BINDIR)/fastrp ./cmd/fastrp

# Build TransE
transe: $(BINDIR)
@echo "Building transe..."
go build $(GOFLAGS) -o $(BINDIR)/transe ./cmd/transe

# Build RotatE
rotate: $(BINDIR)
@echo "Building rotate..."
go build $(GOFLAGS) -o $(BINDIR)/rotate ./cmd/rotate

# Build ComplEx
complex: $(BINDIR)
@echo "Building complex..."
go build $(GOFLAGS) -o $(BINDIR)/complex ./cmd/complex

# Build SNE
sne: $(BINDIR)
@echo "Building sne..."
go build $(GOFLAGS) -o $(BINDIR)/sne ./cmd/sne

# Build Metapath2Vec
metapath2vec: $(BINDIR)
@echo "Building metapath2vec..."
go build $(GOFLAGS) -o $(BINDIR)/metapath2vec ./cmd/metapath2vec

# Build HAN
han: $(BINDIR)
@echo "Building han..."
go build $(GOFLAGS) -o $(BINDIR)/han ./cmd/han

# Build CTDNE
ctdne: $(BINDIR)
@echo "Building ctdne..."
go build $(GOFLAGS) -o $(BINDIR)/ctdne ./cmd/ctdne

# Build JODIE
jodie: $(BINDIR)
@echo "Building jodie..."
go build $(GOFLAGS) -o $(BINDIR)/jodie ./cmd/jodie

# Build LINE
line: $(BINDIR)
@echo "Building line..."
go build $(GOFLAGS) -o $(BINDIR)/line ./cmd/line

# Build BPR
bpr: $(BINDIR)
@echo "Building bpr..."
go build $(GOFLAGS) -o $(BINDIR)/bpr ./cmd/bpr

# Build HPE
hpe: $(BINDIR)
@echo "Building hpe..."
go build $(GOFLAGS) -o $(BINDIR)/hpe ./cmd/hpe

# Build TextGCN
textgcn: $(BINDIR)
@echo "Building textgcn..."
go build $(GOFLAGS) -o $(BINDIR)/textgcn ./cmd/textgcn

# Build Skew-Opt
skewopt: $(BINDIR)
@echo "Building Skew-Opt..."
go build $(GOFLAGS) -o $(BINDIR)/skewopt ./cmd/skewopt

# Build SASRec
sasrec: $(BINDIR)
@echo "Building sasrec..."
go build $(GOFLAGS) -o $(BINDIR)/sasrec ./cmd/sasrec

# Build gSASRec (RecSys 2023 Best Paper)
gsasrec: $(BINDIR)
@echo "Building gsasrec (RecSys 2023 Best Paper)..."
go build $(GOFLAGS) -o $(BINDIR)/gsasrec ./cmd/gsasrec

# Build Rec-Denoiser (RecSys 2022 Best Paper)
recdenoiser: $(BINDIR)
@echo "Building recdenoiser (RecSys 2022 Best Paper)..."
go build $(GOFLAGS) -o $(BINDIR)/recdenoiser ./cmd/recdenoiser

# Build CPR (Cross-Domain Preference Ranking)
cpr: $(BINDIR)
@echo "Building CPR..."
go build $(GOFLAGS) -o $(BINDIR)/cpr ./cmd/cpr

# Build TPR (Text-aware Preference Ranking)
tpr: $(BINDIR)
@echo "Building TPR..."
go build $(GOFLAGS) -o $(BINDIR)/tpr ./cmd/tpr

# Run tests
test:
go test -v ./...

# Clean build artifacts
clean:
rm -rf $(BINDIR)
go clean

# Install binaries to GOPATH/bin
install:
go install ./cmd/deepwalk
go install ./cmd/node2vec
go install ./cmd/fastrp
go install ./cmd/transe
go install ./cmd/rotate
go install ./cmd/complex
go install ./cmd/sne
go install ./cmd/metapath2vec
go install ./cmd/han
go install ./cmd/ctdne
go install ./cmd/jodie
go install ./cmd/line
go install ./cmd/bpr
go install ./cmd/hpe
go install ./cmd/textgcn
go install ./cmd/skewopt
go install ./cmd/sasrec
go install ./cmd/gsasrec
go install ./cmd/recdenoiser
go install ./cmd/cpr
go install ./cmd/tpr

# Format code
fmt:
go fmt ./...

# Run linter
lint:
golangci-lint run ./...
56 changes: 56 additions & 0 deletions cmd/bpr/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package main

import (
"flag"
"fmt"
"os"

"github.com/cnclabs/smore/internal/models/bpr"
)

func main() {
// Define command-line flags
train := flag.String("train", "", "Train the Network data")
save := flag.String("save", "", "Save the representation data")
dimensions := flag.Int("dimensions", 64, "Dimension of vertex representation")
undirected := flag.Bool("undirected", false, "Whether the edge is undirected")
sampleTimes := flag.Int("sample_times", 10, "Number of training iterations")
threads := flag.Int("threads", 1, "Number of training threads")
alpha := flag.Float64("alpha", 0.025, "Init learning rate")
lambda := flag.Float64("lambda", 0.001, "Regularization parameter")

flag.Usage = func() {
fmt.Println("[SMORe-Go]")
fmt.Println("\tGolang implementation of SMORe - BPR")
fmt.Println()
fmt.Println("Options Description:")
flag.PrintDefaults()
fmt.Println()
fmt.Println("Usage:")
fmt.Println("./bpr -train net.txt -save rep.txt -dimensions 64 -sample_times 10 -alpha 0.025 -lambda 0.001 -threads 1")
}

flag.Parse()

// Check required parameters
if *train == "" || *save == "" {
flag.Usage()
os.Exit(1)
}

// Create and train model
b := bpr.New()

if err := b.LoadEdgeList(*train, *undirected); err != nil {
fmt.Printf("Error loading edge list: %v\n", err)
os.Exit(1)
}

b.Init(*dimensions)
b.Train(*sampleTimes, *alpha, *lambda, *threads)

if err := b.SaveWeights(*save); err != nil {
fmt.Printf("Error saving weights: %v\n", err)
os.Exit(1)
}
}
166 changes: 166 additions & 0 deletions cmd/complex/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package main

import (
"flag"
"fmt"
"os"

complex_embeddings "github.com/cnclabs/smore/internal/models/complex"
)

func main() {
// Command-line flags
train := flag.String("train", "", "Path to training triples file (format: head relation tail)")
output := flag.String("output", "", "Path to output embeddings file")
dim := flag.Int("dim", 100, "Embedding dimension")
epochs := flag.Int("epochs", 100, "Number of training epochs")
batchSize := flag.Int("batch-size", 128, "Batch size for training")
negativeSamples := flag.Int("negative-samples", 10, "Number of negative samples per positive triple")
learningRate := flag.Float64("lr", 0.01, "Learning rate")
margin := flag.Float64("margin", 1.0, "Margin for ranking loss")
workers := flag.Int("workers", 4, "Number of parallel workers")
evalSize := flag.Int("eval-size", 1000, "Number of triples to use for evaluation")

flag.Usage = func() {
fmt.Fprintf(os.Stderr, "ComplEx - Complex Embeddings for Knowledge Graphs\n\n")
fmt.Fprintf(os.Stderr, "Usage:\n")
fmt.Fprintf(os.Stderr, " %s [options]\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Description:\n")
fmt.Fprintf(os.Stderr, " ComplEx learns complex-valued embeddings for knowledge graph entities and relations.\n")
fmt.Fprintf(os.Stderr, " Unlike real-valued models, ComplEx can capture:\n")
fmt.Fprintf(os.Stderr, " - Symmetric relations (e.g., \"is_similar_to\")\n")
fmt.Fprintf(os.Stderr, " - Antisymmetric relations (e.g., \"is_parent_of\")\n")
fmt.Fprintf(os.Stderr, " - Inverse relations\n")
fmt.Fprintf(os.Stderr, " - Composition patterns\n\n")
fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nInput Format:\n")
fmt.Fprintf(os.Stderr, " Triple format: head_entity relation tail_entity\n")
fmt.Fprintf(os.Stderr, " Example:\n")
fmt.Fprintf(os.Stderr, " Paris capitalOf France\n")
fmt.Fprintf(os.Stderr, " France locatedIn Europe\n")
fmt.Fprintf(os.Stderr, " Berlin capitalOf Germany\n\n")
fmt.Fprintf(os.Stderr, "Output Format:\n")
fmt.Fprintf(os.Stderr, " Complex-valued embeddings in format:\n")
fmt.Fprintf(os.Stderr, " E entity_name real1 imag1i real2 imag2i ...\n")
fmt.Fprintf(os.Stderr, " R relation_name real1 imag1i real2 imag2i ...\n\n")
fmt.Fprintf(os.Stderr, "Examples:\n")
fmt.Fprintf(os.Stderr, " # Train on FB15k knowledge graph\n")
fmt.Fprintf(os.Stderr, " %s -train fb15k.txt -output complex.emb \\\n", os.Args[0])
fmt.Fprintf(os.Stderr, " -dim 100 -epochs 100 -lr 0.01 -negative-samples 10\n\n")
fmt.Fprintf(os.Stderr, " # Train on WordNet with higher dimensions\n")
fmt.Fprintf(os.Stderr, " %s -train wn18.txt -output complex.emb \\\n", os.Args[0])
fmt.Fprintf(os.Stderr, " -dim 200 -epochs 200 -lr 0.005 -margin 2.0\n\n")
fmt.Fprintf(os.Stderr, " # Train with more workers for speed\n")
fmt.Fprintf(os.Stderr, " %s -train data.txt -output complex.emb \\\n", os.Args[0])
fmt.Fprintf(os.Stderr, " -dim 150 -epochs 150 -workers 8\n\n")
fmt.Fprintf(os.Stderr, "Key Features:\n")
fmt.Fprintf(os.Stderr, " ✓ Complex-valued embeddings (uses Go's native complex128)\n")
fmt.Fprintf(os.Stderr, " ✓ Handles symmetric and antisymmetric relations\n")
fmt.Fprintf(os.Stderr, " ✓ More expressive than real-valued models (TransE)\n")
fmt.Fprintf(os.Stderr, " ✓ Parallel training with goroutines\n")
fmt.Fprintf(os.Stderr, " ✓ Margin-based ranking loss\n\n")
fmt.Fprintf(os.Stderr, "Scoring Function:\n")
fmt.Fprintf(os.Stderr, " score(h, r, t) = Re(<h, r, conj(t)>)\n")
fmt.Fprintf(os.Stderr, " where <> is the trilinear dot product:\n")
fmt.Fprintf(os.Stderr, " Σ h_i * r_i * conj(t_i)\n\n")
fmt.Fprintf(os.Stderr, "References:\n")
fmt.Fprintf(os.Stderr, " Trouillon et al. \"Complex Embeddings for Simple Link Prediction\", ICML 2016\n")
fmt.Fprintf(os.Stderr, " https://arxiv.org/abs/1606.06357\n\n")
}

flag.Parse()

// Validate required arguments
if *train == "" {
fmt.Fprintf(os.Stderr, "Error: -train is required\n\n")
flag.Usage()
os.Exit(1)
}

if *output == "" {
fmt.Fprintf(os.Stderr, "Error: -output is required\n\n")
flag.Usage()
os.Exit(1)
}

// Validate parameters
if *dim <= 0 {
fmt.Fprintf(os.Stderr, "Error: -dim must be positive\n")
os.Exit(1)
}

if *epochs <= 0 {
fmt.Fprintf(os.Stderr, "Error: -epochs must be positive\n")
os.Exit(1)
}

if *batchSize <= 0 {
fmt.Fprintf(os.Stderr, "Error: -batch-size must be positive\n")
os.Exit(1)
}

if *negativeSamples <= 0 {
fmt.Fprintf(os.Stderr, "Error: -negative-samples must be positive\n")
os.Exit(1)
}

if *learningRate <= 0 {
fmt.Fprintf(os.Stderr, "Error: -lr must be positive\n")
os.Exit(1)
}

if *margin <= 0 {
fmt.Fprintf(os.Stderr, "Error: -margin must be positive\n")
os.Exit(1)
}

if *workers <= 0 {
fmt.Fprintf(os.Stderr, "Error: -workers must be positive\n")
os.Exit(1)
}

// Print configuration
fmt.Println("===================================================")
fmt.Println("ComplEx - Complex Embeddings for Knowledge Graphs")
fmt.Println("===================================================")
fmt.Println()

// Create ComplEx model
model := complex_embeddings.New()

// Load knowledge graph
fmt.Println("Loading Knowledge Graph:")
fmt.Printf("\tInput: %s\n", *train)
fmt.Println()

if err := model.LoadTriples(*train); err != nil {
fmt.Fprintf(os.Stderr, "Error loading triples: %v\n", err)
os.Exit(1)
}

// Initialize model
fmt.Println()
model.Init(*dim, *learningRate, *margin)

// Train model
fmt.Println()
model.Train(*epochs, *batchSize, *negativeSamples, *workers)

// Evaluate link prediction
if *evalSize > 0 {
model.EvaluateLinkPrediction(*evalSize)
}

// Save embeddings
fmt.Println()
if err := model.SaveEmbeddings(*output); err != nil {
fmt.Fprintf(os.Stderr, "Error saving embeddings: %v\n", err)
os.Exit(1)
}

fmt.Println()
fmt.Println("===================================================")
fmt.Println("Training Complete!")
fmt.Println("===================================================")
}
Loading