Coverage for icet/tools/ground_state_finder.py: 97%
148 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-12-26 04:12 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-12-26 04:12 +0000
2from math import inf
3import numpy as np
4from typing import List, Dict
6from ase import Atoms
7from ase.data import chemical_symbols as periodic_table
8from .. import ClusterExpansion
9from ..core.local_orbit_list_generator import LocalOrbitListGenerator
10from ..core.structure import Structure
11from .variable_transformation import transform_parameters
12from ..input_output.logging_tools import logger
14try:
15 import mip
16 from mip.constants import BINARY, INTEGER
17except ImportError:
18 raise ImportError('Python-MIP (https://python-mip.readthedocs.io/en/latest/) is required in '
19 'order to use the ground state finder.')
22class GroundStateFinder:
23 """
24 This class provides functionality for determining the ground states
25 using a binary cluster expansion. This is efficiently achieved through the
26 use of mixed integer programming (MIP) as developed by Larsen *et al.* in
27 `Phys. Rev. Lett. 120, 256101 (2018)
28 <https://doi.org/10.1103/PhysRevLett.120.256101>`_.
30 This class relies on the `Python-MIP package
31 <https://python-mip.readthedocs.io>`_. Python-MIP can be used together
32 with `Gurobi <https://www.gurobi.com/>`_, which is not open source
33 but issues academic licenses free of charge. Pleaase note that
34 Gurobi needs to be installed separately. The :class:`GroundStateFinder` works
35 also without Gurobi, but if performance is critical, Gurobi is highly
36 recommended.
38 Warning
39 -------
40 In order to be able to use Gurobi with python-mip one must ensure that
41 `GUROBI_HOME` should point to the installation directory
42 (``<installdir>``)::
44 export GUROBI_HOME=<installdir>
46 Note
47 ----
48 The current implementation only works for binary systems.
51 Parameters
52 ----------
53 cluster_expansion
54 Cluster expansion for which to find ground states.
55 structure
56 Atomic configuration.
57 solver_name
58 ``'gurobi'``, ``'grb'`` or ``'cbc'``. Searches for available
59 solvers if no value is provided.
60 verbose
61 If ``True`` print solver messages to stdout.
63 Example
64 -------
65 The following snippet illustrates how to determine the ground state for a
66 Au-Ag alloy. Here, the parameters of the cluster
67 expansion are set to emulate a simple Ising model in order to obtain an
68 example that can be run without modification. In practice, one should of
69 course use a proper cluster expansion::
71 >>> from ase.build import bulk
72 >>> from icet import ClusterExpansion, ClusterSpace
74 >>> # prepare cluster expansion
75 >>> # the setup emulates a second nearest-neighbor (NN) Ising model
76 >>> # (zerolet and singlet parameters are zero; only first and second neighbor
77 >>> # pairs are included)
78 >>> prim = bulk('Au')
79 >>> chemical_symbols = ['Ag', 'Au']
80 >>> cs = ClusterSpace(prim, cutoffs=[4.3], chemical_symbols=chemical_symbols)
81 >>> ce = ClusterExpansion(cs, [0, 0, 0.1, -0.02])
83 >>> # prepare initial configuration
84 >>> structure = prim.repeat(3)
86 >>> # set up the ground state finder and calculate the ground state energy
87 >>> gsf = GroundStateFinder(ce, structure)
88 >>> ground_state = gsf.get_ground_state({'Ag': 5})
89 >>> print('Ground state energy:', ce.predict(ground_state))
90 """
92 def __init__(self,
93 cluster_expansion: ClusterExpansion,
94 structure: Atoms,
95 solver_name: str = None,
96 verbose: bool = True) -> None:
97 # Check that there is only one active sublattice
98 self._cluster_expansion = cluster_expansion
99 self._fractional_position_tolerance = cluster_expansion.fractional_position_tolerance
100 self.structure = structure
101 cluster_space = self._cluster_expansion.get_cluster_space_copy()
102 primitive_structure = cluster_space.primitive_structure
103 self._active_sublattices = cluster_space.get_sublattices(structure).active_sublattices
105 # Check that there are no more than two allowed species
106 active_species = [set(subl.chemical_symbols) for subl in self._active_sublattices]
107 if any(len(species) > 2 for species in active_species):
108 raise NotImplementedError('Currently, systems with more than two allowed species on '
109 'any sublattice are not supported.')
111 # Check that there are no merged orbits
112 if any(['merge' in elem for elem in cluster_space._pruning_history]):
113 raise NotImplementedError('Currently, systems with merged orbits are not supported.')
115 self._active_species = active_species
117 # Define cluster functions for elements
118 self._reverse_id_maps = []
119 for species in active_species:
120 for species_map in cluster_space.species_maps: 120 ↛ 119line 120 didn't jump to line 119, because the loop on line 120 didn't complete
121 symbols = [periodic_table[n] for n in species_map]
122 if set(symbols) == species:
123 reverse_id_map = {1 - species_map[n]: periodic_table[n] for n in species_map}
124 self._reverse_id_maps.append(reverse_id_map)
125 break
126 self._count_symbols = [reverse_id_map[1] for reverse_id_map in self._reverse_id_maps]
128 # Generate full orbit list
129 self._orbit_list = cluster_space.orbit_list
130 lolg = LocalOrbitListGenerator(
131 orbit_list=self._orbit_list,
132 structure=Structure.from_atoms(primitive_structure),
133 fractional_position_tolerance=self._fractional_position_tolerance)
134 self._full_orbit_list = lolg.generate_full_orbit_list()
136 # Transform the parameters
137 binary_parameters = transform_parameters(primitive_structure,
138 self._full_orbit_list,
139 self._cluster_expansion.parameters)
140 self._transformed_parameters = binary_parameters
142 # Build model
143 if solver_name is None:
144 solver_name = ''
145 self._model = self._build_model(structure, solver_name, verbose)
147 # Properties that are defined when searching for a ground state
148 self._optimization_status = None
150 def _build_model(self,
151 structure: Atoms,
152 solver_name: str,
153 verbose: bool) -> mip.Model:
154 """
155 Build a Python-MIP model based on the provided structure.
157 Parameters
158 ----------
159 structure
160 Atomic configuration.
161 solver_name
162 ``'gurobi'``, ``'grb'`` or ``'cbc'``. Searches for available
163 solvers if no value is provided.
164 verbose
165 If ``True`` print solver messages to stdout.
166 """
168 # Create cluster maps
169 self._create_cluster_maps(structure)
171 # Initiate MIP model
172 model = mip.Model('CE', solver_name=solver_name)
173 model.solver.set_mip_gap(0) # avoid stopping prematurely
174 model.solver.set_emphasis(2) # focus on finding optimal solution
175 model.preprocess = 2 # maximum preprocessing
177 # Set verbosity
178 model.verbose = int(verbose)
180 # Spin variables (remapped) for all atoms in the structure
181 xs = {i: model.add_var(name='atom_{}'.format(i), var_type=BINARY)
182 for subl in self._active_sublattices for i in subl.indices}
183 ys = [model.add_var(name='cluster_{}'.format(i), var_type=BINARY)
184 for i in range(len(self._cluster_to_orbit_map))]
186 # The objective function is added to 'model' first
187 model.objective = mip.minimize(mip.xsum(self._get_total_energy(ys)))
189 # Connect cluster variables to spin variables with cluster constraints
190 # TODO: don't create cluster constraints for singlets
191 constraint_count = 0
192 for i, cluster in enumerate(self._cluster_to_sites_map):
193 orbit = self._cluster_to_orbit_map[i]
194 parameter = self._transformed_parameters[orbit + 1]
195 assert parameter != 0
197 if len(cluster) < 2 or parameter < 0: # no "downwards" pressure
198 for atom in cluster:
199 model.add_constr(ys[i] <= xs[atom],
200 'Decoration -> cluster {}'.format(constraint_count))
201 constraint_count = constraint_count + 1
203 if len(cluster) < 2 or parameter > 0: # no "upwards" pressure
204 model.add_constr(ys[i] >= 1 - len(cluster) +
205 mip.xsum(xs[atom]
206 for atom in cluster),
207 'Decoration -> cluster {}'.format(constraint_count))
208 constraint_count = constraint_count + 1
210 for sym, subl in zip(self._count_symbols, self._active_sublattices):
211 # Create slack variable
212 slack = model.add_var(name='slackvar_{}'.format(sym), var_type=INTEGER,
213 lb=0, ub=len(subl.indices))
215 # Add slack constraint
216 model.add_constr(slack <= -1, name='{} slack'.format(sym))
218 # Set species constraint
219 model.add_constr(mip.xsum([xs[i] for i in subl.indices]) + slack == -1,
220 name='{} count'.format(sym))
222 # Update the model so that variables and constraints can be queried
223 if model.solver_name.upper() in ['GRB', 'GUROBI']: 223 ↛ 224line 223 didn't jump to line 224, because the condition on line 223 was never true
224 model.solver.update()
225 return model
227 def _create_cluster_maps(self, structure: Atoms) -> None:
228 """
229 Create maps that include information regarding which sites and orbits
230 are associated with each cluster as well as the number of clusters per
231 orbit.
233 Parameters
234 ----------
235 structure
236 Atomic configuration.
237 """
238 # Generate full orbit list
239 lolg = LocalOrbitListGenerator(
240 orbit_list=self._orbit_list,
241 structure=Structure.from_atoms(structure),
242 fractional_position_tolerance=self._fractional_position_tolerance)
243 full_orbit_list = lolg.generate_full_orbit_list()
245 # Create maps of site indices and orbits for all clusters
246 cluster_to_sites_map = []
247 cluster_to_orbit_map = []
248 for orb_index in range(len(full_orbit_list)):
250 clusters = full_orbit_list.get_orbit(orb_index).clusters
252 # Determine the sites and the orbit associated with each cluster
253 for cluster in clusters:
255 # Do not include clusters for which the parameter is 0
256 parameter = self._transformed_parameters[orb_index + 1]
257 if parameter == 0:
258 continue
260 # Add the the list of sites and the orbit to the respective cluster maps
261 cluster_sites = [site.index for site in cluster.lattice_sites]
262 cluster_to_sites_map.append(cluster_sites)
263 cluster_to_orbit_map.append(orb_index)
265 # calculate the number of clusters per orbit
266 nclusters_per_orbit = [cluster_to_orbit_map.count(i) for i in
267 range(cluster_to_orbit_map[-1] + 1)]
268 nclusters_per_orbit = [1] + nclusters_per_orbit
270 self._cluster_to_sites_map = cluster_to_sites_map
271 self._cluster_to_orbit_map = cluster_to_orbit_map
272 self._nclusters_per_orbit = nclusters_per_orbit
274 def _get_total_energy(self, cluster_instance_activities: List[int]) -> List[float]:
275 r"""
276 Calculates the total energy using the expression based on binary variables.
278 .. math::
280 H({\boldsymbol x}, {\boldsymbol E})=E_0+
281 \sum\limits_j\sum\limits_{{\boldsymbol c}
282 \in{\boldsymbol C}_j}E_jy_{{\boldsymbol c}},
284 where (:math:`y_{{\boldsymbol c}}=
285 \prod\limits_{i\in{\boldsymbol c}}x_i`).
287 Parameters
288 ----------
289 cluster_instance_activities
290 list of cluster instance activities, (:math:`y_{{\boldsymbol c}}`)
291 """
293 E = [0.0 for _ in self._transformed_parameters]
294 for i in range(len(cluster_instance_activities)):
295 orbit = self._cluster_to_orbit_map[i]
296 E[orbit + 1] = E[orbit + 1] + cluster_instance_activities[i]
297 E[0] = 1
299 E = [0.0 if np.isclose(self._transformed_parameters[orbit], 0.0) else
300 E[orbit] * self._transformed_parameters[orbit] / self._nclusters_per_orbit[orbit]
301 for orbit in range(len(self._transformed_parameters))]
302 return E
304 def get_ground_state(self,
305 species_count: Dict[str, int] = None,
306 max_seconds: float = inf,
307 threads: int = 0) -> Atoms:
308 """Finds the ground state for a given structure and species count. If
309 :attr:`species_count` is not provided when initializing the
310 instance of this class the first species in the list of
311 chemical symbols for the active sublattice will be used.
313 Parameters
314 ----------
315 species_count
316 Dictionary with count for one of the species on each active
317 sublattice. If no count is provided for a sublattice, the
318 concentration is allowed to vary.
319 max_seconds
320 Maximum runtime in seconds.
321 threads
322 Number of threads to be used when solving the problem, given that a
323 positive integer has been provided. If set to :math:`0` the solver default
324 configuration is used while :math:`-1` corresponds to all available
325 processing cores.
327 """
328 if species_count is None:
329 species_count = {}
331 # Check that the species_count is consistent with the cluster space
332 all_active_species = set.union(*self._active_species)
333 for symbol in species_count:
334 if symbol not in all_active_species:
335 raise ValueError('The species {} is not present on any of the active sublattices'
336 ' ({})'.format(symbol, self._active_species))
338 # The model is solved using python-MIPs choice of solver, which is
339 # Gurobi, if available, and COIN-OR Branch-and-Cut, otherwise.
340 model = self._model
342 # Update the species counts
343 for i, species in enumerate(self._active_species):
344 count_symbol = self._count_symbols[i]
345 max_count = len(self._active_sublattices[i].indices)
347 symbols_to_add = set.intersection(set(species_count), set(species))
348 if len(symbols_to_add) > 1:
349 raise ValueError('Provide counts for at most one of the species on each active '
350 'sublattice ({}), not {}!'.format(self._active_species,
351 list(species_count)))
352 elif len(symbols_to_add) == 1:
353 sym = symbols_to_add.pop()
354 count = species_count[sym]
355 if count < 0 or count > max_count:
356 raise ValueError('The count for species {} ({}) must be a positive integer and'
357 ' cannot exceed the number of sites on the active sublattice '
358 '({})'.format(sym, count, max_count))
359 if sym == count_symbol:
360 xcount = count
361 else:
362 xcount = max_count - count
364 max_slack = 0
365 else:
366 xcount = max_slack = max_count
368 model.constr_by_name('{} count'.format(count_symbol)).rhs = xcount
369 model.constr_by_name('{} slack'.format(count_symbol)).rhs = max_slack
371 # Set the number of threads
372 model.threads = threads
374 # Optimize the model
375 self._optimization_status = model.optimize(max_seconds=max_seconds)
377 # The status of the solution is printed to the screen
378 if str(self._optimization_status) != 'OptimizationStatus.OPTIMAL': 378 ↛ 379line 378 didn't jump to line 379, because the condition on line 378 was never true
379 if str(self._optimization_status) == 'OptimizationStatus.FEASIBLE':
380 logger.warning('Solution optimality not proven.')
381 else:
382 raise Exception('Optimization failed ({0})'.format(str(self._optimization_status)))
384 # Each of the variables is printed with it's resolved optimum value
385 gs = self.structure.copy()
387 active_index_to_sublattice_map = {i: j for j, subl in enumerate(self._active_sublattices)
388 for i in subl.indices}
389 for v in model.vars:
390 if 'atom' in v.name:
391 index = int(v.name.split('_')[-1])
392 sublattice_index = active_index_to_sublattice_map[index]
393 gs[index].symbol = self._reverse_id_maps[sublattice_index][int(v.x)]
395 # Assert that the solution agrees with the prediction
396 prediction = self._cluster_expansion.predict(gs)
397 assert abs(model.objective_value - prediction) < 1e-6
398 return gs
400 @property
401 def optimization_status(self) -> mip.OptimizationStatus:
402 """ Optimization status. """
403 return self._optimization_status
405 @property
406 def model(self) -> mip.Model:
407 """ Python-MIP model. """
408 return self._model.copy()