Coverage for icet/tools/ground_state_finder.py: 97%

148 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-12-26 04:12 +0000

1 

2from math import inf 

3import numpy as np 

4from typing import List, Dict 

5 

6from ase import Atoms 

7from ase.data import chemical_symbols as periodic_table 

8from .. import ClusterExpansion 

9from ..core.local_orbit_list_generator import LocalOrbitListGenerator 

10from ..core.structure import Structure 

11from .variable_transformation import transform_parameters 

12from ..input_output.logging_tools import logger 

13 

14try: 

15 import mip 

16 from mip.constants import BINARY, INTEGER 

17except ImportError: 

18 raise ImportError('Python-MIP (https://python-mip.readthedocs.io/en/latest/) is required in ' 

19 'order to use the ground state finder.') 

20 

21 

22class GroundStateFinder: 

23 """ 

24 This class provides functionality for determining the ground states 

25 using a binary cluster expansion. This is efficiently achieved through the 

26 use of mixed integer programming (MIP) as developed by Larsen *et al.* in 

27 `Phys. Rev. Lett. 120, 256101 (2018) 

28 <https://doi.org/10.1103/PhysRevLett.120.256101>`_. 

29 

30 This class relies on the `Python-MIP package 

31 <https://python-mip.readthedocs.io>`_. Python-MIP can be used together 

32 with `Gurobi <https://www.gurobi.com/>`_, which is not open source 

33 but issues academic licenses free of charge. Pleaase note that 

34 Gurobi needs to be installed separately. The :class:`GroundStateFinder` works 

35 also without Gurobi, but if performance is critical, Gurobi is highly 

36 recommended. 

37 

38 Warning 

39 ------- 

40 In order to be able to use Gurobi with python-mip one must ensure that 

41 `GUROBI_HOME` should point to the installation directory 

42 (``<installdir>``):: 

43 

44 export GUROBI_HOME=<installdir> 

45 

46 Note 

47 ---- 

48 The current implementation only works for binary systems. 

49 

50 

51 Parameters 

52 ---------- 

53 cluster_expansion 

54 Cluster expansion for which to find ground states. 

55 structure 

56 Atomic configuration. 

57 solver_name 

58 ``'gurobi'``, ``'grb'`` or ``'cbc'``. Searches for available 

59 solvers if no value is provided. 

60 verbose 

61 If ``True`` print solver messages to stdout. 

62 

63 Example 

64 ------- 

65 The following snippet illustrates how to determine the ground state for a 

66 Au-Ag alloy. Here, the parameters of the cluster 

67 expansion are set to emulate a simple Ising model in order to obtain an 

68 example that can be run without modification. In practice, one should of 

69 course use a proper cluster expansion:: 

70 

71 >>> from ase.build import bulk 

72 >>> from icet import ClusterExpansion, ClusterSpace 

73 

74 >>> # prepare cluster expansion 

75 >>> # the setup emulates a second nearest-neighbor (NN) Ising model 

76 >>> # (zerolet and singlet parameters are zero; only first and second neighbor 

77 >>> # pairs are included) 

78 >>> prim = bulk('Au') 

79 >>> chemical_symbols = ['Ag', 'Au'] 

80 >>> cs = ClusterSpace(prim, cutoffs=[4.3], chemical_symbols=chemical_symbols) 

81 >>> ce = ClusterExpansion(cs, [0, 0, 0.1, -0.02]) 

82 

83 >>> # prepare initial configuration 

84 >>> structure = prim.repeat(3) 

85 

86 >>> # set up the ground state finder and calculate the ground state energy 

87 >>> gsf = GroundStateFinder(ce, structure) 

88 >>> ground_state = gsf.get_ground_state({'Ag': 5}) 

89 >>> print('Ground state energy:', ce.predict(ground_state)) 

90 """ 

91 

92 def __init__(self, 

93 cluster_expansion: ClusterExpansion, 

94 structure: Atoms, 

95 solver_name: str = None, 

96 verbose: bool = True) -> None: 

97 # Check that there is only one active sublattice 

98 self._cluster_expansion = cluster_expansion 

99 self._fractional_position_tolerance = cluster_expansion.fractional_position_tolerance 

100 self.structure = structure 

101 cluster_space = self._cluster_expansion.get_cluster_space_copy() 

102 primitive_structure = cluster_space.primitive_structure 

103 self._active_sublattices = cluster_space.get_sublattices(structure).active_sublattices 

104 

105 # Check that there are no more than two allowed species 

106 active_species = [set(subl.chemical_symbols) for subl in self._active_sublattices] 

107 if any(len(species) > 2 for species in active_species): 

108 raise NotImplementedError('Currently, systems with more than two allowed species on ' 

109 'any sublattice are not supported.') 

110 

111 # Check that there are no merged orbits 

112 if any(['merge' in elem for elem in cluster_space._pruning_history]): 

113 raise NotImplementedError('Currently, systems with merged orbits are not supported.') 

114 

115 self._active_species = active_species 

116 

117 # Define cluster functions for elements 

118 self._reverse_id_maps = [] 

119 for species in active_species: 

120 for species_map in cluster_space.species_maps: 120 ↛ 119line 120 didn't jump to line 119, because the loop on line 120 didn't complete

121 symbols = [periodic_table[n] for n in species_map] 

122 if set(symbols) == species: 

123 reverse_id_map = {1 - species_map[n]: periodic_table[n] for n in species_map} 

124 self._reverse_id_maps.append(reverse_id_map) 

125 break 

126 self._count_symbols = [reverse_id_map[1] for reverse_id_map in self._reverse_id_maps] 

127 

128 # Generate full orbit list 

129 self._orbit_list = cluster_space.orbit_list 

130 lolg = LocalOrbitListGenerator( 

131 orbit_list=self._orbit_list, 

132 structure=Structure.from_atoms(primitive_structure), 

133 fractional_position_tolerance=self._fractional_position_tolerance) 

134 self._full_orbit_list = lolg.generate_full_orbit_list() 

135 

136 # Transform the parameters 

137 binary_parameters = transform_parameters(primitive_structure, 

138 self._full_orbit_list, 

139 self._cluster_expansion.parameters) 

140 self._transformed_parameters = binary_parameters 

141 

142 # Build model 

143 if solver_name is None: 

144 solver_name = '' 

145 self._model = self._build_model(structure, solver_name, verbose) 

146 

147 # Properties that are defined when searching for a ground state 

148 self._optimization_status = None 

149 

150 def _build_model(self, 

151 structure: Atoms, 

152 solver_name: str, 

153 verbose: bool) -> mip.Model: 

154 """ 

155 Build a Python-MIP model based on the provided structure. 

156 

157 Parameters 

158 ---------- 

159 structure 

160 Atomic configuration. 

161 solver_name 

162 ``'gurobi'``, ``'grb'`` or ``'cbc'``. Searches for available 

163 solvers if no value is provided. 

164 verbose 

165 If ``True`` print solver messages to stdout. 

166 """ 

167 

168 # Create cluster maps 

169 self._create_cluster_maps(structure) 

170 

171 # Initiate MIP model 

172 model = mip.Model('CE', solver_name=solver_name) 

173 model.solver.set_mip_gap(0) # avoid stopping prematurely 

174 model.solver.set_emphasis(2) # focus on finding optimal solution 

175 model.preprocess = 2 # maximum preprocessing 

176 

177 # Set verbosity 

178 model.verbose = int(verbose) 

179 

180 # Spin variables (remapped) for all atoms in the structure 

181 xs = {i: model.add_var(name='atom_{}'.format(i), var_type=BINARY) 

182 for subl in self._active_sublattices for i in subl.indices} 

183 ys = [model.add_var(name='cluster_{}'.format(i), var_type=BINARY) 

184 for i in range(len(self._cluster_to_orbit_map))] 

185 

186 # The objective function is added to 'model' first 

187 model.objective = mip.minimize(mip.xsum(self._get_total_energy(ys))) 

188 

189 # Connect cluster variables to spin variables with cluster constraints 

190 # TODO: don't create cluster constraints for singlets 

191 constraint_count = 0 

192 for i, cluster in enumerate(self._cluster_to_sites_map): 

193 orbit = self._cluster_to_orbit_map[i] 

194 parameter = self._transformed_parameters[orbit + 1] 

195 assert parameter != 0 

196 

197 if len(cluster) < 2 or parameter < 0: # no "downwards" pressure 

198 for atom in cluster: 

199 model.add_constr(ys[i] <= xs[atom], 

200 'Decoration -> cluster {}'.format(constraint_count)) 

201 constraint_count = constraint_count + 1 

202 

203 if len(cluster) < 2 or parameter > 0: # no "upwards" pressure 

204 model.add_constr(ys[i] >= 1 - len(cluster) + 

205 mip.xsum(xs[atom] 

206 for atom in cluster), 

207 'Decoration -> cluster {}'.format(constraint_count)) 

208 constraint_count = constraint_count + 1 

209 

210 for sym, subl in zip(self._count_symbols, self._active_sublattices): 

211 # Create slack variable 

212 slack = model.add_var(name='slackvar_{}'.format(sym), var_type=INTEGER, 

213 lb=0, ub=len(subl.indices)) 

214 

215 # Add slack constraint 

216 model.add_constr(slack <= -1, name='{} slack'.format(sym)) 

217 

218 # Set species constraint 

219 model.add_constr(mip.xsum([xs[i] for i in subl.indices]) + slack == -1, 

220 name='{} count'.format(sym)) 

221 

222 # Update the model so that variables and constraints can be queried 

223 if model.solver_name.upper() in ['GRB', 'GUROBI']: 223 ↛ 224line 223 didn't jump to line 224, because the condition on line 223 was never true

224 model.solver.update() 

225 return model 

226 

227 def _create_cluster_maps(self, structure: Atoms) -> None: 

228 """ 

229 Create maps that include information regarding which sites and orbits 

230 are associated with each cluster as well as the number of clusters per 

231 orbit. 

232 

233 Parameters 

234 ---------- 

235 structure 

236 Atomic configuration. 

237 """ 

238 # Generate full orbit list 

239 lolg = LocalOrbitListGenerator( 

240 orbit_list=self._orbit_list, 

241 structure=Structure.from_atoms(structure), 

242 fractional_position_tolerance=self._fractional_position_tolerance) 

243 full_orbit_list = lolg.generate_full_orbit_list() 

244 

245 # Create maps of site indices and orbits for all clusters 

246 cluster_to_sites_map = [] 

247 cluster_to_orbit_map = [] 

248 for orb_index in range(len(full_orbit_list)): 

249 

250 clusters = full_orbit_list.get_orbit(orb_index).clusters 

251 

252 # Determine the sites and the orbit associated with each cluster 

253 for cluster in clusters: 

254 

255 # Do not include clusters for which the parameter is 0 

256 parameter = self._transformed_parameters[orb_index + 1] 

257 if parameter == 0: 

258 continue 

259 

260 # Add the the list of sites and the orbit to the respective cluster maps 

261 cluster_sites = [site.index for site in cluster.lattice_sites] 

262 cluster_to_sites_map.append(cluster_sites) 

263 cluster_to_orbit_map.append(orb_index) 

264 

265 # calculate the number of clusters per orbit 

266 nclusters_per_orbit = [cluster_to_orbit_map.count(i) for i in 

267 range(cluster_to_orbit_map[-1] + 1)] 

268 nclusters_per_orbit = [1] + nclusters_per_orbit 

269 

270 self._cluster_to_sites_map = cluster_to_sites_map 

271 self._cluster_to_orbit_map = cluster_to_orbit_map 

272 self._nclusters_per_orbit = nclusters_per_orbit 

273 

274 def _get_total_energy(self, cluster_instance_activities: List[int]) -> List[float]: 

275 r""" 

276 Calculates the total energy using the expression based on binary variables. 

277 

278 .. math:: 

279 

280 H({\boldsymbol x}, {\boldsymbol E})=E_0+ 

281 \sum\limits_j\sum\limits_{{\boldsymbol c} 

282 \in{\boldsymbol C}_j}E_jy_{{\boldsymbol c}}, 

283 

284 where (:math:`y_{{\boldsymbol c}}= 

285 \prod\limits_{i\in{\boldsymbol c}}x_i`). 

286 

287 Parameters 

288 ---------- 

289 cluster_instance_activities 

290 list of cluster instance activities, (:math:`y_{{\boldsymbol c}}`) 

291 """ 

292 

293 E = [0.0 for _ in self._transformed_parameters] 

294 for i in range(len(cluster_instance_activities)): 

295 orbit = self._cluster_to_orbit_map[i] 

296 E[orbit + 1] = E[orbit + 1] + cluster_instance_activities[i] 

297 E[0] = 1 

298 

299 E = [0.0 if np.isclose(self._transformed_parameters[orbit], 0.0) else 

300 E[orbit] * self._transformed_parameters[orbit] / self._nclusters_per_orbit[orbit] 

301 for orbit in range(len(self._transformed_parameters))] 

302 return E 

303 

304 def get_ground_state(self, 

305 species_count: Dict[str, int] = None, 

306 max_seconds: float = inf, 

307 threads: int = 0) -> Atoms: 

308 """Finds the ground state for a given structure and species count. If 

309 :attr:`species_count` is not provided when initializing the 

310 instance of this class the first species in the list of 

311 chemical symbols for the active sublattice will be used. 

312 

313 Parameters 

314 ---------- 

315 species_count 

316 Dictionary with count for one of the species on each active 

317 sublattice. If no count is provided for a sublattice, the 

318 concentration is allowed to vary. 

319 max_seconds 

320 Maximum runtime in seconds. 

321 threads 

322 Number of threads to be used when solving the problem, given that a 

323 positive integer has been provided. If set to :math:`0` the solver default 

324 configuration is used while :math:`-1` corresponds to all available 

325 processing cores. 

326 

327 """ 

328 if species_count is None: 

329 species_count = {} 

330 

331 # Check that the species_count is consistent with the cluster space 

332 all_active_species = set.union(*self._active_species) 

333 for symbol in species_count: 

334 if symbol not in all_active_species: 

335 raise ValueError('The species {} is not present on any of the active sublattices' 

336 ' ({})'.format(symbol, self._active_species)) 

337 

338 # The model is solved using python-MIPs choice of solver, which is 

339 # Gurobi, if available, and COIN-OR Branch-and-Cut, otherwise. 

340 model = self._model 

341 

342 # Update the species counts 

343 for i, species in enumerate(self._active_species): 

344 count_symbol = self._count_symbols[i] 

345 max_count = len(self._active_sublattices[i].indices) 

346 

347 symbols_to_add = set.intersection(set(species_count), set(species)) 

348 if len(symbols_to_add) > 1: 

349 raise ValueError('Provide counts for at most one of the species on each active ' 

350 'sublattice ({}), not {}!'.format(self._active_species, 

351 list(species_count))) 

352 elif len(symbols_to_add) == 1: 

353 sym = symbols_to_add.pop() 

354 count = species_count[sym] 

355 if count < 0 or count > max_count: 

356 raise ValueError('The count for species {} ({}) must be a positive integer and' 

357 ' cannot exceed the number of sites on the active sublattice ' 

358 '({})'.format(sym, count, max_count)) 

359 if sym == count_symbol: 

360 xcount = count 

361 else: 

362 xcount = max_count - count 

363 

364 max_slack = 0 

365 else: 

366 xcount = max_slack = max_count 

367 

368 model.constr_by_name('{} count'.format(count_symbol)).rhs = xcount 

369 model.constr_by_name('{} slack'.format(count_symbol)).rhs = max_slack 

370 

371 # Set the number of threads 

372 model.threads = threads 

373 

374 # Optimize the model 

375 self._optimization_status = model.optimize(max_seconds=max_seconds) 

376 

377 # The status of the solution is printed to the screen 

378 if str(self._optimization_status) != 'OptimizationStatus.OPTIMAL': 378 ↛ 379line 378 didn't jump to line 379, because the condition on line 378 was never true

379 if str(self._optimization_status) == 'OptimizationStatus.FEASIBLE': 

380 logger.warning('Solution optimality not proven.') 

381 else: 

382 raise Exception('Optimization failed ({0})'.format(str(self._optimization_status))) 

383 

384 # Each of the variables is printed with it's resolved optimum value 

385 gs = self.structure.copy() 

386 

387 active_index_to_sublattice_map = {i: j for j, subl in enumerate(self._active_sublattices) 

388 for i in subl.indices} 

389 for v in model.vars: 

390 if 'atom' in v.name: 

391 index = int(v.name.split('_')[-1]) 

392 sublattice_index = active_index_to_sublattice_map[index] 

393 gs[index].symbol = self._reverse_id_maps[sublattice_index][int(v.x)] 

394 

395 # Assert that the solution agrees with the prediction 

396 prediction = self._cluster_expansion.predict(gs) 

397 assert abs(model.objective_value - prediction) < 1e-6 

398 return gs 

399 

400 @property 

401 def optimization_status(self) -> mip.OptimizationStatus: 

402 """ Optimization status. """ 

403 return self._optimization_status 

404 

405 @property 

406 def model(self) -> mip.Model: 

407 """ Python-MIP model. """ 

408 return self._model.copy()