Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Definition of the abstract base ensemble class.""" 

2 

3import os 

4import random 

5import warnings 

6 

7from abc import ABC, abstractmethod 

8from collections import OrderedDict 

9from math import gcd 

10from time import time 

11from typing import Any, Dict, List, Optional, Type, Union 

12 

13import numpy as np 

14 

15from ase import Atoms 

16from icet.core.sublattices import Sublattices 

17 

18from ..calculators.base_calculator import BaseCalculator 

19from ..configuration_manager import ConfigurationManager 

20from ..data_containers.base_data_container import BaseDataContainer 

21from ..observers.base_observer import BaseObserver 

22 

23 

24class BaseEnsemble(ABC): 

25 """Base ensemble class. 

26 

27 Parameters 

28 ---------- 

29 structure : :class:`Atoms <ase.Atoms>` 

30 atomic configuration to be used in the Monte Carlo simulation; 

31 also defines the initial occupation vector 

32 calculator : :class:`BaseCalculator <mchammer.calculators.ClusterExpansionCalculator>` 

33 calculator to be used for calculating the potential changes 

34 that enter the evaluation of the Metropolis criterion 

35 user_tag : str 

36 human-readable tag for ensemble [default: None] 

37 random_seed : int 

38 seed for the random number generator used in the Monte Carlo 

39 simulation 

40 dc_filename : str 

41 name of file the data container associated with the ensemble 

42 will be written to; if the file exists it will be read, the 

43 data container will be appended, and the file will be 

44 updated/overwritten 

45 data_container_class : BaseDataContainer 

46 used to initialize custom (ensemble specific) data container objects; 

47 by default the class uses the generic BaseDataContainer class 

48 data_container_write_period : float 

49 period in units of seconds at which the data container is 

50 written to file; writing periodically to file provides both 

51 a way to examine the progress of the simulation and to back up 

52 the data. 

53 ensemble_data_write_interval : int 

54 interval at which data is written to the data container; this 

55 includes for example the current value of the calculator 

56 (i.e. usually the energy) as well as ensembles specific fields 

57 such as temperature or the number of atoms of different species 

58 trajectory_write_interval : int 

59 interval at which the current occupation vector of the atomic 

60 configuration is written to the data container. 

61 """ 

62 

63 def __init__(self, 

64 structure: Atoms, 

65 calculator: BaseCalculator, 

66 user_tag: str = None, 

67 random_seed: int = None, 

68 dc_filename: str = None, 

69 data_container: str = None, 

70 data_container_class: Type[BaseDataContainer] = BaseDataContainer, 

71 data_container_write_period: float = 600, 

72 ensemble_data_write_interval: int = None, 

73 trajectory_write_interval: int = None) -> None: 

74 

75 # initialize basic variables 

76 self._accepted_trials = 0 

77 self._observers = {} # type: Dict[str, BaseObserver] 

78 self._step = 0 

79 

80 # calculator and configuration 

81 self._calculator = calculator 

82 self._user_tag = user_tag 

83 sublattices = self.calculator.sublattices 

84 

85 sublattices.assert_occupation_is_allowed(structure.get_chemical_symbols()) 

86 

87 # item for sublist in l for item in sublist 

88 symbols_flat = [s for sl in sublattices.active_sublattices for s in sl.chemical_symbols] 

89 if len(symbols_flat) != len(set(symbols_flat)): 

90 bad_symbols = set([s for s in symbols_flat if symbols_flat.count(s) > 1]) 

91 raise ValueError('Symbols {} found on multiple active sublattices'.format(bad_symbols)) 

92 

93 self.configuration = ConfigurationManager(structure, sublattices) 

94 

95 # random number generator 

96 if random_seed is None: 

97 self._random_seed = random.randint(0, int(1e16)) 

98 else: 

99 self._random_seed = random_seed 

100 random.seed(a=self._random_seed) 

101 

102 # add ensemble parameters and metadata 

103 if not self._ensemble_parameters: 103 ↛ 104line 103 didn't jump to line 104, because the condition on line 103 was never true

104 self._ensemble_parameters = {} # type: Dict[str, Any] 

105 self._ensemble_parameters['n_atoms'] = len(self.structure) 

106 metadata = OrderedDict(ensemble_name=self.__class__.__name__, 

107 user_tag=user_tag, seed=self.random_seed) 

108 

109 # data container 

110 self._data_container_write_period = data_container_write_period 

111 if data_container is not None: 111 ↛ 112line 111 didn't jump to line 112, because the condition on line 111 was never true

112 warnings.simplefilter('always', DeprecationWarning) 

113 warnings.warn('data_container is deprecated, use dc_filename', DeprecationWarning) 

114 self._data_container_filename = data_container 

115 else: 

116 self._data_container_filename = dc_filename 

117 

118 if dc_filename is not None and os.path.isfile(dc_filename): 

119 self._data_container = data_container_class.read(dc_filename) # type: BaseDataContainer 

120 

121 dc_ensemble_parameters = self.data_container.ensemble_parameters 

122 if not dicts_equal(self.ensemble_parameters, 

123 dc_ensemble_parameters): 

124 raise ValueError('Ensemble parameters do not match those' 

125 ' stored in data container file: {}'.format( 

126 set(dc_ensemble_parameters.items()) - 

127 set(self.ensemble_parameters.items()))) 

128 self._restart_ensemble() 

129 else: 

130 if dc_filename is not None: 

131 # check if path to file exists 

132 filedir = os.path.dirname(dc_filename) 

133 if filedir and not os.path.isdir(filedir): 

134 raise FileNotFoundError('Path to data container file does' 

135 ' not exist: {}'.format(filedir)) 

136 self._data_container = data_container_class( 

137 structure=structure, 

138 ensemble_parameters=self.ensemble_parameters, 

139 metadata=metadata) 

140 

141 # interval for writing data and further preparation of data container 

142 self._default_interval = len(structure) 

143 

144 if ensemble_data_write_interval is None: 

145 self._ensemble_data_write_interval = self._default_interval 

146 else: 

147 self._ensemble_data_write_interval = ensemble_data_write_interval 

148 

149 # Handle trajectory writing 

150 if trajectory_write_interval is None: 

151 self._trajectory_write_interval = self._default_interval 

152 else: 

153 self._trajectory_write_interval = trajectory_write_interval 

154 

155 self._find_observer_interval() 

156 

157 @property 

158 def structure(self) -> Atoms: 

159 """ current configuration (copy) """ 

160 return self.configuration.structure 

161 

162 @property 

163 def data_container(self) -> BaseDataContainer: 

164 """ data container associated with ensemble """ 

165 return self._data_container 

166 

167 @property 

168 def observers(self) -> Dict[str, BaseObserver]: 

169 """ observers """ 

170 return self._observers 

171 

172 @property 

173 def calculator(self) -> BaseCalculator: 

174 """ calculator attached to the ensemble """ 

175 return self._calculator 

176 

177 @property 

178 def step(self) -> int: 

179 """ current trial step counter """ 

180 return self._step 

181 

182 def run(self, number_of_trial_steps: int): 

183 """ 

184 Samples the ensemble for the given number of trial steps. 

185 

186 Parameters 

187 ---------- 

188 number_of_trial_steps 

189 number of MC trial steps to run in total 

190 reset_step 

191 if True the MC trial step counter and the data container will 

192 be reset to zero and empty, respectively. 

193 

194 Raises 

195 ------ 

196 TypeError 

197 if `number_of_trial_steps` is not an int 

198 """ 

199 

200 if not isinstance(number_of_trial_steps, int): 200 ↛ 201line 200 didn't jump to line 201, because the condition on line 200 was never true

201 raise TypeError('number_of_trial_steps must be an integer ({})' 

202 .format(number_of_trial_steps)) 

203 

204 last_write_time = time() 

205 

206 initial_step = self.step 

207 final_step = self.step + number_of_trial_steps 

208 # run Monte Carlo simulation such that we start at an 

209 # interval which lands on the observer interval 

210 if initial_step != 0: 

211 first_run_interval = self.observer_interval -\ 

212 (initial_step - 

213 (initial_step // self.observer_interval) * 

214 self.observer_interval) 

215 first_run_interval = min(first_run_interval, number_of_trial_steps) 

216 self._run(first_run_interval) 

217 initial_step += first_run_interval 

218 

219 step = initial_step 

220 while step < final_step and not self._terminate_sampling(): 

221 uninterrupted_steps = min(self.observer_interval, final_step - step) 

222 if self.step % self.observer_interval == 0: 222 ↛ 224line 222 didn't jump to line 224, because the condition on line 222 was never false

223 self._observe(self.step) 

224 if self._data_container_filename is not None and \ 224 ↛ 226line 224 didn't jump to line 226, because the condition on line 224 was never true

225 time() - last_write_time > self._data_container_write_period: 

226 self.write_data_container(self._data_container_filename) 

227 last_write_time = time() 

228 

229 self._run(uninterrupted_steps) 

230 step += uninterrupted_steps 

231 

232 # if we end on an observation interval we also observe 

233 if self.step % self.observer_interval == 0: 

234 self._observe(self.step) 

235 

236 # allow ensemble a chance to go clean 

237 self._finalize() 

238 

239 if self._data_container_filename is not None: 

240 self.write_data_container(self._data_container_filename) 

241 

242 def _run(self, number_of_trial_steps: int): 

243 """Runs MC simulation for a number of trial steps without 

244 interruption. 

245 

246 Parameters 

247 ---------- 

248 number_of_trial_steps 

249 number of trial steps to run without stopping 

250 """ 

251 for _ in range(number_of_trial_steps): 

252 accepted = self._do_trial_step() 

253 self._step += 1 

254 self._accepted_trials += accepted 

255 

256 def _observe(self, step: int): 

257 """Submits current configuration to observers and appends 

258 observations to data container. 

259 

260 Parameters 

261 ---------- 

262 step 

263 the current trial step 

264 """ 

265 row_dict = {} 

266 

267 # Ensemble specific data 

268 if step % self._ensemble_data_write_interval == 0: 

269 ensemble_data = self._get_ensemble_data() 

270 for key, value in ensemble_data.items(): 

271 row_dict[key] = value 

272 

273 # reset accepted trial count 

274 self._accepted_trials = 0 

275 

276 # Trajectory data 

277 if step % self._trajectory_write_interval == 0: 

278 row_dict['occupations'] = self.configuration.occupations.tolist() 

279 

280 # Observer data 

281 for observer in self.observers.values(): 

282 assert isinstance(observer.interval, int), 'interval is not an int' 

283 if step % observer.interval == 0: 

284 if observer.return_type is dict: 

285 for key, val in observer.get_observable(self.configuration.structure).items(): 

286 row_dict[key] = val 

287 else: 

288 row_dict[observer.tag] = observer.get_observable(self.configuration.structure) 

289 

290 if len(row_dict) > 0: 

291 self._data_container.append(mctrial=step, record=row_dict) 

292 

293 @abstractmethod 

294 def _do_trial_step(self): 

295 pass 

296 

297 @property 

298 def user_tag(self) -> Optional[str]: 

299 """ tag used for labeling the ensemble """ 

300 return self._user_tag 

301 

302 @property 

303 def random_seed(self) -> int: 

304 """ seed used to initialize random number generator """ 

305 return self._random_seed 

306 

307 def _next_random_number(self) -> float: 

308 """ Returns the next random number from the PRNG. """ 

309 return random.random() 

310 

311 @property 

312 def observer_interval(self) -> int: 

313 """minimum number of steps to run Monte Carlo simulation without 

314 interruption for observation 

315 """ 

316 return self._observer_interval 

317 

318 def _find_observer_interval(self) -> None: 

319 """ 

320 Finds the greatest common denominator from the observation intervals. 

321 """ 

322 intervals = [obs.interval for obs in self.observers.values()] 

323 

324 if self._ensemble_data_write_interval is not np.inf: 324 ↛ 326line 324 didn't jump to line 326, because the condition on line 324 was never false

325 intervals.append(self._ensemble_data_write_interval) 

326 if self._trajectory_write_interval is not np.inf: 326 ↛ 328line 326 didn't jump to line 328, because the condition on line 326 was never false

327 intervals.append(self._trajectory_write_interval) 

328 if intervals: 328 ↛ exitline 328 didn't return from function '_find_observer_interval', because the condition on line 328 was never false

329 assert all([isinstance(k, int) for k in intervals]), 'intervals must be ints' 

330 self._observer_interval = self._get_gcd(intervals) 

331 

332 def _get_gcd(self, values: List[int]) -> int: 

333 """ Finds the greatest common denominator (GCD) from a list of integers. """ 

334 if len(values) == 1: 

335 return values[0] 

336 

337 if len(values) > 2: 

338 gcd_right = gcd(values[-1], values[-2]) 

339 values.pop() 

340 values.pop() 

341 values.append(gcd_right) 

342 return self._get_gcd(values) 

343 else: 

344 return gcd(values[0], values[1]) 

345 

346 def attach_observer(self, observer: BaseObserver, tag=None): 

347 """ 

348 Attaches an observer to the ensemble. 

349 

350 If the observer does not have an observation interval, 

351 then it will be set to the default_interval len(atoms). 

352 

353 Parameters 

354 ---------- 

355 observer 

356 observer instance to attach 

357 tag 

358 name used in data container 

359 """ 

360 if not isinstance(observer, BaseObserver): 

361 raise TypeError('observer has the wrong type: {}'.format(type(observer))) 

362 

363 if observer.interval is None: 363 ↛ 364line 363 didn't jump to line 364, because the condition on line 363 was never true

364 observer.interval = self._default_interval 

365 

366 if tag is not None: 

367 observer.tag = tag 

368 self.observers[tag] = observer 

369 else: 

370 self.observers[observer.tag] = observer 

371 

372 self._find_observer_interval() 

373 

374 def update_occupations(self, sites: List[int], species: List[int]): 

375 """Updates the occupation vector of the configuration being 

376 sampled. This will change the state of the configuration in 

377 both the calculator and the configuration manager. 

378 

379 Parameters 

380 ---------- 

381 sites 

382 indices of sites of the configuration to change 

383 species 

384 new occupations (species) by atomic number 

385 

386 Raises 

387 ------ 

388 ValueError 

389 if input lists are not of the same length 

390 """ 

391 

392 if len(sites) != len(species): 

393 raise ValueError('sites and species must have the same length.') 

394 self.configuration.update_occupations(sites, species) 

395 

396 def _get_property_change(self, sites: List[int], species: List[int]) -> float: 

397 """Computes and returns the property change due to a change of 

398 the configuration. 

399 

400 _N.B.:_ This method leaves the configuration itself unchanged. 

401 

402 Parameters 

403 ---------- 

404 sites 

405 indices of sites to change 

406 species 

407 new occupations (species) by atomic number 

408 """ 

409 return self.calculator.calculate_change(sites=sites, 

410 current_occupations=self.configuration.occupations, 

411 new_site_occupations=species) 

412 

413 def _get_ensemble_data(self) -> dict: 

414 """ Returns the current calculator property. """ 

415 potential = self.calculator.calculate_total(occupations=self.configuration.occupations) 

416 return {'potential': potential, 

417 'acceptance_ratio': self._accepted_trials / self._ensemble_data_write_interval} 

418 

419 def get_random_sublattice_index(self, probability_distribution) -> int: 

420 """Returns a random sublattice index based on the weights of the 

421 sublattice. 

422 

423 Parameters 

424 ---------- 

425 probability_distribution 

426 probability distributions for the sublattices 

427 """ 

428 

429 if len(probability_distribution) != len(self.sublattices): 

430 raise ValueError('probability_distribution should have the same size as sublattices') 

431 pick = np.random.choice(len(self.sublattices), p=probability_distribution) 

432 return pick 

433 

434 def _restart_ensemble(self): 

435 """Restarts ensemble using the last state saved in data container file. 

436 """ 

437 

438 # Restart step 

439 self._step = self.data_container._last_state['last_step'] 

440 

441 # Update configuration 

442 occupations = self.data_container._last_state['occupations'] 

443 active_sites = [] 

444 for sl in self.sublattices.active_sublattices: 

445 active_sites.extend(sl.indices) 

446 active_occupations = [occupations[s] for s in active_sites] 

447 self.update_occupations(active_sites, active_occupations) 

448 

449 # Restart number of total and accepted trial steps 

450 self._accepted_trials = self.data_container._last_state['accepted_trials'] 

451 

452 # Restart state of random number generator 

453 random.setstate(self.data_container._last_state['random_state']) 

454 

455 def write_data_container(self, outfile: Union[str, bytes]): 

456 """Updates last state of the Monte Carlo simulation and 

457 writes data container to file. 

458 

459 Parameters 

460 ---------- 

461 outfile 

462 file to which to write 

463 """ 

464 self._data_container._update_last_state( 

465 last_step=self.step, 

466 occupations=self.configuration.occupations.tolist(), 

467 accepted_trials=self._accepted_trials, 

468 random_state=random.getstate()) 

469 

470 self.data_container.write(outfile) 

471 

472 @property 

473 def ensemble_parameters(self) -> dict: 

474 """Returns parameters associated with the ensemble.""" 

475 return self._ensemble_parameters.copy() 

476 

477 @property 

478 def sublattices(self) -> Sublattices: 

479 """sublattices for the configuration being sampled""" 

480 return self.configuration.sublattices 

481 

482 def _terminate_sampling(self) -> bool: 

483 """This method is called from the run method to determine whether the MC 

484 sampling loop should be terminated for a reason other than having exhausted 

485 the number of iterations. The method can be overriden by child classes in 

486 order to provide an alternative exit mechanism. 

487 """ 

488 return False 

489 

490 def _finalize(self) -> None: 

491 """This method is called from the run method after the conclusion of 

492 the MC cycles but before the data container is written. This 

493 method can be used by child classes to carry out clean-up 

494 tasks, including e.g., adding "left-over" data to the data 

495 container. 

496 """ 

497 pass 

498 

499 def __str__(self) -> str: 

500 """ string representation of BaseEnsemble. """ 

501 width = 60 

502 name = self.__class__.__name__ 

503 s = [' {} '.format(name).center(width, '=')] 

504 

505 fmt = '{:15} : {}' 

506 for k, v in self.ensemble_parameters.items(): 

507 s += [fmt.format(k, v)] 

508 

509 s += [fmt.format('step', self.step)] 

510 s += [fmt.format('calculator', self._calculator.__class__.__name__)] 

511 return '\n'.join(s) 

512 

513 

514def dicts_equal(dict1: Dict, dict2: Dict, atol: float = 1e-12) -> bool: 

515 """Returns True (False) if two dicts are equal (not equal), if 

516 float or integers are in the dicts then atol is used for comparing them.""" 

517 if len(dict1) != len(dict2): 517 ↛ 518line 517 didn't jump to line 518, because the condition on line 517 was never true

518 return False 

519 for key in dict1.keys(): 

520 if key not in dict2: 520 ↛ 521line 520 didn't jump to line 521, because the condition on line 520 was never true

521 return False 

522 if isinstance(dict1[key], (int, float)) and isinstance(dict2[key], (int, float)): 

523 if not np.isclose(dict1[key], dict2[key], rtol=0.0, atol=atol) and \ 

524 not np.isnan(dict1[key]) and not np.isnan(dict2[key]): 

525 return False 

526 else: 

527 if dict1[key] != dict2[key]: 527 ↛ 528line 527 didn't jump to line 528, because the condition on line 527 was never true

528 return False 

529 return True