diff --git a/CHANGELOG.md b/CHANGELOG.md index 57dcf33..37b86aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ # Changelog +## 0.6.1 + +### Bug fixes + +- Fixed variable shadowing in `polyhedra_from_atom_indices` that caused only the first polyhedron to receive correct vertices. +- Invalid atom indices passed to `polyhedra_from_atom_indices` now raise `ValueError` with a descriptive message instead of a raw `KeyError`. + +### Performance + +- `polyhedra_from_atom_indices` now uses precomputed index-to-atom mappings for O(1) lookups instead of repeated linear scans. + +### Testing + +- Expanded test coverage across `configuration`, `coordination_polyhedron`, `orientation_parameters`, `polyhedra_recipe`, `trajectory`, and `utils` modules. + ## 0.6.0 ### Breaking changes diff --git a/polyhedral_analysis/polyhedra_recipe.py b/polyhedral_analysis/polyhedra_recipe.py index cc14a03..5e5317c 100644 --- a/polyhedral_analysis/polyhedra_recipe.py +++ b/polyhedral_analysis/polyhedra_recipe.py @@ -377,12 +377,23 @@ def polyhedra_from_atom_indices(central_atoms: list[Atom], if len(central_indices) != len(vertex_indices): raise ValueError('central_indices and vertex_indices are different lengths: ' f'{len(central_indices)} vs. {len(vertex_indices)}.') + central_atom_map = {atom.index: atom for atom in central_atoms} + vertex_atom_map = {atom.index: atom for atom in vertex_atoms} polyhedra = [] for ic, iv in zip(central_indices, vertex_indices): - central_atom = next(atom for atom in central_atoms if atom.index == ic) - vertex_atoms = [atom for atom in vertex_atoms if atom.index in iv] + try: + central_atom = central_atom_map[ic] + except KeyError: + raise ValueError( + f'Central atom index {ic} not found in central_atoms.') from None + try: + vertices = [vertex_atom_map[i] for i in iv] + except KeyError: + missing = [i for i in iv if i not in vertex_atom_map] + raise ValueError( + f'Vertex atom indices {missing} not found in vertex_atoms.') from None polyhedra.append(CoordinationPolyhedron(central_atom=central_atom, - vertices=vertex_atoms, + vertices=vertices, label=label)) return polyhedra diff --git a/pyproject.toml b/pyproject.toml index 204ffe0..9b3a166 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "polyhedral-analysis" -version = "0.6.0" +version = "0.6.1" description = "A library for analysis of coordination polyhedra from molecular dynamics trajectories" readme = "README.md" license = { text = "MIT" } diff --git a/tests/test_configuration.py b/tests/test_configuration.py index b06b917..547ce0a 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -1,5 +1,8 @@ import unittest +import tempfile +import os from unittest.mock import Mock, patch +import numpy as np from pymatgen.core.structure import Structure from pymatgen.core.sites import Site from polyhedral_analysis.configuration import Configuration @@ -79,6 +82,47 @@ def test_polyhedra_by_label(self): self.assertEqual(p[0].label, 'foo') self.assertEqual(p[0].central_atom.index, 0) + def test_polyhedra_by_label_with_list(self): + p = self.configuration.polyhedra_by_label(['foo', 'bar']) + self.assertEqual(len(p), 2) + + def test_polyhedra_by_label_raises_type_error_for_invalid_type(self): + with self.assertRaises(TypeError): + self.configuration.polyhedra_by_label(123) + + def test_face_sharing_neighbour_list(self): + self.configuration.polyhedra[0].index = 0 + self.configuration.polyhedra[1].index = 2 + self.configuration.polyhedra[0].face_sharing_neighbour_list = Mock( + return_value=(2,)) + self.configuration.polyhedra[1].face_sharing_neighbour_list = Mock( + return_value=()) + result = self.configuration.face_sharing_neighbour_list(['foo', 'bar']) + self.assertEqual(result, {0: (2,), 2: ()}) + + def test_to_lattice_mc_writes_file(self): + self.configuration.polyhedra[0].index = 0 + self.configuration.polyhedra[1].index = 2 + self.configuration.polyhedra[0].central_atom.coords = np.array( + [1.0, 2.0, 3.0]) + self.configuration.polyhedra[1].central_atom.coords = np.array( + [4.0, 5.0, 6.0]) + neighbour_list = {0: (2,), 2: (0,)} + fd, fname = tempfile.mkstemp(suffix='.txt') + os.close(fd) + try: + self.configuration.to_lattice_mc( + fname, ['foo', 'bar'], neighbour_list) + with open(fname) as f: + content = f.read() + self.assertTrue(content.startswith('2\n\n')) + self.assertIn('site: 0', content) + self.assertIn('site: 2', content) + self.assertIn('label: foo', content) + self.assertIn('label: bar', content) + finally: + os.unlink(fname) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_coordination_polyhedron.py b/tests/test_coordination_polyhedron.py index 2a58816..78874c7 100644 --- a/tests/test_coordination_polyhedron.py +++ b/tests/test_coordination_polyhedron.py @@ -645,5 +645,251 @@ def test_vertex_vector_orientations(self): with self.assertRaises(ValueError): self.coordination_polyhedron.vertex_vector_orientations(reference='invalid') +class TestCoordinationPolyhedronLabel(unittest.TestCase): + + def test_default_label_from_central_atom(self): + mock_central_atom = Mock(spec=Atom) + mock_central_atom.in_polyhedra = [] + mock_central_atom.index = 0 + mock_central_atom.label = 'Ti' + mock_vertices = [Mock(spec=Atom) for _ in range(4)] + for i, v in enumerate(mock_vertices, 1): + v._neighbours = {} + v.index = i + v.__lt__ = mock_atom_lt + v.in_polyhedra = [] + poly = CoordinationPolyhedron( + central_atom=mock_central_atom, vertices=mock_vertices) + self.assertEqual(poly.label, 'Ti') + + def test_explicit_label_overrides_central_atom(self): + mock_central_atom = Mock(spec=Atom) + mock_central_atom.in_polyhedra = [] + mock_central_atom.index = 0 + mock_central_atom.label = 'Ti' + mock_vertices = [Mock(spec=Atom) for _ in range(4)] + for i, v in enumerate(mock_vertices, 1): + v._neighbours = {} + v.index = i + v.__lt__ = mock_atom_lt + v.in_polyhedra = [] + poly = CoordinationPolyhedron( + central_atom=mock_central_atom, vertices=mock_vertices, + label='oct') + self.assertEqual(poly.label, 'oct') + + def test_set_label(self): + mock_central_atom = Mock(spec=Atom) + mock_central_atom.in_polyhedra = [] + mock_central_atom.index = 0 + mock_central_atom.label = 'Ti' + mock_vertices = [Mock(spec=Atom) for _ in range(4)] + for i, v in enumerate(mock_vertices, 1): + v._neighbours = {} + v.index = i + v.__lt__ = mock_atom_lt + v.in_polyhedra = [] + poly = CoordinationPolyhedron( + central_atom=mock_central_atom, vertices=mock_vertices) + poly.set_label('tet') + self.assertEqual(poly.label, 'tet') + + +class TestCoordinationPolyhedronEquality(unittest.TestCase): + + def setUp(self): + self.mock_central_atom_a = Mock(spec=Atom) + self.mock_central_atom_a.in_polyhedra = [] + self.mock_central_atom_a.index = 0 + self.mock_central_atom_a.label = 'A' + self.mock_central_atom_a.__eq__ = mock_atom_eq + self.mock_vertices_a = [Mock(spec=Atom) for _ in range(4)] + for i, v in enumerate(self.mock_vertices_a, 1): + v._neighbours = {} + v.index = i + v.__lt__ = mock_atom_lt + v.__eq__ = mock_atom_eq + v.in_polyhedra = [] + self.poly_a = CoordinationPolyhedron( + central_atom=self.mock_central_atom_a, + vertices=self.mock_vertices_a) + + def test_equal_vertices_true_for_same_indices(self): + mock_central_atom_b = Mock(spec=Atom) + mock_central_atom_b.in_polyhedra = [] + mock_central_atom_b.index = 10 + mock_central_atom_b.label = 'B' + mock_vertices_b = [Mock(spec=Atom) for _ in range(4)] + for i, v in enumerate(mock_vertices_b, 1): + v._neighbours = {} + v.index = i # same indices as poly_a + v.__lt__ = mock_atom_lt + v.in_polyhedra = [] + poly_b = CoordinationPolyhedron( + central_atom=mock_central_atom_b, vertices=mock_vertices_b) + self.assertTrue(self.poly_a.equal_vertices(poly_b)) + + def test_equal_vertices_false_for_different_indices(self): + mock_central_atom_b = Mock(spec=Atom) + mock_central_atom_b.in_polyhedra = [] + mock_central_atom_b.index = 10 + mock_central_atom_b.label = 'B' + mock_vertices_b = [Mock(spec=Atom) for _ in range(4)] + for i, v in enumerate(mock_vertices_b, 11): + v._neighbours = {} + v.index = i # different indices + v.__lt__ = mock_atom_lt + v.in_polyhedra = [] + poly_b = CoordinationPolyhedron( + central_atom=mock_central_atom_b, vertices=mock_vertices_b) + self.assertFalse(self.poly_a.equal_vertices(poly_b)) + + def test_eq_delegates_to_equal_edge_graph(self): + poly_b = copy.deepcopy(self.poly_a) + edge_graph = {1: [2, 3], 2: [1, 3], 3: [1, 2], 4: []} + self.poly_a._edge_graph = edge_graph + poly_b._edge_graph = edge_graph + self.assertEqual(self.poly_a, poly_b) + + def test_eq_false_for_different_edge_graphs(self): + poly_b = copy.deepcopy(self.poly_a) + self.poly_a._edge_graph = {1: [2, 3], 2: [1, 3], 3: [1, 2], 4: []} + poly_b._edge_graph = {1: [2], 2: [1], 3: [4], 4: [3]} + self.assertNotEqual(self.poly_a, poly_b) + + def test_eq_returns_not_implemented_for_non_polyhedron(self): + result = self.poly_a.__eq__('not a polyhedron') + self.assertIs(result, NotImplemented) + + +class TestCoordinationPolyhedronFromSitesOctahedron(unittest.TestCase): + """Tests using a real octahedron built from pymatgen PeriodicSite objects.""" + + def setUp(self): + from pymatgen.core import Lattice + from pymatgen.core.sites import PeriodicSite + lattice = Lattice.cubic(10.0) + central = PeriodicSite('Ti', [0.5, 0.5, 0.5], lattice) + vertices = [ + PeriodicSite('O', [0.6, 0.5, 0.5], lattice), + PeriodicSite('O', [0.4, 0.5, 0.5], lattice), + PeriodicSite('O', [0.5, 0.6, 0.5], lattice), + PeriodicSite('O', [0.5, 0.4, 0.5], lattice), + PeriodicSite('O', [0.5, 0.5, 0.6], lattice), + PeriodicSite('O', [0.5, 0.5, 0.4], lattice), + ] + self.poly = CoordinationPolyhedron.from_sites(central, vertices) + + def test_best_fit_geometry_is_octahedron(self): + result = self.poly.best_fit_geometry + self.assertEqual(result['geometry'], 'Octahedron') + self.assertAlmostEqual(result['symmetry_measure'], 0.0) + + def test_best_fit_geometry_has_expected_keys(self): + result = self.poly.best_fit_geometry + self.assertIn('geometry', result) + self.assertIn('symmetry_measure', result) + + def test_volume_of_regular_octahedron(self): + # Vertices at ±1 from centre: volume = 4/3 * a^3 where a = 1 + self.assertAlmostEqual(self.poly.volume, 4.0 / 3.0, places=5) + + def test_edge_graph_each_vertex_has_four_neighbours(self): + for neighbours in self.poly.edge_graph.values(): + self.assertEqual(len(neighbours), 4) + + def test_edge_graph_opposite_vertices_not_connected(self): + edge_graph = self.poly.edge_graph + vi = self.poly.vertex_indices + # Vertices 0,1 are ±x; 2,3 are ±y; 4,5 are ±z + opposite_pairs = [(vi[0], vi[1]), (vi[2], vi[3]), (vi[4], vi[5])] + for v1, v2 in opposite_pairs: + self.assertNotIn(v2, edge_graph[v1]) + self.assertNotIn(v1, edge_graph[v2]) + + def test_faces_returns_eight_triangles(self): + faces = self.poly.faces() + self.assertEqual(len(faces), 8) + for face in faces: + self.assertEqual(len(face), 3) + + def test_faces_are_sorted_tuples(self): + for face in self.poly.faces(): + self.assertIsInstance(face, tuple) + self.assertEqual(list(face), sorted(face)) + + +class TestMinimumImageVertexCoordinates(unittest.TestCase): + + def test_vertex_wraps_across_periodic_boundary(self): + from pymatgen.core import Lattice + from pymatgen.core.sites import PeriodicSite + lattice = Lattice.cubic(10.0) + central = PeriodicSite('Ti', [0.95, 0.5, 0.5], lattice) + vertices = [ + PeriodicSite('O', [0.05, 0.5, 0.5], lattice), # wraps across x + PeriodicSite('O', [0.95, 0.6, 0.5], lattice), + PeriodicSite('O', [0.95, 0.4, 0.5], lattice), + PeriodicSite('O', [0.95, 0.5, 0.6], lattice), + PeriodicSite('O', [0.95, 0.5, 0.4], lattice), + PeriodicSite('O', [0.85, 0.5, 0.5], lattice), + ] + poly = CoordinationPolyhedron.from_sites(central, vertices) + min_image_coords = poly.minimum_image_vertex_coordinates() + # Vertex at frac [0.05] wraps to image at frac [1.05] → 10.5 Angstrom + # Central atom is at 9.5 Angstrom, so distance is 1.0 Angstrom + np.testing.assert_array_almost_equal( + min_image_coords[0], [10.5, 5.0, 5.0]) + + +class TestIntersectionNoSharedVertices(unittest.TestCase): + + def test_no_shared_vertices_returns_empty_tuple(self): + mock_central_a = Mock(spec=Atom) + mock_central_a.in_polyhedra = [] + mock_central_a.index = 0 + mock_central_a.label = 'Li' + mock_central_b = Mock(spec=Atom) + mock_central_b.in_polyhedra = [] + mock_central_b.index = 1 + mock_central_b.label = 'Li' + verts_a = [Mock(spec=Atom) for _ in range(4)] + for i, v in enumerate(verts_a, 1): + v._neighbours = {} + v.index = i + v.__lt__ = mock_atom_lt + v.in_polyhedra = [] + verts_b = [Mock(spec=Atom) for _ in range(4)] + for i, v in enumerate(verts_b, 11): + v._neighbours = {} + v.index = i + v.__lt__ = mock_atom_lt + v.in_polyhedra = [] + poly_a = CoordinationPolyhedron( + central_atom=mock_central_a, vertices=verts_a) + poly_b = CoordinationPolyhedron( + central_atom=mock_central_b, vertices=verts_b) + self.assertEqual(poly_a.intersection(poly_b), ()) + + +class TestPolyhedronNotOwnNeighbour(unittest.TestCase): + + def test_polyhedron_is_not_its_own_neighbour(self): + mock_central = Mock(spec=Atom) + mock_central.in_polyhedra = [] + mock_central.index = 0 + mock_central.label = 'Li' + mock_vertices = [Mock(spec=Atom) for _ in range(4)] + for i, v in enumerate(mock_vertices, 1): + v._neighbours = {} + v.index = i + v.__lt__ = mock_atom_lt + v.in_polyhedra = [] + poly = CoordinationPolyhedron( + central_atom=mock_central, vertices=mock_vertices) + poly._edge_graph = Mock() + self.assertNotIn(poly, poly.neighbours()) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_orientation_parameters.py b/tests/test_orientation_parameters.py index dfb84bc..f41a9be 100644 --- a/tests/test_orientation_parameters.py +++ b/tests/test_orientation_parameters.py @@ -3,7 +3,7 @@ import numpy as np import math -from polyhedral_analysis.orientation_parameters import cos_theta, oct_rotational_order_parameter +from polyhedral_analysis.orientation_parameters import cos_theta, projection_xyz, oct_rotational_order_parameter class OrientationParametersTestCase( unittest.TestCase ): @@ -22,6 +22,18 @@ def test_cos_theta_three( self ): b = np.array( [ 0.0, 1.0, 1.0 ] ) self.assertTrue( cos_theta( a, b ) - math.sqrt(2)/2.0 < 1e-10 ) +class TestProjectionXyz(unittest.TestCase): + + def test_vector_along_axis_returns_one(self): + self.assertAlmostEqual(projection_xyz(np.array([1.0, 0.0, 0.0])), 1.0) + self.assertAlmostEqual(projection_xyz(np.array([0.0, 1.0, 0.0])), 1.0) + self.assertAlmostEqual(projection_xyz(np.array([0.0, 0.0, 1.0])), 1.0) + + def test_vector_along_diagonal_returns_minus_one_third(self): + self.assertAlmostEqual( + projection_xyz(np.array([1.0, 1.0, 1.0])), -1.0 / 3.0) + + class TestOctRotationalOrderParameter(unittest.TestCase): def test_perfect_aligned_octahedron_returns_one(self): @@ -36,6 +48,21 @@ def test_perfect_aligned_octahedron_returns_one(self): ]) self.assertAlmostEqual(oct_rotational_order_parameter(points), 1.0) + def test_45_degree_rotated_octahedron_returns_one_third(self): + # Rotate around z by 45 degrees: 4 in-plane vertices contribute 0, + # 2 axial vertices contribute 1; sum = 2/6 = 1/3. + s = math.sqrt(2) / 2 + points = np.array([ + [ s, s, 0.0], + [-s, -s, 0.0], + [-s, s, 0.0], + [ s, -s, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, -1.0], + ]) + self.assertAlmostEqual( + oct_rotational_order_parameter(points), 1.0 / 3.0) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_polyhedra_recipe.py b/tests/test_polyhedra_recipe.py index afa46bf..52803b3 100644 --- a/tests/test_polyhedra_recipe.py +++ b/tests/test_polyhedra_recipe.py @@ -5,9 +5,11 @@ polyhedra_from_distance_cutoff, generator_from_atom_argument, polyhedra_from_nearest_neighbours, - polyhedra_from_closest_centre) + polyhedra_from_closest_centre, + polyhedra_from_atom_indices) from polyhedral_analysis.coordination_polyhedron import CoordinationPolyhedron from pymatgen.core import Structure, Lattice +from pymatgen.core.sites import PeriodicSite from polyhedral_analysis.atom import Atom from unittest.mock import Mock, patch @@ -150,7 +152,223 @@ def distance(self, other): for atom in self.mock_central_atoms + self.mock_vertex_atoms: atom.distance = distance.__get__(atom) - + + +class TestGeneratorFromAtomArgument(unittest.TestCase): + + def test_callable_is_returned_directly(self): + fn = lambda structure: [0, 1] + result = generator_from_atom_argument(fn) + self.assertIs(result, fn) + + def test_list_of_ints_returns_those_indices(self): + gen = generator_from_atom_argument([3, 7]) + lattice = Lattice.cubic(10.0) + structure = Structure(lattice, ['Ti'] * 8, + [[i / 8, 0.5, 0.5] for i in range(8)]) + self.assertEqual(list(gen(structure)), [3, 7]) + + def test_raises_type_error_for_invalid_argument(self): + with self.assertRaises(TypeError): + generator_from_atom_argument(42) + + +class TestPolyhedraFromDistanceCutoff(unittest.TestCase): + + def test_includes_vertices_within_cutoff(self): + lattice = Lattice.cubic(10.0) + central = Atom(0, PeriodicSite('Ti', [0.5, 0.5, 0.5], lattice)) + close_coords = [ + [0.6, 0.5, 0.5], [0.4, 0.5, 0.5], + [0.5, 0.6, 0.5], [0.5, 0.4, 0.5], + [0.5, 0.5, 0.6], [0.5, 0.5, 0.4], + ] + far_coords = [[0.8, 0.5, 0.5], [0.2, 0.5, 0.5]] + vertex_atoms = [Atom(i + 1, PeriodicSite('O', c, lattice)) + for i, c in enumerate(close_coords + far_coords)] + # Cutoff of 1.5 Angstrom: close vertices are 1.0 Angstrom away, + # far vertices are 3.0 Angstrom away. + polyhedra = polyhedra_from_distance_cutoff( + central_atoms=[central], vertex_atoms=vertex_atoms, cutoff=1.5) + self.assertEqual(len(polyhedra), 1) + self.assertEqual(polyhedra[0].coordination_number, 6) + self.assertEqual(sorted(polyhedra[0].vertex_indices), [1, 2, 3, 4, 5, 6]) + + +class TestPolyhedraFromNearestNeighbours(unittest.TestCase): + + def test_selects_n_nearest_vertex_atoms(self): + lattice = Lattice.cubic(10.0) + central = Atom(0, PeriodicSite('Ti', [0.5, 0.5, 0.5], lattice)) + close_coords = [ + [0.6, 0.5, 0.5], [0.4, 0.5, 0.5], + [0.5, 0.6, 0.5], [0.5, 0.4, 0.5], + [0.5, 0.5, 0.6], [0.5, 0.5, 0.4], + ] + far_coords = [[0.8, 0.5, 0.5], [0.2, 0.5, 0.5]] + vertex_atoms = [Atom(i + 1, PeriodicSite('O', c, lattice)) + for i, c in enumerate(close_coords + far_coords)] + polyhedra = polyhedra_from_nearest_neighbours( + central_atoms=[central], vertex_atoms=vertex_atoms, nn=6) + self.assertEqual(len(polyhedra), 1) + self.assertEqual(polyhedra[0].coordination_number, 6) + self.assertEqual(sorted(polyhedra[0].vertex_indices), [1, 2, 3, 4, 5, 6]) + + def test_each_central_atom_gets_n_neighbours(self): + lattice = Lattice.cubic(10.0) + c1 = Atom(0, PeriodicSite('Ti', [0.2, 0.5, 0.5], lattice)) + c2 = Atom(1, PeriodicSite('Ti', [0.8, 0.5, 0.5], lattice)) + vertex_coords = [ + [0.15, 0.5, 0.5], [0.25, 0.5, 0.5], + [0.2, 0.55, 0.5], [0.2, 0.45, 0.5], + [0.75, 0.5, 0.5], [0.85, 0.5, 0.5], + [0.8, 0.55, 0.5], [0.8, 0.45, 0.5], + ] + vertex_atoms = [Atom(i + 2, PeriodicSite('O', c, lattice)) + for i, c in enumerate(vertex_coords)] + polyhedra = polyhedra_from_nearest_neighbours( + central_atoms=[c1, c2], vertex_atoms=vertex_atoms, nn=4) + self.assertEqual(len(polyhedra), 2) + for p in polyhedra: + self.assertEqual(p.coordination_number, 4) + + +class TestPolyhedraFromClosestCentre(unittest.TestCase): + + def test_each_vertex_assigned_to_nearest_centre(self): + lattice = Lattice.cubic(10.0) + c1 = Atom(0, PeriodicSite('Ti', [0.2, 0.5, 0.5], lattice)) + c2 = Atom(1, PeriodicSite('Ti', [0.8, 0.5, 0.5], lattice)) + v1 = Atom(2, PeriodicSite('O', [0.25, 0.5, 0.5], lattice)) + v2 = Atom(3, PeriodicSite('O', [0.15, 0.5, 0.5], lattice)) + v3 = Atom(4, PeriodicSite('O', [0.75, 0.5, 0.5], lattice)) + v4 = Atom(5, PeriodicSite('O', [0.85, 0.5, 0.5], lattice)) + polyhedra = polyhedra_from_closest_centre( + central_atoms=[c1, c2], vertex_atoms=[v1, v2, v3, v4]) + self.assertEqual(len(polyhedra), 2) + self.assertEqual(sorted(polyhedra[0].vertex_indices), [2, 3]) + self.assertEqual(sorted(polyhedra[1].vertex_indices), [4, 5]) + + def test_no_vertex_assigned_to_multiple_polyhedra(self): + lattice = Lattice.cubic(10.0) + c1 = Atom(0, PeriodicSite('Ti', [0.2, 0.5, 0.5], lattice)) + c2 = Atom(1, PeriodicSite('Ti', [0.8, 0.5, 0.5], lattice)) + vertex_atoms = [ + Atom(2, PeriodicSite('O', [0.25, 0.5, 0.5], lattice)), + Atom(3, PeriodicSite('O', [0.15, 0.5, 0.5], lattice)), + Atom(4, PeriodicSite('O', [0.75, 0.5, 0.5], lattice)), + Atom(5, PeriodicSite('O', [0.85, 0.5, 0.5], lattice)), + ] + polyhedra = polyhedra_from_closest_centre( + central_atoms=[c1, c2], vertex_atoms=vertex_atoms) + all_vertex_indices = [] + for p in polyhedra: + all_vertex_indices.extend(p.vertex_indices) + self.assertEqual(len(all_vertex_indices), len(set(all_vertex_indices))) + + +class TestPolyhedraFromAtomIndices(unittest.TestCase): + + def test_correct_assignment_from_explicit_indices(self): + # Regression test for #15: vertex_atoms parameter was shadowed + # inside the loop, so only the first polyhedron got correct vertices. + lattice = Lattice.cubic(10.0) + c1 = Atom(0, PeriodicSite('Ti', [0.2, 0.5, 0.5], lattice)) + c2 = Atom(1, PeriodicSite('Ti', [0.8, 0.5, 0.5], lattice)) + v1 = Atom(2, PeriodicSite('O', [0.25, 0.5, 0.5], lattice)) + v2 = Atom(3, PeriodicSite('O', [0.15, 0.5, 0.5], lattice)) + v3 = Atom(4, PeriodicSite('O', [0.75, 0.5, 0.5], lattice)) + v4 = Atom(5, PeriodicSite('O', [0.85, 0.5, 0.5], lattice)) + polyhedra = polyhedra_from_atom_indices( + central_atoms=[c1, c2], + vertex_atoms=[v1, v2, v3, v4], + central_indices=[0, 1], + vertex_indices=[[2, 3], [4, 5]]) + self.assertEqual(len(polyhedra), 2) + self.assertEqual(sorted(polyhedra[0].vertex_indices), [2, 3]) + self.assertEqual(sorted(polyhedra[1].vertex_indices), [4, 5]) + + def test_raises_for_mismatched_index_lengths(self): + lattice = Lattice.cubic(10.0) + c1 = Atom(0, PeriodicSite('Ti', [0.5, 0.5, 0.5], lattice)) + v1 = Atom(1, PeriodicSite('O', [0.6, 0.5, 0.5], lattice)) + with self.assertRaises(ValueError): + polyhedra_from_atom_indices( + central_atoms=[c1], + vertex_atoms=[v1], + central_indices=[0, 1], + vertex_indices=[[1]]) + + def test_raises_for_invalid_central_index(self): + lattice = Lattice.cubic(10.0) + c1 = Atom(0, PeriodicSite('Ti', [0.5, 0.5, 0.5], lattice)) + v1 = Atom(1, PeriodicSite('O', [0.6, 0.5, 0.5], lattice)) + with self.assertRaisesRegex(ValueError, 'Central atom index 99'): + polyhedra_from_atom_indices( + central_atoms=[c1], + vertex_atoms=[v1], + central_indices=[99], + vertex_indices=[[1]]) + + def test_raises_for_invalid_vertex_index(self): + lattice = Lattice.cubic(10.0) + c1 = Atom(0, PeriodicSite('Ti', [0.5, 0.5, 0.5], lattice)) + v1 = Atom(1, PeriodicSite('O', [0.6, 0.5, 0.5], lattice)) + with self.assertRaisesRegex(ValueError, 'Vertex atom indices'): + polyhedra_from_atom_indices( + central_atoms=[c1], + vertex_atoms=[v1], + central_indices=[0], + vertex_indices=[[88, 99]]) + + +class TestFindPolyhedraNearestNeighbours(unittest.TestCase): + + def test_end_to_end_nearest_neighbours(self): + lattice = Lattice.cubic(10.0) + species = ['Ti'] + ['O'] * 8 + coords = [ + [0.5, 0.5, 0.5], + [0.6, 0.5, 0.5], [0.4, 0.5, 0.5], + [0.5, 0.6, 0.5], [0.5, 0.4, 0.5], + [0.5, 0.5, 0.6], [0.5, 0.5, 0.4], + [0.8, 0.5, 0.5], [0.2, 0.5, 0.5], + ] + structure = Structure(lattice, species, coords) + atoms = [Atom(i, site) for i, site in enumerate(structure.sites)] + recipe = PolyhedraRecipe( + method='nearest neighbours', + central_atoms='Ti', + vertex_atoms='O', + n_neighbours=6) + polyhedra = recipe.find_polyhedra(atoms, structure) + self.assertEqual(len(polyhedra), 1) + self.assertEqual(polyhedra[0].coordination_number, 6) + self.assertEqual(sorted(polyhedra[0].vertex_indices), [1, 2, 3, 4, 5, 6]) + + +class TestFindPolyhedraClosestCentre(unittest.TestCase): + + def test_end_to_end_closest_centre(self): + lattice = Lattice.cubic(10.0) + species = ['Ti', 'Ti', 'O', 'O', 'O', 'O'] + coords = [ + [0.2, 0.5, 0.5], [0.8, 0.5, 0.5], + [0.25, 0.5, 0.5], [0.15, 0.5, 0.5], + [0.75, 0.5, 0.5], [0.85, 0.5, 0.5], + ] + structure = Structure(lattice, species, coords) + atoms = [Atom(i, site) for i, site in enumerate(structure.sites)] + recipe = PolyhedraRecipe( + method='closest centre', + central_atoms='Ti', + vertex_atoms='O') + polyhedra = recipe.find_polyhedra(atoms, structure) + self.assertEqual(len(polyhedra), 2) + self.assertEqual(sorted(polyhedra[0].vertex_indices), [2, 3]) + self.assertEqual(sorted(polyhedra[1].vertex_indices), [4, 5]) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index cea9e79..3d43810 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -6,6 +6,7 @@ from polyhedral_analysis.configuration import Configuration from pymatgen.io.vasp.outputs import Xdatcar from pymatgen.core.structure import Structure +from pymatgen.core import Lattice class TestTrajectoryInit( unittest.TestCase ): @@ -59,5 +60,68 @@ def test___len__(self): trajectory = self.trajectory self.assertEqual(len(trajectory), len(trajectory.structures)) + def test_add_returns_new_trajectory(self): + trajectory2 = copy.deepcopy(self.trajectory) + combined = self.trajectory + trajectory2 + self.assertEqual(len(combined), 4) + + def test_add_does_not_modify_originals(self): + trajectory2 = copy.deepcopy(self.trajectory) + _ = self.trajectory + trajectory2 + self.assertEqual(len(self.trajectory), 2) + self.assertEqual(len(trajectory2), 2) + + def test_add_returns_not_implemented_for_non_trajectory(self): + result = self.trajectory.__add__('not a trajectory') + self.assertIs(result, NotImplemented) + + +class TestTrajectoryFromStructures(unittest.TestCase): + + def test_returns_correct_number_of_configurations(self): + lattice = Lattice.cubic(10.0) + species = ['Ti'] + ['O'] * 6 + coords = [ + [0.5, 0.5, 0.5], + [0.6, 0.5, 0.5], [0.4, 0.5, 0.5], + [0.5, 0.6, 0.5], [0.5, 0.4, 0.5], + [0.5, 0.5, 0.6], [0.5, 0.5, 0.4], + ] + structure = Structure(lattice, species, coords) + recipe = PolyhedraRecipe( + method='nearest neighbours', + central_atoms='Ti', + vertex_atoms='O', + n_neighbours=6) + trajectory = Trajectory.from_structures( + [structure, structure], [recipe]) + self.assertEqual(len(trajectory), 2) + + def test_polyhedra_constructed_per_recipes(self): + lattice = Lattice.cubic(10.0) + species = ['Ti'] + ['O'] * 6 + coords = [ + [0.5, 0.5, 0.5], + [0.6, 0.5, 0.5], [0.4, 0.5, 0.5], + [0.5, 0.6, 0.5], [0.5, 0.4, 0.5], + [0.5, 0.5, 0.6], [0.5, 0.5, 0.4], + ] + structure = Structure(lattice, species, coords) + recipe = PolyhedraRecipe( + method='nearest neighbours', + central_atoms='Ti', + vertex_atoms='O', + n_neighbours=6) + trajectory = Trajectory.from_structures( + [structure], [recipe]) + self.assertEqual(len(trajectory.configurations[0].polyhedra), 1) + self.assertEqual( + trajectory.configurations[0].polyhedra[0].coordination_number, 6) + + def test_raises_type_error_for_non_bool_progress(self): + with self.assertRaises(TypeError): + Trajectory.from_structures([], [], progress='notebook') + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index 3bfb24a..139f79a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,12 +1,60 @@ import unittest +from unittest.mock import Mock + +import numpy as np + import polyhedral_analysis.utils as utils -class TestUtilsFunctions( unittest.TestCase ): - def test_flatten( self ): - nested = [ [ 1, 2, 3 ], [ 4, 5, 6 ] ] - flat = [ 1, 2, 3, 4, 5, 6 ] - self.assertEqual( flat, utils.flatten( nested ) ) +class TestFlatten(unittest.TestCase): + + def test_flatten(self): + nested = [[1, 2, 3], [4, 5, 6]] + self.assertEqual(utils.flatten(nested), [1, 2, 3, 4, 5, 6]) + + +class TestPruneNeighbourList(unittest.TestCase): + + def test_prunes_keys_not_in_indices(self): + neighbours = {1: (2, 3), 2: (1, 3), 4: (1,)} + result = utils.prune_neighbour_list(neighbours, [1, 2]) + self.assertNotIn(4, result) + + def test_prunes_values_not_in_indices(self): + neighbours = {1: (2, 3, 5), 2: (1, 3)} + result = utils.prune_neighbour_list(neighbours, [1, 2]) + self.assertNotIn(3, result[1]) + self.assertNotIn(5, result[1]) + + def test_retains_valid_neighbours(self): + neighbours = {1: (2, 3), 2: (1, 3), 3: (1, 2)} + result = utils.prune_neighbour_list(neighbours, [1, 2, 3]) + self.assertEqual(set(result[1]), {2, 3}) + self.assertEqual(set(result[2]), {1, 3}) + self.assertEqual(set(result[3]), {1, 2}) + + def test_empty_indices_returns_empty_dict(self): + neighbours = {1: (2, 3), 2: (1, 3)} + result = utils.prune_neighbour_list(neighbours, []) + self.assertEqual(result, {}) + + +class TestLatticeMcString(unittest.TestCase): + + def test_format(self): + poly = Mock() + poly.index = 5 + poly.central_atom.coords = np.array([1.5, 2.5, 3.5]) + poly.label = 'oct' + neighbour_list = {5: (7, 12)} + result = utils.lattice_mc_string(poly, neighbour_list) + lines = result.strip().split('\n') + self.assertEqual(lines[0], 'site: 5') + self.assertTrue(lines[1].startswith('centre: ')) + self.assertIn('1.50000000', lines[1]) + self.assertEqual(lines[2], 'neighbours: 7 12') + self.assertEqual(lines[3], 'label: oct') + if __name__ == '__main__': unittest.main()