import numpy as np import periodictable from copy import deepcopy, copy from .latmatcher import LatMatch, SR class PipelineLatMatch: def __init__(self, Alat3D, Blat3D, Aatoms3D=None, Batoms3D=None, dim=10, sc_vec=None, optimize_angle=True, optimize_strain=True, dz=4): self.Alat3D = Alat3D self.Blat3D = Blat3D self.Alat = self.Alat3D[:2, :2][:2, :2] self.Blat = self.Blat3D[:2, :2][:2, :2] self.dz = dz if sc_vec is not None: self.sc_vec = sc_vec # TODO: This should be computed self.rez = None else: self.matcher4 = LatMatch(scdim=dim, reference=self.Alat, target=self.Blat, optimize_angle=optimize_angle, optimize_strain=optimize_strain) self.sc_vec = self.matcher4.supercell() self.rez = self.matcher4.result self.sc_vec3 = np.array([[self.sc_vec[0][0], self.sc_vec[0][1], 0], [self.sc_vec[1][0], self.sc_vec[1][1], 0], [0, 0, self.dz+self.Alat3D[2][2]+self.Blat3D[2][2]+0.001]]) self.Aatoms3D, self.Batoms3D = self.atom_shift(Aatoms3D, Batoms3D) self.superA_xyz = [] self.superB_xyz = [] def atom_shift(self, Aatoms3D, Batoms3D): shift_Aatoms3D = deepcopy(Aatoms3D) shift_Batoms3D = deepcopy(Batoms3D) za = [] for atom in Aatoms3D: za.append(atom[1][-1]) zb = [] for atom in Batoms3D: zb.append(atom[1][-1]) a_average = sum(za) / len(za) b_average = sum(zb) / len(zb) min_za=min(za) # min(za)+x=0 for i in range(len(za)): za[i] = za[i] - min_za+0.001 min_zb = min(zb) # min(za)+x=0 for i in range(len(zb)): zb[i] = zb[i] - min_zb max_za = max(za) for i in range(len(zb)): zb[i] = zb[i] + max_za + self.dz+0.001 for i in range(len(shift_Aatoms3D)): shift_Aatoms3D[i][1][-1] = za[i] for i in range(len(shift_Batoms3D)): shift_Batoms3D[i][1][-1] = zb[i] self.Aatoms3D = shift_Aatoms3D self.Batoms3D = shift_Batoms3D return self.Aatoms3D, self.Batoms3D def compute_super_atoms(self): atoms = atoms_to_greed(self.Aatoms3D, lat_v=self.Alat3D, dim=(10, 10, 0)) # initial grid of atoms xyz atoms_a = atom_change_basis2D(atoms, new_basis=self.Alat, old_basis=np.identity(2)) # grid of atoms in A basis atoms_A = atom_change_basis2D(atoms_a, new_basis=self.sc_vec, old_basis=self.Alat) # grid of atoms in super cell basis superA = supar_atoms(atoms_A) # select the atoms from super cell superA_xyz = atom_change_basis2D(superA, new_basis=np.identity(2), old_basis=self.sc_vec) # move the atoms back to xyz basis atoms = atoms_to_greed(self.Batoms3D, lat_v=self.Blat3D, dim=(10, 10, 0)) # initial grid of atoms oBlat3D, atoms_b = rotate_guest(self.rez, self.Blat3D, atoms) atoms_B = atom_change_basis2D(atoms_b, new_basis=self.sc_vec, old_basis=np.identity(2)) # grid of atoms in super cell basis superB = supar_atoms(atoms_B) # select the atoms from super cell superB_xyz = atom_change_basis2D(superB, new_basis=np.identity(2), old_basis=self.sc_vec) # move the atoms back to xyz basis superA_xyz = uniq_list(superA_xyz) superB_xyz = uniq_list(superB_xyz) self.superA_xyz = superA_xyz self.superB_xyz = superB_xyz return self.superA_xyz, self.superB_xyz def get_new_structure(self,dist): structure = { "lattice_vectors": self.sc_vec3, "pbc": [True, True, False], "positions": [], "host_guest": [], "atoms": [], } superA_xyz, superB_xyz = self.compute_super_atoms() atomic_numbers = [] positions = [] for element in superA_xyz: structure["host_guest"].append("host") symbol = element[0] positions.append(element[1].tolist()) atomic_number = getattr(periodictable, symbol) if atomic_number is not None: atomic_numbers.append(atomic_number.number) else: print(f"Warning: Atomic number for element '{symbol}' not found.") atomic_numbers.append(None) # Or handle this case as you see fit for element in superB_xyz: structure["host_guest"].append("guest") symbol = element[0] dd=element[1].tolist()+dist positions.append(dd) atomic_number = getattr(periodictable, symbol) if atomic_number is not None: atomic_numbers.append(atomic_number.number) else: print(f"Warning: Atomic number for element '{symbol}' not found.") atomic_numbers.append(None) # Or handle this case as you see fit structure["atoms"] = atomic_numbers structure["positions"] = positions return structure def is_element_in_list(element, lst): for item in lst: if item[0] == element[0] and np.allclose(item[1], element[1]): return True return False def uniq_list(super_brut): uniq = [] for c in super_brut: if is_element_in_list(c, uniq) is False: uniq.append(c) return uniq def atoms_to_greed(atoms, lat_v, dim): """ Construct a greed of atoms knowing the atoms lattice vectors and dim :param atoms: :param lat_v: :param dim: :return: """ atom_list = deepcopy(atoms) # translation lat_v-x new_atoms = [] for i in range(1, dim[0] + 1): for atom in atom_list: new_atom = deepcopy(atom) new_atom[1][0] += i * lat_v[0][0] new_atom[1][1] += i * lat_v[1][0] new_atom[1][2] += i * lat_v[2][0] new_atoms.append(new_atom) new_atom = deepcopy(atom) new_atom[1][0] -= i * lat_v[0][0] new_atom[1][1] -= i * lat_v[1][0] new_atom[1][2] -= i * lat_v[2][0] new_atoms.append(new_atom) atom_list.extend(new_atoms) # translation lat_v-y for i in range(1, dim[1] + 1): for atom in atom_list: new_atom = deepcopy(atom) new_atom[1][0] += i * lat_v[0][1] new_atom[1][1] += i * lat_v[1][1] new_atom[1][2] += i * lat_v[2][1] new_atoms.append(new_atom) new_atom = deepcopy(atom) new_atom[1][0] -= i * lat_v[0][1] new_atom[1][1] -= i * lat_v[1][1] new_atom[1][2] -= i * lat_v[2][1] new_atoms.append(new_atom) atom_list.extend(new_atoms) # translation lat_v-z for i in range(1, dim[2] + 1): for atom in atom_list: new_atom = deepcopy(atom) new_atom[1][0] += i * lat_v[0][2] new_atom[1][1] += i * lat_v[1][2] new_atom[1][2] += i * lat_v[2][2] new_atoms.append(new_atom) new_atom = deepcopy(atom) new_atom[1][0] -= i * lat_v[0][2] new_atom[1][1] -= i * lat_v[1][2] new_atom[1][2] -= i * lat_v[2][2] new_atoms.append(new_atom) atom_list.extend(new_atoms) return atom_list # atoms in new basis: def atom_change_basis2D(atoms, new_basis, old_basis=np.identity(2)): new_basis = [[new_basis[0][0], new_basis[0][1]], [new_basis[1][0], new_basis[1][1]]] old_basis = [[old_basis[0][0], old_basis[0][1]], [old_basis[1][0], old_basis[1][1]]] atom_list = deepcopy(atoms) change_base = np.linalg.inv(new_basis) @ old_basis new_atoms = [] for atom in atom_list: new_atom = deepcopy(atom) nd = (change_base @ (np.array(new_atom[1][:2]).T)).T new_atom[1][0] = nd[0] new_atom[1][1] = nd[1] new_atoms.append(new_atom) atom_list = deepcopy(new_atoms) return atom_list def supar_atoms(atoms, eps=0.01): atom_list = deepcopy(atoms) # change_base=np.linalg.inv(new_basis) new_atoms = [] for atom in atom_list: if ((atom[1][0] >= -eps) and (atom[1][0] < 1)) and ((atom[1][1] >= -eps) and (atom[1][1] < 1)): new_atoms.append(atom) atom_list = deepcopy(new_atoms) return atom_list def rotate_guest(rez, Blat3D, atoms): Batoms3D = [a[1] for a in atoms] s1, s2, theta = rez Tr = np.eye(3) Tr[:2, :2] = SR(s1, s2, theta) oBlat3D = copy(Blat3D) oBlat3D = (Tr @ (oBlat3D.T)).T rs = [Tr @ r for r in Batoms3D] new_atoms = [[atoms[i][0], rs[i]] for i in range(len(atoms))] return oBlat3D, new_atoms