fmindex_test.cu 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047
  1. /*
  2. * nvbio
  3. * Copyright (c) 2011-2014, NVIDIA CORPORATION. All rights reserved.
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. * * Redistributions of source code must retain the above copyright
  8. * notice, this list of conditions and the following disclaimer.
  9. * * Redistributions in binary form must reproduce the above copyright
  10. * notice, this list of conditions and the following disclaimer in the
  11. * documentation and/or other materials provided with the distribution.
  12. * * Neither the name of the NVIDIA CORPORATION nor the
  13. * names of its contributors may be used to endorse or promote products
  14. * derived from this software without specific prior written permission.
  15. *
  16. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  17. * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  18. * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  19. * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  20. * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  21. * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  22. * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  23. * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  25. * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. */
  27. // fmindex_test.cpp
  28. //
  29. #define MOD_NAMESPACE
  30. #define MOD_NAMESPACE_NAME fmitest
  31. #define MOD_NAMESPACE_BEGIN namespace fmitest {
  32. #define MOD_NAMESPACE_END }
  33. //#define NVBIO_CUDA_DEBUG
  34. //#define NVBIO_CUDA_ASSERTS
  35. #include <nvbio/basic/omp.h>
  36. #include <stdio.h>
  37. #include <stdlib.h>
  38. #include <vector>
  39. #include <algorithm>
  40. #include <nvbio/basic/timer.h>
  41. #include <nvbio/basic/console.h>
  42. #include <nvbio/basic/dna.h>
  43. #include <nvbio/basic/cached_iterator.h>
  44. #include <nvbio/basic/packedstream.h>
  45. #include <nvbio/basic/deinterleaved_iterator.h>
  46. #include <nvbio/fmindex/bwt.h>
  47. #include <nvbio/fmindex/ssa.h>
  48. #include <nvbio/fmindex/fmindex.h>
  49. #include <nvbio/fmindex/backtrack.h>
  50. #include <nvbio/io/sequence/sequence.h>
  51. #include <nvbio/io/fmindex/fmindex.h>
  52. using namespace nvbio;
  53. struct ssa_nop {};
  54. namespace { // anonymous namespace
  55. template <uint32 OCC_INTERVAL,typename FMIndexType, typename word_type>
  56. __global__ void locate_kernel(
  57. const uint32 n_queries,
  58. const uint32 QUERY_LEN,
  59. const uint32 genome_length,
  60. const word_type* genome_stream,
  61. const FMIndexType fmi,
  62. const uint32* input,
  63. uint32* output)
  64. {
  65. typedef typename FMIndexType::index_type index_type;
  66. typedef typename FMIndexType::range_type range_type;
  67. const uint32 thread_id = threadIdx.x + blockIdx.x*blockDim.x;
  68. if (thread_id >= n_queries)
  69. return;
  70. typedef const_cached_iterator<const word_type*> cached_stream_type;
  71. typedef PackedStream<cached_stream_type,uint8,2,true,index_type> genome_stream_type;
  72. const cached_stream_type cached_genome_stream( genome_stream );
  73. const genome_stream_type genome( cached_genome_stream );
  74. const range_type range = match(
  75. fmi,
  76. genome + input[ thread_id ],
  77. QUERY_LEN );
  78. output[ thread_id ] = uint32( locate( fmi, range.x ) );
  79. }
  80. // test the gpu SSA against the cpu one
  81. template <typename SSA_device, typename SSA_host>
  82. void test_ssa(
  83. const SSA_device& ssa_device,
  84. const SSA_host& ssa)
  85. {
  86. thrust::host_vector<typename SSA_device::value_type> d_ssa = ssa_device.m_ssa;
  87. for (uint32 i = 0; i < d_ssa.size(); ++i)
  88. {
  89. if (d_ssa[i] != ssa.m_ssa[i])
  90. {
  91. fprintf(stderr, " \nerror : expected SSA[%u] = %u, got: %u\n", i, (uint32)ssa.m_ssa[i], (uint32)d_ssa[i]);
  92. exit(1);
  93. }
  94. }
  95. }
  96. template <typename index_type>
  97. struct HostData
  98. {
  99. uint32 primary;
  100. thrust::host_vector<index_type> text;
  101. thrust::host_vector<index_type> bwt;
  102. thrust::host_vector<index_type> occ;
  103. thrust::host_vector<index_type> bwt_occ;
  104. thrust::host_vector<index_type> L2;
  105. thrust::host_vector<uint32> count_table;
  106. thrust::host_vector<uint32> input;
  107. thrust::host_vector<uint32> output;
  108. };
  109. template <typename index_type>
  110. struct DeviceData
  111. {
  112. uint32 primary;
  113. thrust::device_vector<index_type> text;
  114. thrust::device_vector<index_type> bwt;
  115. thrust::device_vector<index_type> occ;
  116. thrust::device_vector<index_type> bwt_occ;
  117. thrust::device_vector<index_type> L2;
  118. thrust::device_vector<uint32> count_table;
  119. thrust::device_vector<uint32> input;
  120. thrust::device_vector<uint32> output;
  121. DeviceData(const HostData<index_type>& data) :
  122. primary( data.primary ),
  123. text( data.text ),
  124. bwt( data.bwt ),
  125. occ( data.occ ),
  126. bwt_occ( data.bwt_occ ),
  127. L2( data.L2 ),
  128. count_table( data.count_table ),
  129. input( data.input ),
  130. output( data.output ) {}
  131. };
  132. template <uint32 OCC_INT, uint32 SA_INT, typename BwtIterator, typename OccIterator, typename SSA, typename index_type>
  133. void do_synthetic_test_device(
  134. const uint32 REQS,
  135. const uint32 LEN,
  136. const uint32 PLEN,
  137. const HostData<index_type>& host_data,
  138. const SSA& ssa,
  139. DeviceData<index_type>& device_data,
  140. const OccIterator occ_it,
  141. const BwtIterator bwt_it)
  142. {
  143. typedef cuda::ldg_pointer<uint32> count_table_type;
  144. const count_table_type count_table( thrust::raw_pointer_cast( &device_data.count_table.front() ) );
  145. typedef PackedStream<BwtIterator,uint8,2u,true,index_type> bwt_type;
  146. typedef rank_dictionary< 2u, OCC_INT, bwt_type, OccIterator, count_table_type > rank_dict_type;
  147. rank_dict_type rank_dict(
  148. bwt_type( bwt_it ),
  149. occ_it,
  150. count_table );
  151. typedef SSA_index_multiple_context<SA_INT,const index_type*> ssa_type;
  152. typedef fm_index< rank_dict_type, ssa_type > fm_index_type;
  153. fm_index_type temp_fmi(
  154. LEN,
  155. device_data.primary,
  156. thrust::raw_pointer_cast( &device_data.L2.front() ),
  157. rank_dict,
  158. ssa_type() );
  159. //SSA_value_multiple_device ssa_device( ssa );
  160. //SSA_index_multiple_device<SA_INT> ssa_device( ssa );
  161. fprintf(stderr, " SSA gpu... started\n" );
  162. Timer timer;
  163. timer.start();
  164. SSA_index_multiple_device<SA_INT,index_type> ssa_device( temp_fmi );
  165. timer.stop();
  166. fprintf(stderr, " SSA gpu... done: %.3fs\n", timer.seconds() );
  167. // test the gpu SSA against the cpu one
  168. test_ssa( ssa_device, ssa );
  169. fprintf(stderr, " gpu alignment... started\n");
  170. fm_index_type fmi(
  171. LEN,
  172. device_data.primary,
  173. thrust::raw_pointer_cast( &device_data.L2.front() ),
  174. rank_dict,
  175. ssa_device.get_context() );
  176. cudaEvent_t start, stop;
  177. cudaEventCreate( &start );
  178. cudaEventCreate( &stop );
  179. cudaEventRecord( start, 0 );
  180. const uint32 BLOCK_SIZE = 256;
  181. const uint32 n_blocks = (REQS + BLOCK_SIZE-1) / BLOCK_SIZE;
  182. locate_kernel<OCC_INT> <<<n_blocks,BLOCK_SIZE>>>(
  183. REQS,
  184. PLEN,
  185. LEN,
  186. thrust::raw_pointer_cast( &device_data.text.front() ),
  187. fmi,
  188. thrust::raw_pointer_cast( &device_data.input.front() ),
  189. thrust::raw_pointer_cast( &device_data.output.front() ) );
  190. cudaThreadSynchronize();
  191. float time;
  192. cudaEventRecord( stop, 0 );
  193. cudaEventSynchronize( stop );
  194. cudaEventElapsedTime( &time, start, stop );
  195. fprintf(stderr, " gpu alignment... done: %.1fms, A/s: %.2f M\n", time, REQS/(time*1000.0f) );
  196. thrust::host_vector<uint32> output_h( device_data.output );
  197. for (uint32 i = 0; i < REQS; ++i)
  198. {
  199. if (host_data.output[i] != output_h[i])
  200. {
  201. fprintf(stderr, "\nerror : mismatch at %u: expected %u, got %u\n", i, host_data.output[i], output_h[i] );
  202. exit(1);
  203. }
  204. }
  205. }
  206. template <uint32 OCC_INT, uint32 SA_INT, typename SSA>
  207. void synthetic_test_device(
  208. const uint32 REQS,
  209. const uint32 LEN,
  210. const uint32 PLEN,
  211. const uint32 WORDS,
  212. const uint32 OCC_WORDS,
  213. const HostData<uint32>& host_data,
  214. const SSA& ssa)
  215. {
  216. try
  217. {
  218. DeviceData<uint32> device_data( host_data );
  219. // test an FM-index with separate bwt/occ tables
  220. {
  221. typedef cuda::ldg_pointer<uint4> iterator_type;
  222. iterator_type occ_it( (const uint4*)thrust::raw_pointer_cast( &device_data.occ.front() ) );
  223. iterator_type bwt_it( (const uint4*)thrust::raw_pointer_cast( &device_data.bwt.front() ) );
  224. do_synthetic_test_device<OCC_INT, SA_INT>(
  225. REQS,
  226. LEN,
  227. PLEN,
  228. host_data,
  229. ssa,
  230. device_data,
  231. occ_it,
  232. bwt_it );
  233. }
  234. // test an FM-index with interleaved bwt/occ tables
  235. if (WORDS == OCC_WORDS)
  236. {
  237. typedef cuda::ldg_pointer<uint4> bwt_occ_texture;
  238. bwt_occ_texture bwt_occ_tex( (const uint4*)thrust::raw_pointer_cast( &device_data.bwt_occ.front() ) );
  239. typedef deinterleaved_iterator<2,0,bwt_occ_texture> bwt_iterator;
  240. typedef deinterleaved_iterator<2,1,bwt_occ_texture> occ_iterator;
  241. occ_iterator occ_it( bwt_occ_tex );
  242. bwt_iterator bwt_it( bwt_occ_tex );
  243. do_synthetic_test_device<OCC_INT, SA_INT>(
  244. REQS,
  245. LEN,
  246. PLEN,
  247. host_data,
  248. ssa,
  249. device_data,
  250. occ_it,
  251. bwt_it );
  252. }
  253. }
  254. catch (std::exception exception)
  255. {
  256. fprintf(stderr, " \nerror : exception caught : %s\n", exception.what());
  257. exit(1);
  258. }
  259. catch (...)
  260. {
  261. fprintf(stderr, " \nerror : unknown exception\n");
  262. exit(1);
  263. }
  264. }
  265. template <uint32 OCC_INT, uint32 SA_INT, typename SSA>
  266. void synthetic_test_device(
  267. const uint32 REQS,
  268. const uint32 LEN,
  269. const uint32 PLEN,
  270. const uint32 WORDS,
  271. const uint32 OCC_WORDS,
  272. const HostData<uint64>& host_data,
  273. const SSA& ssa)
  274. {
  275. try
  276. {
  277. DeviceData<uint64> device_data( host_data );
  278. // test an FM-index with separate bwt/occ tables
  279. {
  280. typedef cuda::ldg_pointer<uint64> iterator_type;
  281. iterator_type occ_it( (const uint64*)thrust::raw_pointer_cast( &device_data.occ.front() ) );
  282. iterator_type bwt_it( (const uint64*)thrust::raw_pointer_cast( &device_data.bwt.front() ) );
  283. do_synthetic_test_device<OCC_INT, SA_INT>(
  284. REQS,
  285. LEN,
  286. PLEN,
  287. host_data,
  288. ssa,
  289. device_data,
  290. occ_it,
  291. bwt_it );
  292. }
  293. // test an FM-index with interleaved bwt/occ tables
  294. if (WORDS == OCC_WORDS)
  295. {
  296. typedef cuda::ldg_pointer<uint64> bwt_occ_texture;
  297. bwt_occ_texture bwt_occ_tex( (const uint64*)thrust::raw_pointer_cast( &device_data.bwt_occ.front() ) );
  298. typedef deinterleaved_iterator<2,0,bwt_occ_texture> bwt_iterator;
  299. typedef deinterleaved_iterator<2,1,bwt_occ_texture> occ_iterator;
  300. occ_iterator occ_it( bwt_occ_tex );
  301. bwt_iterator bwt_it( bwt_occ_tex );
  302. do_synthetic_test_device<OCC_INT, SA_INT>(
  303. REQS,
  304. LEN,
  305. PLEN,
  306. host_data,
  307. ssa,
  308. device_data,
  309. occ_it,
  310. bwt_it );
  311. }
  312. }
  313. catch (std::exception exception)
  314. {
  315. fprintf(stderr, " \nerror : exception caught : %s\n", exception.what());
  316. exit(1);
  317. }
  318. catch (...)
  319. {
  320. fprintf(stderr, " \nerror : unknown exception\n");
  321. exit(1);
  322. }
  323. }
  324. // perform an alignment test on the cpu
  325. //
  326. template <
  327. typename TextType,
  328. typename FMIndexType,
  329. typename index_type>
  330. void synthetic_test_host(
  331. const uint32 REQS,
  332. const uint32 PLEN,
  333. const TextType text,
  334. const FMIndexType fmi,
  335. HostData<index_type>& data)
  336. {
  337. fprintf(stderr, " cpu alignment... started" );
  338. typedef typename FMIndexType::range_type range_type;
  339. Timer timer;
  340. timer.start();
  341. for (uint32 i = 0; i < REQS; ++i)
  342. {
  343. if ((i & 1023) == 0)
  344. fprintf(stderr, "\r cpu alignment... started: %.1f%% ", 100.0f*float(i)/float(REQS) );
  345. const range_type range = match(
  346. fmi,
  347. text + data.input[i],
  348. PLEN );
  349. if (range.y < range.x)
  350. {
  351. fprintf(stderr, " \nerror: unable to match pattern %u\n", data.input[i]);
  352. exit(1);
  353. }
  354. data.output[i] = uint32( locate( fmi, range.x ) );
  355. }
  356. timer.stop();
  357. fprintf(stderr, "\n cpu alignment... done: %.1fms, A/s: %.2f M\n", timer.seconds()*1000.0f, REQS/(timer.seconds()*1.0e6f) );
  358. }
  359. } // anonymous namespace
  360. template <typename index_type>
  361. void synthetic_test(const uint32 LEN, const uint32 QUERIES)
  362. {
  363. fprintf(stderr, " %u-bits synthetic test\n", uint32(sizeof(index_type)*8));
  364. const uint32 OCC_INT = sizeof(index_type) == sizeof(uint32) ? 64 : 128;
  365. const uint32 SA_INT = 32;
  366. const uint32 SYM_PER_WORD = 4*sizeof(index_type);
  367. const uint32 PLEN = 8;
  368. const uint32 REQS = nvbio::min( uint32(LEN-PLEN-1u), QUERIES );
  369. const uint32 WORDS = (LEN+SYM_PER_WORD-1)/SYM_PER_WORD;
  370. const uint32 OCC_WORDS = ((LEN+OCC_INT-1) / OCC_INT) * 4;
  371. Timer timer;
  372. const uint64 memory_footprint =
  373. sizeof(index_type)*WORDS +
  374. sizeof(index_type)*WORDS +
  375. sizeof(index_type)*OCC_WORDS +
  376. sizeof(index_type)*uint64(LEN+SA_INT)/SA_INT;
  377. fprintf(stderr, " memory : %.1f MB\n", float(memory_footprint)/float(1024*1024));
  378. HostData<index_type> data;
  379. data.text.resize( align<4>(WORDS), 0u );
  380. data.bwt.resize( align<4>(WORDS), 0u );
  381. data.occ.resize( align<4>(OCC_WORDS), 0u );
  382. data.L2.resize( 5 );
  383. data.count_table.resize( 256 );
  384. data.input.resize( REQS );
  385. data.output.resize( REQS );
  386. typedef PackedStream<index_type*,uint8,2,true,index_type> stream_type;
  387. stream_type text( &data.text[0] );
  388. for (uint32 i = 0; i < LEN; ++i)
  389. text[i] = (rand() % 4);
  390. // print the string
  391. if (LEN < 64)
  392. {
  393. char string[64];
  394. dna_to_string(
  395. text,
  396. text + LEN,
  397. string );
  398. fprintf(stderr, " string : %s\n", string);
  399. }
  400. // generate the suffix array
  401. std::vector<int32> sa( LEN+1, 0u );
  402. gen_sa( LEN, text, &sa[0] );
  403. stream_type bwt( &data.bwt[0] );
  404. data.primary = gen_bwt_from_sa( LEN, text, &sa[0], bwt );
  405. // set sa[0] to -1 so as to get a modulo for free
  406. sa[0] = -1;
  407. // print the string
  408. if (LEN < 64)
  409. {
  410. char string[64];
  411. dna_to_string(
  412. bwt,
  413. bwt + LEN,
  414. string );
  415. fprintf(stderr, " bwt : %s\n", string);
  416. }
  417. fprintf(stderr," primary : %d\n", data.primary );
  418. // buld the occurrence table
  419. build_occurrence_table<2u,OCC_INT>(
  420. bwt,
  421. bwt + LEN,
  422. &data.occ[0],
  423. &data.L2[1] );
  424. // transform the L2 table into a cumulative sum
  425. data.L2[0] = 0;
  426. for (uint32 c = 0; c < 4; ++c)
  427. data.L2[c+1] += data.L2[c];
  428. // print the L2
  429. if (LEN < 64)
  430. {
  431. for (uint32 i = 0; i < 5; ++i)
  432. fprintf(stderr, " L2[%u] : %u\n", i, uint32( data.L2[i] ));
  433. }
  434. // generate the count table
  435. gen_bwt_count_table( &data.count_table[0] );
  436. // build the interleaved bwt/occ array
  437. if (WORDS == OCC_WORDS)
  438. {
  439. fprintf(stderr, " building interleaved bwt/occ... started\n" );
  440. data.bwt_occ.resize( WORDS*2 );
  441. if (sizeof(index_type) == 4)
  442. {
  443. for (uint32 w = 0; w < WORDS; w += 4)
  444. {
  445. data.bwt_occ[ w*2+0 ] = data.bwt[ w+0 ];
  446. data.bwt_occ[ w*2+1 ] = data.bwt[ w+1 ];
  447. data.bwt_occ[ w*2+2 ] = data.bwt[ w+2 ];
  448. data.bwt_occ[ w*2+3 ] = data.bwt[ w+3 ];
  449. data.bwt_occ[ w*2+4 ] = data.occ[ w+0 ];
  450. data.bwt_occ[ w*2+5 ] = data.occ[ w+1 ];
  451. data.bwt_occ[ w*2+6 ] = data.occ[ w+2 ];
  452. data.bwt_occ[ w*2+7 ] = data.occ[ w+3 ];
  453. }
  454. }
  455. else
  456. {
  457. for (uint32 w = 0; w < WORDS; ++w)
  458. {
  459. data.bwt_occ[ w*2+0 ] = data.bwt[ w ];
  460. data.bwt_occ[ w*2+1 ] = data.occ[ w ];
  461. }
  462. }
  463. fprintf(stderr, " building interleaved bwt/occ... done\n" );
  464. }
  465. typedef PackedStream<const index_type*,uint8,2u,true,index_type> bwt_type;
  466. typedef rank_dictionary<2u, OCC_INT, bwt_type, const index_type*, const uint32*> rank_dict_type;
  467. typedef fm_index<rank_dict_type, ssa_nop> temp_fm_index_type;
  468. temp_fm_index_type temp_fmi(
  469. LEN,
  470. data.primary,
  471. &data.L2[0],
  472. rank_dict_type(
  473. bwt_type( &data.bwt[0] ),
  474. &data.occ[0],
  475. &data.count_table[0] ),
  476. ssa_nop() );
  477. #if 0
  478. // test the Sampled Suffix Array class
  479. typedef SSA_value_multiple SSA_type;
  480. SSA_value_multiple ssa( temp_fmi, SA_INT );
  481. SSA_value_multiple::context_type ssa_context = ssa.get_context();
  482. #else
  483. // test the Sampled Suffix Array class
  484. typedef SSA_index_multiple<SA_INT,index_type> SSA_type;
  485. timer.start();
  486. SSA_type ssa( temp_fmi );
  487. timer.stop();
  488. fprintf(stderr, " SSA cpu time: %.3fs\n", timer.seconds() );
  489. typename SSA_type::context_type ssa_context = ssa.get_context();
  490. #endif
  491. fprintf(stderr, " SSA test... started\n" );
  492. for (uint32 i = 1; i < LEN; ++i)
  493. {
  494. index_type val;
  495. if (ssa_context.fetch( index_type(i), val ) && (val != (uint32)sa[i]))
  496. {
  497. fprintf(stderr, " SSA mismatch at %u: expected %d, got: %u\n", i, uint32( sa[i] ), uint32( val ));
  498. exit(1);
  499. }
  500. }
  501. fprintf(stderr, " SSA test... done\n" );
  502. typedef fm_index<rank_dict_type, typename SSA_type::context_type> fm_index_type;
  503. fm_index_type fmi(
  504. LEN,
  505. data.primary,
  506. &data.L2[0],
  507. rank_dict_type(
  508. bwt_type( &data.bwt[0] ),
  509. &data.occ[0],
  510. &data.count_table[0] ),
  511. ssa_context );
  512. typedef typename fm_index_type::range_type range_type;
  513. uint8 pattern[PLEN];
  514. char pattern_str[PLEN+1];
  515. fprintf(stderr, " alignment test... started:" );
  516. for (uint32 i = 0; i < 1000; ++i)
  517. {
  518. fprintf(stderr, "\r alignment test... started: %.1f%% ", 100.0f*float(i)/1000.0f );
  519. for (uint32 j = 0; j < PLEN; ++j)
  520. pattern[j] = text[i+j];
  521. dna_to_string(
  522. pattern,
  523. pattern + PLEN,
  524. pattern_str );
  525. range_type range = match(
  526. fmi,
  527. pattern,
  528. PLEN );
  529. if (range.x > range.y)
  530. {
  531. fprintf(stderr, " \nerror : searching for %s @ %u, resulted in (%u,%u)\n", pattern_str, i, uint32( range.x ), uint32( range.y ));
  532. exit(1);
  533. }
  534. // locate the first 100 alignments
  535. range.y = nvbio::min( range.x + 10u, range.y );
  536. for (index_type x = range.x; x <= range.y; ++x)
  537. {
  538. const uint32 prefix = locate( fmi, x );
  539. if (prefix >= LEN)
  540. {
  541. const range_type inv = inv_psi( fmi, x );
  542. fprintf(stderr, " \nerror : searching for %s @ %u, resulted in prefix out of bounds: %u (= sa[%u] + %u)\n", pattern_str, i, prefix, uint32(inv.x), uint32(inv.y));
  543. exit(1);
  544. }
  545. char found_str[PLEN+1];
  546. dna_to_string(
  547. text + prefix,
  548. text + prefix + PLEN,
  549. found_str );
  550. if (strcmp( found_str, pattern_str ) != 0)
  551. {
  552. const range_type inv = inv_psi( fmi, x );
  553. fprintf(stderr, " \nerror : locating %s @ %u at SA=%u in SA(%u,%u), resulted in %s @ %u (= sa[%u] + %u)\n", pattern_str, i, uint32( x ), uint32( range.x ), uint32( range.y ), found_str, prefix, uint32(inv.x), uint32(inv.y));
  554. exit(1);
  555. }
  556. /*{
  557. const uint2 inv = inv_psi( fmi, x );
  558. fprintf(stderr, " locating %s @ %u at %u, matched at %u (= sa[%u] + %u)\n", pattern_str, i, x, prefix, inv.x, inv.y);
  559. }*/
  560. }
  561. }
  562. fprintf(stderr, "\n alignment test... done\n" );
  563. const uint32 SPARSITY = 100;
  564. data.input[0] = 0;
  565. for (uint32 i = 1; i < REQS; ++i)
  566. data.input[i] = (data.input[i-1] + (rand() % SPARSITY)) % (LEN - PLEN);
  567. fprintf(stderr, " sorted alignment tests... started\n" );
  568. synthetic_test_host(
  569. REQS,
  570. PLEN,
  571. text,
  572. fmi,
  573. data );
  574. synthetic_test_device<OCC_INT,SA_INT>(
  575. REQS,
  576. LEN,
  577. PLEN,
  578. WORDS,
  579. OCC_WORDS,
  580. data,
  581. ssa );
  582. fprintf(stderr, " sorted alignment tests... done\n" );
  583. fprintf(stderr, " shuffled alignment tests... started\n" );
  584. for (uint32 i = 0; i < REQS; ++i)
  585. {
  586. const uint32 j = i + rand() % (REQS - i);
  587. std::swap( data.input[i], data.input[j] );
  588. }
  589. synthetic_test_host(
  590. REQS,
  591. PLEN,
  592. text,
  593. fmi,
  594. data );
  595. synthetic_test_device<OCC_INT,SA_INT>(
  596. REQS,
  597. LEN,
  598. PLEN,
  599. WORDS,
  600. OCC_WORDS,
  601. data,
  602. ssa );
  603. fprintf(stderr, " shuffled alignment tests... done\n" );
  604. }
  605. //
  606. // A backtracking delegate used to count the total number of occurrences
  607. //
  608. struct CountDelegate
  609. {
  610. // constructor
  611. //
  612. // \param count pointer to the global counter
  613. NVBIO_FORCEINLINE NVBIO_HOST_DEVICE
  614. CountDelegate(uint32* count) : m_count( count ) {}
  615. // main functor operator
  616. //
  617. NVBIO_FORCEINLINE NVBIO_HOST_DEVICE
  618. void operator() (const uint2 range) const
  619. {
  620. #if defined(NVBIO_DEVICE_COMPILATION)
  621. atomicAdd( m_count, range.y + 1u - range.x );
  622. #else
  623. *m_count += range.y + 1u - range.x;
  624. #endif
  625. }
  626. private:
  627. uint32* m_count; // global counter
  628. };
  629. //
  630. // k-mer counting kernel
  631. //
  632. template <typename ReadsView, typename FMIndexType>
  633. NVBIO_FORCEINLINE NVBIO_HOST_DEVICE
  634. void count_core(
  635. const uint32 read_id, // read id
  636. const ReadsView reads, // reads view
  637. const FMIndexType fmi, // FM-index
  638. const uint32 len, // pattern length
  639. const uint32 seed, // exact-matching seed length
  640. const uint32 mismatches, // number of allowed mismatches after the seed
  641. uint32* count) // global output counter
  642. {
  643. CountDelegate counter( count );
  644. typedef typename ReadsView::sequence_stream_type read_stream_type;
  645. uint4 stack[32*4];
  646. hamming_backtrack(
  647. fmi,
  648. reads.get_read( read_id ).begin(),
  649. len,
  650. seed,
  651. mismatches,
  652. stack,
  653. counter );
  654. }
  655. //
  656. // k-mer counting kernel
  657. //
  658. template <typename ReadsView, typename FMIndexType>
  659. __global__
  660. void count_kernel(
  661. const ReadsView reads, // reads view
  662. const FMIndexType fmi, // FM-index
  663. const uint32 len, // pattern length
  664. const uint32 seed, // exact-matching seed length
  665. const uint32 mismatches, // number of allowed mismatches after the seed
  666. uint32* count) // global output counter
  667. {
  668. const uint32 thread_id = threadIdx.x + blockIdx.x*blockDim.x;
  669. if (thread_id >= reads.size())
  670. return;
  671. count_core( thread_id, reads, fmi, len, seed, mismatches, count );
  672. }
  673. //
  674. // run a set of backtracking tests with real data
  675. //
  676. void backtrack_test(const char* index_file, const char* reads_name, const uint32 n_reads)
  677. {
  678. io::FMIndexDataHost h_fmi;
  679. if (h_fmi.load( index_file, io::FMIndexData::FORWARD ))
  680. {
  681. typedef io::FMIndexData::partial_fm_index_type host_fmindex_type;
  682. typedef io::FMIndexDataDevice::fm_index_type cuda_fmindex_type;
  683. io::FMIndexDataDevice d_fmi( h_fmi, io::FMIndexDataDevice::FORWARD );
  684. host_fmindex_type h_fmindex = h_fmi.partial_index();
  685. cuda_fmindex_type d_fmindex = d_fmi.index();
  686. io::SequenceDataStream* reads_file = io::open_sequence_file(
  687. reads_name,
  688. io::Phred,
  689. n_reads,
  690. 50 );
  691. if (reads_file == NULL)
  692. {
  693. log_error(stderr, "unable to load \"%s\"\n", reads_name);
  694. exit(1);
  695. }
  696. // create a host-side read batch
  697. io::SequenceDataHost h_reads_data;
  698. // load a batch
  699. if (io::next( DNA_N, &h_reads_data, reads_file, n_reads ) == 0)
  700. {
  701. log_error(stderr, "unable to fetch reads from file \"%s\"\n", reads_name);
  702. exit(1);
  703. }
  704. // create a device-side read_batch
  705. const io::SequenceDataDevice d_reads_data( h_reads_data );
  706. // create a host-side read batch
  707. typedef io::SequenceDataAccess<DNA_N> read_access_type;
  708. // create a read access
  709. const read_access_type h_reads_view( h_reads_data );
  710. const read_access_type d_reads_view( d_reads_data );
  711. thrust::device_vector<uint32> counter(1);
  712. counter[0] = 0;
  713. const uint32 blockdim = 128;
  714. const uint32 n_blocks = (d_reads_data.size() + blockdim - 1) / blockdim;
  715. // 20-mers, distance=0
  716. {
  717. cudaEvent_t start, stop;
  718. cudaEventCreate( &start );
  719. cudaEventCreate( &stop );
  720. cudaEventRecord( start, 0 );
  721. count_kernel<<<n_blocks,blockdim>>>(
  722. d_reads_view,
  723. d_fmindex,
  724. 20u,
  725. 0u,
  726. 0u,
  727. thrust::raw_pointer_cast( &counter.front() ) );
  728. cudaThreadSynchronize();
  729. nvbio::cuda::check_error("count_kernel");
  730. float time;
  731. cudaEventRecord( stop, 0 );
  732. cudaEventSynchronize( stop );
  733. cudaEventElapsedTime( &time, start, stop );
  734. fprintf(stderr, " gpu backtracking (20,0,0)... done: %.1fms, A/s: %.3f M\n", time, d_reads_data.size()/(time*1000.0f) );
  735. }
  736. {
  737. Timer timer;
  738. timer.start();
  739. uint32 counter = 0;
  740. #pragma omp parallel for
  741. for (int i = 0; i < (int)h_reads_data.size(); ++i)
  742. {
  743. count_core(
  744. i,
  745. h_reads_view,
  746. h_fmindex,
  747. 20u,
  748. 0u,
  749. 0u,
  750. &counter );
  751. }
  752. timer.stop();
  753. float time = timer.seconds() * 1000.0f;
  754. fprintf(stderr, " cpu backtracking (20,0,0)... done: %.1fms, A/s: %.3f M\n", time, d_reads_data.size()/(time*1000.0f) );
  755. }
  756. // 32-mers, distance=1
  757. {
  758. cudaEvent_t start, stop;
  759. cudaEventCreate( &start );
  760. cudaEventCreate( &stop );
  761. cudaEventRecord( start, 0 );
  762. count_kernel<<<n_blocks,blockdim>>>(
  763. d_reads_view,
  764. d_fmindex,
  765. 32u,
  766. 0u,
  767. 1u,
  768. thrust::raw_pointer_cast( &counter.front() ) );
  769. cudaThreadSynchronize();
  770. nvbio::cuda::check_error("count_kernel");
  771. float time;
  772. cudaEventRecord( stop, 0 );
  773. cudaEventSynchronize( stop );
  774. cudaEventElapsedTime( &time, start, stop );
  775. fprintf(stderr, " gpu backtracking (32,1,0)... done: %.1fms, A/s: %.3f M\n", time, d_reads_data.size()/(time*1000.0f) );
  776. }
  777. {
  778. Timer timer;
  779. timer.start();
  780. uint32 counter = 0;
  781. #pragma omp parallel for
  782. for (int i = 0; i < (int)h_reads_data.size(); ++i)
  783. {
  784. count_core(
  785. i,
  786. h_reads_view,
  787. h_fmindex,
  788. 32u,
  789. 0u,
  790. 1u,
  791. &counter );
  792. }
  793. timer.stop();
  794. float time = timer.seconds() * 1000.0f;
  795. fprintf(stderr, " cpu backtracking (32,1,0)... done: %.1fms, A/s: %.3f M\n", time, d_reads_data.size()/(time*1000.0f) );
  796. }
  797. // 50-mers, distance=2, seed=25
  798. {
  799. cudaEvent_t start, stop;
  800. cudaEventCreate( &start );
  801. cudaEventCreate( &stop );
  802. cudaEventRecord( start, 0 );
  803. count_kernel<<<n_blocks,blockdim>>>(
  804. d_reads_view,
  805. d_fmindex,
  806. 50u,
  807. 25u,
  808. 2u,
  809. thrust::raw_pointer_cast( &counter.front() ) );
  810. cudaThreadSynchronize();
  811. nvbio::cuda::check_error("count_kernel");
  812. float time;
  813. cudaEventRecord( stop, 0 );
  814. cudaEventSynchronize( stop );
  815. cudaEventElapsedTime( &time, start, stop );
  816. fprintf(stderr, " gpu backtracking (50,2,25)... done: %.1fms, A/s: %.3f M\n", time, d_reads_data.size()/(time*1000.0f) );
  817. }
  818. {
  819. Timer timer;
  820. timer.start();
  821. uint32 counter = 0;
  822. #pragma omp parallel for
  823. for (int i = 0; i < (int)h_reads_data.size(); ++i)
  824. {
  825. count_core(
  826. i,
  827. h_reads_view,
  828. h_fmindex,
  829. 50u,
  830. 25u,
  831. 2u,
  832. &counter );
  833. }
  834. timer.stop();
  835. float time = timer.seconds() * 1000.0f;
  836. fprintf(stderr, " cpu backtracking (52,2,25)... done: %.1fms, A/s: %.3f M\n", time, d_reads_data.size()/(time*1000.0f) );
  837. }
  838. delete reads_file;
  839. }
  840. else
  841. log_warning(stderr, "unable to load \"%s\"\n", index_file);
  842. }
  843. int fmindex_test(int argc, char* argv[])
  844. {
  845. uint32 synth_len = 10000000;
  846. uint32 synth_queries = 64*1024;
  847. const char* index_name = "./data/human.NCBI36/Human.NCBI36";
  848. const char* reads_name = "./data/SRR493095_1.fastq.gz";
  849. uint32 backtrack_queries = 64*1024;
  850. uint32 threads = omp_get_num_procs();
  851. for (int i = 0; i < argc; ++i)
  852. {
  853. if (strcmp( argv[i], "-synth-length" ) == 0)
  854. synth_len = atoi( argv[++i] )*1000;
  855. else if (strcmp( argv[i], "-synth-queries" ) == 0)
  856. synth_queries = atoi( argv[++i] )*1000;
  857. else if (strcmp( argv[i], "-backtrack-queries" ) == 0)
  858. backtrack_queries = atoi( argv[++i] ) * 1024;
  859. else if (strcmp( argv[i], "-index" ) == 0)
  860. index_name = argv[++i];
  861. else if (strcmp( argv[i], "-reads" ) == 0)
  862. reads_name = argv[++i];
  863. else if (strcmp( argv[i], "-threads" ) == 0)
  864. threads = atoi( argv[++i] );
  865. }
  866. omp_set_num_threads( threads );
  867. fprintf(stderr, "FM-index test... started\n");
  868. if (synth_len && synth_queries)
  869. {
  870. synthetic_test<uint32>( synth_len, synth_queries );
  871. synthetic_test<uint64>( synth_len, synth_queries );
  872. }
  873. if (backtrack_queries)
  874. backtrack_test( index_name, reads_name, backtrack_queries );
  875. fprintf(stderr, "FM-index test... done\n");
  876. return 0;
  877. }