Coverage for icet/tools/ground_state_finder.py: 97%
152 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-09-14 04:08 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-09-14 04:08 +0000
1from math import inf
2import numpy as np
4from ase import Atoms
5from ase.data import chemical_symbols as periodic_table
6from .. import ClusterExpansion
7from ..core.local_orbit_list_generator import LocalOrbitListGenerator
8from ..core.structure import Structure
9from .variable_transformation import transform_parameters
10from ..input_output.logging_tools import logger
12try:
13 import mip
14 from mip.constants import BINARY, INTEGER
15 mip_loaded = True
16except ImportError:
17 from types import SimpleNamespace
18 # For typing purposes
19 mip = SimpleNamespace(Model=None, OptimizationStatus=None)
20 mip_loaded = False
23class GroundStateFinder:
24 """
25 This class provides functionality for determining the ground states
26 using a binary cluster expansion. This is efficiently achieved through the
27 use of mixed integer programming (MIP) as developed by Larsen *et al.* in
28 `Phys. Rev. Lett. 120, 256101 (2018)
29 <https://doi.org/10.1103/PhysRevLett.120.256101>`_.
31 This class relies on the `Python-MIP package
32 <https://python-mip.readthedocs.io>`_. Python-MIP can be used together
33 with `Gurobi <https://www.gurobi.com/>`_, which is not open source
34 but issues academic licenses free of charge. Pleaase note that
35 Gurobi needs to be installed separately. The :class:`GroundStateFinder` works
36 also without Gurobi, but if performance is critical, Gurobi is highly
37 recommended.
39 Warning
40 -------
41 In order to be able to use Gurobi with python-mip one must ensure that
42 `GUROBI_HOME` should point to the installation directory
43 (``<installdir>``)::
45 export GUROBI_HOME=<installdir>
47 Note
48 ----
49 The current implementation only works for binary systems.
52 Parameters
53 ----------
54 cluster_expansion
55 Cluster expansion for which to find ground states.
56 structure
57 Atomic configuration.
58 solver_name
59 ``'gurobi'``, ``'grb'`` or ``'cbc'``. Searches for available
60 solvers if no value is provided.
61 verbose
62 If ``True`` print solver messages to stdout.
64 Example
65 -------
66 The following snippet illustrates how to determine the ground state for a
67 Au-Ag alloy. Here, the parameters of the cluster
68 expansion are set to emulate a simple Ising model in order to obtain an
69 example that can be run without modification. In practice, one should of
70 course use a proper cluster expansion::
72 >>> from ase.build import bulk
73 >>> from icet import ClusterExpansion, ClusterSpace
75 >>> # prepare cluster expansion
76 >>> # the setup emulates a second nearest-neighbor (NN) Ising model
77 >>> # (zerolet and singlet parameters are zero; only first and second
78 >>> # neighbor pairs are included)
79 >>> prim = bulk('Au')
80 >>> chemical_symbols = ['Ag', 'Au']
81 >>> cs = ClusterSpace(prim, cutoffs=[4.3], chemical_symbols=chemical_symbols)
82 >>> ce = ClusterExpansion(cs, [0, 0, 0.1, -0.02])
84 >>> # prepare initial configuration
85 >>> structure = prim.repeat(3)
87 >>> # set up the ground state finder and calculate the ground state energy
88 >>> gsf = GroundStateFinder(ce, structure)
89 >>> ground_state = gsf.get_ground_state({'Ag': 5})
90 >>> print('Ground state energy:', ce.predict(ground_state))
91 """
93 def __init__(self,
94 cluster_expansion: ClusterExpansion,
95 structure: Atoms,
96 solver_name: str = None,
97 verbose: bool = True) -> None:
98 if not mip_loaded:
99 raise ImportError('Python-MIP (https://python-mip.readthedocs.io/en/latest/) is '
100 'required in order to use the ground state finder.')
101 # Check that there is only one active sublattice
102 self._cluster_expansion = cluster_expansion
103 self._fractional_position_tolerance = cluster_expansion.fractional_position_tolerance
104 self.structure = structure
105 cluster_space = self._cluster_expansion.get_cluster_space_copy()
106 primitive_structure = cluster_space.primitive_structure
107 self._active_sublattices = cluster_space.get_sublattices(structure).active_sublattices
109 # Check that there are no more than two allowed species
110 active_species = [set(subl.chemical_symbols) for subl in self._active_sublattices]
111 if any(len(species) > 2 for species in active_species):
112 raise NotImplementedError('Currently, systems with more than two allowed species on '
113 'any sublattice are not supported.')
115 # Check that there are no merged orbits
116 if any(['merge' in elem for elem in cluster_space._pruning_history]):
117 raise NotImplementedError('Currently, systems with merged orbits are not supported.')
119 self._active_species = active_species
121 # Define cluster functions for elements
122 self._reverse_id_maps = []
123 for species in active_species:
124 for species_map in cluster_space.species_maps: 124 ↛ 123line 124 didn't jump to line 123 because the loop on line 124 didn't complete
125 symbols = [periodic_table[n] for n in species_map]
126 if set(symbols) == species:
127 reverse_id_map = {1 - species_map[n]: periodic_table[n] for n in species_map}
128 self._reverse_id_maps.append(reverse_id_map)
129 break
130 self._count_symbols = [reverse_id_map[1] for reverse_id_map in self._reverse_id_maps]
132 # Generate full orbit list
133 self._orbit_list = cluster_space.orbit_list
134 lolg = LocalOrbitListGenerator(
135 orbit_list=self._orbit_list,
136 structure=Structure.from_atoms(primitive_structure),
137 fractional_position_tolerance=self._fractional_position_tolerance)
138 self._full_orbit_list = lolg.generate_full_orbit_list()
140 # Transform the parameters
141 binary_parameters = transform_parameters(primitive_structure,
142 self._full_orbit_list,
143 self._cluster_expansion.parameters)
144 self._transformed_parameters = binary_parameters
146 # Build model
147 if solver_name is None:
148 solver_name = ''
149 self._model = self._build_model(structure, solver_name, verbose)
151 # Properties that are defined when searching for a ground state
152 self._optimization_status = None
154 def _build_model(self,
155 structure: Atoms,
156 solver_name: str,
157 verbose: bool) -> mip.Model:
158 """
159 Build a Python-MIP model based on the provided structure.
161 Parameters
162 ----------
163 structure
164 Atomic configuration.
165 solver_name
166 ``'gurobi'``, ``'grb'`` or ``'cbc'``. Searches for available
167 solvers if no value is provided.
168 verbose
169 If ``True`` print solver messages to stdout.
170 """
172 # Create cluster maps
173 self._create_cluster_maps(structure)
175 # Initiate MIP model
176 model = mip.Model('CE', solver_name=solver_name)
177 model.solver.set_mip_gap(0) # avoid stopping prematurely
178 model.solver.set_emphasis(2) # focus on finding optimal solution
179 model.preprocess = 2 # maximum preprocessing
181 # Set verbosity
182 model.verbose = int(verbose)
184 # Spin variables (remapped) for all atoms in the structure
185 xs = {i: model.add_var(name='atom_{}'.format(i), var_type=BINARY)
186 for subl in self._active_sublattices for i in subl.indices}
187 ys = [model.add_var(name='cluster_{}'.format(i), var_type=BINARY)
188 for i in range(len(self._cluster_to_orbit_map))]
190 # The objective function is added to 'model' first
191 model.objective = mip.minimize(mip.xsum(self._get_total_energy(ys)))
193 # Connect cluster variables to spin variables with cluster constraints
194 # TODO: don't create cluster constraints for singlets
195 constraint_count = 0
196 for i, cluster in enumerate(self._cluster_to_sites_map):
197 orbit = self._cluster_to_orbit_map[i]
198 parameter = self._transformed_parameters[orbit + 1]
199 assert parameter != 0
201 if len(cluster) < 2 or parameter < 0: # no "downwards" pressure
202 for atom in cluster:
203 model.add_constr(ys[i] <= xs[atom],
204 'Decoration -> cluster {}'.format(constraint_count))
205 constraint_count = constraint_count + 1
207 if len(cluster) < 2 or parameter > 0: # no "upwards" pressure
208 model.add_constr(ys[i] >= 1 - len(cluster) +
209 mip.xsum(xs[atom]
210 for atom in cluster),
211 'Decoration -> cluster {}'.format(constraint_count))
212 constraint_count = constraint_count + 1
214 for sym, subl in zip(self._count_symbols, self._active_sublattices):
215 # Create slack variable
216 slack = model.add_var(name='slackvar_{}'.format(sym), var_type=INTEGER,
217 lb=0, ub=len(subl.indices))
219 # Add slack constraint
220 model.add_constr(slack <= -1, name='{} slack'.format(sym))
222 # Set species constraint
223 model.add_constr(mip.xsum([xs[i] for i in subl.indices]) + slack == -1,
224 name='{} count'.format(sym))
226 # Update the model so that variables and constraints can be queried
227 if model.solver_name.upper() in ['GRB', 'GUROBI']: 227 ↛ 228line 227 didn't jump to line 228 because the condition on line 227 was never true
228 model.solver.update()
229 return model
231 def _create_cluster_maps(self, structure: Atoms) -> None:
232 """
233 Create maps that include information regarding which sites and orbits
234 are associated with each cluster as well as the number of clusters per
235 orbit.
237 Parameters
238 ----------
239 structure
240 Atomic configuration.
241 """
242 # Generate full orbit list
243 lolg = LocalOrbitListGenerator(
244 orbit_list=self._orbit_list,
245 structure=Structure.from_atoms(structure),
246 fractional_position_tolerance=self._fractional_position_tolerance)
247 full_orbit_list = lolg.generate_full_orbit_list()
249 # Create maps of site indices and orbits for all clusters
250 cluster_to_sites_map = []
251 cluster_to_orbit_map = []
252 for orb_index in range(len(full_orbit_list)):
254 clusters = full_orbit_list.get_orbit(orb_index).clusters
256 # Determine the sites and the orbit associated with each cluster
257 for cluster in clusters:
259 # Do not include clusters for which the parameter is 0
260 parameter = self._transformed_parameters[orb_index + 1]
261 if parameter == 0:
262 continue
264 # Add the the list of sites and the orbit to the respective cluster maps
265 cluster_sites = [site.index for site in cluster.lattice_sites]
266 cluster_to_sites_map.append(cluster_sites)
267 cluster_to_orbit_map.append(orb_index)
269 # calculate the number of clusters per orbit
270 nclusters_per_orbit = [cluster_to_orbit_map.count(i) for i in
271 range(cluster_to_orbit_map[-1] + 1)]
272 nclusters_per_orbit = [1] + nclusters_per_orbit
274 self._cluster_to_sites_map = cluster_to_sites_map
275 self._cluster_to_orbit_map = cluster_to_orbit_map
276 self._nclusters_per_orbit = nclusters_per_orbit
278 def _get_total_energy(self, cluster_instance_activities: list[int]) -> list[float]:
279 r"""
280 Calculates the total energy using the expression based on binary variables.
282 .. math::
284 H({\boldsymbol x}, {\boldsymbol E})=E_0+
285 \sum\limits_j\sum\limits_{{\boldsymbol c}
286 \in{\boldsymbol C}_j}E_jy_{{\boldsymbol c}},
288 where (:math:`y_{{\boldsymbol c}}=
289 \prod\limits_{i\in{\boldsymbol c}}x_i`).
291 Parameters
292 ----------
293 cluster_instance_activities
294 list of cluster instance activities, (:math:`y_{{\boldsymbol c}}`)
295 """
297 E = [0.0 for _ in self._transformed_parameters]
298 for i in range(len(cluster_instance_activities)):
299 orbit = self._cluster_to_orbit_map[i]
300 E[orbit + 1] = E[orbit + 1] + cluster_instance_activities[i]
301 E[0] = 1
303 E = [0.0 if np.isclose(self._transformed_parameters[orbit], 0.0) else
304 E[orbit] * self._transformed_parameters[orbit] / self._nclusters_per_orbit[orbit]
305 for orbit in range(len(self._transformed_parameters))]
306 return E
308 def get_ground_state(self,
309 species_count: dict[str, int] = None,
310 max_seconds: float = inf,
311 threads: int = 0) -> Atoms:
312 """Finds the ground state for a given structure and species count. If
313 :attr:`species_count` is not provided when initializing the
314 instance of this class the first species in the list of
315 chemical symbols for the active sublattice will be used.
317 Parameters
318 ----------
319 species_count
320 Dictionary with count for one of the species on each active
321 sublattice. If no count is provided for a sublattice, the
322 concentration is allowed to vary.
323 max_seconds
324 Maximum runtime in seconds.
325 threads
326 Number of threads to be used when solving the problem, given that a
327 positive integer has been provided. If set to :math:`0` the solver default
328 configuration is used while :math:`-1` corresponds to all available
329 processing cores.
331 """
332 if species_count is None:
333 species_count = {}
335 # Check that the species_count is consistent with the cluster space
336 all_active_species = set.union(*self._active_species)
337 for symbol in species_count:
338 if symbol not in all_active_species:
339 raise ValueError('The species {} is not present on any of the active sublattices'
340 ' ({})'.format(symbol, self._active_species))
342 # The model is solved using python-MIPs choice of solver, which is
343 # Gurobi, if available, and COIN-OR Branch-and-Cut, otherwise.
344 model = self._model
346 # Update the species counts
347 for i, species in enumerate(self._active_species):
348 count_symbol = self._count_symbols[i]
349 max_count = len(self._active_sublattices[i].indices)
351 symbols_to_add = set.intersection(set(species_count), set(species))
352 if len(symbols_to_add) > 1:
353 raise ValueError('Provide counts for at most one of the species on each active '
354 'sublattice ({}), not {}!'.format(self._active_species,
355 list(species_count)))
356 elif len(symbols_to_add) == 1:
357 sym = symbols_to_add.pop()
358 count = species_count[sym]
359 if count < 0 or count > max_count:
360 raise ValueError('The count for species {} ({}) must be a positive integer and'
361 ' cannot exceed the number of sites on the active sublattice '
362 '({})'.format(sym, count, max_count))
363 if sym == count_symbol:
364 xcount = count
365 else:
366 xcount = max_count - count
368 max_slack = 0
369 else:
370 xcount = max_slack = max_count
372 model.constr_by_name('{} count'.format(count_symbol)).rhs = xcount
373 model.constr_by_name('{} slack'.format(count_symbol)).rhs = max_slack
375 # Set the number of threads
376 model.threads = threads
378 # Optimize the model
379 self._optimization_status = model.optimize(max_seconds=max_seconds)
381 # The status of the solution is printed to the screen
382 if str(self._optimization_status) != 'OptimizationStatus.OPTIMAL': 382 ↛ 383line 382 didn't jump to line 383 because the condition on line 382 was never true
383 if str(self._optimization_status) == 'OptimizationStatus.FEASIBLE':
384 logger.warning('Solution optimality not proven.')
385 else:
386 raise Exception('Optimization failed ({0})'.format(str(self._optimization_status)))
388 # Each of the variables is printed with it's resolved optimum value
389 gs = self.structure.copy()
391 active_index_to_sublattice_map = {i: j for j, subl in enumerate(self._active_sublattices)
392 for i in subl.indices}
393 for v in model.vars:
394 if 'atom' in v.name:
395 index = int(v.name.split('_')[-1])
396 sublattice_index = active_index_to_sublattice_map[index]
397 gs[index].symbol = self._reverse_id_maps[sublattice_index][int(v.x)]
399 # Assert that the solution agrees with the prediction
400 prediction = self._cluster_expansion.predict(gs)
401 assert abs(model.objective_value - prediction) < 1e-6
402 return gs
404 @property
405 def optimization_status(self) -> mip.OptimizationStatus:
406 """ Optimization status. """
407 return self._optimization_status
409 @property
410 def model(self) -> mip.Model:
411 """ Python-MIP model. """
412 return self._model.copy()