Skip to content

Commit f42507a

Browse files
committed
upd
1 parent 45c7486 commit f42507a

File tree

4 files changed

+433
-1
lines changed

4 files changed

+433
-1
lines changed

__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from .creation import cm_to_eids
2-
from .tensor_ops import delete_indices
2+
from .tensor_ops import delete_indices
3+
from .parsing import parse_graph_data
4+
from .parsing import parse_graph_data_torch

examples.ipynb

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 193,
6+
"id": "measured-andrew",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import os\n",
11+
"import math\n",
12+
"\n",
13+
"import numba\n",
14+
"import numpy as np\n",
15+
"import atomium\n",
16+
"import Bio\n",
17+
"import torch as th\n",
18+
"\n",
19+
"from scipy.spatial import distance_matrix\n",
20+
"protein_letters_3to1 = Bio.SeqUtils.IUPACData.protein_letters_3to1_extended\n",
21+
"protein_letters_3to1 = {k.upper() : v for k,v in protein_letters_3to1.items()}"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": 239,
27+
"id": "common-kingdom",
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"@numba.njit(parallel=True)\n",
32+
"def numba_jit_scalar_distance_parallel(xyz):\n",
33+
" rows = xyz.shape[0]\n",
34+
" output = np.empty((rows, rows), dtype=np.float32)\n",
35+
" for i in numba.prange(rows):\n",
36+
" cols = rows - i\n",
37+
" for j in numba.prange(cols):\n",
38+
" tmp = 0.0\n",
39+
" tmp += (xyz[i, 0] - xyz[j, 0])**2\n",
40+
" tmp += (xyz[i, 1] - xyz[j, 1])**2\n",
41+
" tmp += (xyz[i, 2] - xyz[j, 2])**2\n",
42+
" tmp = math.sqrt(tmp) \n",
43+
" output[i,j] = tmp\n",
44+
" output[j,i] = tmp\n",
45+
" return output"
46+
]
47+
},
48+
{
49+
"cell_type": "code",
50+
"execution_count": 107,
51+
"id": "norwegian-sport",
52+
"metadata": {},
53+
"outputs": [],
54+
"source": [
55+
"def get_atom_xyz(atoms, atom_name):\n",
56+
" for a in atoms:\n",
57+
" if a.name == atom_name:\n",
58+
" return a.location\n",
59+
" return (np.nan, np.nan, np.nan)\n",
60+
"\n",
61+
"def get_ss_label(residue):\n",
62+
" '''\n",
63+
" E, H or C label from atomium\n",
64+
" '''\n",
65+
" if residue.helix:\n",
66+
" return 'H'\n",
67+
" elif residue.strand:\n",
68+
" return 'E'\n",
69+
" else:\n",
70+
" return 'C'"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": 237,
76+
"id": "muslim-smoke",
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"def parse_graph_data_numba(path_pdb, chain):\n",
81+
" \n",
82+
" if not os.path.isfile(path_pdb):\n",
83+
" FileNotFoundError('no such file', path_pdb)\n",
84+
" file = atomium.open(path_pdb)\n",
85+
" chain = file.model.chain(chain)\n",
86+
" preparation_dict = dict()\n",
87+
" for i, r in enumerate(chain.residues()):\n",
88+
" r_atoms = r.atoms()\n",
89+
" preparation_dict[i] = {'aa' : protein_letters_3to1[r.name],\n",
90+
" 'charge' : r.charge,\n",
91+
" 'CA' : get_atom_xyz(r_atoms, 'CA'),\n",
92+
" 'CB' : get_atom_xyz(r_atoms, 'CB'),\n",
93+
" 'ss_label' : get_ss_label(r)\n",
94+
" }\n",
95+
"\n",
96+
" ca_xyz = np.asarray(list(map(lambda v : v['CA'], preparation_dict.values())), dtype=np.float32)\n",
97+
" sequence = list(map(lambda v : v['aa'], preparation_dict.values()))\n",
98+
" ca_ca_matrix = numba_jit_scalar_distance_parallel(ca_xyz)\n",
99+
" return ca_ca_matrix, sequence"
100+
]
101+
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": 176,
105+
"id": "hidden-thanksgiving",
106+
"metadata": {},
107+
"outputs": [],
108+
"source": [
109+
"def parse_graph_data_torch(path_pdb, chain):\n",
110+
" \n",
111+
" if not os.path.isfile(path_pdb):\n",
112+
" FileNotFoundError('no such file', path_pdb)\n",
113+
" file = atomium.open(path_pdb)\n",
114+
" chain = file.model.chain(chain)\n",
115+
" preparation_dict = dict()\n",
116+
" for i, r in enumerate(chain.residues()):\n",
117+
" r_atoms = r.atoms()\n",
118+
" preparation_dict[i] = {'aa' : protein_letters_3to1[r.name],\n",
119+
" 'charge' : r.charge,\n",
120+
" 'CA' : get_atom_xyz(r_atoms, 'CA'),\n",
121+
" 'CB' : get_atom_xyz(r_atoms, 'CB'),\n",
122+
" 'ss_label' : get_ss_label(r)\n",
123+
" }\n",
124+
"\n",
125+
" ca_xyz = th.FloatTensor(list(map(lambda v : v['CA'], preparation_dict.values())))\n",
126+
" sequence = list(map(lambda v : v['aa'], preparation_dict.values()))\n",
127+
"\n",
128+
" ca_ca_matrix = th.cdist(ca_xyz, ca_xyz)\n",
129+
" return ca_ca_matrix, sequence"
130+
]
131+
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": 184,
135+
"id": "demographic-marijuana",
136+
"metadata": {},
137+
"outputs": [],
138+
"source": [
139+
"def parse_graph_data(path_pdb, chain):\n",
140+
" \n",
141+
" if not os.path.isfile(path_pdb):\n",
142+
" FileNotFoundError('no such file', path_pdb)\n",
143+
" file = atomium.open(path_pdb)\n",
144+
" chain = file.model.chain(chain)\n",
145+
" preparation_dict = dict()\n",
146+
" for i, r in enumerate(chain.residues()):\n",
147+
" r_atoms = r.atoms()\n",
148+
" preparation_dict[i] = {'aa' : protein_letters_3to1[r.name],\n",
149+
" 'charge' : r.charge,\n",
150+
" 'CA' : get_atom_xyz(r_atoms, 'CA'),\n",
151+
" 'CB' : get_atom_xyz(r_atoms, 'CB'),\n",
152+
" 'ss_label' : get_ss_label(r)\n",
153+
" }\n",
154+
"\n",
155+
" ca_xyz = np.asarray(list(map(lambda v : v['CA'], preparation_dict.values())), dtype=np.float32)\n",
156+
" sequence = list(map(lambda v : v['aa'], preparation_dict.values()))\n",
157+
"\n",
158+
" ca_ca_matrix = distance_matrix(ca_xyz, ca_xyz)\n",
159+
" return ca_ca_matrix, sequence"
160+
]
161+
},
162+
{
163+
"cell_type": "code",
164+
"execution_count": 169,
165+
"id": "laden-chart",
166+
"metadata": {},
167+
"outputs": [],
168+
"source": [
169+
"path = '/home/db/localpdb/mirror/ea/pdb6eac.ent.gz'\n",
170+
"chain = 'A'"
171+
]
172+
},
173+
{
174+
"cell_type": "code",
175+
"execution_count": 185,
176+
"id": "monetary-wrapping",
177+
"metadata": {},
178+
"outputs": [
179+
{
180+
"name": "stdout",
181+
"output_type": "stream",
182+
"text": [
183+
"4.11 s ± 3.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
184+
]
185+
}
186+
],
187+
"source": [
188+
"%timeit parse_graph_data(path, chain)"
189+
]
190+
},
191+
{
192+
"cell_type": "code",
193+
"execution_count": 186,
194+
"id": "armed-syndication",
195+
"metadata": {},
196+
"outputs": [
197+
{
198+
"name": "stdout",
199+
"output_type": "stream",
200+
"text": [
201+
"2.99 s ± 7.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
202+
]
203+
}
204+
],
205+
"source": [
206+
"%timeit parse_graph_data_torch(path, chain)"
207+
]
208+
},
209+
{
210+
"cell_type": "code",
211+
"execution_count": 240,
212+
"id": "rapid-accountability",
213+
"metadata": {},
214+
"outputs": [
215+
{
216+
"name": "stdout",
217+
"output_type": "stream",
218+
"text": [
219+
"3.05 s ± 9.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
220+
]
221+
}
222+
],
223+
"source": [
224+
"%timeit parse_graph_data_numba(path, chain)"
225+
]
226+
},
227+
{
228+
"cell_type": "code",
229+
"execution_count": 241,
230+
"id": "foster-output",
231+
"metadata": {},
232+
"outputs": [],
233+
"source": [
234+
"a,b = parse_graph_data_torch(path, chain)"
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": 244,
240+
"id": "broke-heart",
241+
"metadata": {},
242+
"outputs": [
243+
{
244+
"data": {
245+
"text/plain": [
246+
"885.0625"
247+
]
248+
},
249+
"execution_count": 244,
250+
"metadata": {},
251+
"output_type": "execute_result"
252+
}
253+
],
254+
"source": [
255+
"a.element_size()*a.nelement() / 1024"
256+
]
257+
},
258+
{
259+
"cell_type": "code",
260+
"execution_count": null,
261+
"id": "union-injection",
262+
"metadata": {},
263+
"outputs": [],
264+
"source": []
265+
}
266+
],
267+
"metadata": {
268+
"kernelspec": {
269+
"display_name": "Python 3",
270+
"language": "python",
271+
"name": "python3"
272+
},
273+
"language_info": {
274+
"codemirror_mode": {
275+
"name": "ipython",
276+
"version": 3
277+
},
278+
"file_extension": ".py",
279+
"mimetype": "text/x-python",
280+
"name": "python",
281+
"nbconvert_exporter": "python",
282+
"pygments_lexer": "ipython3",
283+
"version": "3.7.9"
284+
}
285+
},
286+
"nbformat": 4,
287+
"nbformat_minor": 5
288+
}

files_io.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import json
2+
import os
3+
import pickle
4+
import gzip
5+
6+
'''
7+
strorage for quick function
8+
'''
9+
10+
11+
def get_json(file):
12+
'''
13+
read json to object
14+
'''
15+
assert isinstance(file, str)
16+
17+
with open(file, 'r') as f:
18+
data = json.load(f)
19+
return data
20+
21+
def save_json(file, data):
22+
'''
23+
save json to file
24+
'''
25+
assert isinstance(file, str)
26+
assert isinstance(data, dict)
27+
28+
with open(file, 'w') as f:
29+
json.dump(data, f)
30+
31+
32+
def load_gpickle(file):
33+
'''
34+
returns content of gziped (optional requires .gz extension) pickle file
35+
'''
36+
37+
assert isinstance(file, str)
38+
assert os.path.isfile(file), f'no such file {file}'
39+
40+
if file.endswith('.gz'):
41+
with gzip.open(file, 'rb') as f:
42+
data = pickle.load(f)
43+
else:
44+
with open(file, 'rb') as f:
45+
data = pickle.load(f)
46+
return data
47+
48+
def save_gpickle(obj, file):
49+
'''
50+
pickles `obj` if file endswith .gz then zip pickle
51+
'''
52+
53+
assert isinstance(file, str)
54+
assert os.path.isdir(os.path.dirname(file)), f'no such directory: {file}'
55+
56+
if file.endswith('.gz'):
57+
with gzip.open(file, 'wb') as f:
58+
pickle.dump(obj, f)
59+
else:
60+
with open(file, 'rb') as f:
61+
pickle.dump(obj, f)
62+

0 commit comments

Comments
 (0)