Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" Base data container class. """ 

2 

3import getpass 

4import json 

5import numbers 

6import os 

7import shutil 

8import socket 

9import tarfile 

10import tempfile 

11import warnings 

12 

13from collections import OrderedDict 

14from datetime import datetime 

15from typing import Any, BinaryIO, Dict, List, Set, TextIO, Tuple, Union 

16 

17import numpy as np 

18import pandas as pd 

19 

20from ase import Atoms 

21from ase.io import read as ase_read 

22from ase.io import write as ase_write 

23from icet import __version__ as icet_version 

24from ..observers.base_observer import BaseObserver 

25 

26 

27class Int64Encoder(json.JSONEncoder): 

28 

29 def default(self, obj): 

30 if isinstance(obj, np.int64): 

31 return int(obj) 

32 return json.JSONEncoder.default(self, obj) 

33 

34 

35class BaseDataContainer: 

36 """ 

37 Base data container for storing information concerned with 

38 Monte Carlo simulations performed with mchammer. 

39 

40 Parameters 

41 ---------- 

42 structure : ase.Atoms 

43 reference atomic structure associated with the data container 

44 

45 ensemble_parameters : dict 

46 parameters associated with the underlying ensemble 

47 

48 metadata : dict 

49 metadata associated with the data container 

50 """ 

51 

52 def __init__(self, structure: Atoms, 

53 ensemble_parameters: dict, 

54 metadata: dict = OrderedDict()): 

55 """ 

56 Initializes a BaseDataContainer object. 

57 """ 

58 if not isinstance(structure, Atoms): 

59 raise TypeError('structure is not an ASE Atoms object') 

60 

61 self.structure = structure.copy() 

62 self._ensemble_parameters = ensemble_parameters 

63 self._metadata = metadata 

64 self._add_default_metadata() 

65 self._last_state = {} # type: Dict[str, Any] 

66 

67 self._observables = set() # type: Set[str] 

68 self._data_list = [] # type: List[Dict[str, Any]] 

69 

70 def append(self, mctrial: int, record: Dict[str, Union[int, float, list]]): 

71 """ 

72 Appends data to data container. 

73 

74 Parameters 

75 ---------- 

76 mctrial 

77 current Monte Carlo trial step 

78 record 

79 dictionary of tag-value pairs representing observations 

80 

81 Raises 

82 ------ 

83 TypeError 

84 if input parameters have the wrong type 

85 

86 """ 

87 if not isinstance(mctrial, numbers.Integral): 

88 raise TypeError('mctrial has the wrong type: {}'.format(type(mctrial))) 

89 

90 if self._data_list: 

91 if self._data_list[-1]['mctrial'] > mctrial: 

92 raise ValueError('mctrial values should be given in ascending' 

93 ' order. This error can for example occur' 

94 ' when trying to append to an existing data' 

95 ' container after having reset the time step.' 

96 ' Note that the latter happens automatically' 

97 ' when initializing a new ensemble.') 

98 

99 if not isinstance(record, dict): 

100 raise TypeError('record has the wrong type: {}'.format(type(record))) 

101 

102 for tag in record.keys(): 

103 self._observables.add(tag) 

104 

105 row_data = OrderedDict() 

106 row_data['mctrial'] = mctrial 

107 row_data.update(record) 

108 self._data_list.append(row_data) 

109 

110 def _update_last_state(self, last_step: int, occupations: List[int], 

111 accepted_trials: int, random_state: Any): 

112 """Updates last state of the simulation: last step, occupation vector 

113 and number of accepted trial steps. 

114 

115 Parameters 

116 ---------- 

117 last_step 

118 last trial step 

119 occupations 

120 occupation vector observed during the last trial step 

121 accepted_trial 

122 number of current accepted trial steps 

123 random_state 

124 tuple representing the last state of the random generator 

125 """ 

126 self._last_state['last_step'] = last_step 

127 self._last_state['occupations'] = occupations 

128 self._last_state['accepted_trials'] = accepted_trials 

129 self._last_state['random_state'] = random_state 

130 

131 def apply_observer(self, observer: BaseObserver): 

132 """ Adds observer data from observer to data container. 

133 

134 The observer will only be run for the mctrials for which the 

135 trajectory have been saved. 

136 

137 The interval of the observer is ignored. 

138 

139 Parameters 

140 ---------- 

141 observer 

142 observer to be used 

143 """ 

144 for row_data in self._data_list: 

145 if 'occupations' in row_data: 

146 structure = self.structure.copy() 

147 structure.numbers = row_data['occupations'] 

148 record = dict() 

149 if observer.return_type is dict: 149 ↛ 150line 149 didn't jump to line 150, because the condition on line 149 was never true

150 for key, value in observer.get_observable(structure).items(): 

151 record[key] = value 

152 self._observables.add(key) 

153 else: 

154 record[observer.tag] = observer.get_observable(structure) 

155 self._observables.add(observer.tag) 

156 row_data.update(record) 

157 

158 def get(self, 

159 *tags: str, 

160 start: int = 0) \ 

161 -> Union[np.ndarray, List[Atoms], Tuple[np.ndarray, List[Atoms]]]: 

162 """Returns the accumulated data for the requested observables, 

163 including configurations stored in the data container. The latter 

164 can be achieved by including 'trajectory' as one of the tags. 

165 

166 Parameters 

167 ---------- 

168 tags 

169 names of the requested properties 

170 start 

171 minimum value of trial step to consider; by default the 

172 smallest value in the mctrial column will be used. 

173 

174 Raises 

175 ------ 

176 ValueError 

177 if tags is empty 

178 ValueError 

179 if observables are requested that are not in data container 

180 

181 Examples 

182 -------- 

183 Below the `get` method is illustrated but first we require a data container. 

184 

185 >>> from ase.build import bulk 

186 >>> from icet import ClusterExpansion, ClusterSpace 

187 >>> from mchammer.calculators import ClusterExpansionCalculator 

188 >>> from mchammer.ensembles import CanonicalEnsemble 

189 

190 >>> # prepare cluster expansion 

191 >>> prim = bulk('Au') 

192 >>> cs = ClusterSpace(prim, cutoffs=[4.3], chemical_symbols=['Ag', 'Au']) 

193 >>> ce = ClusterExpansion(cs, [0, 0, 0.1, -0.02]) 

194 

195 >>> # prepare initial configuration 

196 >>> structure = prim.repeat(3) 

197 >>> for k in range(5): 

198 ... structure[k].symbol = 'Ag' 

199 

200 >>> # set up and run MC simulation 

201 >>> calc = ClusterExpansionCalculator(structure, ce) 

202 >>> mc = CanonicalEnsemble(structure=structure, calculator=calc, 

203 ... temperature=600, 

204 ... dc_filename='myrun_canonical.dc') 

205 >>> mc.run(100) # carry out 100 trial swaps 

206 

207 We can now access the data container by reading it from file by using 

208 the `read` method. For the purpose of this example, however, we access 

209 the data container associated with the ensemble directly. 

210 

211 >>> dc = mc.data_container 

212 

213 The following lines illustrate how to use the `get` method 

214 for extracting data from the data container. 

215 

216 >>> # obtain all values of the potential represented by 

217 >>> # the cluster expansion along the trajectory 

218 >>> p = dc.get('potential') 

219 

220 >>> import matplotlib.pyplot as plt 

221 >>> # as above but this time the MC trial step is included as well 

222 >>> s, p = dc.get('mctrial', 'potential') 

223 >>> _ = plt.plot(s, p) 

224 >>> plt.show() 

225 

226 >>> # obtain configurations along the trajectory along with 

227 >>> # their potential 

228 >>> p, confs = dc.get('potential', 'trajectory') 

229 """ 

230 

231 if len(tags) == 0: 

232 raise TypeError('Missing tags argument') 

233 

234 local_tags = ['occupations' if tag == 'trajectory' else tag for tag in tags] 

235 

236 for tag in local_tags: 

237 if isinstance(tag, str) and tag in 'mctrial': 

238 continue 

239 if tag not in self.observables: 

240 raise ValueError('No observable named {} in data container'.format(tag)) 

241 

242 # collect data 

243 mctrials = [row_dict['mctrial'] for row_dict in self._data_list] 

244 data = pd.DataFrame.from_records(self._data_list, index=mctrials, columns=local_tags) 

245 data = data.loc[start::, local_tags].copy() 

246 data.dropna(inplace=True) 

247 

248 # handling of trajectory 

249 def occupation_to_atoms(occupation): 

250 structure = self.structure.copy() 

251 structure.numbers = occupation 

252 return structure 

253 

254 data_list = [] 

255 for tag in local_tags: 

256 if tag == 'occupations': 

257 traj = [occupation_to_atoms(o) for o in data['occupations']] 

258 data_list.append(traj) 

259 else: 

260 data_list.append(data[tag].values) 

261 

262 if len(data_list) > 1: 

263 return tuple(data_list) 

264 else: 

265 return data_list[0] 

266 

267 @property 

268 def data(self) -> pd.DataFrame: 

269 """ pandas data frame (see :class:`pandas.DataFrame`) """ 

270 if self._data_list: 

271 df = pd.DataFrame.from_records(self._data_list, index='mctrial', 

272 exclude=['occupations']) 

273 df.dropna(axis='index', how='all', inplace=True) 

274 df.reset_index(inplace=True) 

275 return df 

276 else: 

277 return pd.DataFrame() 

278 

279 @property 

280 def ensemble_parameters(self) -> dict: 

281 """ parameters associated with Monte Carlo simulation """ 

282 return self._ensemble_parameters.copy() 

283 

284 @property 

285 def observables(self) -> List[str]: 

286 """ observable names """ 

287 return list(self._observables) 

288 

289 @property 

290 def metadata(self) -> dict: 

291 """ metadata associated with data container """ 

292 return self._metadata 

293 

294 def write(self, outfile: Union[bytes, str]): 

295 """ 

296 Writes BaseDataContainer object to file. 

297 

298 Parameters 

299 ---------- 

300 outfile 

301 file to which to write 

302 """ 

303 self._metadata['date_last_backup'] = datetime.now().strftime('%Y-%m-%dT%H:%M:%S') 

304 

305 # Save reference atomic structure 

306 reference_structure_file = tempfile.NamedTemporaryFile() 

307 ase_write(reference_structure_file.name, self.structure, format='json') 

308 

309 # Save reference data 

310 data_container_type = str(self.__class__).split('.')[-1].replace("'>", '') 

311 reference_data = {'parameters': self._ensemble_parameters, 

312 'metadata': self._metadata, 

313 'last_state': self._last_state, 

314 'data_container_type': data_container_type} 

315 

316 reference_data_file = tempfile.NamedTemporaryFile() 

317 with open(reference_data_file.name, 'w') as fileobj: 

318 json.dump(reference_data, fileobj, cls=Int64Encoder) 

319 

320 # Save runtime data 

321 runtime_data_file = tempfile.NamedTemporaryFile() 

322 np.savez_compressed(runtime_data_file, self._data_list) 

323 

324 # Write temporary tar file 

325 with tempfile.NamedTemporaryFile('wb', delete=False) as f: 

326 with tarfile.open(fileobj=f, mode='w') as handle: 

327 handle.add(reference_data_file.name, arcname='reference_data') 

328 handle.add(reference_structure_file.name, arcname='atoms') 

329 handle.add(runtime_data_file.name, arcname='runtime_data') 

330 

331 # Copy to permanent location 

332 file_name = f.name 

333 f.close() # Required for Windows 

334 shutil.copy(file_name, outfile) 

335 os.remove(file_name) 

336 runtime_data_file.close() 

337 

338 def _add_default_metadata(self): 

339 """Adds default metadata to metadata dict.""" 

340 

341 self._metadata['date_created'] = \ 

342 datetime.now().strftime('%Y-%m-%dT%H:%M:%S') 

343 self._metadata['username'] = getpass.getuser() 

344 self._metadata['hostname'] = socket.gethostname() 

345 self._metadata['icet_version'] = icet_version 

346 

347 def __str__(self): 

348 """ string representation of data container """ 

349 width = 80 

350 s = [] # type: List 

351 s += ['{s:=^{n}}'.format(s=' Data Container ', n=width)] 

352 data_container_type = str(self.__class__).split('.')[-1].replace("'>", '') 

353 s += [' {:22}: {}'.format('data_container_type', data_container_type)] 

354 for key, value in self._last_state.items(): 

355 if isinstance(value, int) or isinstance(value, float) or isinstance(value, str): 

356 s += [' {:22}: {}'.format(key, value)] 

357 for key, value in sorted(self._ensemble_parameters.items()): 

358 s += [' {:22}: {}'.format(key, value)] 

359 for key, value in sorted(self._metadata.items()): 

360 s += [' {:22}: {}'.format(key, value)] 

361 s += [' {:22}: {}'.format('columns_in_data', self.data.columns.tolist())] 

362 s += [' {:22}: {}'.format('n_rows_in_data', len(self.data))] 

363 s += [''.center(width, '=')] 

364 return '\n'.join(s) 

365 

366 @classmethod 

367 # todo: cls and the return should be type hinted as BaseDataContainer. 

368 # Unfortunately, this requires from __future__ import annotations, which 

369 # in turn requires Python 3.8. 

370 def read(cls, infile: Union[str, BinaryIO, TextIO], old_format: bool = False): 

371 """Reads data container from file. 

372 

373 Parameters 

374 ---------- 

375 infile 

376 file from which to read 

377 old_format 

378 If true use old json format to read runtime data; default to false 

379 

380 Raises 

381 ------ 

382 FileNotFoundError 

383 if file is not found (str) 

384 ValueError 

385 if file is of incorrect type (not a tarball) 

386 """ 

387 if isinstance(infile, str): 

388 filename = infile 

389 else: 

390 filename = infile.name 

391 

392 if not tarfile.is_tarfile(filename): 

393 raise TypeError('{} is not a tar file'.format(filename)) 

394 

395 with tarfile.open(mode='r', name=filename) as tf: 

396 # file with structures 

397 with tempfile.NamedTemporaryFile() as fobj: 

398 fobj.write(tf.extractfile('atoms').read()) 

399 fobj.flush() 

400 structure = ase_read(fobj.name, format='json') 

401 

402 # file with reference data 

403 with tempfile.NamedTemporaryFile() as fobj: 

404 fobj.write(tf.extractfile('reference_data').read()) 

405 fobj.flush() 

406 with open(fobj.name, encoding='utf-8') as fd: 

407 reference_data = json.load(fd) 

408 

409 # init DataContainer 

410 dc = cls(structure=structure, 

411 ensemble_parameters=reference_data['parameters']) 

412 

413 # overwrite metadata 

414 dc._metadata = reference_data['metadata'] 

415 

416 for tag, value in reference_data['last_state'].items(): 

417 if tag == 'random_state': 

418 value = tuple(tuple(x) if isinstance(x, list) else x for x in value) 

419 dc._last_state[tag] = value 

420 

421 # add runtime data from file 

422 with tempfile.NamedTemporaryFile() as fobj: 

423 fobj.write(tf.extractfile('runtime_data').read()) 

424 fobj.seek(0) 

425 if old_format: 425 ↛ 426line 425 didn't jump to line 426, because the condition on line 425 was never true

426 runtime_data = pd.read_json(fobj) 

427 data = runtime_data.sort_index(ascending=True) 

428 dc._data_list = data.T.apply(lambda x: x.dropna().to_dict()).tolist() 

429 else: 

430 dc._data_list = np.load(fobj, allow_pickle=True)['arr_0'].tolist() 

431 

432 dc._observables = set([key for data in dc._data_list for key in data]) 

433 dc._observables = dc._observables - {'mctrial'} 

434 

435 return dc 

436 

437 def get_data(self, *args, **kwargs): 

438 warnings.simplefilter('always', DeprecationWarning) 

439 warnings.warn('get_data is deprecated, use get instead', DeprecationWarning) 

440 return self.get(*args, **kwargs) 

441 

442 def get_trajectory(self, *args, **kwargs): 

443 """ Returns trajectory as a list of ASE Atoms objects.""" 

444 warnings.simplefilter('always', DeprecationWarning) 

445 warnings.warn('get_trajectory is deprecated, use get instead', DeprecationWarning) 

446 return self.get('trajectory', *args, **kwargs)