Coverage for icet/tools/ground_state_finder.py: 97%

152 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-09-14 04:08 +0000

1from math import inf 

2import numpy as np 

3 

4from ase import Atoms 

5from ase.data import chemical_symbols as periodic_table 

6from .. import ClusterExpansion 

7from ..core.local_orbit_list_generator import LocalOrbitListGenerator 

8from ..core.structure import Structure 

9from .variable_transformation import transform_parameters 

10from ..input_output.logging_tools import logger 

11 

12try: 

13 import mip 

14 from mip.constants import BINARY, INTEGER 

15 mip_loaded = True 

16except ImportError: 

17 from types import SimpleNamespace 

18 # For typing purposes 

19 mip = SimpleNamespace(Model=None, OptimizationStatus=None) 

20 mip_loaded = False 

21 

22 

23class GroundStateFinder: 

24 """ 

25 This class provides functionality for determining the ground states 

26 using a binary cluster expansion. This is efficiently achieved through the 

27 use of mixed integer programming (MIP) as developed by Larsen *et al.* in 

28 `Phys. Rev. Lett. 120, 256101 (2018) 

29 <https://doi.org/10.1103/PhysRevLett.120.256101>`_. 

30 

31 This class relies on the `Python-MIP package 

32 <https://python-mip.readthedocs.io>`_. Python-MIP can be used together 

33 with `Gurobi <https://www.gurobi.com/>`_, which is not open source 

34 but issues academic licenses free of charge. Pleaase note that 

35 Gurobi needs to be installed separately. The :class:`GroundStateFinder` works 

36 also without Gurobi, but if performance is critical, Gurobi is highly 

37 recommended. 

38 

39 Warning 

40 ------- 

41 In order to be able to use Gurobi with python-mip one must ensure that 

42 `GUROBI_HOME` should point to the installation directory 

43 (``<installdir>``):: 

44 

45 export GUROBI_HOME=<installdir> 

46 

47 Note 

48 ---- 

49 The current implementation only works for binary systems. 

50 

51 

52 Parameters 

53 ---------- 

54 cluster_expansion 

55 Cluster expansion for which to find ground states. 

56 structure 

57 Atomic configuration. 

58 solver_name 

59 ``'gurobi'``, ``'grb'`` or ``'cbc'``. Searches for available 

60 solvers if no value is provided. 

61 verbose 

62 If ``True`` print solver messages to stdout. 

63 

64 Example 

65 ------- 

66 The following snippet illustrates how to determine the ground state for a 

67 Au-Ag alloy. Here, the parameters of the cluster 

68 expansion are set to emulate a simple Ising model in order to obtain an 

69 example that can be run without modification. In practice, one should of 

70 course use a proper cluster expansion:: 

71 

72 >>> from ase.build import bulk 

73 >>> from icet import ClusterExpansion, ClusterSpace 

74 

75 >>> # prepare cluster expansion 

76 >>> # the setup emulates a second nearest-neighbor (NN) Ising model 

77 >>> # (zerolet and singlet parameters are zero; only first and second 

78 >>> # neighbor pairs are included) 

79 >>> prim = bulk('Au') 

80 >>> chemical_symbols = ['Ag', 'Au'] 

81 >>> cs = ClusterSpace(prim, cutoffs=[4.3], chemical_symbols=chemical_symbols) 

82 >>> ce = ClusterExpansion(cs, [0, 0, 0.1, -0.02]) 

83 

84 >>> # prepare initial configuration 

85 >>> structure = prim.repeat(3) 

86 

87 >>> # set up the ground state finder and calculate the ground state energy 

88 >>> gsf = GroundStateFinder(ce, structure) 

89 >>> ground_state = gsf.get_ground_state({'Ag': 5}) 

90 >>> print('Ground state energy:', ce.predict(ground_state)) 

91 """ 

92 

93 def __init__(self, 

94 cluster_expansion: ClusterExpansion, 

95 structure: Atoms, 

96 solver_name: str = None, 

97 verbose: bool = True) -> None: 

98 if not mip_loaded: 

99 raise ImportError('Python-MIP (https://python-mip.readthedocs.io/en/latest/) is ' 

100 'required in order to use the ground state finder.') 

101 # Check that there is only one active sublattice 

102 self._cluster_expansion = cluster_expansion 

103 self._fractional_position_tolerance = cluster_expansion.fractional_position_tolerance 

104 self.structure = structure 

105 cluster_space = self._cluster_expansion.get_cluster_space_copy() 

106 primitive_structure = cluster_space.primitive_structure 

107 self._active_sublattices = cluster_space.get_sublattices(structure).active_sublattices 

108 

109 # Check that there are no more than two allowed species 

110 active_species = [set(subl.chemical_symbols) for subl in self._active_sublattices] 

111 if any(len(species) > 2 for species in active_species): 

112 raise NotImplementedError('Currently, systems with more than two allowed species on ' 

113 'any sublattice are not supported.') 

114 

115 # Check that there are no merged orbits 

116 if any(['merge' in elem for elem in cluster_space._pruning_history]): 

117 raise NotImplementedError('Currently, systems with merged orbits are not supported.') 

118 

119 self._active_species = active_species 

120 

121 # Define cluster functions for elements 

122 self._reverse_id_maps = [] 

123 for species in active_species: 

124 for species_map in cluster_space.species_maps: 124 ↛ 123line 124 didn't jump to line 123 because the loop on line 124 didn't complete

125 symbols = [periodic_table[n] for n in species_map] 

126 if set(symbols) == species: 

127 reverse_id_map = {1 - species_map[n]: periodic_table[n] for n in species_map} 

128 self._reverse_id_maps.append(reverse_id_map) 

129 break 

130 self._count_symbols = [reverse_id_map[1] for reverse_id_map in self._reverse_id_maps] 

131 

132 # Generate full orbit list 

133 self._orbit_list = cluster_space.orbit_list 

134 lolg = LocalOrbitListGenerator( 

135 orbit_list=self._orbit_list, 

136 structure=Structure.from_atoms(primitive_structure), 

137 fractional_position_tolerance=self._fractional_position_tolerance) 

138 self._full_orbit_list = lolg.generate_full_orbit_list() 

139 

140 # Transform the parameters 

141 binary_parameters = transform_parameters(primitive_structure, 

142 self._full_orbit_list, 

143 self._cluster_expansion.parameters) 

144 self._transformed_parameters = binary_parameters 

145 

146 # Build model 

147 if solver_name is None: 

148 solver_name = '' 

149 self._model = self._build_model(structure, solver_name, verbose) 

150 

151 # Properties that are defined when searching for a ground state 

152 self._optimization_status = None 

153 

154 def _build_model(self, 

155 structure: Atoms, 

156 solver_name: str, 

157 verbose: bool) -> mip.Model: 

158 """ 

159 Build a Python-MIP model based on the provided structure. 

160 

161 Parameters 

162 ---------- 

163 structure 

164 Atomic configuration. 

165 solver_name 

166 ``'gurobi'``, ``'grb'`` or ``'cbc'``. Searches for available 

167 solvers if no value is provided. 

168 verbose 

169 If ``True`` print solver messages to stdout. 

170 """ 

171 

172 # Create cluster maps 

173 self._create_cluster_maps(structure) 

174 

175 # Initiate MIP model 

176 model = mip.Model('CE', solver_name=solver_name) 

177 model.solver.set_mip_gap(0) # avoid stopping prematurely 

178 model.solver.set_emphasis(2) # focus on finding optimal solution 

179 model.preprocess = 2 # maximum preprocessing 

180 

181 # Set verbosity 

182 model.verbose = int(verbose) 

183 

184 # Spin variables (remapped) for all atoms in the structure 

185 xs = {i: model.add_var(name='atom_{}'.format(i), var_type=BINARY) 

186 for subl in self._active_sublattices for i in subl.indices} 

187 ys = [model.add_var(name='cluster_{}'.format(i), var_type=BINARY) 

188 for i in range(len(self._cluster_to_orbit_map))] 

189 

190 # The objective function is added to 'model' first 

191 model.objective = mip.minimize(mip.xsum(self._get_total_energy(ys))) 

192 

193 # Connect cluster variables to spin variables with cluster constraints 

194 # TODO: don't create cluster constraints for singlets 

195 constraint_count = 0 

196 for i, cluster in enumerate(self._cluster_to_sites_map): 

197 orbit = self._cluster_to_orbit_map[i] 

198 parameter = self._transformed_parameters[orbit + 1] 

199 assert parameter != 0 

200 

201 if len(cluster) < 2 or parameter < 0: # no "downwards" pressure 

202 for atom in cluster: 

203 model.add_constr(ys[i] <= xs[atom], 

204 'Decoration -> cluster {}'.format(constraint_count)) 

205 constraint_count = constraint_count + 1 

206 

207 if len(cluster) < 2 or parameter > 0: # no "upwards" pressure 

208 model.add_constr(ys[i] >= 1 - len(cluster) + 

209 mip.xsum(xs[atom] 

210 for atom in cluster), 

211 'Decoration -> cluster {}'.format(constraint_count)) 

212 constraint_count = constraint_count + 1 

213 

214 for sym, subl in zip(self._count_symbols, self._active_sublattices): 

215 # Create slack variable 

216 slack = model.add_var(name='slackvar_{}'.format(sym), var_type=INTEGER, 

217 lb=0, ub=len(subl.indices)) 

218 

219 # Add slack constraint 

220 model.add_constr(slack <= -1, name='{} slack'.format(sym)) 

221 

222 # Set species constraint 

223 model.add_constr(mip.xsum([xs[i] for i in subl.indices]) + slack == -1, 

224 name='{} count'.format(sym)) 

225 

226 # Update the model so that variables and constraints can be queried 

227 if model.solver_name.upper() in ['GRB', 'GUROBI']: 227 ↛ 228line 227 didn't jump to line 228 because the condition on line 227 was never true

228 model.solver.update() 

229 return model 

230 

231 def _create_cluster_maps(self, structure: Atoms) -> None: 

232 """ 

233 Create maps that include information regarding which sites and orbits 

234 are associated with each cluster as well as the number of clusters per 

235 orbit. 

236 

237 Parameters 

238 ---------- 

239 structure 

240 Atomic configuration. 

241 """ 

242 # Generate full orbit list 

243 lolg = LocalOrbitListGenerator( 

244 orbit_list=self._orbit_list, 

245 structure=Structure.from_atoms(structure), 

246 fractional_position_tolerance=self._fractional_position_tolerance) 

247 full_orbit_list = lolg.generate_full_orbit_list() 

248 

249 # Create maps of site indices and orbits for all clusters 

250 cluster_to_sites_map = [] 

251 cluster_to_orbit_map = [] 

252 for orb_index in range(len(full_orbit_list)): 

253 

254 clusters = full_orbit_list.get_orbit(orb_index).clusters 

255 

256 # Determine the sites and the orbit associated with each cluster 

257 for cluster in clusters: 

258 

259 # Do not include clusters for which the parameter is 0 

260 parameter = self._transformed_parameters[orb_index + 1] 

261 if parameter == 0: 

262 continue 

263 

264 # Add the the list of sites and the orbit to the respective cluster maps 

265 cluster_sites = [site.index for site in cluster.lattice_sites] 

266 cluster_to_sites_map.append(cluster_sites) 

267 cluster_to_orbit_map.append(orb_index) 

268 

269 # calculate the number of clusters per orbit 

270 nclusters_per_orbit = [cluster_to_orbit_map.count(i) for i in 

271 range(cluster_to_orbit_map[-1] + 1)] 

272 nclusters_per_orbit = [1] + nclusters_per_orbit 

273 

274 self._cluster_to_sites_map = cluster_to_sites_map 

275 self._cluster_to_orbit_map = cluster_to_orbit_map 

276 self._nclusters_per_orbit = nclusters_per_orbit 

277 

278 def _get_total_energy(self, cluster_instance_activities: list[int]) -> list[float]: 

279 r""" 

280 Calculates the total energy using the expression based on binary variables. 

281 

282 .. math:: 

283 

284 H({\boldsymbol x}, {\boldsymbol E})=E_0+ 

285 \sum\limits_j\sum\limits_{{\boldsymbol c} 

286 \in{\boldsymbol C}_j}E_jy_{{\boldsymbol c}}, 

287 

288 where (:math:`y_{{\boldsymbol c}}= 

289 \prod\limits_{i\in{\boldsymbol c}}x_i`). 

290 

291 Parameters 

292 ---------- 

293 cluster_instance_activities 

294 list of cluster instance activities, (:math:`y_{{\boldsymbol c}}`) 

295 """ 

296 

297 E = [0.0 for _ in self._transformed_parameters] 

298 for i in range(len(cluster_instance_activities)): 

299 orbit = self._cluster_to_orbit_map[i] 

300 E[orbit + 1] = E[orbit + 1] + cluster_instance_activities[i] 

301 E[0] = 1 

302 

303 E = [0.0 if np.isclose(self._transformed_parameters[orbit], 0.0) else 

304 E[orbit] * self._transformed_parameters[orbit] / self._nclusters_per_orbit[orbit] 

305 for orbit in range(len(self._transformed_parameters))] 

306 return E 

307 

308 def get_ground_state(self, 

309 species_count: dict[str, int] = None, 

310 max_seconds: float = inf, 

311 threads: int = 0) -> Atoms: 

312 """Finds the ground state for a given structure and species count. If 

313 :attr:`species_count` is not provided when initializing the 

314 instance of this class the first species in the list of 

315 chemical symbols for the active sublattice will be used. 

316 

317 Parameters 

318 ---------- 

319 species_count 

320 Dictionary with count for one of the species on each active 

321 sublattice. If no count is provided for a sublattice, the 

322 concentration is allowed to vary. 

323 max_seconds 

324 Maximum runtime in seconds. 

325 threads 

326 Number of threads to be used when solving the problem, given that a 

327 positive integer has been provided. If set to :math:`0` the solver default 

328 configuration is used while :math:`-1` corresponds to all available 

329 processing cores. 

330 

331 """ 

332 if species_count is None: 

333 species_count = {} 

334 

335 # Check that the species_count is consistent with the cluster space 

336 all_active_species = set.union(*self._active_species) 

337 for symbol in species_count: 

338 if symbol not in all_active_species: 

339 raise ValueError('The species {} is not present on any of the active sublattices' 

340 ' ({})'.format(symbol, self._active_species)) 

341 

342 # The model is solved using python-MIPs choice of solver, which is 

343 # Gurobi, if available, and COIN-OR Branch-and-Cut, otherwise. 

344 model = self._model 

345 

346 # Update the species counts 

347 for i, species in enumerate(self._active_species): 

348 count_symbol = self._count_symbols[i] 

349 max_count = len(self._active_sublattices[i].indices) 

350 

351 symbols_to_add = set.intersection(set(species_count), set(species)) 

352 if len(symbols_to_add) > 1: 

353 raise ValueError('Provide counts for at most one of the species on each active ' 

354 'sublattice ({}), not {}!'.format(self._active_species, 

355 list(species_count))) 

356 elif len(symbols_to_add) == 1: 

357 sym = symbols_to_add.pop() 

358 count = species_count[sym] 

359 if count < 0 or count > max_count: 

360 raise ValueError('The count for species {} ({}) must be a positive integer and' 

361 ' cannot exceed the number of sites on the active sublattice ' 

362 '({})'.format(sym, count, max_count)) 

363 if sym == count_symbol: 

364 xcount = count 

365 else: 

366 xcount = max_count - count 

367 

368 max_slack = 0 

369 else: 

370 xcount = max_slack = max_count 

371 

372 model.constr_by_name('{} count'.format(count_symbol)).rhs = xcount 

373 model.constr_by_name('{} slack'.format(count_symbol)).rhs = max_slack 

374 

375 # Set the number of threads 

376 model.threads = threads 

377 

378 # Optimize the model 

379 self._optimization_status = model.optimize(max_seconds=max_seconds) 

380 

381 # The status of the solution is printed to the screen 

382 if str(self._optimization_status) != 'OptimizationStatus.OPTIMAL': 382 ↛ 383line 382 didn't jump to line 383 because the condition on line 382 was never true

383 if str(self._optimization_status) == 'OptimizationStatus.FEASIBLE': 

384 logger.warning('Solution optimality not proven.') 

385 else: 

386 raise Exception('Optimization failed ({0})'.format(str(self._optimization_status))) 

387 

388 # Each of the variables is printed with it's resolved optimum value 

389 gs = self.structure.copy() 

390 

391 active_index_to_sublattice_map = {i: j for j, subl in enumerate(self._active_sublattices) 

392 for i in subl.indices} 

393 for v in model.vars: 

394 if 'atom' in v.name: 

395 index = int(v.name.split('_')[-1]) 

396 sublattice_index = active_index_to_sublattice_map[index] 

397 gs[index].symbol = self._reverse_id_maps[sublattice_index][int(v.x)] 

398 

399 # Assert that the solution agrees with the prediction 

400 prediction = self._cluster_expansion.predict(gs) 

401 assert abs(model.objective_value - prediction) < 1e-6 

402 return gs 

403 

404 @property 

405 def optimization_status(self) -> mip.OptimizationStatus: 

406 """ Optimization status. """ 

407 return self._optimization_status 

408 

409 @property 

410 def model(self) -> mip.Model: 

411 """ Python-MIP model. """ 

412 return self._model.copy()