gtzan.py 24 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108
  1. import os
  2. from pathlib import Path
  3. from typing import Optional, Tuple, Union
  4. import torchaudio
  5. from torch import Tensor
  6. from torch.hub import download_url_to_file
  7. from torch.utils.data import Dataset
  8. from torchaudio.datasets.utils import extract_archive
  9. # The following lists prefixed with `filtered_` provide a filtered split
  10. # that:
  11. #
  12. # a. Mitigate a known issue with GTZAN (duplication)
  13. #
  14. # b. Provide a standard split for testing it against other
  15. # methods (e.g. the one in jordipons/sklearn-audio-transfer-learning).
  16. #
  17. # Those are used when GTZAN is initialised with the `filtered` keyword.
  18. # The split was taken from (github) jordipons/sklearn-audio-transfer-learning.
  19. gtzan_genres = [
  20. "blues",
  21. "classical",
  22. "country",
  23. "disco",
  24. "hiphop",
  25. "jazz",
  26. "metal",
  27. "pop",
  28. "reggae",
  29. "rock",
  30. ]
  31. filtered_test = [
  32. "blues.00012",
  33. "blues.00013",
  34. "blues.00014",
  35. "blues.00015",
  36. "blues.00016",
  37. "blues.00017",
  38. "blues.00018",
  39. "blues.00019",
  40. "blues.00020",
  41. "blues.00021",
  42. "blues.00022",
  43. "blues.00023",
  44. "blues.00024",
  45. "blues.00025",
  46. "blues.00026",
  47. "blues.00027",
  48. "blues.00028",
  49. "blues.00061",
  50. "blues.00062",
  51. "blues.00063",
  52. "blues.00064",
  53. "blues.00065",
  54. "blues.00066",
  55. "blues.00067",
  56. "blues.00068",
  57. "blues.00069",
  58. "blues.00070",
  59. "blues.00071",
  60. "blues.00072",
  61. "blues.00098",
  62. "blues.00099",
  63. "classical.00011",
  64. "classical.00012",
  65. "classical.00013",
  66. "classical.00014",
  67. "classical.00015",
  68. "classical.00016",
  69. "classical.00017",
  70. "classical.00018",
  71. "classical.00019",
  72. "classical.00020",
  73. "classical.00021",
  74. "classical.00022",
  75. "classical.00023",
  76. "classical.00024",
  77. "classical.00025",
  78. "classical.00026",
  79. "classical.00027",
  80. "classical.00028",
  81. "classical.00029",
  82. "classical.00034",
  83. "classical.00035",
  84. "classical.00036",
  85. "classical.00037",
  86. "classical.00038",
  87. "classical.00039",
  88. "classical.00040",
  89. "classical.00041",
  90. "classical.00049",
  91. "classical.00077",
  92. "classical.00078",
  93. "classical.00079",
  94. "country.00030",
  95. "country.00031",
  96. "country.00032",
  97. "country.00033",
  98. "country.00034",
  99. "country.00035",
  100. "country.00036",
  101. "country.00037",
  102. "country.00038",
  103. "country.00039",
  104. "country.00040",
  105. "country.00043",
  106. "country.00044",
  107. "country.00046",
  108. "country.00047",
  109. "country.00048",
  110. "country.00050",
  111. "country.00051",
  112. "country.00053",
  113. "country.00054",
  114. "country.00055",
  115. "country.00056",
  116. "country.00057",
  117. "country.00058",
  118. "country.00059",
  119. "country.00060",
  120. "country.00061",
  121. "country.00062",
  122. "country.00063",
  123. "country.00064",
  124. "disco.00001",
  125. "disco.00021",
  126. "disco.00058",
  127. "disco.00062",
  128. "disco.00063",
  129. "disco.00064",
  130. "disco.00065",
  131. "disco.00066",
  132. "disco.00069",
  133. "disco.00076",
  134. "disco.00077",
  135. "disco.00078",
  136. "disco.00079",
  137. "disco.00080",
  138. "disco.00081",
  139. "disco.00082",
  140. "disco.00083",
  141. "disco.00084",
  142. "disco.00085",
  143. "disco.00086",
  144. "disco.00087",
  145. "disco.00088",
  146. "disco.00091",
  147. "disco.00092",
  148. "disco.00093",
  149. "disco.00094",
  150. "disco.00096",
  151. "disco.00097",
  152. "disco.00099",
  153. "hiphop.00000",
  154. "hiphop.00026",
  155. "hiphop.00027",
  156. "hiphop.00030",
  157. "hiphop.00040",
  158. "hiphop.00043",
  159. "hiphop.00044",
  160. "hiphop.00045",
  161. "hiphop.00051",
  162. "hiphop.00052",
  163. "hiphop.00053",
  164. "hiphop.00054",
  165. "hiphop.00062",
  166. "hiphop.00063",
  167. "hiphop.00064",
  168. "hiphop.00065",
  169. "hiphop.00066",
  170. "hiphop.00067",
  171. "hiphop.00068",
  172. "hiphop.00069",
  173. "hiphop.00070",
  174. "hiphop.00071",
  175. "hiphop.00072",
  176. "hiphop.00073",
  177. "hiphop.00074",
  178. "hiphop.00075",
  179. "hiphop.00099",
  180. "jazz.00073",
  181. "jazz.00074",
  182. "jazz.00075",
  183. "jazz.00076",
  184. "jazz.00077",
  185. "jazz.00078",
  186. "jazz.00079",
  187. "jazz.00080",
  188. "jazz.00081",
  189. "jazz.00082",
  190. "jazz.00083",
  191. "jazz.00084",
  192. "jazz.00085",
  193. "jazz.00086",
  194. "jazz.00087",
  195. "jazz.00088",
  196. "jazz.00089",
  197. "jazz.00090",
  198. "jazz.00091",
  199. "jazz.00092",
  200. "jazz.00093",
  201. "jazz.00094",
  202. "jazz.00095",
  203. "jazz.00096",
  204. "jazz.00097",
  205. "jazz.00098",
  206. "jazz.00099",
  207. "metal.00012",
  208. "metal.00013",
  209. "metal.00014",
  210. "metal.00015",
  211. "metal.00022",
  212. "metal.00023",
  213. "metal.00025",
  214. "metal.00026",
  215. "metal.00027",
  216. "metal.00028",
  217. "metal.00029",
  218. "metal.00030",
  219. "metal.00031",
  220. "metal.00032",
  221. "metal.00033",
  222. "metal.00038",
  223. "metal.00039",
  224. "metal.00067",
  225. "metal.00070",
  226. "metal.00073",
  227. "metal.00074",
  228. "metal.00075",
  229. "metal.00078",
  230. "metal.00083",
  231. "metal.00085",
  232. "metal.00087",
  233. "metal.00088",
  234. "pop.00000",
  235. "pop.00001",
  236. "pop.00013",
  237. "pop.00014",
  238. "pop.00043",
  239. "pop.00063",
  240. "pop.00064",
  241. "pop.00065",
  242. "pop.00066",
  243. "pop.00069",
  244. "pop.00070",
  245. "pop.00071",
  246. "pop.00072",
  247. "pop.00073",
  248. "pop.00074",
  249. "pop.00075",
  250. "pop.00076",
  251. "pop.00077",
  252. "pop.00078",
  253. "pop.00079",
  254. "pop.00082",
  255. "pop.00088",
  256. "pop.00089",
  257. "pop.00090",
  258. "pop.00091",
  259. "pop.00092",
  260. "pop.00093",
  261. "pop.00094",
  262. "pop.00095",
  263. "pop.00096",
  264. "reggae.00034",
  265. "reggae.00035",
  266. "reggae.00036",
  267. "reggae.00037",
  268. "reggae.00038",
  269. "reggae.00039",
  270. "reggae.00040",
  271. "reggae.00046",
  272. "reggae.00047",
  273. "reggae.00048",
  274. "reggae.00052",
  275. "reggae.00053",
  276. "reggae.00064",
  277. "reggae.00065",
  278. "reggae.00066",
  279. "reggae.00067",
  280. "reggae.00068",
  281. "reggae.00071",
  282. "reggae.00079",
  283. "reggae.00082",
  284. "reggae.00083",
  285. "reggae.00084",
  286. "reggae.00087",
  287. "reggae.00088",
  288. "reggae.00089",
  289. "reggae.00090",
  290. "rock.00010",
  291. "rock.00011",
  292. "rock.00012",
  293. "rock.00013",
  294. "rock.00014",
  295. "rock.00015",
  296. "rock.00027",
  297. "rock.00028",
  298. "rock.00029",
  299. "rock.00030",
  300. "rock.00031",
  301. "rock.00032",
  302. "rock.00033",
  303. "rock.00034",
  304. "rock.00035",
  305. "rock.00036",
  306. "rock.00037",
  307. "rock.00039",
  308. "rock.00040",
  309. "rock.00041",
  310. "rock.00042",
  311. "rock.00043",
  312. "rock.00044",
  313. "rock.00045",
  314. "rock.00046",
  315. "rock.00047",
  316. "rock.00048",
  317. "rock.00086",
  318. "rock.00087",
  319. "rock.00088",
  320. "rock.00089",
  321. "rock.00090",
  322. ]
  323. filtered_train = [
  324. "blues.00029",
  325. "blues.00030",
  326. "blues.00031",
  327. "blues.00032",
  328. "blues.00033",
  329. "blues.00034",
  330. "blues.00035",
  331. "blues.00036",
  332. "blues.00037",
  333. "blues.00038",
  334. "blues.00039",
  335. "blues.00040",
  336. "blues.00041",
  337. "blues.00042",
  338. "blues.00043",
  339. "blues.00044",
  340. "blues.00045",
  341. "blues.00046",
  342. "blues.00047",
  343. "blues.00048",
  344. "blues.00049",
  345. "blues.00073",
  346. "blues.00074",
  347. "blues.00075",
  348. "blues.00076",
  349. "blues.00077",
  350. "blues.00078",
  351. "blues.00079",
  352. "blues.00080",
  353. "blues.00081",
  354. "blues.00082",
  355. "blues.00083",
  356. "blues.00084",
  357. "blues.00085",
  358. "blues.00086",
  359. "blues.00087",
  360. "blues.00088",
  361. "blues.00089",
  362. "blues.00090",
  363. "blues.00091",
  364. "blues.00092",
  365. "blues.00093",
  366. "blues.00094",
  367. "blues.00095",
  368. "blues.00096",
  369. "blues.00097",
  370. "classical.00030",
  371. "classical.00031",
  372. "classical.00032",
  373. "classical.00033",
  374. "classical.00043",
  375. "classical.00044",
  376. "classical.00045",
  377. "classical.00046",
  378. "classical.00047",
  379. "classical.00048",
  380. "classical.00050",
  381. "classical.00051",
  382. "classical.00052",
  383. "classical.00053",
  384. "classical.00054",
  385. "classical.00055",
  386. "classical.00056",
  387. "classical.00057",
  388. "classical.00058",
  389. "classical.00059",
  390. "classical.00060",
  391. "classical.00061",
  392. "classical.00062",
  393. "classical.00063",
  394. "classical.00064",
  395. "classical.00065",
  396. "classical.00066",
  397. "classical.00067",
  398. "classical.00080",
  399. "classical.00081",
  400. "classical.00082",
  401. "classical.00083",
  402. "classical.00084",
  403. "classical.00085",
  404. "classical.00086",
  405. "classical.00087",
  406. "classical.00088",
  407. "classical.00089",
  408. "classical.00090",
  409. "classical.00091",
  410. "classical.00092",
  411. "classical.00093",
  412. "classical.00094",
  413. "classical.00095",
  414. "classical.00096",
  415. "classical.00097",
  416. "classical.00098",
  417. "classical.00099",
  418. "country.00019",
  419. "country.00020",
  420. "country.00021",
  421. "country.00022",
  422. "country.00023",
  423. "country.00024",
  424. "country.00025",
  425. "country.00026",
  426. "country.00028",
  427. "country.00029",
  428. "country.00065",
  429. "country.00066",
  430. "country.00067",
  431. "country.00068",
  432. "country.00069",
  433. "country.00070",
  434. "country.00071",
  435. "country.00072",
  436. "country.00073",
  437. "country.00074",
  438. "country.00075",
  439. "country.00076",
  440. "country.00077",
  441. "country.00078",
  442. "country.00079",
  443. "country.00080",
  444. "country.00081",
  445. "country.00082",
  446. "country.00083",
  447. "country.00084",
  448. "country.00085",
  449. "country.00086",
  450. "country.00087",
  451. "country.00088",
  452. "country.00089",
  453. "country.00090",
  454. "country.00091",
  455. "country.00092",
  456. "country.00093",
  457. "country.00094",
  458. "country.00095",
  459. "country.00096",
  460. "country.00097",
  461. "country.00098",
  462. "country.00099",
  463. "disco.00005",
  464. "disco.00015",
  465. "disco.00016",
  466. "disco.00017",
  467. "disco.00018",
  468. "disco.00019",
  469. "disco.00020",
  470. "disco.00022",
  471. "disco.00023",
  472. "disco.00024",
  473. "disco.00025",
  474. "disco.00026",
  475. "disco.00027",
  476. "disco.00028",
  477. "disco.00029",
  478. "disco.00030",
  479. "disco.00031",
  480. "disco.00032",
  481. "disco.00033",
  482. "disco.00034",
  483. "disco.00035",
  484. "disco.00036",
  485. "disco.00037",
  486. "disco.00039",
  487. "disco.00040",
  488. "disco.00041",
  489. "disco.00042",
  490. "disco.00043",
  491. "disco.00044",
  492. "disco.00045",
  493. "disco.00047",
  494. "disco.00049",
  495. "disco.00053",
  496. "disco.00054",
  497. "disco.00056",
  498. "disco.00057",
  499. "disco.00059",
  500. "disco.00061",
  501. "disco.00070",
  502. "disco.00073",
  503. "disco.00074",
  504. "disco.00089",
  505. "hiphop.00002",
  506. "hiphop.00003",
  507. "hiphop.00004",
  508. "hiphop.00005",
  509. "hiphop.00006",
  510. "hiphop.00007",
  511. "hiphop.00008",
  512. "hiphop.00009",
  513. "hiphop.00010",
  514. "hiphop.00011",
  515. "hiphop.00012",
  516. "hiphop.00013",
  517. "hiphop.00014",
  518. "hiphop.00015",
  519. "hiphop.00016",
  520. "hiphop.00017",
  521. "hiphop.00018",
  522. "hiphop.00019",
  523. "hiphop.00020",
  524. "hiphop.00021",
  525. "hiphop.00022",
  526. "hiphop.00023",
  527. "hiphop.00024",
  528. "hiphop.00025",
  529. "hiphop.00028",
  530. "hiphop.00029",
  531. "hiphop.00031",
  532. "hiphop.00032",
  533. "hiphop.00033",
  534. "hiphop.00034",
  535. "hiphop.00035",
  536. "hiphop.00036",
  537. "hiphop.00037",
  538. "hiphop.00038",
  539. "hiphop.00041",
  540. "hiphop.00042",
  541. "hiphop.00055",
  542. "hiphop.00056",
  543. "hiphop.00057",
  544. "hiphop.00058",
  545. "hiphop.00059",
  546. "hiphop.00060",
  547. "hiphop.00061",
  548. "hiphop.00077",
  549. "hiphop.00078",
  550. "hiphop.00079",
  551. "hiphop.00080",
  552. "jazz.00000",
  553. "jazz.00001",
  554. "jazz.00011",
  555. "jazz.00012",
  556. "jazz.00013",
  557. "jazz.00014",
  558. "jazz.00015",
  559. "jazz.00016",
  560. "jazz.00017",
  561. "jazz.00018",
  562. "jazz.00019",
  563. "jazz.00020",
  564. "jazz.00021",
  565. "jazz.00022",
  566. "jazz.00023",
  567. "jazz.00024",
  568. "jazz.00041",
  569. "jazz.00047",
  570. "jazz.00048",
  571. "jazz.00049",
  572. "jazz.00050",
  573. "jazz.00051",
  574. "jazz.00052",
  575. "jazz.00053",
  576. "jazz.00054",
  577. "jazz.00055",
  578. "jazz.00056",
  579. "jazz.00057",
  580. "jazz.00058",
  581. "jazz.00059",
  582. "jazz.00060",
  583. "jazz.00061",
  584. "jazz.00062",
  585. "jazz.00063",
  586. "jazz.00064",
  587. "jazz.00065",
  588. "jazz.00066",
  589. "jazz.00067",
  590. "jazz.00068",
  591. "jazz.00069",
  592. "jazz.00070",
  593. "jazz.00071",
  594. "jazz.00072",
  595. "metal.00002",
  596. "metal.00003",
  597. "metal.00005",
  598. "metal.00021",
  599. "metal.00024",
  600. "metal.00035",
  601. "metal.00046",
  602. "metal.00047",
  603. "metal.00048",
  604. "metal.00049",
  605. "metal.00050",
  606. "metal.00051",
  607. "metal.00052",
  608. "metal.00053",
  609. "metal.00054",
  610. "metal.00055",
  611. "metal.00056",
  612. "metal.00057",
  613. "metal.00059",
  614. "metal.00060",
  615. "metal.00061",
  616. "metal.00062",
  617. "metal.00063",
  618. "metal.00064",
  619. "metal.00065",
  620. "metal.00066",
  621. "metal.00069",
  622. "metal.00071",
  623. "metal.00072",
  624. "metal.00079",
  625. "metal.00080",
  626. "metal.00084",
  627. "metal.00086",
  628. "metal.00089",
  629. "metal.00090",
  630. "metal.00091",
  631. "metal.00092",
  632. "metal.00093",
  633. "metal.00094",
  634. "metal.00095",
  635. "metal.00096",
  636. "metal.00097",
  637. "metal.00098",
  638. "metal.00099",
  639. "pop.00002",
  640. "pop.00003",
  641. "pop.00004",
  642. "pop.00005",
  643. "pop.00006",
  644. "pop.00007",
  645. "pop.00008",
  646. "pop.00009",
  647. "pop.00011",
  648. "pop.00012",
  649. "pop.00016",
  650. "pop.00017",
  651. "pop.00018",
  652. "pop.00019",
  653. "pop.00020",
  654. "pop.00023",
  655. "pop.00024",
  656. "pop.00025",
  657. "pop.00026",
  658. "pop.00027",
  659. "pop.00028",
  660. "pop.00029",
  661. "pop.00031",
  662. "pop.00032",
  663. "pop.00033",
  664. "pop.00034",
  665. "pop.00035",
  666. "pop.00036",
  667. "pop.00038",
  668. "pop.00039",
  669. "pop.00040",
  670. "pop.00041",
  671. "pop.00042",
  672. "pop.00044",
  673. "pop.00046",
  674. "pop.00049",
  675. "pop.00050",
  676. "pop.00080",
  677. "pop.00097",
  678. "pop.00098",
  679. "pop.00099",
  680. "reggae.00000",
  681. "reggae.00001",
  682. "reggae.00002",
  683. "reggae.00004",
  684. "reggae.00006",
  685. "reggae.00009",
  686. "reggae.00011",
  687. "reggae.00012",
  688. "reggae.00014",
  689. "reggae.00015",
  690. "reggae.00016",
  691. "reggae.00017",
  692. "reggae.00018",
  693. "reggae.00019",
  694. "reggae.00020",
  695. "reggae.00021",
  696. "reggae.00022",
  697. "reggae.00023",
  698. "reggae.00024",
  699. "reggae.00025",
  700. "reggae.00026",
  701. "reggae.00027",
  702. "reggae.00028",
  703. "reggae.00029",
  704. "reggae.00030",
  705. "reggae.00031",
  706. "reggae.00032",
  707. "reggae.00042",
  708. "reggae.00043",
  709. "reggae.00044",
  710. "reggae.00045",
  711. "reggae.00049",
  712. "reggae.00050",
  713. "reggae.00051",
  714. "reggae.00054",
  715. "reggae.00055",
  716. "reggae.00056",
  717. "reggae.00057",
  718. "reggae.00058",
  719. "reggae.00059",
  720. "reggae.00060",
  721. "reggae.00063",
  722. "reggae.00069",
  723. "rock.00000",
  724. "rock.00001",
  725. "rock.00002",
  726. "rock.00003",
  727. "rock.00004",
  728. "rock.00005",
  729. "rock.00006",
  730. "rock.00007",
  731. "rock.00008",
  732. "rock.00009",
  733. "rock.00016",
  734. "rock.00017",
  735. "rock.00018",
  736. "rock.00019",
  737. "rock.00020",
  738. "rock.00021",
  739. "rock.00022",
  740. "rock.00023",
  741. "rock.00024",
  742. "rock.00025",
  743. "rock.00026",
  744. "rock.00057",
  745. "rock.00058",
  746. "rock.00059",
  747. "rock.00060",
  748. "rock.00061",
  749. "rock.00062",
  750. "rock.00063",
  751. "rock.00064",
  752. "rock.00065",
  753. "rock.00066",
  754. "rock.00067",
  755. "rock.00068",
  756. "rock.00069",
  757. "rock.00070",
  758. "rock.00091",
  759. "rock.00092",
  760. "rock.00093",
  761. "rock.00094",
  762. "rock.00095",
  763. "rock.00096",
  764. "rock.00097",
  765. "rock.00098",
  766. "rock.00099",
  767. ]
  768. filtered_valid = [
  769. "blues.00000",
  770. "blues.00001",
  771. "blues.00002",
  772. "blues.00003",
  773. "blues.00004",
  774. "blues.00005",
  775. "blues.00006",
  776. "blues.00007",
  777. "blues.00008",
  778. "blues.00009",
  779. "blues.00010",
  780. "blues.00011",
  781. "blues.00050",
  782. "blues.00051",
  783. "blues.00052",
  784. "blues.00053",
  785. "blues.00054",
  786. "blues.00055",
  787. "blues.00056",
  788. "blues.00057",
  789. "blues.00058",
  790. "blues.00059",
  791. "blues.00060",
  792. "classical.00000",
  793. "classical.00001",
  794. "classical.00002",
  795. "classical.00003",
  796. "classical.00004",
  797. "classical.00005",
  798. "classical.00006",
  799. "classical.00007",
  800. "classical.00008",
  801. "classical.00009",
  802. "classical.00010",
  803. "classical.00068",
  804. "classical.00069",
  805. "classical.00070",
  806. "classical.00071",
  807. "classical.00072",
  808. "classical.00073",
  809. "classical.00074",
  810. "classical.00075",
  811. "classical.00076",
  812. "country.00000",
  813. "country.00001",
  814. "country.00002",
  815. "country.00003",
  816. "country.00004",
  817. "country.00005",
  818. "country.00006",
  819. "country.00007",
  820. "country.00009",
  821. "country.00010",
  822. "country.00011",
  823. "country.00012",
  824. "country.00013",
  825. "country.00014",
  826. "country.00015",
  827. "country.00016",
  828. "country.00017",
  829. "country.00018",
  830. "country.00027",
  831. "country.00041",
  832. "country.00042",
  833. "country.00045",
  834. "country.00049",
  835. "disco.00000",
  836. "disco.00002",
  837. "disco.00003",
  838. "disco.00004",
  839. "disco.00006",
  840. "disco.00007",
  841. "disco.00008",
  842. "disco.00009",
  843. "disco.00010",
  844. "disco.00011",
  845. "disco.00012",
  846. "disco.00013",
  847. "disco.00014",
  848. "disco.00046",
  849. "disco.00048",
  850. "disco.00052",
  851. "disco.00067",
  852. "disco.00068",
  853. "disco.00072",
  854. "disco.00075",
  855. "disco.00090",
  856. "disco.00095",
  857. "hiphop.00081",
  858. "hiphop.00082",
  859. "hiphop.00083",
  860. "hiphop.00084",
  861. "hiphop.00085",
  862. "hiphop.00086",
  863. "hiphop.00087",
  864. "hiphop.00088",
  865. "hiphop.00089",
  866. "hiphop.00090",
  867. "hiphop.00091",
  868. "hiphop.00092",
  869. "hiphop.00093",
  870. "hiphop.00094",
  871. "hiphop.00095",
  872. "hiphop.00096",
  873. "hiphop.00097",
  874. "hiphop.00098",
  875. "jazz.00002",
  876. "jazz.00003",
  877. "jazz.00004",
  878. "jazz.00005",
  879. "jazz.00006",
  880. "jazz.00007",
  881. "jazz.00008",
  882. "jazz.00009",
  883. "jazz.00010",
  884. "jazz.00025",
  885. "jazz.00026",
  886. "jazz.00027",
  887. "jazz.00028",
  888. "jazz.00029",
  889. "jazz.00030",
  890. "jazz.00031",
  891. "jazz.00032",
  892. "metal.00000",
  893. "metal.00001",
  894. "metal.00006",
  895. "metal.00007",
  896. "metal.00008",
  897. "metal.00009",
  898. "metal.00010",
  899. "metal.00011",
  900. "metal.00016",
  901. "metal.00017",
  902. "metal.00018",
  903. "metal.00019",
  904. "metal.00020",
  905. "metal.00036",
  906. "metal.00037",
  907. "metal.00068",
  908. "metal.00076",
  909. "metal.00077",
  910. "metal.00081",
  911. "metal.00082",
  912. "pop.00010",
  913. "pop.00053",
  914. "pop.00055",
  915. "pop.00058",
  916. "pop.00059",
  917. "pop.00060",
  918. "pop.00061",
  919. "pop.00062",
  920. "pop.00081",
  921. "pop.00083",
  922. "pop.00084",
  923. "pop.00085",
  924. "pop.00086",
  925. "reggae.00061",
  926. "reggae.00062",
  927. "reggae.00070",
  928. "reggae.00072",
  929. "reggae.00074",
  930. "reggae.00076",
  931. "reggae.00077",
  932. "reggae.00078",
  933. "reggae.00085",
  934. "reggae.00092",
  935. "reggae.00093",
  936. "reggae.00094",
  937. "reggae.00095",
  938. "reggae.00096",
  939. "reggae.00097",
  940. "reggae.00098",
  941. "reggae.00099",
  942. "rock.00038",
  943. "rock.00049",
  944. "rock.00050",
  945. "rock.00051",
  946. "rock.00052",
  947. "rock.00053",
  948. "rock.00054",
  949. "rock.00055",
  950. "rock.00056",
  951. "rock.00071",
  952. "rock.00072",
  953. "rock.00073",
  954. "rock.00074",
  955. "rock.00075",
  956. "rock.00076",
  957. "rock.00077",
  958. "rock.00078",
  959. "rock.00079",
  960. "rock.00080",
  961. "rock.00081",
  962. "rock.00082",
  963. "rock.00083",
  964. "rock.00084",
  965. "rock.00085",
  966. ]
  967. URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz"
  968. FOLDER_IN_ARCHIVE = "genres"
  969. _CHECKSUMS = {
  970. "http://opihi.cs.uvic.ca/sound/genres.tar.gz": "24347e0223d2ba798e0a558c4c172d9d4a19c00bb7963fe055d183dadb4ef2c6"
  971. }
  972. def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]:
  973. """
  974. Loads a file from the dataset and returns the raw waveform
  975. as a Torch Tensor, its sample rate as an integer, and its
  976. genre as a string.
  977. """
  978. # Filenames are of the form label.id, e.g. blues.00078
  979. label, _ = fileid.split(".")
  980. # Read wav
  981. file_audio = os.path.join(path, label, fileid + ext_audio)
  982. waveform, sample_rate = torchaudio.load(file_audio)
  983. return waveform, sample_rate, label
  984. class GTZAN(Dataset):
  985. """Create a Dataset for *GTZAN* [:footcite:`tzanetakis_essl_cook_2001`].
  986. Note:
  987. Please see http://marsyas.info/downloads/datasets.html if you are planning to use
  988. this dataset to publish results.
  989. Args:
  990. root (str or Path): Path to the directory where the dataset is found or downloaded.
  991. url (str, optional): The URL to download the dataset from.
  992. (default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
  993. folder_in_archive (str, optional): The top-level directory of the dataset.
  994. download (bool, optional):
  995. Whether to download the dataset if it is not found at root path. (default: ``False``).
  996. subset (str or None, optional): Which subset of the dataset to use.
  997. One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``.
  998. If ``None``, the entire dataset is used. (default: ``None``).
  999. """
  1000. _ext_audio = ".wav"
  1001. def __init__(
  1002. self,
  1003. root: Union[str, Path],
  1004. url: str = URL,
  1005. folder_in_archive: str = FOLDER_IN_ARCHIVE,
  1006. download: bool = False,
  1007. subset: Optional[str] = None,
  1008. ) -> None:
  1009. # super(GTZAN, self).__init__()
  1010. # Get string representation of 'root' in case Path object is passed
  1011. root = os.fspath(root)
  1012. self.root = root
  1013. self.url = url
  1014. self.folder_in_archive = folder_in_archive
  1015. self.download = download
  1016. self.subset = subset
  1017. assert subset is None or subset in ["training", "validation", "testing"], (
  1018. "When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}."
  1019. )
  1020. archive = os.path.basename(url)
  1021. archive = os.path.join(root, archive)
  1022. self._path = os.path.join(root, folder_in_archive)
  1023. if download:
  1024. if not os.path.isdir(self._path):
  1025. if not os.path.isfile(archive):
  1026. checksum = _CHECKSUMS.get(url, None)
  1027. download_url_to_file(url, archive, hash_prefix=checksum)
  1028. extract_archive(archive)
  1029. if not os.path.isdir(self._path):
  1030. raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
  1031. if self.subset is None:
  1032. # Check every subdirectory under dataset root
  1033. # which has the same name as the genres in
  1034. # GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.)
  1035. # This lets users remove or move around song files,
  1036. # useful when e.g. they want to use only some of the files
  1037. # in a genre or want to label other files with a different
  1038. # genre.
  1039. self._walker = []
  1040. root = os.path.expanduser(self._path)
  1041. for directory in gtzan_genres:
  1042. fulldir = os.path.join(root, directory)
  1043. if not os.path.exists(fulldir):
  1044. continue
  1045. songs_in_genre = os.listdir(fulldir)
  1046. songs_in_genre.sort()
  1047. for fname in songs_in_genre:
  1048. name, ext = os.path.splitext(fname)
  1049. if ext.lower() == ".wav" and "." in name:
  1050. # Check whether the file is of the form
  1051. # `gtzan_genre`.`5 digit number`.wav
  1052. genre, num = name.split(".")
  1053. if genre in gtzan_genres and len(num) == 5 and num.isdigit():
  1054. self._walker.append(name)
  1055. else:
  1056. if self.subset == "training":
  1057. self._walker = filtered_train
  1058. elif self.subset == "validation":
  1059. self._walker = filtered_valid
  1060. elif self.subset == "testing":
  1061. self._walker = filtered_test
  1062. def __getitem__(self, n: int) -> Tuple[Tensor, int, str]:
  1063. """Load the n-th sample from the dataset.
  1064. Args:
  1065. n (int): The index of the sample to be loaded
  1066. Returns:
  1067. (Tensor, int, str): ``(waveform, sample_rate, label)``
  1068. """
  1069. fileid = self._walker[n]
  1070. item = load_gtzan_item(fileid, self._path, self._ext_audio)
  1071. waveform, sample_rate, label = item
  1072. return waveform, sample_rate, label
  1073. def __len__(self) -> int:
  1074. return len(self._walker)