Skip to content

Commit 81a3c23

Browse files
authored
Merge pull request #61 from raminqaf/fix/gensim-size-param
Check gensim version and set size parameter name respectively
2 parents f6623b7 + d61531d commit 81a3c23

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

node2vec/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from . import edges
22
from .node2vec import Node2Vec
33

4-
__version__ = '0.4.1'
4+
__version__ = '0.4.2'

node2vec/node2vec.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
import random
21
import os
2+
import random
33
from collections import defaultdict
44

5-
import numpy as np
6-
import networkx as nx
75
import gensim
6+
import networkx as nx
7+
import numpy as np
8+
import pkg_resources
89
from joblib import Parallel, delayed
910
from tqdm.auto import tqdm
1011

@@ -166,16 +167,20 @@ def _generate_walks(self) -> list:
166167
def fit(self, **skip_gram_params) -> gensim.models.Word2Vec:
167168
"""
168169
Creates the embeddings using gensim's Word2Vec.
169-
:param skip_gram_params: Parameteres for gensim.models.Word2Vec - do not supply 'size' it is taken from the Node2Vec 'dimensions' parameter
170+
:param skip_gram_params: Parameters for gensim.models.Word2Vec - do not supply 'size' / 'vector_size' it is
171+
taken from the Node2Vec 'dimensions' parameter
170172
:type skip_gram_params: dict
171173
:return: A gensim word2vec model
172174
"""
173175

174176
if 'workers' not in skip_gram_params:
175177
skip_gram_params['workers'] = self.workers
176178

177-
if 'size' not in skip_gram_params:
178-
skip_gram_params['size'] = self.dimensions
179+
# Figure out gensim version, naming of output dimensions changed from size to vector_size in v4.0.0
180+
gensim_version = pkg_resources.get_distribution("gensim").version
181+
size = 'size' if gensim_version < '4.0.0' else 'vector_size'
182+
if size not in skip_gram_params:
183+
skip_gram_params[size] = self.dimensions
179184

180185
if 'sg' not in skip_gram_params:
181186
skip_gram_params['sg'] = 1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name='node2vec',
55
packages=['node2vec'],
6-
version='0.4.1',
6+
version='0.4.2',
77
description='Implementation of the node2vec algorithm.',
88
author='Elior Cohen',
99
author_email='[email protected]',

0 commit comments

Comments
 (0)