diff --git a/manm_cs/graph/graph_builder.py b/manm_cs/graph/graph_builder.py index dd0e544..a20646d 100644 --- a/manm_cs/graph/graph_builder.py +++ b/manm_cs/graph/graph_builder.py @@ -186,7 +186,7 @@ def generate_continuous_variable(self, parents, node_idx) -> 'ContinuousVariable def generate_dag(self, seed: int) -> 'DiGraph': # Generate graph using networkx package - G = nx.gnp_random_graph(n=self.num_nodes, p=self.edge_density, seed=seed, directed=True) + G = nx.gnp_random_graph(n=self.num_nodes, p=self.edge_density, seed=seed, directed=False) # Convert generated graph to DAG dag = nx.DiGraph() dag.add_nodes_from(G)