Coverage for icet/tools/variable_transformation.py: 98%
54 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-09-14 04:08 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-09-14 04:08 +0000
1from itertools import combinations, permutations
3import numpy as np
4from ase import Atoms
5from icet.core.orbit import Orbit
6from icet.core.orbit_list import OrbitList
7from icet.core.lattice_site import LatticeSite
10def _is_site_group_in_orbit(orbit: Orbit, site_group: list[LatticeSite]) -> bool:
11 """Checks if a list of sites is found among the clusters in an orbit.
12 The number of sites must match the order of the orbit.
14 Parameters
15 ----------
16 orbit
17 Orbit.
18 site_group
19 Sites to be searched for.
20 """
22 # Ensure that the number of sites matches the order of the orbit
23 if len(site_group) != orbit.order: 23 ↛ 24line 23 didn't jump to line 24 because the condition on line 23 was never true
24 return False
26 # Check if the set of lattice sites is found among the equivalent sites
27 if set(site_group) in [set(cl.lattice_sites) for cl in orbit.clusters]:
28 return True
30 # Go through all clusters
31 site_indices = [site.index for site in site_group]
32 for cluster in orbit.clusters:
33 cluster_site_indices = [s.index for s in cluster.lattice_sites]
35 # Skip if the site indices do not match
36 if set(site_indices) != set(cluster_site_indices):
37 continue
39 # Loop over all permutations of the lattice sites in cluster
40 for cluster_site_group in permutations(cluster.lattice_sites):
42 # Skip all cases that include pairs of sites with different site indices
43 if any(site1.index != site2.index
44 for site1, site2 in zip(site_group, cluster_site_group)):
45 continue
47 # If the relative offsets for all pairs of sites match, the two
48 # clusters are equivalent
49 relative_offsets = [site1.unitcell_offset - site2.unitcell_offset
50 for site1, site2 in zip(site_group, cluster_site_group)]
51 if all(np.array_equal(ro, relative_offsets[0]) for ro in relative_offsets):
52 return True
53 return False
56def get_transformation_matrix(structure: Atoms,
57 full_orbit_list: OrbitList) -> np.ndarray:
58 r"""
59 Determines the matrix that transforms the cluster functions in the form
60 of spin variables, :math:`\sigma_i\in\{-1,1\}`, to their binary
61 equivalents, :math:`x_i\in\{0,1\}`. The form is obtained by
62 performing the substitution (:math:`\sigma_i=1-2x_i`) in the
63 cluster expansion expression of the predicted property (commonly the energy).
65 Parameters
66 ----------
67 structure
68 Atomic configuration.
69 full_orbit_list
70 Full orbit list.
71 """
72 # Go through all clusters associated with each active orbit and
73 # determine its contribution to each orbit
74 orbit_indices = range(len(full_orbit_list))
75 transformation = np.zeros((len(orbit_indices) + 1,
76 len(orbit_indices) + 1))
77 transformation[0, 0] = 1.0
78 for i, orb_index in enumerate(orbit_indices, 1):
79 orbit = full_orbit_list.get_orbit(orb_index)
80 repr_sites = orbit.representative_cluster.lattice_sites
81 # add contributions to the lower order orbits to which the
82 # subclusters belong
83 for sub_order in range(orbit.order + 1):
84 n_terms_target = len(list(combinations(orbit.representative_cluster.lattice_sites,
85 sub_order)))
86 n_terms_actual = 0
87 if sub_order == 0:
88 transformation[0, i] += 1.0
89 n_terms_actual += 1
90 if sub_order == orbit.order:
91 transformation[i, i] += (-2.0) ** (sub_order)
92 n_terms_actual += 1
93 else:
94 comb_sub_sites = combinations(repr_sites, sub_order)
95 for sub_sites in comb_sub_sites:
96 for j, sub_index in enumerate(orbit_indices, 1):
97 sub_orbit = full_orbit_list.get_orbit(sub_index)
98 if sub_orbit.order != sub_order:
99 continue
100 if _is_site_group_in_orbit(sub_orbit, sub_sites):
101 transformation[j, i] += (-2.0) ** (sub_order)
102 n_terms_actual += 1
103 # If the number of contributions does not match the number of subclusters,
104 # this orbit list is incompatible with the ground state finder
105 # of subclusters
106 if n_terms_actual != n_terms_target:
107 raise ValueError('At least one cluster had subclusters that were not included'
108 ' in the cluster space. This is typically caused by cutoffs'
109 ' that are longer for a higher-order orbit than lower-order one'
110 ' (such as 8 Angstrom for triplets and 6 Angstrom for pairs).'
111 ' Please use a different cluster space for the ground state '
112 ' finder.')
114 return transformation
117def transform_parameters(structure: Atoms,
118 full_orbit_list: OrbitList,
119 parameters: np.ndarray) -> np.ndarray:
120 r"""
121 Transforms the list of parameters, obtained using cluster functions in the
122 form of of spin variables, :math:`\sigma_i\in\{-1,1\}`, to their
123 equivalents for the case of binary variables,
124 :math:`x_i\in\{0,1\}`.
126 Parameters
127 ----------
128 structure
129 Atomic configuration.
130 full_orbit_list
131 Full orbit list.
132 parameters
133 Parameter vector (spin variables).
134 """
135 A = get_transformation_matrix(structure, full_orbit_list)
136 return np.dot(A, parameters)