Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions tests/test_hexary_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest

from trie import HexaryTrie
from trie.utils.db import (
KeyAccessLogger,
)


@pytest.mark.parametrize(
'items1, items2, expected_key_diffs',
(
([], [], {}),
(
[(b'a', b'A')],
[],
{b'a': (b'A', None)},
),
(
[],
[(b'a', b'A')],
{b'a': (None, b'A')},
),
(
[(b'a', b'A')],
[(b'b', b'B')],
{b'a': (b'A', None), b'b': (None, b'B')},
),
(
[(b'aa', b'A')],
[(b'aa', b'A'), (b'ab', b'B')],
{b'ab': (None, b'B')},
),
(
[(b'\x0a', b'A'), (b'\x1a', b'B')],
[(b'\x0a', b'A'), (b'\x1a', b'B'), (b'\x2a', b'C')],
{b'\x2a': (None, b'C')},
),
(
[(b'\x0a', b'A'), (b'\x0b', b'B')],
[(b'\x0ac', b'C')],
{
b'\x0a': (b'A', None),
b'\x0b': (b'B', None),
b'\x0ac': (None, b'C'),
},
),
(
[(b'\x0a', b'A' * 33), (b'\x0b', b'B' * 33)],
[(b'\x0ac', b'C' * 33)],
{
b'\x0a': (b'A' * 33, None),
b'\x0b': (b'B' * 33, None),
b'\x0ac': (None, b'C' * 33),
},
),
),
)
def test_hexary_diff(items1, items2, expected_key_diffs):
db1 = {}
trie1 = HexaryTrie(db1)
db2 = {}
trie2 = HexaryTrie(db2)
for key, val in items1:
trie1[key] = val
for key, val in items2:
trie2[key] = val

key_diffs = HexaryTrie.diff(trie1, trie2)
assert key_diffs == expected_key_diffs

logger1db = KeyAccessLogger(db1)
logger1 = HexaryTrie(logger1db, trie1.root_hash)
logger2db = KeyAccessLogger(db2)
logger2 = HexaryTrie(logger2db, trie2.root_hash)
for key, (val1, val2) in key_diffs.items():
# trigger reads to the relavant (diffed) keys
logger1[key]
logger2[key]

proof_only_db1 = {key: db1[key] for key in logger1db.read_keys}
proof_only_db2 = {key: db2[key] for key in logger2db.read_keys}

# make sure you don't get KeyErrors when creating the same diff from only the proof:
proof_only_trie1 = HexaryTrie(proof_only_db1, trie1.root_hash)
proof_only_trie2 = HexaryTrie(proof_only_db2, trie2.root_hash)
proof_only_diff = HexaryTrie.diff(proof_only_trie1, proof_only_trie2)
# also make sure you get the same diff
assert proof_only_diff == key_diffs

'''
(
[(b'aa', b'A' * 33)],
[(b'aa', b'A' * 33), (b'ab', b'B' * 33)],
{b'ab': (None, b'B' * 33)},
),
'''
17 changes: 3 additions & 14 deletions tests/test_hexary_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
MissingTrieNode,
ValidationError,
)
from trie.utils.db import (
KeyAccessLogger,
)
from trie.utils.nodes import (
decode_node,
)
Expand Down Expand Up @@ -153,20 +156,6 @@ def test_trie_using_fixtures(name, updates, expected, deleted, final_root):
trie.get_proof(invalid_proof_key)


class KeyAccessLogger(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.read_keys = set()

def __getitem__(self, key):
result = super().__getitem__(key)
self.read_keys.add(key)
return result

def unread_keys(self):
return self.keys() - self.read_keys


def test_hexary_trie_saves_each_root():
changes = ((b'ab', b'b'*32), (b'ac', b'c'*32), (b'ac', None), (b'ad', b'd'*32))
expected = ((b'ab', b'b'*32), (b'ad', b'd'*32))
Expand Down
159 changes: 159 additions & 0 deletions trie/hexary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
to_list,
to_tuple,
)
from eth_utils.toolz import (
merge,
)

from trie.constants import (
BLANK_NODE,
Expand Down Expand Up @@ -551,6 +554,162 @@ def at_root(self, at_root_hash):
snapshot = type(self)(self.db, at_root_hash, prune=False)
yield snapshot

#
# Differ
#
@classmethod
def diff(cls, trie1, trie2):
diffed = cls._diff_nodes(trie1, trie2, trie1.root_node, trie2.root_node)
return {
nibbles_to_bytes(nibbles): valdiff
for nibbles, valdiff
in diffed.items()
}

@classmethod
def _diff_nodes(cls, trie1, trie2, node1, node2):
if node1 == node2:
return {}
node1type = get_node_type(node1)
node2type = get_node_type(node2)

if node1type == NODE_TYPE_BRANCH:
if node1[-1]:
raise NotImplementedError
elif node2type == NODE_TYPE_BLANK:
return merge([
{
(branch_nibble, ) + keytail: val
for keytail, val
in cls._diff_nodes(trie1, trie2, trie1._get_subnode(subnode), node2).items()
}
for branch_nibble, subnode in enumerate(node1[:-1])
])
elif node2type == NODE_TYPE_LEAF:
key2 = extract_key(node2)
leaf_nibble = key2[0]
branch_subnode = trie1._get_subnode(node1[leaf_nibble])
if len(key2) == 1:
if not is_leaf_node(branch_subnode):
raise NotImplementedError
elif branch_subnode[1] != node2[-1]:
diff_leaf = {key2: (branch_subnode[1] or None, node2[-1])}
else:
diff_leaf = {}
else:
trimmed_leaf = [
compute_leaf_key(key2[1:]),
node2[1],
]
diff_leaf = {
(leaf_nibble, ) + keytail: val
for keytail, val
in cls._diff_nodes(trie1, trie2, branch_subnode, trimmed_leaf).items()
}

diff_blanks = [
{
(branch_nibble, ) + keytail: val
for keytail, val
in cls._diff_nodes(trie1, trie2, trie1._get_subnode(subnode), BLANK_NODE).items()
}
for branch_nibble, subnode in enumerate(node1[:-1])
if branch_nibble != leaf_nibble
]
return merge(diff_leaf, *diff_blanks)
elif node2type == NODE_TYPE_BRANCH:
if node2[-1]:
raise NotImplementedError
return merge([
{
(branch_nibble, ) + keytail: val
for keytail, val
in cls._diff_nodes(
trie1,
trie2,
trie1._get_subnode(subnode),
trie2._get_subnode(node2[branch_nibble])
).items()
}
for branch_nibble, subnode in enumerate(node1[:-1])
if subnode != node2[branch_nibble]
])
else:
raise NotImplementedError

elif node1type == NODE_TYPE_EXTENSION:
key1 = extract_key(node1)
subnode = trie1._get_subnode(node1[1])
if node2type == NODE_TYPE_BLANK:
return {
key1 + keytail: val
for keytail, val
in cls._diff_nodes(trie1, trie2, subnode, node2)
}
elif node2type == NODE_TYPE_LEAF:
key2 = extract_key(node2)
common_prefix, key1_remainder, key2_remainder = consume_common_prefix(
key1,
key2,
)
if not common_prefix:
return merge(
cls._diff_nodes(trie1, trie2, node1, BLANK_NODE),
cls._diff_nodes(trie1, trie2, BLANK_NODE, node2),
)
elif not key1_remainder and key2_remainder:
trimmed_leaf = [
compute_leaf_key(key2_remainder),
node2[1],
]
return {
common_prefix + keytail: val
for keytail, val
in cls._diff_nodes(trie1, trie2, subnode, trimmed_leaf).items()
}
else:
raise NotImplementedError

else:
raise NotImplementedError

elif is_leaf_node(node1):
key1 = extract_key(node1)
if is_blank_node(node2):
return {key1: (node1[1], None)}
elif is_leaf_node(node2):
diff = {}
key2 = extract_key(node2)
if key2 == key1:
diff[key1] = (node1[1], node2[1])
else:
diff[key1] = (node1[1], None)
diff[key2] = (None, node2[1])
return diff
elif node2type in (NODE_TYPE_EXTENSION, NODE_TYPE_BRANCH):
return cls._flip_diff(trie1, trie2, node1, node2)
else:
raise NotImplementedError
elif is_blank_node(node1):
if is_leaf_node(node2):
return cls._flip_diff(trie1, trie2, node1, node2)
else:
raise NotImplementedError
else:
raise NotImplementedError
raise NotImplementedError

def _get_subnode(self, subnode_ref):
if len(subnode_ref) == 32:
return self.get_node(subnode_ref)
else:
return subnode_ref

@classmethod
def _flip_diff(cls, t1, t2, h1, h2):
return {k: (diff1, diff2) for k, (diff2, diff1) in cls._diff_nodes(t2, t1, h2, h1).items()}



@to_tuple
def tuplify(node):
Expand Down
14 changes: 14 additions & 0 deletions trie/utils/db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
import contextlib


class KeyAccessLogger(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.read_keys = set()

def __getitem__(self, key):
result = super().__getitem__(key)
self.read_keys.add(key)
return result

def unread_keys(self):
return self.keys() - self.read_keys


class ScratchDB:
"""
A wrapper of basic DB objects with uncommitted DB changes stored in local cache,
Expand Down
2 changes: 1 addition & 1 deletion trie/utils/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_node_type(node):
elif len(node) == 17:
return NODE_TYPE_BRANCH
else:
raise InvalidNode("Unable to determine node type")
raise InvalidNode("Unable to determine node type: %r" % node)


def is_blank_node(node):
Expand Down