modified: splanegrid.py
[GalaxyCodeBases.git] / python / salus / pygrid / splanegrid.py
blob0e70056dd217732d20853dc2ca90cd40694fe167
1 #!/usr/bin/env python3
2 # pip3 install python-graphblas speedict dinopy fast-matrix-market tqdm
4 #from numba import jit
5 import concurrent.futures
6 import sys
7 import os
8 import io
9 import functools
10 import re
11 import argparse
12 import pathlib
13 import gzip
14 import graphblas as gb
15 import dinopy
16 import speedict
17 import fast_matrix_market
18 import tqdm
19 #from collections import defaultdict
20 import time
22 import pprint
23 pp = pprint.PrettyPrinter(indent=4)
24 # import gc
25 # gc.collect()
27 spatialDB = None
28 mgBoolMtx = None
30 def eprint(*args, **kwargs) -> None:
31 print(*args, **kwargs, file=sys.stderr, flush=True)
33 def init_argparse() -> argparse.ArgumentParser:
34 parser = argparse.ArgumentParser(
35 description='merge scSeq data with spBarcode coordinates that gridded by given binning size ',
36 epilog='Contact: <huxs@salus-bio.com>')
37 parser.add_argument('-b', '--bin', type=int, required = True, help='grid binning pixels')
38 parser.add_argument('-i', '--spatial', type=pathlib.Path, default='spatial.txt', metavar='txt', help='For spatial.txt[.gz]')
39 group = parser.add_mutually_exclusive_group(required=True)
40 group.add_argument('-r', '--scseq-path', type=pathlib.Path, dest='scSeqPath')
41 group.add_argument('-f', '--scseq-files', type=pathlib.Path, nargs=3, action='extend', metavar='<f>', dest='scSeqFiles', help='matrix.mtx[.gz] barcodes.tsv[.gz] features.tsv[.gz]')
43 parser.add_argument('-s', '--split-zones', dest='zones', type=int, choices=[0,4,5], default=0, help='split to 4 or 5 zones, default 0=off')
44 #parser.add_argument("files", nargs="*")
45 parser.add_argument('-o', '--output-path', type=pathlib.Path, default='./gridded/', dest='outpath')
46 parser.add_argument('-z', '--gzip', action=argparse.BooleanOptionalAction, default=True, help='Output gzipped files, default on', dest='gzip')
47 parser.add_argument('-n', '--dryrun', '--dry-run', action='store_true', dest='dryrun')
48 parser.add_argument(
49 "-v", "--version", action="version",
50 version=f"{parser.prog} version 1.0.0"
52 return parser
54 def checkFile(PathList, suffixStrs):
55 for onePath in PathList:
56 for oneExt in suffixStrs:
57 thisPath = pathlib.Path(''.join((onePath.as_posix(),oneExt)))
58 #print(thisPath)
59 if thisPath.exists():
60 return thisPath;
61 return None;
63 def db_options():
64 opt = speedict.Options(raw_mode=False)
65 # create table
66 opt.create_if_missing(True)
67 # config to more jobs
68 opt.set_max_background_jobs(os.cpu_count())
69 # configure mem-table to a large value (256 MB)
70 opt.set_write_buffer_size(0x10000000)
71 opt.set_level_zero_file_num_compaction_trigger(4)
72 # configure l0 and l1 size, let them have the same size (1 GB)
73 opt.set_max_bytes_for_level_base(0x40000000)
74 # 256 MB file size
75 opt.set_target_file_size_base(0x10000000)
76 # use a smaller compaction multiplier
77 opt.set_max_bytes_for_level_multiplier(4.0)
78 # use 8-byte prefix (2 ^ 64 is far enough for transaction counts)
79 opt.set_prefix_extractor(speedict.SliceTransform.create_max_len_prefix(8))
80 # set to plain-table
81 opt.set_plain_table_factory(speedict.PlainTableFactoryOptions())
82 # by Galaxy
83 opt.set_compaction_style(speedict.DBCompactionStyle.level())
84 opt.optimize_level_style_compaction(0x20000000) # 512 MB
85 opt.increase_parallelism(os.cpu_count())
86 opt.set_compression_type(speedict.DBCompressionType.snappy())
87 return opt
89 def fileOpener(filename):
90 f = open(filename,'rb')
91 fh = f
92 if (f.read(2) == b'\x1f\x8b'):
93 f.seek(0)
94 fh = gzip.GzipFile(fileobj=f, mode='rb')
95 else:
96 f.seek(0)
97 fht = io.TextIOWrapper(fh, encoding='utf-8', line_buffering=True)
98 return fht
100 def cmpGridID(a, b):
101 print("comparing ", a, " and ", b)
102 global spatialDB, args
103 #pp.pprint(gridRangeCnt)
104 Va = spatialDB[a]
105 aXgrid = Va[0] // args.bin
106 aYgrid = Va[1] // args.bin
107 agridID = aXgrid * gridRangeY + aYgrid
108 Vb = spatialDB[b]
109 bXgrid = Vb[0] // args.bin
110 bYgrid = Vb[1] // args.bin
111 bgridID = bXgrid * gridRangeY + bYgrid
112 return cmp(agridID,bgridID)
114 maxBarcodeLen = 0
115 SpatialBarcodeRange_xXyY = [0,0,0,0]
116 gridRangeCnt = ()
117 def readSpatial(infile, db):
118 global maxBarcodeLen
119 global SpatialBarcodeRange_xXyY
120 global GenesCnt
121 global BarcodesCnt
122 pbar = tqdm.tqdm(desc='Spatial', total=BarcodesCnt, ncols=70, mininterval=0.5, maxinterval=10, unit='', unit_scale=True, dynamic_ncols=True)
123 with fileOpener(infile) as f:
124 for index,line in enumerate(f, start=1):
125 [ seq, Xpos, Ypos, *_ ] = line.split()
126 seqLen = len(seq)
127 if seqLen > maxBarcodeLen:
128 maxBarcodeLen = seqLen
129 theXpos = int(float(Xpos))
130 theYpos = int(float(Ypos))
131 if (not SpatialBarcodeRange_xXyY[0]) or (SpatialBarcodeRange_xXyY[0] > theXpos):
132 SpatialBarcodeRange_xXyY[0] = theXpos
133 if (not SpatialBarcodeRange_xXyY[1]) or (SpatialBarcodeRange_xXyY[1] < theXpos):
134 SpatialBarcodeRange_xXyY[1] = theXpos
135 if (not SpatialBarcodeRange_xXyY[2]) or (SpatialBarcodeRange_xXyY[2] > theYpos):
136 SpatialBarcodeRange_xXyY[2] = theYpos
137 if (not SpatialBarcodeRange_xXyY[3]) or (SpatialBarcodeRange_xXyY[3] < theYpos):
138 SpatialBarcodeRange_xXyY[3] = theYpos
139 intSeq = dinopy.conversion.encode_twobit(seq)
140 #strSeq = dinopy.conversion.decode_twobit(intSeq, maxBarcodeLen, str)
141 #pp.pprint([seq, Xpos, Ypos, f'{intSeq:b}', strSeq])
142 db[intSeq] = [ theXpos, theYpos, None, None ]
143 if not index % 1000:
144 pbar.update(index - pbar.n)
145 pbar.update(index - pbar.n)
146 return index
148 def updateBarcodesID(infile, db, binPixels):
149 missingCnt = 0
150 global gridRangeCnt
151 global mgBoolMtx
152 #eprint(len(mtxBar2sp))
153 (gridRangeX, gridRangeY, gridCnt) = gridRangeCnt
154 pbar = tqdm.tqdm(desc='Barcodes', total=BarcodesCnt, ncols=70, mininterval=0.5, maxinterval=10, unit='', unit_scale=True, dynamic_ncols=True)
155 RePattern = re.compile("[-_|,./\\:;`\'!~!@#$%^&*()+= \t\r\n]+")
156 with fileOpener(infile) as f:
157 for index,line in enumerate(f, start=0):
158 [ seq, *_ ] = RePattern.split(line)
159 #seq = line.strip()
160 intSeq = dinopy.conversion.encode_twobit(seq)
161 if db.key_may_exist(intSeq):
162 #if intSeq in db:
163 thisValue = db[intSeq]
164 Xgrid = thisValue[0] // binPixels
165 Ygrid = thisValue[1] // binPixels
166 gridID = Xgrid * gridRangeY + Ygrid
167 #thisValue[2] = index
168 #thisValue[3] = gridID
169 #db[intSeq] = thisValue
170 #mtxBar2sp[index] = intSeq
171 #eprint("Pos:",str(index),', ',str(gridID),'.')
172 mgBoolMtx[index,gridID] << True
173 else:
174 ++missingCnt
175 pbar.update(index - pbar.n + 1)
176 return missingCnt
178 GenesCnt = 0
179 BarcodesCnt = 0
180 mtxNNZ = 0
181 def checkmtx(mtxfile) -> None:
182 mheader = fast_matrix_market.read_header(mtxfile)
183 global GenesCnt, BarcodesCnt, mtxNNZ
184 GenesCnt = mheader.nrows
185 BarcodesCnt = mheader.ncols
186 mtxNNZ = mheader.nnz
188 def write2gzip(outfile):
189 fh = gzip.open(outfile, mode='wb', compresslevel=1)
190 return fh
192 def mkcopy(fromFile, toFile):
193 if toFile.exists():
194 if toFile.samefile(fromFile):
195 return 0
196 else:
197 return 1
198 else:
199 try:
200 toFile.hardlink_to(fromFile)
201 return 0
202 except OSError as error :
203 eprint(error)
204 try:
205 toFile.symlink_to(fromFile)
206 return 0
207 except OSError as error :
208 eprint(error)
209 return 1
211 def mkGridSpatial(spFile, scBarcodeFile, gridRangeCnt):
212 spFh = gzip.open(spFile, mode='wt', compresslevel=1)
213 scFh = gzip.open(scBarcodeFile, mode='wt', compresslevel=1)
214 numLen = len(str(gridRangeCnt[2]))
215 for i in range(gridRangeCnt[2]):
216 barcodeStr = "Barcode{:0{}d}".format(i,numLen)
217 ### gridID = Xgrid * gridRangeY + Ygrid
218 Ygrid = i // gridRangeCnt[1]
219 Xgrid = i - (Ygrid * gridRangeCnt[1])
220 #eprint(barcodeStr,str(Xgrid),str(Ygrid))
221 print(barcodeStr, file=scFh)
222 print(barcodeStr,str(Xgrid),str(Ygrid), file=spFh)
223 spFh.close()
224 scFh.close()
226 def main() -> None:
227 parser = init_argparse()
228 if len(sys.argv) == 1:
229 parser.print_help()
230 exit(0);
231 args = parser.parse_args()
232 #pp.pprint(args)
233 eprint('[!]GridBin=[',args.bin,'], SplitZone:[',args.zones,']. OutPath:[',args.outpath,']',sep='');
234 scFileNameTuple = ('matrix.mtx', 'barcodes.tsv', 'features.tsv', 'genes.tsv')
235 spFileNameList = ['spatial.txt']; spFileNameList.extend(scFileNameTuple[0:3])
236 #pp.pprint(spFileNameList)
237 if args.scSeqPath == None:
238 #args.scSeqFiles.append( args.scSeqFiles[2].with_stem('genes') )
239 scSeqFiles = tuple( args.scSeqFiles )
240 else:
241 scSeqFiles = tuple( args.scSeqPath.joinpath(x) for x in scFileNameTuple )
242 FileDotExts = ('', '.gz')
243 #pp.pprint(scSeqFiles)
244 spNameTuple = ('spatial', 'matrix', 'barcodes', 'features')
245 spStandardNameDict = dict(zip(spNameTuple,[ '.'.join((fn,'gz')) if args.gzip else fn for fn in spFileNameList ]))
246 #pp.pprint(spStandardNameDict)
247 InFileDict={}
248 InFileDict['spatial'] = checkFile([args.spatial], FileDotExts)
249 InFileDict['matrix'] = checkFile([scSeqFiles[0]], FileDotExts)
250 InFileDict['barcodes'] = checkFile([scSeqFiles[1]], FileDotExts)
251 InFileDict['features'] = checkFile(scSeqFiles[2:], FileDotExts)
252 #pp.pprint(inFiles)
253 eprint('[!]Confirmed Input Files:[',', '.join([ str(x) if x else '<Missing>' for x in InFileDict.values() ]),'].',sep='')
254 for fname in spNameTuple:
255 if InFileDict[fname]==None:
256 eprint('[x]The',fname,'file is missing !\n')
257 exit(1)
258 OutFileDict={}
259 for fname in spNameTuple:
260 OutFileDict[fname] = args.outpath.joinpath(spStandardNameDict[fname])
261 OutFileDict['Rdict'] = args.outpath.joinpath('_rdict').as_posix()
262 OutFileDict['mgBoolMtx'] = args.outpath.joinpath('mgBoolMtx.mtx.gz').as_posix()
263 #pp.pprint(OutFileDict)
264 args.outpath.mkdir(parents=True, exist_ok=True)
265 eprint('[!]Output Files:[',', '.join([ OutFileDict[x].as_posix() for x in spNameTuple]),'].',sep='')
266 checkmtx(InFileDict['matrix'])
267 eprint('[!]Matrix Size: Gene count(nrows)=',GenesCnt,', Barcode count(ncols)=',BarcodesCnt,', Values(nnz)=',mtxNNZ,'.',sep='')
268 if args.dryrun: exit(0);
269 global spatialDB, SpatialBarcodeRange_xXyY, gridRangeCnt, mgBoolMtx
270 #mtxBar2sp = [None] * BarcodesCnt
271 mkcopy(InFileDict['features'], OutFileDict['features'])
273 start = time.perf_counter()
274 eprint('[!]Reading spatial file ...')
275 spatialDB = speedict.Rdict(OutFileDict['Rdict'],db_options())
276 #spatialDB = {}
277 lineCnt = readSpatial(InFileDict['spatial'], spatialDB)
278 eprint('[!]Finished with [',lineCnt,'] records. X∈[',','.join(map(str,SpatialBarcodeRange_xXyY[0:2])),'], Y∈[',','.join(map(str,SpatialBarcodeRange_xXyY[2:4])),'].',sep='') # X∈[8000,38000], Y∈[9000,39000]
279 #pp.pprint(SpatialBarcodeRange_xXyY)
280 SpatialGridRange_xXyY = [ (x // args.bin) for x in SpatialBarcodeRange_xXyY ]
281 #gridRangeX = 1 + SpatialGridRange_xXyY[1] - SpatialGridRange_xXyY[0]
282 #gridRangeY = 1 + SpatialGridRange_xXyY[3] - SpatialGridRange_xXyY[2]
283 (gridRangeX, gridRangeY) = (1+SpatialGridRange_xXyY[1], 1+SpatialGridRange_xXyY[3])
284 gridRangeCnt = (gridRangeX, gridRangeY, gridRangeX * gridRangeY)
285 eprint('[!]Gridded by Bin [',args.bin,'], GridSize=','×'.join(map(str,(gridRangeX,gridRangeY))),'=',str(gridRangeCnt[2]),'.',sep='' )
286 mgBoolMtx = gb.Matrix(bool, BarcodesCnt, gridRangeCnt[2])
287 end1p = time.perf_counter()
288 eprint("\tElapsed {}s".format((end1p - start)))
290 executor = concurrent.futures.ProcessPoolExecutor(max_workers=4)
291 executor.submit( mkGridSpatial, OutFileDict['spatial'],OutFileDict['barcodes'], gridRangeCnt )
293 eprint('[!]Reading barcodes file ...')
294 #cmpGridID(1,2)
295 missingCnt = updateBarcodesID(InFileDict['barcodes'], spatialDB, args.bin)
296 eprint('[!]Finished with [',missingCnt,'] missing barcodes.',sep='')
297 fh = write2gzip(OutFileDict['mgBoolMtx'])
298 gb.io.mmwrite(target=fh, matrix=mgBoolMtx)
299 fh.close()
300 spatialDB.close()
301 end2p = time.perf_counter()
302 eprint("\tElapsed {}s".format((end2p - end1p)))
304 eprint('[!]Reading Matrix file ...')
305 scMtx = gb.io.mmread(source=InFileDict['matrix'], engine='fmm')
306 outGrid = scMtx.mxm(mgBoolMtx) # lazy
307 end3p = time.perf_counter()
308 eprint("\tElapsed {}s".format((end3p - end2p)))
310 eprint('[!]Calculating Matrix file ...')
311 outGridResult = outGrid.new()
312 end4p = time.perf_counter()
313 eprint("\tElapsed {}s".format((end4p - end3p)))
315 eprint('[!]Writing Matrix file ...')
316 fh = write2gzip(OutFileDict['matrix'])
317 gb.io.mmwrite(target=fh, matrix=outGridResult)
318 fh.close()
319 end5p = time.perf_counter()
320 eprint("\tElapsed {}s".format((end5p - end4p)))
322 executor.shutdown(wait=True)
323 eprint('[!]All done !')
324 exit(0);
325 #spatialDB.destroy(OutFileDict['Rdict']) # It is better to keep db file to enable supporting restore running.
326 exit(0);
327 #outMtx = ''.join((outPrefix,'.mtx'))
328 #matrixData = gb.io.mmread(matrixFile)
330 if __name__ == "__main__":
331 gb.init("suitesparse", blocking=False)
332 main() # time ./splanegrid.py -b20 -f matrix2.mtx.gz barcodes.tsv.gz features.tsv.gz -i spatial.txt.gz
334 # ./splanegrid.py -b20 -i GSE166635_RAW/GSM5076750_HCC2.barcodes.spatial.txt -f GSE166635_RAW/GSM5076750_HCC2.matrix.mtx.gz GSE166635_RAW/GSM5076750_HCC2.barcodes.tsv.gz GSE166635_RAW/GSM5076750_HCC2.features.tsv.gz