Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2This module provides the StructureContainer class. 

3""" 

4 

5import tarfile 

6import tempfile 

7 

8from typing import BinaryIO, List, TextIO, Tuple, Union 

9 

10import numpy as np 

11import ase.db 

12from ase import Atoms 

13 

14from icet import ClusterSpace 

15from icet.input_output.logging_tools import logger 

16logger = logger.getChild('structure_container') 

17 

18 

19class StructureContainer: 

20 """ 

21 This class serves as a container for structure objects as well as their fit 

22 properties and cluster vectors. 

23 

24 Parameters 

25 ---------- 

26 cluster_space : icet.ClusterSpace 

27 cluster space used for evaluating the cluster vectors 

28 

29 Example 

30 ------- 

31 The following snippet illustrates the initialization 

32 and usage of a StructureContainer object. The construction 

33 of a structure container is convenient for compiling the 

34 data needed to train a cluster expansion, i.e., a sensing 

35 matrix and target energies:: 

36 

37 >>> from ase.build import bulk 

38 >>> from icet import ClusterSpace, StructureContainer 

39 >>> from icet.tools import enumerate_structures 

40 >>> from random import random 

41 

42 >>> # create cluster space 

43 >>> prim = bulk('Au') 

44 >>> cs = ClusterSpace(prim, cutoffs=[7.0, 5.0], 

45 ... chemical_symbols=[['Au', 'Pd']]) 

46 

47 >>> # build structure container 

48 >>> sc = StructureContainer(cs) 

49 >>> for structure in enumerate_structures(prim, range(5), ['Au', 'Pd']): 

50 >>> sc.add_structure(structure, 

51 ... properties={'my_random_energy': random()}) 

52 >>> print(sc) 

53 

54 >>> # fetch sensing matrix and target energies 

55 >>> A, y = sc.get_fit_data(key='my_random_energy') 

56 """ 

57 

58 def __init__(self, cluster_space: ClusterSpace): 

59 

60 if not isinstance(cluster_space, ClusterSpace): 

61 raise TypeError('cluster_space must be a ClusterSpace object') 

62 

63 self._cluster_space = cluster_space 

64 self._structure_list = [] 

65 

66 def __len__(self) -> int: 

67 return len(self._structure_list) 

68 

69 def __getitem__(self, ind: int): 

70 return self._structure_list[ind] 

71 

72 def get_structure_indices(self, user_tag: str = None) -> List[int]: 

73 """ 

74 Get structure indices via user_tag 

75 

76 Parameters 

77 ---------- 

78 user_tag 

79 user_tag used for selecting structures 

80 

81 Returns 

82 ------- 

83 list of integers 

84 List of structure's indices 

85 """ 

86 return [i for i, s in enumerate(self) if user_tag is None or s.user_tag == user_tag] 

87 

88 def _get_string_representation(self, print_threshold: int = None) -> str: 

89 """ 

90 String representation of the structure container that provides an 

91 overview of the structures in the container. 

92 

93 Parameters 

94 ---------- 

95 print_threshold 

96 if the number of structures exceeds this number print dots 

97 

98 Returns 

99 ------- 

100 multi-line string 

101 string representation of the structure container 

102 """ 

103 

104 if len(self) == 0: 

105 return 'Empty StructureContainer' 

106 

107 # Number of structures to print before cutting and printing dots 

108 if print_threshold is None or print_threshold >= len(self): 

109 print_threshold = len(self) + 2 

110 

111 # format specifiers for fields in table 

112 def get_format(val): 

113 if isinstance(val, float): 

114 return '{:9.4f}' 

115 else: 

116 return '{}' 

117 

118 # table headers 

119 default_headers = ['index', 'user_tag', 'n_atoms', 'chemical formula'] 

120 property_headers = sorted(set(key for fs in self for key in fs.properties)) 

121 headers = default_headers + property_headers 

122 

123 # collect the table data 

124 str_table = [] 

125 for i, fs in enumerate(self): 

126 default_data = [i, fs.user_tag, len(fs), fs.structure.get_chemical_formula()] 

127 property_data = [fs.properties.get(key, '') for key in property_headers] 

128 str_row = [get_format(d).format(d) for d in default_data+property_data] 

129 str_table.append(str_row) 

130 str_table = np.array(str_table) 

131 

132 # find maximum widths for each column 

133 widths = [] 

134 for i in range(str_table.shape[1]): 

135 data_width = max(len(val) for val in str_table[:, i]) 

136 header_width = len(headers[i]) 

137 widths.append(max([data_width, header_width])) 

138 

139 total_width = sum(widths) + 3 * len(headers) 

140 row_format = ' | '.join('{:'+str(width)+'}' for width in widths) 

141 

142 # Make string representation of table 

143 s = [] 

144 s += ['{s:=^{n}}'.format(s=' Structure Container ', n=total_width)] 

145 s += ['Total number of structures: {}'.format(len(self))] 

146 s += [''.center(total_width, '-')] 

147 s += [row_format.format(*headers)] 

148 s += [''.center(total_width, '-')] 

149 for i, fs_data in enumerate(str_table, start=1): 

150 s += [row_format.format(*fs_data)] 

151 if i+1 >= print_threshold: 

152 s += [' ...'] 

153 s += [row_format.format(*str_table[-1])] 

154 break 

155 s += [''.center(total_width, '=')] 

156 s = '\n'.join(s) 

157 

158 return s 

159 

160 def __repr__(self) -> str: 

161 """ String representation. """ 

162 return self._get_string_representation(print_threshold=50) 

163 

164 def print_overview(self, print_threshold: int = None): 

165 """ 

166 Prints a list of structures in the structure container. 

167 

168 Parameters 

169 ---------- 

170 print_threshold 

171 if the number of orbits exceeds this number print dots 

172 """ 

173 print(self._get_string_representation(print_threshold=print_threshold)) 

174 

175 def add_structure(self, structure: Atoms, user_tag: str = None, 

176 properties: dict = None, allow_duplicate: bool = True, 

177 sanity_check: bool = True): 

178 """ 

179 Adds a structure to the structure container. 

180 

181 Parameters 

182 ---------- 

183 structure 

184 the atomic structure to be added 

185 user_tag 

186 custom user tag to label structure 

187 properties 

188 scalar properties. If properties are not specified the structure 

189 object will be checked for an attached ASE calculator object 

190 with a calculated potential energy 

191 allow_duplicate 

192 whether or not to add the structure if there already exists a 

193 structure with identical cluster-vector 

194 sanity_check 

195 whether or not to carry out a sanity check before adding the 

196 structure. This includes checking occupations and volume. 

197 """ 

198 

199 # structure must have a proper format and label 

200 if not isinstance(structure, Atoms): 200 ↛ 201line 200 didn't jump to line 201, because the condition on line 200 was never true

201 raise TypeError('structure must be an ASE Atoms object not {}'.format(type(structure))) 

202 

203 if user_tag is not None: 

204 if not isinstance(user_tag, str): 204 ↛ 205line 204 didn't jump to line 205, because the condition on line 204 was never true

205 raise TypeError('user_tag must be a string not {}.'.format(type(user_tag))) 

206 

207 if sanity_check: 207 ↛ 211line 207 didn't jump to line 211, because the condition on line 207 was never false

208 self._cluster_space.assert_structure_compatibility(structure) 

209 

210 # check for properties in attached calculator 

211 if properties is None: 

212 properties = {} 

213 if structure.calc is not None: 213 ↛ 219line 213 didn't jump to line 219, because the condition on line 213 was never false

214 if not structure.calc.calculation_required(structure, ['energy']): 

215 energy = structure.get_potential_energy() 

216 properties['energy'] = energy / len(structure) 

217 

218 # check if there exist structures with identical cluster vectors 

219 structure_copy = structure.copy() 

220 cv = self._cluster_space.get_cluster_vector(structure_copy) 

221 if not allow_duplicate: 

222 for i, fs in enumerate(self): 

223 if np.allclose(cv, fs.cluster_vector): 

224 msg = '{} and {} have identical cluster vectors'.format( 

225 user_tag if user_tag is not None else 'Input structure', 

226 fs.user_tag if fs.user_tag != 'None' else 'structure') 

227 msg += ' at index {}'.format(i) 

228 raise ValueError(msg) 

229 

230 # add structure 

231 structure = FitStructure(structure_copy, user_tag, cv, properties) 

232 self._structure_list.append(structure) 

233 

234 def get_condition_number(self, structure_indices: List[int] = None, 

235 key: str = 'energy') -> float: 

236 """ Returns the condition number for the sensing matrix. 

237 

238 A very large condition number can be a sign of multicollinearity, 

239 read more here https://en.wikipedia.org/wiki/Condition_number 

240 

241 Parameters 

242 ---------- 

243 structure_indices 

244 list of structure indices; by default (``None``) the 

245 method will return all fit data available. 

246 key 

247 key of properties dictionary 

248 

249 Returns 

250 ------- 

251 condition number of the sensing matrix 

252 """ 

253 return np.linalg.cond(self.get_fit_data(structure_indices, key)[0]) 

254 

255 def get_fit_data(self, structure_indices: List[int] = None, 

256 key: str = 'energy') -> Tuple[np.ndarray, np.ndarray]: 

257 """ 

258 Returns fit data for all structures. The cluster vectors and 

259 target properties for all structures are stacked into numpy arrays. 

260 

261 Parameters 

262 ---------- 

263 structure_indices 

264 list of structure indices; by default (``None``) the 

265 method will return all fit data available. 

266 key 

267 key of properties dictionary 

268 

269 Returns 

270 ------- 

271 cluster vectors and target properties for desired structures 

272 """ 

273 if structure_indices is None: 

274 cv_list = [s.cluster_vector for s in self._structure_list] 

275 prop_list = [s.properties[key] for s in self._structure_list] 

276 else: 

277 cv_list, prop_list = [], [] 

278 for i in structure_indices: 

279 cv_list.append(self._structure_list[i].cluster_vector) 

280 prop_list.append(self._structure_list[i].properties[key]) 

281 

282 if cv_list is None: 

283 raise Exception('No available fit data for {}' 

284 .format(structure_indices)) 

285 

286 return np.array(cv_list), np.array(prop_list) 

287 

288 @property 

289 def cluster_space(self) -> ClusterSpace: 

290 """Cluster space used to calculate the cluster vectors.""" 

291 return self._cluster_space 

292 

293 @property 

294 def available_properties(self) -> List[str]: 

295 """List of the available properties.""" 

296 return sorted(set([p for fs in self for p in fs.properties.keys()])) 

297 

298 def write(self, outfile: Union[str, BinaryIO, TextIO]): 

299 """ 

300 Writes structure container to a file. 

301 

302 Parameters 

303 ---------- 

304 outfile 

305 output file name or file object 

306 """ 

307 # Write cluster space to tempfile 

308 temp_cs_file = tempfile.NamedTemporaryFile() 

309 self.cluster_space.write(temp_cs_file.name) 

310 

311 # Write fit structures as an ASE db in tempfile 

312 temp_db_file = tempfile.NamedTemporaryFile() 

313 if self._structure_list: 

314 db = ase.db.connect(temp_db_file.name, type='db', append=False) 

315 

316 for fit_structure in self._structure_list: 

317 data_dict = {'user_tag': fit_structure.user_tag, 

318 'properties': fit_structure.properties, 

319 'cluster_vector': fit_structure.cluster_vector} 

320 db.write(fit_structure.structure, data=data_dict) 

321 

322 with tarfile.open(outfile, mode='w') as handle: 

323 handle.add(temp_db_file.name, arcname='database') 

324 handle.add(temp_cs_file.name, arcname='cluster_space') 

325 

326 @staticmethod 

327 def read(infile: Union[str, BinaryIO, TextIO]): 

328 """ 

329 Reads StructureContainer object from file. 

330 

331 Parameters 

332 ---------- 

333 infile 

334 file from which to read 

335 

336 """ 

337 if isinstance(infile, str): 

338 filename = infile 

339 else: 

340 filename = infile.name 

341 

342 if not tarfile.is_tarfile(filename): 

343 raise TypeError('{} is not a tar file'.format(filename)) 

344 

345 temp_db_file = tempfile.NamedTemporaryFile() 

346 with tarfile.open(mode='r', name=filename) as tar_file: 

347 cs_file = tar_file.extractfile('cluster_space') 

348 temp_db_file.write(tar_file.extractfile('database').read()) 

349 temp_db_file.seek(0) 

350 cluster_space = ClusterSpace.read(cs_file) 

351 database = ase.db.connect(temp_db_file.name, type='db') 

352 

353 structure_container = StructureContainer(cluster_space) 

354 fit_structures = [] 

355 for row in database.select(): 

356 data = row.data 

357 fit_structure = FitStructure(row.toatoms(), 

358 user_tag=data['user_tag'], 

359 cv=data['cluster_vector'], 

360 properties=data['properties']) 

361 fit_structures.append(fit_structure) 

362 structure_container._structure_list = fit_structures 

363 

364 return structure_container 

365 

366 

367class FitStructure: 

368 """ 

369 This class holds a supercell along with its properties and cluster 

370 vector. 

371 

372 Attributes 

373 ---------- 

374 structure : Atoms 

375 supercell structure 

376 user_tag : str 

377 custom user tag 

378 cvs : np.ndarray 

379 calculated cluster vector for actual structure 

380 properties : dict 

381 dictionary of properties 

382 """ 

383 

384 def __init__(self, structure: Atoms, user_tag: str, 

385 cv: np.ndarray, properties: dict = {}): 

386 self._structure = structure 

387 self._user_tag = user_tag 

388 self._cluster_vector = cv 

389 self.properties = properties 

390 

391 @property 

392 def cluster_vector(self) -> np.ndarray: 

393 """calculated cluster vector""" 

394 return self._cluster_vector 

395 

396 @property 

397 def structure(self) -> Atoms: 

398 """atomic structure""" 

399 return self._structure 

400 

401 @property 

402 def user_tag(self) -> str: 

403 """structure label""" 

404 return str(self._user_tag) 

405 

406 def __getattr__(self, key): 

407 """ Accesses properties if possible and returns value. """ 

408 if key not in self.properties.keys(): 

409 return super().__getattribute__(key) 

410 return self.properties[key] 

411 

412 def __len__(self) -> int: 

413 """ Number of sites in the structure. """ 

414 return len(self._structure)