Coverage for icet/tools/variable_transformation.py: 98%
55 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-12-26 04:12 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-12-26 04:12 +0000
1from itertools import combinations, permutations
2from typing import List
4import numpy as np
5from ase import Atoms
6from icet.core.orbit import Orbit
7from icet.core.orbit_list import OrbitList
8from icet.core.lattice_site import LatticeSite
11def _is_site_group_in_orbit(orbit: Orbit, site_group: List[LatticeSite]) -> bool:
12 """Checks if a list of sites is found among the clusters in an orbit.
13 The number of sites must match the order of the orbit.
15 Parameters
16 ----------
17 orbit
18 Orbit.
19 site_group
20 Sites to be searched for.
21 """
23 # Ensure that the number of sites matches the order of the orbit
24 if len(site_group) != orbit.order: 24 ↛ 25line 24 didn't jump to line 25, because the condition on line 24 was never true
25 return False
27 # Check if the set of lattice sites is found among the equivalent sites
28 if set(site_group) in [set(cl.lattice_sites) for cl in orbit.clusters]:
29 return True
31 # Go through all clusters
32 site_indices = [site.index for site in site_group]
33 for cluster in orbit.clusters:
34 cluster_site_indices = [s.index for s in cluster.lattice_sites]
36 # Skip if the site indices do not match
37 if set(site_indices) != set(cluster_site_indices):
38 continue
40 # Loop over all permutations of the lattice sites in cluster
41 for cluster_site_group in permutations(cluster.lattice_sites):
43 # Skip all cases that include pairs of sites with different site indices
44 if any(site1.index != site2.index
45 for site1, site2 in zip(site_group, cluster_site_group)):
46 continue
48 # If the relative offsets for all pairs of sites match, the two
49 # clusters are equivalent
50 relative_offsets = [site1.unitcell_offset - site2.unitcell_offset
51 for site1, site2 in zip(site_group, cluster_site_group)]
52 if all(np.array_equal(ro, relative_offsets[0]) for ro in relative_offsets):
53 return True
54 return False
57def get_transformation_matrix(structure: Atoms,
58 full_orbit_list: OrbitList) -> np.ndarray:
59 r"""
60 Determines the matrix that transforms the cluster functions in the form
61 of spin variables, :math:`\sigma_i\in\{-1,1\}`, to their binary
62 equivalents, :math:`x_i\in\{0,1\}`. The form is obtained by
63 performing the substitution (:math:`\sigma_i=1-2x_i`) in the
64 cluster expansion expression of the predicted property (commonly the energy).
66 Parameters
67 ----------
68 structure
69 Atomic configuration.
70 full_orbit_list
71 Full orbit list.
72 """
73 # Go through all clusters associated with each active orbit and
74 # determine its contribution to each orbit
75 orbit_indices = range(len(full_orbit_list))
76 transformation = np.zeros((len(orbit_indices) + 1,
77 len(orbit_indices) + 1))
78 transformation[0, 0] = 1.0
79 for i, orb_index in enumerate(orbit_indices, 1):
80 orbit = full_orbit_list.get_orbit(orb_index)
81 repr_sites = orbit.representative_cluster.lattice_sites
82 # add contributions to the lower order orbits to which the
83 # subclusters belong
84 for sub_order in range(orbit.order + 1):
85 n_terms_target = len(list(combinations(orbit.representative_cluster.lattice_sites,
86 sub_order)))
87 n_terms_actual = 0
88 if sub_order == 0:
89 transformation[0, i] += 1.0
90 n_terms_actual += 1
91 if sub_order == orbit.order:
92 transformation[i, i] += (-2.0) ** (sub_order)
93 n_terms_actual += 1
94 else:
95 comb_sub_sites = combinations(repr_sites, sub_order)
96 for sub_sites in comb_sub_sites:
97 for j, sub_index in enumerate(orbit_indices, 1):
98 sub_orbit = full_orbit_list.get_orbit(sub_index)
99 if sub_orbit.order != sub_order:
100 continue
101 if _is_site_group_in_orbit(sub_orbit, sub_sites):
102 transformation[j, i] += (-2.0) ** (sub_order)
103 n_terms_actual += 1
104 # If the number of contributions does not match the number of subclusters,
105 # this orbit list is incompatible with the ground state finder
106 # of subclusters
107 if n_terms_actual != n_terms_target:
108 raise ValueError('At least one cluster had subclusters that were not included'
109 ' in the cluster space. This is typically caused by cutoffs'
110 ' that are longer for a higher-order orbit than lower-order one'
111 ' (such as 8 Angstrom for triplets and 6 Angstrom for pairs).'
112 ' Please use a different cluster space for the ground state '
113 ' finder.')
115 return transformation
118def transform_parameters(structure: Atoms,
119 full_orbit_list: OrbitList,
120 parameters: np.ndarray) -> np.ndarray:
121 r"""
122 Transforms the list of parameters, obtained using cluster functions in the
123 form of of spin variables, :math:`\sigma_i\in\{-1,1\}`, to their
124 equivalents for the case of binary variables,
125 :math:`x_i\in\{0,1\}`.
127 Parameters
128 ----------
129 structure
130 Atomic configuration.
131 full_orbit_list
132 Full orbit list.
133 parameters
134 Parameter vector (spin variables).
135 """
136 A = get_transformation_matrix(structure, full_orbit_list)
137 return np.dot(A, parameters)