Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions gqlalchemy/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ def __init__(
self,
host: str,
port: int,
scheme: str,
username: str,
password: str,
encrypted: bool,
client_name: Optional[str] = None,
):
self.host = host
self.port = port
self.scheme = scheme
self.username = username
self.password = password
self.encrypted = encrypted
Expand All @@ -65,14 +67,21 @@ def __init__(
self,
host: str,
port: int,
scheme: str,
username: str,
password: str,
encrypted: bool,
client_name: Optional[str] = None,
lazy: bool = False,
):
super().__init__(
host=host, port=port, username=username, password=password, encrypted=encrypted, client_name=client_name
host=host,
port=port,
scheme=scheme,
username=username,
password=password,
encrypted=encrypted,
client_name=client_name,
)
self.lazy = lazy
self._connection = self._create_connection()
Expand Down Expand Up @@ -106,6 +115,7 @@ def _create_connection(self) -> Connection:
connection = mgclient.connect(
host=self.host,
port=self.port,
scheme=self.scheme,
username=self.username,
password=self.password,
sslmode=sslmode,
Expand Down Expand Up @@ -154,14 +164,21 @@ def __init__(
self,
host: str,
port: int,
scheme: str,
username: str,
password: str,
encrypted: bool,
client_name: Optional[str] = None,
lazy: bool = True,
):
super().__init__(
host=host, port=port, username=username, password=password, encrypted=encrypted, client_name=client_name
host=host,
port=port,
scheme=scheme,
username=username,
password=password,
encrypted=encrypted,
client_name=client_name,
)
self.lazy = lazy
self._connection = self._create_connection()
Expand All @@ -184,6 +201,7 @@ def is_active(self) -> bool:
return self._connection is not None

def _create_connection(self):
# TODO: antepusic fit in scheme
return GraphDatabase.driver(
f"bolt://{self.host}:{self.port}", auth=(self.username, self.password), encrypted=self.encrypted
)
Expand Down
1 change: 1 addition & 0 deletions gqlalchemy/memgraph_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

MG_HOST = os.getenv("MG_HOST", "127.0.0.1")
MG_PORT = int(os.getenv("MG_PORT", "7687"))
MG_SCHEME = os.getenv("MG_SCHEME", "")
MG_USERNAME = os.getenv("MG_USERNAME", "")
MG_PASSWORD = os.getenv("MG_PASSWORD", "")
MG_ENCRYPTED = os.getenv("MG_ENCRYPT", "false").lower() == "true"
Expand Down
7 changes: 4 additions & 3 deletions gqlalchemy/transformations/export/graph_transporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
graph_type: str,
host: str = mg_consts.MG_HOST,
port: int = mg_consts.MG_PORT,
scheme: str = mg_consts.MG_SCHEME,
username: str = mg_consts.MG_USERNAME,
password: str = mg_consts.MG_PASSWORD,
encrypted: bool = mg_consts.MG_ENCRYPTED,
Expand All @@ -58,12 +59,12 @@ def __init__(
self.graph_type = graph_type.upper()
if self.graph_type == GraphType.DGL.name:
raise_if_not_imported(dependency=DGLTranslator, dependency_name="dgl")
self.translator = DGLTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = DGLTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.PYG.name:
raise_if_not_imported(dependency=PyGTranslator, dependency_name="torch_geometric")
self.translator = PyGTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = PyGTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.NX.name:
self.translator = NxTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = NxTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
else:
raise ValueError("Unknown export option. Currently supported are DGL, PyG and NetworkX.")

Expand Down
7 changes: 4 additions & 3 deletions gqlalchemy/transformations/importing/graph_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
graph_type: str,
host: str = mg_consts.MG_HOST,
port: int = mg_consts.MG_PORT,
scheme: str = mg_consts.MG_SCHEME,
username: str = mg_consts.MG_USERNAME,
password: str = mg_consts.MG_PASSWORD,
encrypted: bool = mg_consts.MG_ENCRYPTED,
Expand All @@ -57,12 +58,12 @@ def __init__(
self.graph_type = graph_type.upper()
if self.graph_type == GraphType.DGL.name:
raise_if_not_imported(dependency=DGLTranslator, dependency_name="dgl")
self.translator = DGLTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = DGLTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.PYG.name:
raise_if_not_imported(dependency=PyGTranslator, dependency_name="torch_geometric")
self.translator = PyGTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = PyGTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.NX.name:
self.translator = NxTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = NxTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
else:
raise ValueError("Unknown import option. Currently supported options are: DGL, PyG and NetworkX.")

Expand Down
4 changes: 3 additions & 1 deletion gqlalchemy/transformations/translators/dgl_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from gqlalchemy.memgraph_constants import (
MG_HOST,
MG_PORT,
MG_SCHEME,
MG_USERNAME,
MG_PASSWORD,
MG_ENCRYPTED,
Expand All @@ -45,13 +46,14 @@ def __init__(
self,
host: str = MG_HOST,
port: int = MG_PORT,
scheme: str = MG_SCHEME,
username: str = MG_USERNAME,
password: str = MG_PASSWORD,
encrypted: bool = MG_ENCRYPTED,
client_name: str = MG_CLIENT_NAME,
lazy: bool = MG_LAZY,
) -> None:
super().__init__(host, port, username, password, encrypted, client_name, lazy)
super().__init__(host, port, scheme, username, password, encrypted, client_name, lazy)

def to_cypher_queries(self, graph: Union[dgl.DGLGraph, dgl.DGLHeteroGraph]):
"""Produce cypher queries for data saved as part of the DGL graph. The method handles both homogeneous and heterogeneous graph. If the graph is homogeneous, a default DGL's labels will be used.
Expand Down
15 changes: 10 additions & 5 deletions gqlalchemy/transformations/translators/nx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from gqlalchemy.memgraph_constants import (
MG_HOST,
MG_PORT,
MG_SCHEME,
MG_USERNAME,
MG_PASSWORD,
MG_ENCRYPTED,
Expand Down Expand Up @@ -152,14 +153,15 @@ def __init__(
self,
host: str = MG_HOST,
port: int = MG_PORT,
scheme: str = MG_SCHEME,
username: str = MG_USERNAME,
password: str = MG_PASSWORD,
encrypted: bool = MG_ENCRYPTED,
client_name: str = MG_CLIENT_NAME,
lazy: bool = MG_LAZY,
) -> None:
self.__all__ = ("nx_to_cypher", "nx_graph_to_memgraph_parallel")
super().__init__(host, port, username, password, encrypted, client_name, lazy)
super().__init__(host, port, scheme, username, password, encrypted, client_name, lazy)

def to_cypher_queries(self, graph: nx.Graph, config: NetworkXCypherConfig = None) -> Iterator[str]:
"""Generates a Cypher query for creating a graph."""
Expand Down Expand Up @@ -187,14 +189,15 @@ def nx_graph_to_memgraph_parallel(
self._check_for_index_hint(
self.host,
self.port,
self.scheme,
self.username,
self.password,
self.encrypted,
)

for query_group in query_groups:
self._start_parallel_execution(
query_group, self.host, self.port, self.username, self.password, self.encrypted
query_group, self.host, self.port, self.scheme, self.username, self.password, self.encrypted
)

def _start_parallel_execution(self, queries_gen: Iterator[str]) -> None:
Expand All @@ -212,6 +215,7 @@ def _start_parallel_execution(self, queries_gen: Iterator[str]) -> None:
process_queries,
self.host,
self.port,
self.scheme,
self.username,
self.password,
self.encrypted,
Expand All @@ -224,10 +228,10 @@ def _start_parallel_execution(self, queries_gen: Iterator[str]) -> None:
p.join()

def _insert_queries(
self, queries: List[str], host: str, port: int, username: str, password: str, encrypted: bool
self, queries: List[str], host: str, port: int, scheme: str, username: str, password: str, encrypted: bool
) -> None:
"""Used by multiprocess insertion of nx into memgraph, works on a chunk of queries."""
memgraph = Memgraph(host, port, username, password, encrypted)
memgraph = Memgraph(host, port, scheme, username, password, encrypted)
while len(queries) > 0:
try:
query = queries.pop()
Expand All @@ -241,12 +245,13 @@ def _check_for_index_hint(
self,
host: str = "127.0.0.1",
port: int = 7687,
scheme: str = "",
username: str = "",
password: str = "",
encrypted: bool = False,
):
"""Check if the there are indexes, if not show warnings."""
memgraph = Memgraph(host, port, username, password, encrypted)
memgraph = Memgraph(host, port, scheme, username, password, encrypted)
indexes = memgraph.get_indexes()
if len(indexes) == 0:
logging.getLogger(__file__).warning(
Expand Down
4 changes: 3 additions & 1 deletion gqlalchemy/transformations/translators/pyg_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gqlalchemy.memgraph_constants import (
MG_HOST,
MG_PORT,
MG_SCHEME,
MG_USERNAME,
MG_PASSWORD,
MG_ENCRYPTED,
Expand All @@ -33,13 +34,14 @@ def __init__(
self,
host: str = MG_HOST,
port: int = MG_PORT,
scheme: str = MG_SCHEME,
username: str = MG_USERNAME,
password: str = MG_PASSWORD,
encrypted: bool = MG_ENCRYPTED,
client_name: str = MG_CLIENT_NAME,
lazy: bool = MG_LAZY,
) -> None:
super().__init__(host, port, username, password, encrypted, client_name, lazy)
super().__init__(host, port, scheme, username, password, encrypted, client_name, lazy)

@classmethod
def get_node_properties(cls, graph, node_label: str, node_id: int):
Expand Down
10 changes: 5 additions & 5 deletions gqlalchemy/transformations/translators/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from gqlalchemy.memgraph_constants import (
MG_HOST,
MG_PORT,
MG_SCHEME,
MG_USERNAME,
MG_PASSWORD,
MG_ENCRYPTED,
Expand All @@ -40,25 +41,24 @@

class Translator(ABC):
# Lambda function to concat list of labels
merge_labels: Callable[[Set[str]], str] = (
lambda labels, default_node_label: LABELS_CONCAT.join([label for label in sorted(labels)])
if len(labels)
else default_node_label
merge_labels: Callable[[Set[str]], str] = lambda labels, default_node_label: (
LABELS_CONCAT.join([label for label in sorted(labels)]) if len(labels) else default_node_label
)

@abstractmethod
def __init__(
self,
host: str = MG_HOST,
port: int = MG_PORT,
scheme: str = MG_SCHEME,
username: str = MG_USERNAME,
password: str = MG_PASSWORD,
encrypted: bool = MG_ENCRYPTED,
client_name: str = MG_CLIENT_NAME,
lazy: bool = MG_LAZY,
) -> None:
super().__init__()
self.connection = Memgraph(host, port, username, password, encrypted, client_name, lazy)
self.connection = Memgraph(host, port, scheme, username, password, encrypted, client_name, lazy)

@abstractmethod
def to_cypher_queries(graph):
Expand Down
2 changes: 2 additions & 0 deletions gqlalchemy/vendors/database_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def __init__(
self,
host: str,
port: int,
scheme: str,
username: str,
password: str,
encrypted: bool,
client_name: str,
):
self._host = host
self._port = port
self._scheme = scheme
self._username = username
self._password = password
self._encrypted = encrypted
Expand Down
11 changes: 10 additions & 1 deletion gqlalchemy/vendors/memgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,21 @@ def __init__(
self,
host: str = mg_consts.MG_HOST,
port: int = mg_consts.MG_PORT,
scheme: str = mg_consts.MG_SCHEME,
username: str = mg_consts.MG_USERNAME,
password: str = mg_consts.MG_PASSWORD,
encrypted: bool = mg_consts.MG_ENCRYPTED,
client_name: str = mg_consts.MG_CLIENT_NAME,
lazy: bool = mg_consts.MG_LAZY,
):
super().__init__(
host=host, port=port, username=username, password=password, encrypted=encrypted, client_name=client_name
host=host,
port=port,
scheme=scheme,
username=username,
password=password,
encrypted=encrypted,
client_name=client_name,
)
self._lazy = lazy
self._on_disk_db = None
Expand Down Expand Up @@ -124,6 +131,7 @@ def new_connection(self) -> Connection:
args = dict(
host=self._host,
port=self._port,
scheme=self._scheme,
username=self._username,
password=self._password,
encrypted=self._encrypted,
Expand Down Expand Up @@ -197,6 +205,7 @@ def _new_connection(self) -> Connection:
args = dict(
host=self._host,
port=self._port,
scheme=self._scheme,
username=self._username,
password=self._password,
encrypted=self._encrypted,
Expand Down
Loading