symbolic_opset9.py 180 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205
  1. """This file exports ONNX ops for opset 9.
  2. Opset 9 is supported by ONNX release 1.4.1
  3. release on 01/23/19
  4. """
  5. import functools
  6. import math
  7. import sys
  8. import warnings
  9. from typing import List, Optional, Tuple, Union
  10. import torch
  11. import torch._C._onnx as _C_onnx
  12. import torch.nn.modules.utils
  13. import torch.onnx
  14. from torch import _C
  15. # This import monkey-patches graph manipulation methods on Graph, used for the
  16. # ONNX symbolics
  17. from torch.onnx import _patch_torch # noqa: F401
  18. from torch.onnx import symbolic_helper
  19. from torch.onnx._globals import GLOBALS
  20. # EDITING THIS FILE? READ THIS FIRST!
  21. # see Note [Edit Symbolic Files] in symbolic_helper.py
  22. # Note [Pointwise by scalar]
  23. # ~~~~~~~~~~~~~~~~~~~~~~~~~~
  24. # What happens if you add a tensor with a constant (e.g., x + 2)? There are
  25. # some moving parts to implementing the ONNX translation in this case:
  26. #
  27. # - By the time we get the scalar in a symbolic function here, it is no longer
  28. # a Python long/float, but a PyTorch tensor with numel == 1 (eventually, we
  29. # want it to be a zero dim tensor but this change has not happened yet.)
  30. # However, the type of this scalar is *exactly* what the user wrote in
  31. # Python, which may not match the tensor it is being added to. PyTorch
  32. # will do implicit conversions on scalars; however, ONNX will not, so
  33. # we must do the conversion ourselves. This is what _if_scalar_type_as
  34. # does.
  35. #
  36. # - Dispatch to these functions takes advantage an outrageous coincidence
  37. # between the tensor and scalar name. When we add two tensors together,
  38. # you get the dispatch:
  39. #
  40. # add(*[self, other], **{"alpha": alpha})
  41. #
  42. # When you add a tensor and a scalar, you get the dispatch:
  43. #
  44. # add(*[self], **{"other": other, "alpha": alpha})
  45. #
  46. # By having the argument name line up with the name of the scalar attribute
  47. # if it exists, we can write a single function for both overloads.
  48. #
  49. # used to represent "missing" optional inputs
  50. def unused(g):
  51. n = g.op("prim::Constant")
  52. n.setType(_C.OptionalType.ofTensor())
  53. return n
  54. def _shape_as_tensor(g, input):
  55. return g.op("Shape", input)
  56. def _reshape_from_tensor(g, input, shape):
  57. if isinstance(shape, list):
  58. shape = g.op("Concat", *shape, axis_i=0)
  59. return reshape(g, input, shape)
  60. def reshape(g, self, shape):
  61. return symbolic_helper._reshape_helper(g, self, shape)
  62. def reshape_as(g, self, other):
  63. shape = g.op("Shape", other)
  64. return reshape(g, self, shape)
  65. def add(g, self, other, alpha=None):
  66. if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
  67. return symbolic_helper._onnx_opset_unsupported_detailed(
  68. "Add", 9, 11, "Add between list of tensors not supported"
  69. )
  70. # default alpha arg is to allow no-alpha add (aten add st overload no alpha)
  71. if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
  72. return symbolic_helper._unimplemented("add", "alpha != 1")
  73. return g.op("Add", self, other)
  74. def sub(g, self, other, alpha=None):
  75. # default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
  76. if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
  77. return symbolic_helper._unimplemented("sub", "alpha != 1")
  78. return g.op("Sub", self, other)
  79. def rsub(g, self, other, alpha=None):
  80. return sub(g, other, self, alpha=alpha)
  81. def mul(g, self, other):
  82. return g.op("Mul", self, other)
  83. def div(g, self, other, *args):
  84. if len(args) == 0:
  85. return true_divide(g, self, other)
  86. else:
  87. return _div_rounding_mode(g, self, other, *args)
  88. @symbolic_helper.parse_args("v", "v", "v", "f")
  89. def addcmul(g, self, tensor1, tensor2, value=1.0):
  90. value_tens = g.op("Constant", value_t=torch.tensor([value]))
  91. return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens))
  92. @symbolic_helper.parse_args("v", "v", "s")
  93. def _div_rounding_mode(g, self, other, rounding_mode):
  94. if rounding_mode is None:
  95. return true_divide(g, self, other)
  96. elif rounding_mode == "floor":
  97. return _floor_divide(g, self, other)
  98. elif rounding_mode == "trunc":
  99. return _trunc_divide(g, self, other)
  100. else:
  101. raise RuntimeError(
  102. f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"'
  103. )
  104. def _trunc_divide(g, self, other):
  105. out = g.op("Div", self, other)
  106. # the correct operation is truncate, which is not supported in ONNX,
  107. # we cannot call floor since it will behave differently for negative numbers
  108. # (eg. -0.1 should become -0 )
  109. # - if scalar_type information are not available, assume that
  110. # we need to call floor (treat as float)
  111. out = g.op("Cast", out, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"])
  112. # Matching PyTorch's behavior:
  113. # - if self is fp the output's type is self's type
  114. # - if self is not fp and other is fp, the output is of type "Float"
  115. # - self is not fp and other is not fp, the output's type is self's output type
  116. # - the output type defaults to Float
  117. scalar_type = self.type().scalarType()
  118. if scalar_type is not None:
  119. if (
  120. not symbolic_helper._is_fp(self)
  121. and other.type().scalarType() is not None
  122. and symbolic_helper._is_fp(other)
  123. ):
  124. out = g.op("Cast", out, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"])
  125. else:
  126. out = g.op(
  127. "Cast", out, to_i=symbolic_helper.cast_pytorch_to_onnx[scalar_type]
  128. )
  129. else:
  130. out = g.op("Cast", out, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"])
  131. return out
  132. def _floor_divide(g, self, other):
  133. if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
  134. out = true_divide(g, self, other)
  135. return g.op("Floor", out)
  136. else:
  137. # Integer division does trunction rounding
  138. div = g.op("Div", self, other)
  139. # Division is negative if: self < 0 != other < 0
  140. zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
  141. negative = g.op(
  142. "Xor",
  143. symbolic_helper._lt_helper(g, self, zero),
  144. symbolic_helper._lt_helper(g, other, zero),
  145. )
  146. # For negative numbers with self % other != 0, subtract 1 to round down instead of up
  147. mod = g.op("Sub", self, g.op("Mul", div, other))
  148. fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
  149. one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
  150. fixup = g.op("Mul", fixup_mask, one)
  151. return g.op("Sub", div, fixup)
  152. def floor_divide(g, self, other):
  153. # Deprecated behavior, floor_divide actually truncates
  154. return _trunc_divide(g, self, other)
  155. def floordiv(g, self, other):
  156. return floor_divide(g, self, other)
  157. def true_divide(g, self, other):
  158. """Division where both inputs are cast to floating types
  159. If both inputs are floating, performs div as usual
  160. If only one input is a floating type, the other input is cast to its type
  161. If neither input is a floating type, both inputs are cast to the default scalar type
  162. """
  163. # Case 1: either values are floating
  164. # Performs div as usual.
  165. # Implicit casting will be handled in scalar type analysis pass.
  166. if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
  167. return g.op("Div", self, other)
  168. # Case 2: neither is floating
  169. # Casts both inputs to the default scalar type
  170. scalar_type = torch.get_default_dtype()
  171. onnx_scalar_type = symbolic_helper.cast_pytorch_to_onnx["Float"]
  172. assert scalar_type is torch.float or scalar_type is torch.double
  173. if torch.get_default_dtype() is torch.double:
  174. onnx_scalar_type = symbolic_helper.cast_pytorch_to_onnx["Double"]
  175. self = g.op("Cast", self, to_i=onnx_scalar_type)
  176. other = g.op("Cast", other, to_i=onnx_scalar_type)
  177. return g.op("Div", self, other)
  178. def reciprocal(g, self):
  179. # torch.reciprocal implicitly casts to float, so we do the same.
  180. if not symbolic_helper._is_fp(self):
  181. self = g.op("Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"])
  182. return g.op("Reciprocal", self)
  183. @symbolic_helper.parse_args("v", "i")
  184. def cat(g, tensor_list, dim):
  185. tensors = symbolic_helper._unpack_list(tensor_list)
  186. return g.op("Concat", *tensors, axis_i=dim)
  187. @symbolic_helper.parse_args("v", "i")
  188. def stack(g, tensor_list, dim):
  189. unsqueezed = [
  190. symbolic_helper._unsqueeze_helper(g, t, [dim])
  191. for t in symbolic_helper._unpack_list(tensor_list)
  192. ]
  193. return g.op("Concat", *unsqueezed, axis_i=dim)
  194. def _list(g, self):
  195. return self
  196. def mm(g, self, other):
  197. # Create a dummy C tensor. Only needed for API purposes, the value is
  198. # since beta = 0
  199. C = g.op("Constant", value_t=torch.tensor([1]))
  200. return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
  201. def bmm(g, self, other):
  202. return g.op("MatMul", self, other)
  203. def matmul(g, self, other):
  204. return g.op("MatMul", self, other)
  205. @symbolic_helper.parse_args("v", "v", "v", "t", "t")
  206. def addmm(g, self, mat1, mat2, beta, alpha):
  207. dtype = None
  208. self_dtype = symbolic_helper._try_get_scalar_type(self)
  209. mat1_dtype = symbolic_helper._try_get_scalar_type(mat1)
  210. mat2_dtype = symbolic_helper._try_get_scalar_type(mat2)
  211. if self_dtype is not None:
  212. dtype = self_dtype
  213. elif mat1_dtype is not None:
  214. dtype = mat1_dtype
  215. elif mat2_dtype is not None:
  216. dtype = mat2_dtype
  217. mat1_rank = symbolic_helper._get_tensor_rank(mat1)
  218. mat2_rank = symbolic_helper._get_tensor_rank(mat2)
  219. def isNotNoneAnd(v, u):
  220. return v is not None and v != u
  221. if dtype is not None and (isNotNoneAnd(mat1_rank, 2) or isNotNoneAnd(mat2_rank, 2)):
  222. dtype = symbolic_helper.scalar_type_to_onnx.index(
  223. symbolic_helper.cast_pytorch_to_onnx[dtype]
  224. )
  225. dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype]
  226. res1 = g.op("MatMul", mat1, mat2)
  227. res2 = self
  228. alpha = symbolic_helper._scalar(alpha)
  229. beta = symbolic_helper._scalar(beta)
  230. if alpha != 1:
  231. alpha = g.op("Constant", value_t=torch.tensor(alpha, dtype=dtype))
  232. res1 = g.op("Mul", res1, alpha)
  233. if beta != 1:
  234. beta = g.op(
  235. "Constant",
  236. value_t=torch.tensor(symbolic_helper._scalar(beta), dtype=dtype),
  237. )
  238. res2 = g.op("Mul", res2, beta)
  239. return g.op("Add", res1, res2)
  240. return g.op(
  241. "Gemm",
  242. mat1,
  243. mat2,
  244. self,
  245. beta_f=symbolic_helper._scalar(beta),
  246. alpha_f=symbolic_helper._scalar(alpha),
  247. )
  248. def neg(g, self):
  249. return g.op("Neg", self)
  250. def sqrt(g, self):
  251. return g.op("Sqrt", self)
  252. def rsqrt(g, self):
  253. return g.op(
  254. "Div", symbolic_helper._if_scalar_type_as(g, torch.ones(1), self), sqrt(g, self)
  255. )
  256. def tanh(g, self):
  257. return g.op("Tanh", self)
  258. def sin(g, self):
  259. return g.op("Sin", self)
  260. def cos(g, self):
  261. return g.op("Cos", self)
  262. def tan(g, self):
  263. return g.op("Tan", self)
  264. def asin(g, self):
  265. return g.op("Asin", self)
  266. def acos(g, self):
  267. return g.op("Acos", self)
  268. def atan(g, self):
  269. return g.op("Atan", self)
  270. # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
  271. @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0)
  272. def sigmoid(g, self):
  273. return g.op("Sigmoid", self)
  274. def sign(g, self):
  275. return g.op("Sign", self)
  276. def _slice(g, input, axes, starts, ends):
  277. assert len(starts) == len(ends)
  278. if len(starts) == 1 and starts[0] == 0 and ends[0] == 9223372036854775807:
  279. return input
  280. return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends)
  281. def _maybe_cast_reduce_op_input(g, self):
  282. dtype = self.type().scalarType()
  283. # This check only covers traced modules where dtype is present
  284. if dtype is not None:
  285. # pytorch reduce-ops cast all other integral types to int64
  286. if not symbolic_helper._is_fp(self) and not (dtype == "Long"):
  287. self = _cast_Long(g, self, False) # type: ignore[name-defined]
  288. return self
  289. def _reduce_op_symbolic(onnx_op_name, allow_multi_dim_support=True):
  290. def symbolic(g, self, dim=None, keepdim=None):
  291. self = _maybe_cast_reduce_op_input(g, self)
  292. if dim is None:
  293. # all-reduce path
  294. return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name)
  295. else:
  296. # dim-reduce path
  297. desc = "is" if allow_multi_dim_support else "i"
  298. dim, keepdim = symbolic_helper._get_const(
  299. dim, desc, "dim"
  300. ), symbolic_helper._get_const(keepdim, "i", "keepdim")
  301. dim_list = dim if allow_multi_dim_support else [dim]
  302. return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim)
  303. return symbolic
  304. def overload_by_arg_count(fn):
  305. @functools.wraps(fn)
  306. def wrapper(g, *args):
  307. overloads = fn(g, *args)
  308. last_exception = None
  309. for overload in overloads:
  310. arg_descriptors = overload._arg_descriptors
  311. if len(arg_descriptors) == len(args):
  312. return overload(g, *args)
  313. raise NotImplementedError("Unknown aten::{} signature".format(fn.__name__))
  314. return wrapper
  315. def _reduce_with_dtype(onnx_op, name, allow_multi_dim_support=True):
  316. symbolic = _reduce_op_symbolic(
  317. onnx_op, allow_multi_dim_support=allow_multi_dim_support
  318. )
  319. @overload_by_arg_count
  320. def reduce(g, *args, **kwargs):
  321. @symbolic_helper.parse_args("v", "none")
  322. def reduce_nodim(g, self, dtype):
  323. if dtype.node().kind() == "onnx::Constant":
  324. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  325. self = g.op(
  326. "Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype]
  327. )
  328. elif dtype.node().kind() != "prim::Constant":
  329. return symbolic_helper._unimplemented(name, "dtype")
  330. return symbolic(g, self)
  331. dim_desc = "is" if allow_multi_dim_support else "i"
  332. @symbolic_helper.parse_args("v", dim_desc, "i", "none")
  333. def reduce_dim(g, self, dim, keepdim, dtype):
  334. if dtype.node().kind() == "onnx::Constant":
  335. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  336. self = g.op(
  337. "Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype]
  338. )
  339. elif dtype.node().kind() != "prim::Constant":
  340. return symbolic_helper._unimplemented(name, "dtype")
  341. return symbolic(g, self, dim, keepdim)
  342. return reduce_nodim, reduce_dim
  343. return reduce
  344. sum = _reduce_with_dtype("ReduceSum", "sum")
  345. mean = _reduce_with_dtype("ReduceMean", "mean")
  346. # torch.prod does not support multidimensional "dim"
  347. prod = _reduce_with_dtype("ReduceProd", "prod", allow_multi_dim_support=False)
  348. @symbolic_helper.parse_args("v", "i", "none")
  349. def cumsum(g, input, dim, dtype):
  350. if symbolic_helper.is_caffe2_aten_fallback():
  351. if dtype.node().kind() != "prim::Constant":
  352. return symbolic_helper._unimplemented(name, "dtype")
  353. return g.at("cumsum", input, dim_i=dim)
  354. else:
  355. symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11)
  356. def _sample_dirichlet(g, self, generator):
  357. if symbolic_helper.is_caffe2_aten_fallback():
  358. if not symbolic_helper._is_none(generator):
  359. return symbolic_helper._unimplemented(
  360. "_sample_dirichlet", "We are not able to export generator"
  361. )
  362. return g.at("_sample_dirichlet", self)
  363. else:
  364. return symbolic_helper._onnx_unsupported("_sample_dirichlet")
  365. def _standard_gamma(g, self, generator):
  366. if symbolic_helper.is_caffe2_aten_fallback():
  367. if not symbolic_helper._is_none(generator):
  368. return symbolic_helper._unimplemented(
  369. "_standard_gamma", "We are not able to export generator"
  370. )
  371. return g.at("_standard_gamma", self)
  372. else:
  373. return symbolic_helper._onnx_unsupported("_standard_gamma")
  374. def t(g, self):
  375. return g.op("Transpose", self, perm_i=(1, 0))
  376. def expand(g, self, size, implicit):
  377. size = symbolic_helper._maybe_get_const(size, "is")
  378. if not symbolic_helper._is_value(size):
  379. size = g.op("Constant", value_t=torch.LongTensor(size))
  380. elif symbolic_helper._is_packed_list(size):
  381. # Expand with -1 dim value means dim is unchanged.
  382. # Since onnx::expand supports two-way broadcasting,
  383. # -1 dim value can be exported to onnx as 1
  384. size = symbolic_helper._reshape_helper(
  385. g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1]))
  386. )
  387. dtype = symbolic_helper.ScalarType.INT64
  388. ones = ones_like(g, size, dtype)
  389. neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
  390. size = where(g, g.op("Equal", size, neg_ones), ones, size)
  391. return g.op("Expand", self, size)
  392. def expand_as(g, self, other):
  393. self_t = symbolic_helper._maybe_get_const(self, "t")
  394. if isinstance(self_t, torch.Tensor):
  395. orig_type = self_t.dtype
  396. self_t = self_t.to(torch.double)
  397. dims = []
  398. for d in range(self_t.dim()):
  399. if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t):
  400. dims.append(d)
  401. self = g.op("Constant", value_t=self_t.mean(dims).to(orig_type))
  402. shape = g.op("Shape", other)
  403. return g.op("Expand", self, shape)
  404. @symbolic_helper.parse_args("v", "v", "i", "b", "v")
  405. def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
  406. if scale_grad_by_freq and GLOBALS.training_mode:
  407. raise RuntimeError(
  408. "Unsupported: ONNX export of embedding with scale_grad_by_freq=True "
  409. "for training mode. ONNX does not support scaling the gradients."
  410. )
  411. if padding_idx >= 0 and GLOBALS.training_mode:
  412. warnings.warn(
  413. "Warning: ONNX export of embedding with padding_idx >= 0 "
  414. "for training mode. "
  415. "ONNX does not support not updating the embedding vector at padding_idx during training."
  416. )
  417. return g.op("Gather", weight, indices)
  418. @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
  419. def embedding_bag(
  420. g,
  421. embedding_matrix,
  422. indices,
  423. offsets,
  424. scale_grad_by_freq,
  425. mode,
  426. sparse,
  427. per_sample_weights,
  428. include_last_offset,
  429. padding_idx,
  430. ):
  431. if not symbolic_helper._is_none(per_sample_weights):
  432. return symbolic_helper._onnx_unsupported(
  433. "embedding_bag with per_sample_weights"
  434. )
  435. if symbolic_helper.is_caffe2_aten_fallback():
  436. return g.at(
  437. "embedding_bag",
  438. embedding_matrix,
  439. indices,
  440. offsets,
  441. outputs=4,
  442. scale_grad_by_freq_i=scale_grad_by_freq,
  443. mode_i=mode,
  444. sparse_i=sparse,
  445. include_last_offset_i=include_last_offset,
  446. padding_idx_i=padding_idx,
  447. )
  448. else:
  449. return symbolic_helper._onnx_unsupported("embedding_bag")
  450. def size(g, self, dim=None):
  451. if dim is None:
  452. return g.op("Shape", self)
  453. if symbolic_helper._maybe_get_const(dim, "i") < 0:
  454. rank = symbolic_helper._get_tensor_rank(self)
  455. if rank is not None:
  456. dim = symbolic_helper._maybe_get_const(dim, "i") + rank
  457. dim = g.op("Constant", value_t=torch.tensor(dim))
  458. return symbolic_helper._size_helper(g, self, dim)
  459. @symbolic_helper.parse_args("v", "i", "i")
  460. def transpose(g, self, dim0, dim1):
  461. if dim0 == dim1: # micro-optimization
  462. return self
  463. # NB: Transpose in ONNX is actually a Permute
  464. rank = symbolic_helper._get_tensor_rank(self)
  465. if rank is not None:
  466. axes = list(range(rank))
  467. axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
  468. return g.op("Transpose", self, perm_i=axes)
  469. else:
  470. # if we don't have dim information we cannot
  471. # output a permute so use ATen instead
  472. if symbolic_helper.is_caffe2_aten_fallback():
  473. return g.at(
  474. "transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1
  475. )
  476. else:
  477. raise RuntimeError(
  478. "Unsupported: ONNX export of transpose for tensor " "of unknown rank."
  479. )
  480. @symbolic_helper.parse_args("v", "is")
  481. def permute(g, self, dims):
  482. if dims == list(range(0, len(dims))):
  483. return self
  484. return g.op("Transpose", self, perm_i=dims)
  485. def view(g, self, size):
  486. return reshape(g, self, size)
  487. def view_as(g, self, other):
  488. shape = g.op("Shape", other)
  489. return reshape(g, self, shape)
  490. @symbolic_helper.parse_args("v", "i", "i", "i")
  491. def unsafe_chunk(g, self, chunks, dim, _outputs=None):
  492. if _outputs is None:
  493. return symbolic_helper._onnx_opset_unsupported_detailed(
  494. "unsafe_chunk", 9, 11, "Dynamic number of outputs not supported"
  495. )
  496. size = symbolic_helper._get_tensor_dim_size(self, dim)
  497. if size is None:
  498. return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size")
  499. split_size = (size + chunks - 1) // chunks
  500. splits = [split_size] * (size // split_size)
  501. leftover = size % split_size
  502. if leftover:
  503. splits.append(leftover)
  504. return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs)
  505. @symbolic_helper.parse_args("v", "v", "v", "i")
  506. def split(g, self, split_size_or_sizes, dim, _outputs=None):
  507. if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
  508. return symbolic_helper._onnx_opset_unsupported_detailed(
  509. "split", 9, 11, "Dynamic number of outputs not supported"
  510. )
  511. split_val = split_size_or_sizes.node()["value"]
  512. if split_val.dim() > 0:
  513. return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs)
  514. split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size")
  515. dim = symbolic_helper._get_const(dim, "i", "dim")
  516. size = symbolic_helper._get_tensor_dim_size(self, dim)
  517. if size is None:
  518. if _outputs is not None:
  519. size = split_size * _outputs
  520. else:
  521. return symbolic_helper._onnx_opset_unsupported_detailed(
  522. "split", 9, 11, "Unknown dimension size not supported"
  523. )
  524. splits = [split_size] * (size // split_size)
  525. leftover = size % split_size
  526. if leftover:
  527. splits.append(leftover)
  528. return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs)
  529. def unsafe_split(g, self, split_size_or_sizes, dim, _outputs=None):
  530. return split(g, self, split_size_or_sizes, dim, _outputs)
  531. @symbolic_helper.parse_args("v", "is", "i", "i")
  532. def split_with_sizes(g, self, split_sizes, dim, _outputs=None):
  533. if not symbolic_helper._is_split_static(split_sizes, _outputs):
  534. return symbolic_helper._onnx_opset_unsupported_detailed(
  535. "split_with_sizes", 9, 11, "Dynamic number of outputs not supported"
  536. )
  537. return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs)
  538. def unsafe_split_with_sizes(g, self, split_sizes, dim, _outputs=None):
  539. return split_with_sizes(g, self, split_sizes, dim, _outputs)
  540. @symbolic_helper.parse_args("v", "i", "i")
  541. def unbind(g, self, dim=0, _outputs=None):
  542. if _outputs is None:
  543. return symbolic_helper._onnx_opset_unsupported_detailed(
  544. "unbind", 9, 11, "Dynamic number of outputs not supported"
  545. )
  546. outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs)
  547. outputs = [outputs] if _outputs == 1 else outputs
  548. squeezed_outputs = [
  549. symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs
  550. ]
  551. return squeezed_outputs
  552. @symbolic_helper.parse_args("v", "i", "v")
  553. def select(g, self, dim, index):
  554. index = symbolic_helper._maybe_get_scalar(index)
  555. if (not symbolic_helper._is_value(index)) and (index < 0):
  556. if index == -1:
  557. end_index = 9223372036854775807
  558. else:
  559. end_index = index + 1
  560. slice_node = symbolic_helper._slice_helper(
  561. g, self, axes=[dim], starts=[index], ends=[end_index]
  562. )
  563. return symbolic_helper._squeeze_helper(g, slice_node, [dim])
  564. else:
  565. return g.op("Gather", self, index, axis_i=dim)
  566. def square(g, self):
  567. return g.op("Mul", self, self)
  568. def squeeze(g, self, dim=None):
  569. if dim is None:
  570. return g.op("Squeeze", self)
  571. squeeze_dim = symbolic_helper._get_const(dim, "i", "dim")
  572. # Handle negative dims
  573. if squeeze_dim < 0:
  574. rank = symbolic_helper._get_tensor_rank(self)
  575. if rank is not None:
  576. warnings.warn(
  577. "ONNX export squeeze with negative axis "
  578. + str(squeeze_dim)
  579. + " might cause the onnx model to be incorrect. "
  580. + "Negative axis is not supported in ONNX. "
  581. + "Axis is converted to "
  582. + str(squeeze_dim + rank)
  583. + " based on input shape at export time. "
  584. + "Passing an tensor of different rank in execution will be incorrect."
  585. )
  586. squeeze_dim += rank
  587. else:
  588. return symbolic_helper._unimplemented(
  589. "squeeze", "negative axis with unknown input rank"
  590. )
  591. dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim)
  592. if dim_size is None:
  593. warnings.warn(
  594. "This model contains a squeeze operation on dimension "
  595. + str(squeeze_dim)
  596. + " on an input "
  597. + "with unknown shape. Note that if the size of dimension "
  598. + str(squeeze_dim)
  599. + " of the input "
  600. + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on "
  601. + "non-singleton dimensions, it is recommended to export this model using opset "
  602. + "version 11 or higher."
  603. )
  604. return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim])
  605. if dim_size > 1:
  606. warnings.warn(
  607. "This model contains a squeeze operation on dimension "
  608. + str(squeeze_dim)
  609. + ". The size of "
  610. + "this dimension in the given input is "
  611. + str(dim_size)
  612. + ". The model will "
  613. + "be exported without the squeeze node. If the model is intended to be used with dynamic "
  614. + "input shapes, please use opset version 11 to "
  615. + "export the model."
  616. )
  617. return self
  618. warnings.warn(
  619. "This model contains a squeeze operation on dimension "
  620. + str(squeeze_dim)
  621. + ". If the model is "
  622. + "intended to be used with dynamic input shapes, please use opset version 11 to export the model."
  623. )
  624. return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim])
  625. def prelu(g, self, weight):
  626. self_rank = symbolic_helper._get_tensor_rank(self)
  627. if self_rank is not None:
  628. if self_rank > 2:
  629. # make weight unidirectional broadcastable
  630. weight = symbolic_helper._unsqueeze_helper(
  631. g, weight, list(range(1, self_rank - 1))
  632. )
  633. elif self_rank == 0:
  634. # weight is always rank 1. torch allows scalar self, and ONNX is ambiguous
  635. # about whether this is allowed, but some implementations enforce
  636. # rank(self) >= rank(weight), which makes sense.
  637. self = symbolic_helper._unsqueeze_helper(g, self, [0])
  638. self_rank = 1
  639. weight_rank = symbolic_helper._get_tensor_rank(weight)
  640. if self_rank is not None and weight_rank is not None:
  641. assert (
  642. self_rank >= weight_rank
  643. ), "rank(x) should be >= rank(slope) but got {} < {}".format(
  644. self_rank, weight_rank
  645. )
  646. return g.op("PRelu", self, weight)
  647. def silu(g, input):
  648. return g.op("Mul", input, g.op("Sigmoid", input))
  649. def mish(g, input):
  650. return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input)))
  651. def op_with_optional_float_cast(g, op_name, *args, **kwargs):
  652. """Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types.
  653. This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch
  654. operator data type. For example, `Cast<int>(Clip<float>(Cast<float>(INPUT)))` can be used to mimic
  655. `Clip<int>(INPUT)` (opset version < 12).
  656. Args:
  657. g (torch._C.Graph): graph to write the ONNX representation into.
  658. op_name (str): operator name in ONNX.
  659. *args (tuple): operands to the operator.
  660. **kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default)
  661. indicating the smallest opset version to trigger such casting behavior and "target_float_t"
  662. (optional, "Float" by default) indicating the data type of internal operator.
  663. Returns:
  664. Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator.
  665. """
  666. opset_before = kwargs.pop("opset_before", None)
  667. target_float_t = kwargs.pop("target_float_t", "Float")
  668. inputs = list(args)
  669. dtype_0 = inputs[0].type().scalarType()
  670. require_cast = not symbolic_helper._is_fp(inputs[0]) and (
  671. opset_before is None or GLOBALS.export_onnx_opset_version < opset_before
  672. )
  673. if require_cast:
  674. for input in inputs:
  675. if input.isCompleteTensor() and input.type().scalarType() != dtype_0:
  676. raise RuntimeError(
  677. f"Inputs of {op_name} must have same dtype. Got {dtype_0} and {input.type().scalarType()}"
  678. )
  679. for i, input in enumerate(inputs):
  680. if input.isCompleteTensor() and not symbolic_helper._is_fp(input):
  681. inputs[i] = g.op(
  682. "Cast",
  683. input,
  684. to_i=symbolic_helper.cast_pytorch_to_onnx[target_float_t],
  685. )
  686. self = g.op(op_name, *inputs, **kwargs)
  687. if require_cast:
  688. self = g.op("Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype_0])
  689. return self
  690. @symbolic_helper.quantized_args(True)
  691. def relu(g, input):
  692. return op_with_optional_float_cast(g, "Relu", input, opset_before=14)
  693. @symbolic_helper.quantized_args(True)
  694. def relu6(g, input):
  695. relu = op_with_optional_float_cast(g, "Relu", input, opset_before=14)
  696. return clamp_max(g, relu, 6)
  697. def ceil(g, input):
  698. return g.op("Ceil", input)
  699. def floor(g, input):
  700. return g.op("Floor", input)
  701. def _len(g, self):
  702. sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0])))
  703. return symbolic_helper._squeeze_helper(g, sz_0, [0])
  704. @symbolic_helper.parse_args("v", "t", "t")
  705. def threshold(g, self, threshold, value):
  706. # See Note [Export inplace]
  707. if symbolic_helper._scalar(threshold) != 0:
  708. return symbolic_helper._unimplemented("threshold", "non-zero threshold")
  709. if symbolic_helper._scalar(value) != 0:
  710. return symbolic_helper._unimplemented("threshold", "non-zero value")
  711. return g.op("Relu", self)
  712. def leaky_relu(g, input, negative_slope, inplace=False):
  713. negative_slope = symbolic_helper._get_const(negative_slope, "t", "negative_slope")
  714. # See Note [Export inplace]
  715. # TODO: Talk to ONNX about unconditional cast of scalar to float
  716. return g.op("LeakyRelu", input, alpha_f=symbolic_helper._scalar(negative_slope))
  717. @symbolic_helper.parse_args("v", "i")
  718. def glu(g, input, dim):
  719. dim_size = symbolic_helper._get_tensor_dim_size(input, dim)
  720. if dim_size is not None:
  721. assert dim_size % 2 == 0
  722. first, second = g.op("Split", input, axis_i=dim, outputs=2)
  723. return g.op("Mul", first, g.op("Sigmoid", second))
  724. @symbolic_helper.parse_args("v", "i", "none")
  725. def softmax(g, input, dim, dtype=None):
  726. # Softmax does normalization at vector level.
  727. # PyTorch and ONNX use different strategies to split the input tensor into vectors.
  728. # Thus dim and axis have different meanings.
  729. # PyTorch slices the input tensor into vectors along the `dim`-th dimension.
  730. # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced.
  731. # If input is a 2 x 3 tensor:
  732. # input = [[1.0, 1.0, 1.0],
  733. # [1.0, 1,0, 1,0]]
  734. # with dim = 0, the result is:
  735. # result = [[0.5, 0.5, 0.5],
  736. # [0.5, 0.5, 0.5]]
  737. # with axis = 0, the result is:
  738. # result = [[0.167, 0.167, 0.167],
  739. # [0.167, 0.167, 0.167]]
  740. # So only when dim and axis both equal to ndim - 1 (the last dimension),
  741. # their semantics are equivalent.
  742. # So use softmax when dim and axis both equal to ndim - 1,
  743. # otherwise transpose the input to put the vectors to be normalized to the last dimension.
  744. # When input rank is not known at export time we compute softmax using a subgraph
  745. # with other operators
  746. input_dim = symbolic_helper._get_tensor_rank(input)
  747. if input_dim is not None:
  748. # TODO: remove this as onnx opset 11 spec allows negative axes
  749. if dim < 0:
  750. dim = input_dim + dim
  751. is_transpose_required = input_dim != dim + 1
  752. if is_transpose_required:
  753. axes = list(range(input_dim))
  754. axes[dim], axes[-1] = axes[-1], axes[dim]
  755. input = g.op("Transpose", input, perm_i=axes)
  756. dim = input_dim - 1
  757. softmax = g.op("Softmax", input, axis_i=dim)
  758. if dtype and dtype.node().kind() != "prim::Constant":
  759. parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  760. softmax = g.op(
  761. "Cast", softmax, to_i=symbolic_helper.scalar_type_to_onnx[parsed_dtype]
  762. )
  763. if is_transpose_required:
  764. softmax = g.op("Transpose", softmax, perm_i=axes)
  765. return softmax
  766. # Apply max normalization.
  767. input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1))
  768. exp = g.op("Exp", input)
  769. sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim])
  770. softmax = g.op("Div", exp, sum)
  771. if dtype and dtype.node().kind() != "prim::Constant":
  772. parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  773. softmax = g.op(
  774. "Cast", softmax, to_i=symbolic_helper.scalar_type_to_onnx[parsed_dtype]
  775. )
  776. return softmax
  777. def softplus(g, self, beta, threshold):
  778. beta_const = symbolic_helper._maybe_get_const(beta, "f")
  779. if beta_const != 1:
  780. return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta)
  781. return g.op("Softplus", self)
  782. def get_pool_ceil_padding(input, kernel_size, stride, padding):
  783. sizes = symbolic_helper._get_tensor_sizes(input)
  784. dim = sizes[-len(padding) :] if sizes is not None else None
  785. if dim is None or any([i is None for i in dim]):
  786. return symbolic_helper._unimplemented(name, "input size not accessible")
  787. ceiled_output_dim = [
  788. int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i])))
  789. + 1
  790. for i in range(0, len(padding))
  791. ]
  792. # ensure last pooling starts inside
  793. ceiled_output_dim = [
  794. ceiled_output_dim[i] - 1
  795. if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i]))
  796. else ceiled_output_dim[i]
  797. for i in range(0, len(ceiled_output_dim))
  798. ]
  799. padding_ceil = [
  800. 0
  801. if (stride[i] == 1)
  802. else (
  803. kernel_size[i]
  804. - (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1))
  805. )
  806. for i in range(0, len(padding))
  807. ]
  808. # ensure padding is not > kernel_size
  809. padding_ceil = [
  810. (
  811. int(padding_ceil[i])
  812. if padding_ceil[i] < kernel_size[i] - 1
  813. else int(kernel_size[i] - 1)
  814. )
  815. if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i]))
  816. else int(padding_ceil[i])
  817. for i in range(0, len(padding_ceil))
  818. ]
  819. return padding_ceil
  820. def _max_pool(name, tuple_fn, ndims, return_indices):
  821. @symbolic_helper.quantized_args(True, False, False, False, False, False)
  822. @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
  823. def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
  824. if set(tuple_fn(dilation)) != {1}:
  825. return symbolic_helper._unimplemented(name, "dilation")
  826. if not stride:
  827. stride = kernel_size
  828. padding = tuple(tuple_fn(padding))
  829. if ceil_mode:
  830. padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
  831. padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding))
  832. else:
  833. padding = padding * 2
  834. kwargs = {
  835. "kernel_shape_i": tuple_fn(kernel_size),
  836. "pads_i": padding,
  837. "strides_i": tuple_fn(stride),
  838. }
  839. # easy but hacky way to get flattened indices values
  840. # to be used to convert the indices values to non-flattened.
  841. # In ONNX the indices are computed as a flatten 1-D tensor,
  842. # so the values in indices are in [0, N x C x D1 x ... x Dn).
  843. # To convert the indices to the same format used by Pytorch,
  844. # we first execute a maxpool with a kernel and stride of 1 on the same input.
  845. # This will result in a tensor of indices in which each index will have it's own value.
  846. # Using this tensor as a reference, we extract the first index of each axis and substract
  847. # it from each index of this axis in the indices to convert.
  848. # This step will result in a tensor were each dimension has values of indices within
  849. # the dimension it is in.
  850. # For more information :
  851. # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
  852. if return_indices:
  853. r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
  854. _, flattened_indices = g.op(
  855. "MaxPool",
  856. input,
  857. outputs=2,
  858. kernel_shape_i=[1 for _ in range(ndims)],
  859. strides_i=[1 for _ in range(ndims)],
  860. )
  861. # convert indices to have non-flattened indices values
  862. s = symbolic_helper._slice_helper(
  863. g,
  864. flattened_indices,
  865. axes=[2 + i for i in range(ndims)],
  866. starts=tuple_fn(0),
  867. ends=tuple_fn(1),
  868. )
  869. indices = sub(g, indices, s)
  870. return r, indices
  871. else:
  872. r = g.op("MaxPool", input, outputs=1, **kwargs)
  873. return r
  874. return symbolic_fn
  875. max_pool1d = _max_pool(
  876. "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
  877. )
  878. max_pool2d = _max_pool(
  879. "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
  880. )
  881. max_pool3d = _max_pool(
  882. "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
  883. )
  884. max_pool1d_with_indices = _max_pool(
  885. "max_pool1d_with_indices",
  886. torch.nn.modules.utils._single,
  887. 1,
  888. return_indices=True,
  889. )
  890. max_pool2d_with_indices = _max_pool(
  891. "max_pool2d_with_indices",
  892. torch.nn.modules.utils._pair,
  893. 2,
  894. return_indices=True,
  895. )
  896. max_pool3d_with_indices = _max_pool(
  897. "max_pool3d_with_indices",
  898. torch.nn.modules.utils._triple,
  899. 3,
  900. return_indices=True,
  901. )
  902. def _avg_pool(name, tuple_fn):
  903. @symbolic_helper.quantized_args(True)
  904. @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
  905. def symbolic_fn(
  906. g,
  907. input: _C.Value,
  908. kernel_size: Tuple[int, ...],
  909. stride: Tuple[int, ...],
  910. padding: Union[int, Tuple[int, ...]],
  911. ceil_mode: int,
  912. count_include_pad: int,
  913. divisor_override=None,
  914. ):
  915. if not stride:
  916. stride = kernel_size
  917. padding = symbolic_helper._avgpool_helper(
  918. tuple_fn, padding, kernel_size, stride, divisor_override, name
  919. )
  920. adjusted_padding = padding
  921. if count_include_pad:
  922. input = g.op(
  923. "Pad",
  924. input,
  925. pads_i=((0,) * 2 + padding) * 2,
  926. mode_s="constant",
  927. value_f=0.0,
  928. )
  929. adjusted_padding = (0,) * len(padding)
  930. if ceil_mode:
  931. padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
  932. adjusted_padding = adjusted_padding + tuple(
  933. a + b for (a, b) in zip(padding_ceil, adjusted_padding)
  934. )
  935. else:
  936. adjusted_padding = adjusted_padding * 2
  937. output = g.op(
  938. "AveragePool",
  939. input,
  940. kernel_shape_i=tuple_fn(kernel_size),
  941. strides_i=tuple_fn(stride),
  942. pads_i=adjusted_padding,
  943. )
  944. return output
  945. return symbolic_fn
  946. avg_pool1d = _avg_pool("avg_pool1d", torch.nn.modules.utils._single)
  947. avg_pool2d = _avg_pool("avg_pool2d", torch.nn.modules.utils._pair)
  948. avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple)
  949. def _adaptive_pool(name, type, tuple_fn, fn=None):
  950. @symbolic_helper.quantized_args(True, False)
  951. def symbolic_fn(g, input, output_size):
  952. # _adaptive_pool is supported for cases where output_size is 1 for all dimensions,
  953. # by executing a GlobalPool.
  954. # It is also supported for cases where the output size is a factor of the input size.
  955. # For these cases the stride and kernel size are uniform along all the indices of
  956. # the same dimension, which makes it possible to export it to ONNX.
  957. # for MaxPool, GlobalMaxPool does not return indices,
  958. # so we try using max_poolxd_with_indices, and if it is not possible
  959. # (input is not a complete tensor or output size not factor of input size)
  960. # then we call GlobalAveragePool and return None for the indices
  961. try:
  962. output_size = symbolic_helper._parse_arg(output_size, "is")
  963. except Exception:
  964. return symbolic_helper._onnx_unsupported(
  965. "adaptive pooling, since output_size is not constant."
  966. )
  967. if output_size == [1] * len(output_size) and type == "AveragePool":
  968. return g.op("GlobalAveragePool", input)
  969. sizes = symbolic_helper._get_tensor_sizes(input)
  970. try:
  971. dim = sizes[2:]
  972. except Exception:
  973. dim = None
  974. if dim is None or any([i is None for i in dim]):
  975. if output_size == [1] * len(output_size):
  976. return g.op("GlobalMaxPool", input), None
  977. return symbolic_helper._unimplemented(name, "input size not accessible")
  978. # verify if output size % input size = 0 for all dim
  979. mod = [dim[i] % output_size[i] for i in range(0, len(dim))]
  980. if mod != [0] * len(mod):
  981. if output_size == [1] * len(output_size):
  982. return g.op("GlobalMaxPool", input), None
  983. return symbolic_helper._unimplemented(
  984. name, "output size that are not factor of input size"
  985. )
  986. k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
  987. # call max_poolxd_with_indices to get indices in the output
  988. if type == "MaxPool":
  989. return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False)
  990. output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k))
  991. return output
  992. return symbolic_fn
  993. adaptive_avg_pool1d = _adaptive_pool(
  994. "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single
  995. )
  996. adaptive_avg_pool2d = _adaptive_pool(
  997. "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair
  998. )
  999. adaptive_avg_pool3d = _adaptive_pool(
  1000. "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple
  1001. )
  1002. adaptive_max_pool1d = _adaptive_pool(
  1003. "adaptive_max_pool1d",
  1004. "MaxPool",
  1005. torch.nn.modules.utils._single,
  1006. max_pool1d_with_indices,
  1007. )
  1008. adaptive_max_pool2d = _adaptive_pool(
  1009. "adaptive_max_pool2d",
  1010. "MaxPool",
  1011. torch.nn.modules.utils._pair,
  1012. max_pool2d_with_indices,
  1013. )
  1014. adaptive_max_pool3d = _adaptive_pool(
  1015. "adaptive_max_pool3d",
  1016. "MaxPool",
  1017. torch.nn.modules.utils._triple,
  1018. max_pool3d_with_indices,
  1019. )
  1020. # Generate paddings in ONNX order based on pad in pytorch.
  1021. # Args:
  1022. # dim: the dimension of the tensor.
  1023. # pad: the paddings in pytorch.
  1024. # The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ...
  1025. def _prepare_onnx_paddings(dim, pad):
  1026. assert isinstance(dim, int)
  1027. # The desired order of paddings is
  1028. # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
  1029. # n is the dimension of input.
  1030. # assume zero-dimensions in the beginning
  1031. paddings = list(pad[:]) + [0] * (dim * 2 - len(pad))
  1032. # reverse order and collate first beginnings and then ends
  1033. paddings = paddings[-2::-2] + paddings[-1::-2]
  1034. return paddings
  1035. def _convert_padding_node(padding):
  1036. padding = symbolic_helper._maybe_get_const(padding, "is")
  1037. if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding):
  1038. input_list = symbolic_helper._unpack_list(padding)
  1039. try:
  1040. padding = [
  1041. symbolic_helper._get_const(v, "i", "padding") for v in input_list
  1042. ]
  1043. except Exception:
  1044. return symbolic_helper._onnx_opset_unsupported_detailed(
  1045. "Pad", 9, 11, "The sizes of the padding must be constant"
  1046. )
  1047. return padding
  1048. def constant_pad_nd(g, input, padding, value):
  1049. mode = "constant"
  1050. try:
  1051. value = symbolic_helper._get_const(value, "f", "value")
  1052. except Exception:
  1053. return symbolic_helper._onnx_opset_unsupported_detailed(
  1054. "Pad", 9, 11, "The value for the padding must be constant"
  1055. )
  1056. padding = _convert_padding_node(padding)
  1057. paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
  1058. return op_with_optional_float_cast(
  1059. g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11
  1060. )
  1061. def _pad_circular(g, input, pad):
  1062. padding = _convert_padding_node(pad)
  1063. assert len(padding) % 2 == 0
  1064. ndim = len(padding) // 2
  1065. cur = input
  1066. for idx in range(ndim):
  1067. pad_l = padding[-(2 * idx + 1)]
  1068. pad_r = padding[-(2 * idx + 2)]
  1069. tensors = []
  1070. if pad_l > 0:
  1071. left = symbolic_helper._slice_helper(
  1072. g, cur, axes=[2 + idx], starts=[-(pad_l + 1)], ends=[-1]
  1073. )
  1074. tensors.append(left)
  1075. if pad_l < 0 or pad_r < 0:
  1076. middle = symbolic_helper._slice_helper(
  1077. g,
  1078. cur,
  1079. axes=[2 + idx],
  1080. starts=[max(0, -pad_l)],
  1081. ends=[-(1 + max(0, -pad_r))],
  1082. )
  1083. tensors.append(middle)
  1084. else:
  1085. tensors.append(cur)
  1086. if pad_r > 0:
  1087. right = symbolic_helper._slice_helper(
  1088. g, cur, axes=[2 + idx], starts=[0], ends=[pad_r]
  1089. )
  1090. tensors.append(right)
  1091. cur = g.op("Concat", *tensors, axis_i=(2 + idx))
  1092. return cur
  1093. def reflection_pad(g, input, padding):
  1094. mode = "reflect"
  1095. padding = _convert_padding_node(padding)
  1096. paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
  1097. return op_with_optional_float_cast(
  1098. g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
  1099. )
  1100. def replication_pad(g, input, padding):
  1101. mode = "edge"
  1102. padding = _convert_padding_node(padding)
  1103. paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding)
  1104. return op_with_optional_float_cast(
  1105. g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11
  1106. )
  1107. reflection_pad1d = reflection_pad
  1108. reflection_pad2d = reflection_pad
  1109. reflection_pad3d = reflection_pad
  1110. replication_pad1d = replication_pad
  1111. replication_pad2d = replication_pad
  1112. replication_pad3d = replication_pad
  1113. def pad(g, input, pad, mode, value):
  1114. mode = symbolic_helper._parse_arg(mode, "s")
  1115. if mode == "replicate":
  1116. return replication_pad(g, input, pad)
  1117. elif mode == "reflect":
  1118. return reflection_pad(g, input, pad)
  1119. elif mode == "constant":
  1120. return constant_pad_nd(g, input, pad, value)
  1121. elif mode == "circular":
  1122. return _pad_circular(g, input, pad)
  1123. else:
  1124. raise RuntimeError(f"Unrecognized padding mode {mode}")
  1125. def _interpolate(name, dim, interpolate_mode):
  1126. def symbolic_fn(g, input, output_size, *args):
  1127. scales, align_corners = symbolic_helper._get_interpolate_attributes(
  1128. g, interpolate_mode, args
  1129. )
  1130. symbolic_helper._interpolate_warning(interpolate_mode)
  1131. align_corners = symbolic_helper._maybe_get_scalar(align_corners)
  1132. if align_corners:
  1133. return symbolic_helper._unimplemented(name, "align_corners == True")
  1134. if scales is None:
  1135. scales = symbolic_helper._interpolate_size_to_scales(
  1136. g, input, output_size, dim
  1137. )
  1138. return g.op("Upsample", input, scales, mode_s=interpolate_mode)
  1139. return symbolic_fn
  1140. upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest")
  1141. upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest")
  1142. upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest")
  1143. upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear")
  1144. upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear")
  1145. upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
  1146. def __interpolate(
  1147. g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
  1148. ):
  1149. scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
  1150. g, input, size, scale_factor, mode, align_corners
  1151. )
  1152. return g.op("Upsample", input, scales, mode_s=mode)
  1153. def bitwise_not(g, inp):
  1154. if inp.type().scalarType() != "Bool":
  1155. raise NotImplementedError(
  1156. "ONNX export does NOT support exporting bitwise Not "
  1157. + "for non-boolean input values"
  1158. )
  1159. return g.op("Not", inp)
  1160. def wrap_logical_op_with_cast_to(to_type):
  1161. def decorator(fn):
  1162. def wrap_with_cast(g, input, other):
  1163. return g.op(
  1164. "Cast",
  1165. fn(g, input, other),
  1166. to_i=symbolic_helper.cast_pytorch_to_onnx[to_type],
  1167. )
  1168. return wrap_with_cast
  1169. return decorator
  1170. def wrap_logical_op_with_cast_to_and_from(to_type):
  1171. def decorator(fn):
  1172. def wrap_with_cast(g, input, other):
  1173. to_cast_func = globals()["_cast_{}".format(to_type)]
  1174. from_cast_func = wrap_logical_op_with_cast_to(input.type().scalarType())(fn)
  1175. return from_cast_func(
  1176. g, to_cast_func(g, input, False), to_cast_func(g, other, False)
  1177. )
  1178. return wrap_with_cast
  1179. return decorator
  1180. def wrap_logical_op_with_negation(func):
  1181. def wrap_with_not(g, input, other):
  1182. return g.op("Not", func(g, input, other))
  1183. return wrap_with_not
  1184. def __not_(g, self):
  1185. if self.type().scalarType() != "Bool":
  1186. raise NotImplementedError(
  1187. "ONNX export does NOT support exporting bitwise Not "
  1188. + "for non-boolean input values"
  1189. )
  1190. return g.op("Not", self)
  1191. def eq(g, self, other):
  1192. if isinstance(self.type(), _C.DeviceObjType) and isinstance(
  1193. other.type(), _C.DeviceObjType
  1194. ):
  1195. # ONNX doesn't have devices, so consider them all to be equal.
  1196. # The no-op check for equality will get constant-folded.
  1197. return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool))
  1198. return g.op("Equal", self, other)
  1199. @wrap_logical_op_with_negation
  1200. def ne(g, self, other):
  1201. return eq(g, self, other)
  1202. def gt(g, input, other):
  1203. return gt_impl(g, input, other)
  1204. def gt_impl(g, input, other):
  1205. if (
  1206. input.type().scalarType() is not None
  1207. and input.type().scalarType() == "Bool"
  1208. and other.type().scalarType() is not None
  1209. and other.type().scalarType() == "Bool"
  1210. ):
  1211. input = g.op("Cast", input, to_i=symbolic_helper.cast_pytorch_to_onnx["Int"])
  1212. other = g.op("Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx["Int"])
  1213. return g.op("Greater", input, other)
  1214. def lt(g, input, other):
  1215. return lt_impl(g, input, other)
  1216. def lt_impl(g, input, other):
  1217. if (
  1218. input.type().scalarType() is not None
  1219. and input.type().scalarType() == "Bool"
  1220. and other.type().scalarType() is not None
  1221. and other.type().scalarType() == "Bool"
  1222. ):
  1223. input = g.op("Cast", input, to_i=symbolic_helper.cast_pytorch_to_onnx["Int"])
  1224. other = g.op("Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx["Int"])
  1225. return g.op("Less", input, other)
  1226. @wrap_logical_op_with_negation
  1227. def ge(g, input, other):
  1228. return lt_impl(g, input, other)
  1229. @wrap_logical_op_with_negation
  1230. def le(g, input, other):
  1231. return gt_impl(g, input, other)
  1232. def __and_(g, input, other):
  1233. if input.type().scalarType() == "Bool" and other.type().scalarType() == "Bool":
  1234. return g.op("And", input, other)
  1235. else:
  1236. raise NotImplementedError(
  1237. "ONNX export does NOT support exporting bitwise AND "
  1238. + "for non-boolean input values"
  1239. )
  1240. def __or_(g, input, other):
  1241. if input.type().scalarType() == "Bool" and other.type().scalarType() == "Bool":
  1242. return g.op("Or", input, other)
  1243. else:
  1244. raise NotImplementedError(
  1245. "ONNX export does NOT support exporting bitwise OR "
  1246. + "for non-boolean input values"
  1247. )
  1248. def __xor_(g, input, other):
  1249. if input.type().scalarType() == "Bool" and other.type().scalarType() == "Bool":
  1250. return g.op("Xor", input, other)
  1251. else:
  1252. raise NotImplementedError(
  1253. "ONNX export does NOT support exporting bitwise XOR "
  1254. + "for non-boolean input values"
  1255. )
  1256. @wrap_logical_op_with_cast_to_and_from("Bool")
  1257. def logical_and(g, input, other):
  1258. return g.op("And", input, other)
  1259. @wrap_logical_op_with_cast_to_and_from("Bool")
  1260. def logical_or(g, input, other):
  1261. return g.op("Or", input, other)
  1262. @wrap_logical_op_with_cast_to_and_from("Bool")
  1263. def logical_xor(g, input, other):
  1264. return g.op("Xor", input, other)
  1265. def __rshift_(g, self, other):
  1266. # make sure to cast other to self's type
  1267. # (when self is long, make sure that other is not float)
  1268. if other.type().scalarType() != self.type().scalarType():
  1269. other = g.op(
  1270. "Cast",
  1271. other,
  1272. to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
  1273. )
  1274. two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
  1275. # exponent (same type as self) has to be float or double in onnx::Pow
  1276. if not symbolic_helper._is_fp(self):
  1277. other = g.op("Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"])
  1278. two_pow = g.op("Pow", two, other)
  1279. two_pow = g.op(
  1280. "Cast",
  1281. two_pow,
  1282. to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
  1283. )
  1284. rshift = g.op("Div", self, two_pow)
  1285. return rshift
  1286. def __lshift_(g, self, other):
  1287. # make sure to cast other to self's type
  1288. # (when self is long, make sure that other is not float)
  1289. if other.type().scalarType() != self.type().scalarType():
  1290. other = g.op(
  1291. "Cast",
  1292. other,
  1293. to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
  1294. )
  1295. two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
  1296. # exponent (same type as self) has to be float or double in onnx::Pow
  1297. if not symbolic_helper._is_fp(self):
  1298. other = g.op("Cast", other, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"])
  1299. two_pow = g.op("Pow", two, other)
  1300. two_pow = g.op(
  1301. "Cast",
  1302. two_pow,
  1303. to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
  1304. )
  1305. lshift = g.op("Mul", self, two_pow)
  1306. return lshift
  1307. @symbolic_helper.parse_args("v", "v", "v", "i")
  1308. def where(g, condition, self=None, other=None, _outputs=None):
  1309. # Assumes that torch.where's first argument takes only Bool and Byte tensors.
  1310. if condition.type().scalarType() != "Bool":
  1311. condition = g.op(
  1312. "Cast", condition, to_i=symbolic_helper.cast_pytorch_to_onnx["Bool"]
  1313. )
  1314. if self is None:
  1315. condition = nonzero(g, condition)
  1316. return symbolic_helper._unbind_helper(
  1317. g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs
  1318. )
  1319. return g.op("Where", condition, self, other)
  1320. @symbolic_helper.parse_args("v", "i", "none")
  1321. def log_softmax(g, input, dim, dtype=None):
  1322. # PyTorch dim and ONNX axis have different meanings.
  1323. # See Softmax comment for details.
  1324. # TODO: remove this as onnx opset 11 spec allows negative axes
  1325. input_dim = symbolic_helper._get_tensor_rank(input)
  1326. if input_dim is None:
  1327. return symbolic_helper._unimplemented(
  1328. "dim",
  1329. "ONNX and PyTorch use different strategies to split the input. "
  1330. "Input rank must be known at export time.",
  1331. )
  1332. if dim < 0:
  1333. dim = input_dim + dim
  1334. is_transpose_required = input_dim != dim + 1
  1335. # ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases.
  1336. if is_transpose_required:
  1337. axes = list(range(input_dim))
  1338. axes[dim], axes[-1] = axes[-1], axes[dim]
  1339. input = g.op("Transpose", input, perm_i=axes)
  1340. dim = input_dim - 1
  1341. return_op = g.op("LogSoftmax", input, axis_i=dim)
  1342. if dtype and dtype.node().kind() != "prim::Constant":
  1343. parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  1344. return_op = g.op(
  1345. "Cast", return_op, to_i=symbolic_helper.scalar_type_to_onnx[parsed_dtype]
  1346. )
  1347. if is_transpose_required:
  1348. return_op = g.op("Transpose", return_op, perm_i=axes)
  1349. return return_op
  1350. @symbolic_helper.parse_args(
  1351. "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i"
  1352. )
  1353. def _convolution(
  1354. g,
  1355. input,
  1356. weight,
  1357. bias,
  1358. stride,
  1359. padding,
  1360. dilation,
  1361. transposed,
  1362. output_padding,
  1363. groups,
  1364. benchmark,
  1365. deterministic,
  1366. cudnn_enabled,
  1367. allow_tf32=None,
  1368. ):
  1369. weight_size = symbolic_helper._get_tensor_sizes(weight)
  1370. try:
  1371. kernel_shape = weight_size[2:]
  1372. except Exception:
  1373. kernel_shape = None
  1374. if kernel_shape is None or any([i is None for i in kernel_shape]):
  1375. raise RuntimeError(
  1376. "Unsupported: ONNX export of convolution for kernel " "of unknown shape."
  1377. )
  1378. args = [input, weight]
  1379. # ONNX only supports 1D bias
  1380. if (
  1381. not symbolic_helper._is_none(bias)
  1382. and symbolic_helper._get_tensor_rank(bias) == 1
  1383. ):
  1384. args.append(bias)
  1385. kwargs = {
  1386. "kernel_shape_i": weight_size[2:],
  1387. "strides_i": stride,
  1388. # NB: ONNX supports asymmetric padding, whereas PyTorch supports only
  1389. # symmetric padding
  1390. "pads_i": padding + padding,
  1391. "dilations_i": dilation,
  1392. "group_i": groups,
  1393. }
  1394. if any(o != 0 for o in output_padding):
  1395. # ONNX supports both output_shape and output_padding. they are equivalent expressive.
  1396. # output_padding is more straightforward, so we use it here.
  1397. # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2
  1398. assert transposed
  1399. assert len(stride) == len(output_padding)
  1400. kwargs["output_padding_i"] = output_padding
  1401. n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
  1402. if (
  1403. not symbolic_helper._is_none(bias)
  1404. and symbolic_helper._get_tensor_rank(bias) != 1
  1405. ):
  1406. return g.op("Add", n, bias)
  1407. else:
  1408. return n
  1409. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i")
  1410. def conv1d(g, input, weight, bias, stride, padding, dilation, groups):
  1411. return _convolution(
  1412. g,
  1413. input,
  1414. weight,
  1415. bias,
  1416. stride,
  1417. padding,
  1418. dilation,
  1419. False,
  1420. (),
  1421. groups,
  1422. None,
  1423. None,
  1424. None,
  1425. None,
  1426. )
  1427. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i")
  1428. def conv2d(g, input, weight, bias, stride, padding, dilation, groups):
  1429. return _convolution(
  1430. g,
  1431. input,
  1432. weight,
  1433. bias,
  1434. stride,
  1435. padding,
  1436. dilation,
  1437. False,
  1438. (),
  1439. groups,
  1440. None,
  1441. None,
  1442. None,
  1443. None,
  1444. )
  1445. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i")
  1446. def conv3d(g, input, weight, bias, stride, padding, dilation, groups):
  1447. return _convolution(
  1448. g,
  1449. input,
  1450. weight,
  1451. bias,
  1452. stride,
  1453. padding,
  1454. dilation,
  1455. False,
  1456. (),
  1457. groups,
  1458. None,
  1459. None,
  1460. None,
  1461. None,
  1462. )
  1463. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
  1464. def conv_transpose1d(
  1465. g, input, weight, bias, stride, padding, output_padding, groups, dilation
  1466. ):
  1467. return _convolution(
  1468. g,
  1469. input,
  1470. weight,
  1471. bias,
  1472. stride,
  1473. padding,
  1474. dilation,
  1475. True,
  1476. output_padding,
  1477. groups,
  1478. None,
  1479. None,
  1480. None,
  1481. None,
  1482. )
  1483. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
  1484. def conv_transpose2d(
  1485. g, input, weight, bias, stride, padding, output_padding, groups, dilation
  1486. ):
  1487. return _convolution(
  1488. g,
  1489. input,
  1490. weight,
  1491. bias,
  1492. stride,
  1493. padding,
  1494. dilation,
  1495. True,
  1496. output_padding,
  1497. groups,
  1498. None,
  1499. None,
  1500. None,
  1501. None,
  1502. )
  1503. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is")
  1504. def conv_transpose3d(
  1505. g, input, weight, bias, stride, padding, output_padding, groups, dilation
  1506. ):
  1507. return _convolution(
  1508. g,
  1509. input,
  1510. weight,
  1511. bias,
  1512. stride,
  1513. padding,
  1514. dilation,
  1515. True,
  1516. output_padding,
  1517. groups,
  1518. None,
  1519. None,
  1520. None,
  1521. None,
  1522. )
  1523. @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
  1524. def batch_norm(
  1525. g,
  1526. input,
  1527. weight,
  1528. bias,
  1529. running_mean,
  1530. running_var,
  1531. training,
  1532. momentum,
  1533. eps,
  1534. cudnn_enabled,
  1535. ):
  1536. symbolic_helper.check_training_mode(training, "batch_norm")
  1537. if (
  1538. torch.is_autocast_enabled()
  1539. and not symbolic_helper.args_have_same_dtype(
  1540. [input, weight, bias, running_mean, running_var]
  1541. )
  1542. and GLOBALS.export_onnx_opset_version < 15
  1543. ):
  1544. return symbolic_helper._onnx_opset_unsupported_detailed(
  1545. "BatchNormalization",
  1546. 9,
  1547. 15,
  1548. "All input tensors must have the same `dtype`."
  1549. " Turn off Autocast or export using opset version 15.",
  1550. )
  1551. weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
  1552. g, input, weight, bias, running_mean, running_var
  1553. )
  1554. out = g.op(
  1555. "BatchNormalization",
  1556. input,
  1557. weight,
  1558. bias,
  1559. running_mean,
  1560. running_var,
  1561. epsilon_f=eps,
  1562. momentum_f=1 - momentum,
  1563. outputs=1 if not training else 5,
  1564. )
  1565. if not training:
  1566. return out
  1567. else:
  1568. res, new_running_mean, new_running_var, saved_mean, saved_var = out
  1569. new_running_mean.setType(running_mean.type())
  1570. new_running_var.setType(running_var.type())
  1571. saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName())
  1572. saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName())
  1573. return res
  1574. @symbolic_helper.parse_args("v", "is", "v", "v", "f", "i")
  1575. def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
  1576. if symbolic_helper.is_caffe2_aten_fallback():
  1577. return g.at(
  1578. "layer_norm",
  1579. input,
  1580. weight,
  1581. bias,
  1582. normalized_shape_i=normalized_shape,
  1583. eps_f=eps,
  1584. cudnn_enable_i=cudnn_enable,
  1585. )
  1586. axes = [-i for i in range(len(normalized_shape), 0, -1)]
  1587. two_cst = symbolic_helper._generate_wrapped_number(g, 2.0)
  1588. eps_cst = symbolic_helper._generate_wrapped_number(g, eps)
  1589. mean = g.op("ReduceMean", input, axes_i=axes)
  1590. numerator = sub(g, input, mean)
  1591. # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula
  1592. variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes)
  1593. denominator = sqrt(g, add(g, variance, eps_cst))
  1594. layer_norm = g.op("Div", numerator, denominator)
  1595. if not (weight is None or symbolic_helper._is_none(weight)):
  1596. layer_norm = mul(g, layer_norm, weight)
  1597. if not (bias is None or symbolic_helper._is_none(bias)):
  1598. layer_norm = add(g, layer_norm, bias)
  1599. return layer_norm
  1600. @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
  1601. def instance_norm(
  1602. g,
  1603. input,
  1604. weight,
  1605. bias,
  1606. running_mean,
  1607. running_var,
  1608. use_input_stats,
  1609. momentum,
  1610. eps,
  1611. cudnn_enabled,
  1612. ):
  1613. symbolic_helper.check_training_mode(use_input_stats, "instance_norm")
  1614. channel_size = symbolic_helper._get_tensor_dim_size(input, 1)
  1615. if weight is None or symbolic_helper._is_none(weight):
  1616. if channel_size is None:
  1617. raise RuntimeError(
  1618. "Unsupported: ONNX export of instance_norm for unknown " "channel size."
  1619. )
  1620. weight_value = torch.tensor([1.0] * channel_size).type(
  1621. "torch." + input.type().scalarType() + "Tensor"
  1622. )
  1623. weight = g.op("Constant", value_t=weight_value)
  1624. if bias is None or symbolic_helper._is_none(bias):
  1625. if channel_size is None:
  1626. raise RuntimeError(
  1627. "Unsupported: ONNX export of instance_norm for unknown " "channel size."
  1628. )
  1629. bias_value = torch.tensor([0.0] * channel_size).type(
  1630. "torch." + input.type().scalarType() + "Tensor"
  1631. )
  1632. bias = g.op("Constant", value_t=bias_value)
  1633. if (
  1634. running_mean is None
  1635. or symbolic_helper._is_none(running_mean)
  1636. or running_var is None
  1637. or symbolic_helper._is_none(running_var)
  1638. ):
  1639. return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
  1640. else:
  1641. input_size = symbolic_helper._get_tensor_sizes(input)
  1642. # If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm.
  1643. # For more information instance_norm():
  1644. # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542
  1645. input_size_reshape = input_size.copy()
  1646. n = input_size[0]
  1647. if n is None:
  1648. raise RuntimeError(
  1649. "Unsupported: ONNX export of instance_norm training for unknown "
  1650. "batch size."
  1651. )
  1652. c = input_size[1]
  1653. input_size_reshape[0] = 1
  1654. input_size_reshape[1] = n * c
  1655. weight_ = repeat(
  1656. g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64))
  1657. )
  1658. bias_ = repeat(
  1659. g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64))
  1660. )
  1661. running_mean_ = repeat(
  1662. g,
  1663. running_mean,
  1664. g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)),
  1665. )
  1666. running_var_ = repeat(
  1667. g,
  1668. running_var,
  1669. g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)),
  1670. )
  1671. input_reshaped = g.op(
  1672. "Reshape",
  1673. input,
  1674. g.op("Constant", value_t=torch.LongTensor(input_size_reshape)),
  1675. )
  1676. out = batch_norm(
  1677. g,
  1678. input_reshaped,
  1679. weight_,
  1680. bias_,
  1681. running_mean_,
  1682. running_var_,
  1683. use_input_stats,
  1684. momentum,
  1685. eps,
  1686. cudnn_enabled,
  1687. )
  1688. return view(g, out, g.op("Constant", value_t=torch.tensor(input_size)))
  1689. @symbolic_helper.parse_args("v", "i", "i", "i")
  1690. def unfold(g, input, dimension, size, step):
  1691. if symbolic_helper.is_caffe2_aten_fallback():
  1692. return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)
  1693. sizes = symbolic_helper._get_tensor_sizes(input)
  1694. try:
  1695. sizedim = sizes[dimension]
  1696. except Exception:
  1697. sizedim = None
  1698. if sizedim is not None:
  1699. low_indices = range(0, sizedim, step)
  1700. hi_indices = range(size, sizedim + 1, step)
  1701. stack = [
  1702. symbolic_helper._slice_helper(
  1703. g, input, axes=[dimension], starts=[low], ends=[hi]
  1704. )
  1705. for low, hi in zip(low_indices, hi_indices)
  1706. ]
  1707. ndim = len(sizes)
  1708. perm = list(range(0, ndim))
  1709. perm.append(perm.pop(dimension))
  1710. unsqueeze = [
  1711. symbolic_helper._unsqueeze_helper(
  1712. g, g.op("Transpose", t, perm_i=perm), [dimension]
  1713. )
  1714. for t in stack
  1715. ]
  1716. return g.op("Concat", *unsqueeze, axis_i=dimension)
  1717. else:
  1718. return symbolic_helper._unimplemented("Unfold", "input size not accessible")
  1719. @symbolic_helper.parse_args("v", "t", "t", "t")
  1720. def elu(g, input, alpha, scale, input_scale):
  1721. if scale and scale != 1.0:
  1722. return symbolic_helper._unimplemented("scale", "does not support scale in Elu")
  1723. if input_scale and input_scale != 1.0:
  1724. return symbolic_helper._unimplemented(
  1725. "input_scale", "does not support input_scale in Elu"
  1726. )
  1727. # See Note [Export inplace]
  1728. return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha))
  1729. def selu(g, input):
  1730. return g.op("Selu", input)
  1731. @symbolic_helper.parse_args("v", "i", "v")
  1732. def index_select(g, self, dim, index):
  1733. # In case of a scalar index, index_select returns a tensor with the same rank as the input.
  1734. # To match this behavior in ONNX, we make index a 1D tensor so that the following gather
  1735. # also produces a tensor with the same rank as the input.
  1736. return symbolic_helper._select_helper(g, self, dim, index)
  1737. def index_put(g, self, indices_list_value, values, accumulate):
  1738. if symbolic_helper._is_packed_list(indices_list_value):
  1739. indices_list = symbolic_helper._unpack_list(indices_list_value)
  1740. else:
  1741. indices_list = [indices_list_value]
  1742. if symbolic_helper.is_caffe2_aten_fallback():
  1743. args = [self] + indices_list + [values, accumulate]
  1744. return g.at("index_put", *args)
  1745. accumulate = symbolic_helper._parse_arg(accumulate, "b")
  1746. if len(indices_list) == 0:
  1747. if accumulate:
  1748. return add(g, self, values)
  1749. else:
  1750. return values
  1751. else:
  1752. symbolic_helper._onnx_opset_unsupported("index_put", 9, 11)
  1753. def index_fill(g, self, dim, index, value):
  1754. dim_value = symbolic_helper._parse_arg(dim, "i")
  1755. if symbolic_helper.is_caffe2_aten_fallback():
  1756. return g.at(
  1757. "index_fill",
  1758. self,
  1759. index,
  1760. value,
  1761. overload_name="int_Scalar",
  1762. dim_i=dim_value,
  1763. )
  1764. expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
  1765. g, self, dim, index
  1766. )
  1767. value = symbolic_helper._maybe_get_scalar(value)
  1768. value = symbolic_helper._if_scalar_type_as(g, value, self)
  1769. expanded_value = expand(g, value, expanded_index_shape, None)
  1770. return scatter(g, self, dim, expanded_index, expanded_value)
  1771. def index_copy(g, self, dim, index, source):
  1772. dim_value = symbolic_helper._parse_arg(dim, "i")
  1773. if symbolic_helper.is_caffe2_aten_fallback():
  1774. return g.at("index_copy", self, index, source, dim_i=dim_value)
  1775. expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
  1776. g, self, dim, index
  1777. )
  1778. return scatter(g, self, dim, expanded_index, source)
  1779. @symbolic_helper.parse_args("v", "v", "b", "b")
  1780. def bucketize(g, self, boundaries, out_int32=False, right=False):
  1781. out_type = _C_onnx.TensorProtoDataType.INT64
  1782. if out_int32:
  1783. out_type = _C_onnx.TensorProtoDataType.INT32
  1784. # A tensor expanded_boundaries is created such that it
  1785. # contains a copy of boundaries for each element of self.
  1786. new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0)
  1787. # Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops
  1788. # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
  1789. unsqueeze_axes = list(range(1, symbolic_helper._get_tensor_rank(self) + 1))
  1790. expanded_boundaries = expand(
  1791. g,
  1792. symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes),
  1793. new_shape,
  1794. None,
  1795. )
  1796. # Compare each element of self to boundaries to get a tensor
  1797. # with leading 1s and trailing 0s.
  1798. # e.g., 4 > [1, 3, 4] = [1, 1, 0]
  1799. # The index of the last 1 is the bucket where the element should go.
  1800. if right:
  1801. cond = ge(g, self, expanded_boundaries)
  1802. else:
  1803. cond = gt(g, self, expanded_boundaries)
  1804. cond_out = g.op("Cast", cond, to_i=out_type)
  1805. # Sum to get the number of 1s corresponding to each element,
  1806. # which is the same as the bucket index.
  1807. # e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2
  1808. return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0)
  1809. def type_as(g, self, other):
  1810. self_dtype = symbolic_helper._try_get_scalar_type(self)
  1811. other_dtype = symbolic_helper._try_get_scalar_type(other)
  1812. if self_dtype == other_dtype and self_dtype is not None:
  1813. return self
  1814. if other_dtype is not None:
  1815. return g.op(
  1816. "Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx[other_dtype]
  1817. )
  1818. else:
  1819. if symbolic_helper.is_caffe2_aten_fallback():
  1820. # We don't know the type of other, bail by emitting ATen
  1821. return g.at("type_as", self, other)
  1822. else:
  1823. raise RuntimeError(
  1824. "Unsupported: ONNX export of type_as for tensor "
  1825. "of unknown dtype. Please check if the dtype of the "
  1826. "parameter passed to the type_as function is correct."
  1827. )
  1828. @symbolic_helper.parse_args("v", "v", "i", "f")
  1829. def cosine_similarity(g, x1, x2, dim, eps):
  1830. if symbolic_helper.is_caffe2_aten_fallback():
  1831. return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps)
  1832. cross = symbolic_helper._reducesum_helper(
  1833. g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0
  1834. )
  1835. x1_l2 = symbolic_helper._reducesum_helper(
  1836. g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0
  1837. )
  1838. x2_l2 = symbolic_helper._reducesum_helper(
  1839. g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0
  1840. )
  1841. div_tens = max(
  1842. g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps]))
  1843. )
  1844. return div(g, cross, div_tens)
  1845. def pairwise_distance(g, input1, input2, p, eps, keepdim):
  1846. if not symbolic_helper._is_value(eps):
  1847. eps = g.op("Constant", value_t=torch.tensor([eps]))
  1848. inv_p = div(
  1849. g,
  1850. g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)),
  1851. add(g, p, eps),
  1852. )
  1853. summation = symbolic_helper._reducesum_helper(
  1854. g,
  1855. pow(g, sub(g, input1, input2), p),
  1856. axes_i=[-1],
  1857. keepdims_i=symbolic_helper._parse_arg(keepdim, "i"),
  1858. )
  1859. return pow(g, summation, inv_p)
  1860. # ignore clone operators that are inserted by PyTorch autograd
  1861. def clone(g, input, unused_memory_format):
  1862. return input
  1863. def abs(g, self):
  1864. return g.op("Abs", self)
  1865. def log(g, self):
  1866. return g.op("Log", self)
  1867. def log1p(g, self):
  1868. return log(
  1869. g, add(g, symbolic_helper._if_scalar_type_as(g, torch.ones(1), self), self)
  1870. )
  1871. def log10(g, self):
  1872. _ln10 = 2.30258509299404568401
  1873. return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10])))
  1874. def pow(g, self, exponent):
  1875. f_dtype = self_dtype = self.type().scalarType()
  1876. if not symbolic_helper._is_fp(self):
  1877. f_dtype = "Float"
  1878. self = g.op("Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx[f_dtype])
  1879. if not symbolic_helper._is_fp(exponent):
  1880. exponent = g.op(
  1881. "Cast", exponent, to_i=symbolic_helper.cast_pytorch_to_onnx[f_dtype]
  1882. )
  1883. pow = g.op("Pow", self, exponent)
  1884. return pow
  1885. def clamp(g, self, min, max):
  1886. # min or max may be None that we need to dispatch to
  1887. # Clip separately, as ONNX does not have None syntax
  1888. if symbolic_helper._is_none(min):
  1889. return clamp_max(g, self, max)
  1890. elif symbolic_helper._is_none(max):
  1891. return clamp_min(g, self, min)
  1892. else:
  1893. if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max):
  1894. return op_with_optional_float_cast(
  1895. g,
  1896. "Clip",
  1897. self,
  1898. min_f=symbolic_helper._parse_arg(min, "f"),
  1899. max_f=symbolic_helper._parse_arg(max, "f"),
  1900. opset_before=12,
  1901. )
  1902. else:
  1903. return clamp_max(g, clamp_min(g, self, min), max)
  1904. @symbolic_helper.parse_args("v", "v")
  1905. def clamp_min(g, self, min):
  1906. if symbolic_helper._is_constant(min):
  1907. return op_with_optional_float_cast(
  1908. g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12
  1909. )
  1910. else:
  1911. dtype = self.type().scalarType()
  1912. min = g.op("Cast", min, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
  1913. return op_with_optional_float_cast(g, "Max", self, min, opset_before=12)
  1914. @symbolic_helper.parse_args("v", "v")
  1915. def clamp_max(g, self, max):
  1916. if symbolic_helper._is_constant(max):
  1917. return op_with_optional_float_cast(
  1918. g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12
  1919. )
  1920. else:
  1921. dtype = self.type().scalarType()
  1922. max = g.op("Cast", max, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
  1923. return op_with_optional_float_cast(g, "Min", self, max, opset_before=12)
  1924. # torch.max (same for torch.min) actually has two interfaces smashed together:
  1925. # torch.max(x, dim, keepdim) and torch.max(x, y)
  1926. def max(g, self, dim_or_y=None, keepdim=None):
  1927. # torch.max(input)
  1928. if dim_or_y is None and keepdim is None:
  1929. return g.op("ReduceMax", self, keepdims_i=0)
  1930. # torch.max(input, other)
  1931. if keepdim is None:
  1932. return op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12)
  1933. # torch.max(input, dim, keepdim)
  1934. else:
  1935. dim = symbolic_helper._get_const(dim_or_y, "i", "dim")
  1936. keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
  1937. max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim)
  1938. indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim)
  1939. return max, indices
  1940. def maximum(g, input, other):
  1941. return max(g, input, dim_or_y=other)
  1942. def min(g, self, dim_or_y=None, keepdim=None):
  1943. # torch.min(input)
  1944. if dim_or_y is None and keepdim is None:
  1945. return g.op("ReduceMin", self, keepdims_i=0)
  1946. # torch.min(input, other)
  1947. if keepdim is None:
  1948. return op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12)
  1949. # torch.min(input, dim, keepdim)
  1950. else:
  1951. dim = symbolic_helper._get_const(dim_or_y, "i", "dim")
  1952. keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
  1953. min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim)
  1954. indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim)
  1955. return min, indices
  1956. def minimum(g, input, other):
  1957. return min(g, input, dim_or_y=other)
  1958. @symbolic_helper.parse_args("v", "is", "i")
  1959. def amax(g, self, dim, keepdim):
  1960. return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim)
  1961. @symbolic_helper.parse_args("v", "is", "i")
  1962. def amin(g, self, dim, keepdim):
  1963. return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim)
  1964. @symbolic_helper.parse_args("v", "v", "i")
  1965. def aminmax(g, self, dim, keepdim):
  1966. reduce_kwargs = {"keepdims_i": keepdim}
  1967. if not symbolic_helper._is_none(dim):
  1968. dim = symbolic_helper._get_const(dim, "i", "dim")
  1969. reduce_kwargs["axes_i"] = [dim]
  1970. return g.op("ReduceMin", self, **reduce_kwargs), g.op(
  1971. "ReduceMax", self, **reduce_kwargs
  1972. )
  1973. def exp(g, self):
  1974. return g.op("Exp", self)
  1975. @symbolic_helper.parse_args("v", "f", "i")
  1976. def dropout(g, input, p, train):
  1977. symbolic_helper.check_training_mode(train, "dropout")
  1978. # in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
  1979. if not train:
  1980. return input
  1981. warnings.warn(
  1982. "Dropout is a training op and should not be exported in inference mode. "
  1983. "For inference, make sure to call eval() on the model and to export it with param training=False."
  1984. )
  1985. r, _ = g.op("Dropout", input, ratio_f=p, outputs=2)
  1986. return r
  1987. def _unsupported_dropout(name):
  1988. @symbolic_helper.parse_args("v", "f", "i")
  1989. def feature_dropout(g, input, p, train):
  1990. # NB: In inference mode, FeatureDropout is exported as an identity op.
  1991. if train:
  1992. return symbolic_helper._unimplemented(name, "training mode")
  1993. return input
  1994. return feature_dropout
  1995. feature_dropout = _unsupported_dropout("feature_dropout")
  1996. alpha_dropout = _unsupported_dropout("alpha_dropout")
  1997. feature_alpha_dropout = _unsupported_dropout("feature_alpha_dropout")
  1998. # See Note [Export inplace]
  1999. dropout_ = dropout
  2000. feature_dropout_ = feature_dropout
  2001. alpha_dropout_ = alpha_dropout
  2002. feature_alpha_dropout_ = feature_alpha_dropout
  2003. @symbolic_helper.parse_args("v", "t", "is", "i")
  2004. def norm(g, self, p, dim, keepdim):
  2005. if p == 1:
  2006. f = _reduce_op_symbolic("ReduceL1")
  2007. elif p == 2:
  2008. f = _reduce_op_symbolic("ReduceL2")
  2009. else:
  2010. raise RuntimeError("ONNX export only p-norms with p of 1 or 2")
  2011. return f(g, self, dim=dim, keepdim=keepdim)
  2012. @symbolic_helper.parse_args("v", "v", "v", "i")
  2013. def conv_tbc(g, input, weight, bias, pad):
  2014. if symbolic_helper.is_caffe2_aten_fallback():
  2015. return g.at("conv_tbc", input, weight, bias, pad_i=pad)
  2016. else:
  2017. # input must have 3 dimensions, see:
  2018. # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
  2019. # input = (time, batch, in_channels)
  2020. # weight = (kernel_width, in_channels, out_channels)
  2021. # bias = (out_channels,)
  2022. input = g.op("Transpose", input, perm_i=[1, 2, 0])
  2023. weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
  2024. conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
  2025. return g.op("Transpose", conv, perm_i=[2, 0, 1])
  2026. @symbolic_helper.parse_args("v", "i", "i")
  2027. def _unique(g, input, sorted, return_inverse):
  2028. if symbolic_helper.is_caffe2_aten_fallback():
  2029. return g.at(
  2030. "_unique",
  2031. input,
  2032. sorted_i=sorted,
  2033. return_inverse_i=return_inverse,
  2034. outputs=2,
  2035. )
  2036. else:
  2037. return symbolic_helper._onnx_unsupported("_unique")
  2038. @symbolic_helper.parse_args("v", "i", "i", "i")
  2039. def _unique2(g, input, sorted, return_inverse, return_counts):
  2040. if symbolic_helper.is_caffe2_aten_fallback():
  2041. return g.at(
  2042. "_unique2",
  2043. input,
  2044. sorted_i=sorted,
  2045. return_inverse_i=return_inverse,
  2046. return_counts_i=return_counts,
  2047. outputs=3,
  2048. )
  2049. else:
  2050. symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11)
  2051. # TODO(justinchuby): Clean up this function generation magic by defining the functions
  2052. # explicitly.
  2053. for k, v in symbolic_helper.cast_pytorch_to_onnx.items():
  2054. name = "_cast_{}".format(k)
  2055. globals()[name] = symbolic_helper.parse_args("v", "i")(
  2056. functools.partial(symbolic_helper._cast_func_template, v)
  2057. )
  2058. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  2059. def empty(g, sizes, dtype, layout, device, pin_memory=False, memory_format=None):
  2060. return zeros(g, sizes, dtype, layout, device, pin_memory)
  2061. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  2062. def empty_like(
  2063. g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None
  2064. ):
  2065. return zeros_like(g, input, dtype, layout, device, pin_memory)
  2066. def new_empty(g, self, sizes, dtype, layout, device, pin_memory=False):
  2067. self_dtype = symbolic_helper._try_get_scalar_type(self)
  2068. if dtype is None and self_dtype is not None:
  2069. dtype = self_dtype
  2070. dtype = symbolic_helper.scalar_type_to_onnx.index(
  2071. symbolic_helper.cast_pytorch_to_onnx[dtype]
  2072. )
  2073. return empty(g, sizes, dtype, layout, device, pin_memory)
  2074. def scalar_tensor(g, scalar, dtype, *options):
  2075. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  2076. if dtype is None:
  2077. dtype = symbolic_helper.ScalarType.FLOAT
  2078. scalar = g.op("Cast", scalar, to_i=symbolic_helper.scalar_type_to_onnx[dtype])
  2079. return scalar
  2080. def tensor(g, data, dtype=None, device=None, requires_grad=False):
  2081. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  2082. if symbolic_helper._is_packed_list(data):
  2083. if dtype is None:
  2084. dtype = symbolic_helper._unpack_list(data)[0].type().scalarType() # type: ignore[attr-defined]
  2085. # TODO(justinchuby): Remove type ignore after #81112 is checked in.
  2086. dtype = symbolic_helper.scalar_type_to_onnx.index(
  2087. symbolic_helper.cast_pytorch_to_onnx[dtype]
  2088. )
  2089. input_list = list()
  2090. for t in symbolic_helper._unpack_list(data):
  2091. shape_reference = g.op("Constant", value_t=torch.LongTensor([1]))
  2092. t = symbolic_helper._reshape_helper(g, t, shape_reference)
  2093. t = g.op("Cast", t, to_i=symbolic_helper.scalar_type_to_onnx[dtype])
  2094. input_list.append(t)
  2095. return g.op("Concat", *input_list, axis_i=0)
  2096. else:
  2097. if dtype is None:
  2098. dtype = data.type().scalarType()
  2099. dtype = symbolic_helper.scalar_type_to_onnx.index(
  2100. symbolic_helper.cast_pytorch_to_onnx[dtype]
  2101. )
  2102. if symbolic_helper._is_list(data) and (
  2103. symbolic_helper._is_tensor_list(data)
  2104. or symbolic_helper._is_scalar_list(data)
  2105. ):
  2106. data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1)
  2107. return g.op("Cast", data, to_i=symbolic_helper.scalar_type_to_onnx[dtype])
  2108. def as_tensor(g, data, dtype=None, device=None):
  2109. return tensor(g, data, dtype, device)
  2110. @symbolic_helper.parse_args("v", "i", "v", "v", "v")
  2111. def zeros(g, sizes, dtype, layout, device, pin_memory=False):
  2112. # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it
  2113. if dtype is None:
  2114. dtype = symbolic_helper.ScalarType.FLOAT
  2115. sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
  2116. if isinstance(sizes_, list) and len(sizes_) == 0:
  2117. sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
  2118. return g.op(
  2119. "ConstantOfShape",
  2120. sizes,
  2121. value_t=torch.tensor(
  2122. [0], dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype]
  2123. ),
  2124. )
  2125. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  2126. def zeros_like(
  2127. g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None
  2128. ):
  2129. shape = g.op("Shape", input)
  2130. if dtype is None:
  2131. dtype = symbolic_helper.ScalarType.FLOAT
  2132. return g.op(
  2133. "ConstantOfShape",
  2134. shape,
  2135. value_t=torch.tensor(
  2136. [0], dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype]
  2137. ),
  2138. )
  2139. def new_zeros(g, self, sizes, dtype, layout, device, pin_memory=False):
  2140. self_dtype = symbolic_helper._try_get_scalar_type(self)
  2141. if dtype is None and self_dtype is not None:
  2142. dtype = self_dtype
  2143. dtype = symbolic_helper.scalar_type_to_onnx.index(
  2144. symbolic_helper.cast_pytorch_to_onnx[dtype]
  2145. )
  2146. return zeros(g, sizes, dtype, layout, device, pin_memory)
  2147. @symbolic_helper.parse_args("v", "i", "v", "v", "v")
  2148. def ones(g, sizes, dtype, layout, device, pin_memory=False):
  2149. if dtype is None:
  2150. dtype = symbolic_helper.ScalarType.FLOAT
  2151. sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
  2152. if isinstance(sizes_, list) and len(sizes_) == 0:
  2153. sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
  2154. return g.op(
  2155. "ConstantOfShape",
  2156. sizes,
  2157. value_t=torch.tensor(
  2158. [1], dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype]
  2159. ),
  2160. )
  2161. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  2162. def ones_like(
  2163. g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None
  2164. ):
  2165. shape = g.op("Shape", input)
  2166. if dtype is None:
  2167. dtype = symbolic_helper.ScalarType.FLOAT
  2168. return g.op(
  2169. "ConstantOfShape",
  2170. shape,
  2171. value_t=torch.tensor(
  2172. [1], dtype=symbolic_helper.scalar_type_to_pytorch_type[dtype]
  2173. ),
  2174. )
  2175. def new_ones(g, self, sizes, dtype, layout, device, pin_memory=False):
  2176. self_dtype = symbolic_helper._try_get_scalar_type(self)
  2177. if dtype is None and self_dtype is not None:
  2178. dtype = self_dtype
  2179. dtype = symbolic_helper.scalar_type_to_onnx.index(
  2180. symbolic_helper.cast_pytorch_to_onnx[dtype]
  2181. )
  2182. return ones(g, sizes, dtype, layout, device, pin_memory)
  2183. def full(g, sizes, value, dtype, layout, device, pin_memory=False):
  2184. const_value = symbolic_helper._maybe_get_const(value, "t")
  2185. if symbolic_helper._is_value(const_value):
  2186. dtype = symbolic_helper.ScalarType.FLOAT if dtype is None else dtype
  2187. tmp = zeros(g, sizes, dtype, layout, device)
  2188. return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
  2189. else:
  2190. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  2191. dtype = symbolic_helper.ScalarType.FLOAT if dtype is None else dtype
  2192. sizes_ = symbolic_helper._maybe_get_const(sizes, "is")
  2193. if isinstance(sizes_, list) and len(sizes_) == 0:
  2194. sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64))
  2195. return g.op(
  2196. "ConstantOfShape",
  2197. sizes,
  2198. value_t=const_value.view(1).to(
  2199. symbolic_helper.scalar_type_to_pytorch_type[dtype]
  2200. ),
  2201. )
  2202. def full_like(
  2203. g,
  2204. input,
  2205. fill_value,
  2206. dtype=None,
  2207. layout=None,
  2208. device=None,
  2209. pin_memory=False,
  2210. memory_format=None,
  2211. ):
  2212. fill_value = symbolic_helper._maybe_get_const(fill_value, "f")
  2213. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  2214. dtype = symbolic_helper.ScalarType.FLOAT if dtype is None else dtype
  2215. if symbolic_helper._is_value(fill_value):
  2216. tmp = zeros_like(g, input, dtype, layout, device)
  2217. fill_value = g.op(
  2218. "Cast", fill_value, to_i=symbolic_helper.scalar_type_to_onnx[dtype]
  2219. )
  2220. return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1)))
  2221. else:
  2222. shape = g.op("Shape", input)
  2223. return g.op(
  2224. "ConstantOfShape",
  2225. shape,
  2226. value_t=torch.tensor([fill_value]).to(
  2227. symbolic_helper.scalar_type_to_pytorch_type[dtype]
  2228. ),
  2229. )
  2230. def new_full(g, self, size, fill_value, dtype, layout, device, pin_memory=False):
  2231. self_dtype = symbolic_helper._try_get_scalar_type(self)
  2232. if dtype is None and self_dtype is not None:
  2233. dtype = self_dtype
  2234. dtype = symbolic_helper.scalar_type_to_onnx.index(
  2235. symbolic_helper.cast_pytorch_to_onnx[dtype]
  2236. )
  2237. return full(g, size, fill_value, dtype, layout, device, pin_memory)
  2238. def eye(g, *args):
  2239. if len(args) == 5:
  2240. # aten::eye(n, dtype, layout, device, pin_memory)
  2241. n, dtype, layout, device, pin_memory = args
  2242. dim_size = symbolic_helper._unsqueeze_helper(g, n, [0])
  2243. shape = g.op("Concat", dim_size, dim_size, axis_i=0)
  2244. tensor = zeros(g, shape, dtype, layout, device)
  2245. return g.op("EyeLike", tensor)
  2246. elif len(args) == 6:
  2247. # aten::eye(n, m, dtype, layout, device, pin_memory)
  2248. n, m, dtype, layout, device, pin_memory = args
  2249. shape = g.op(
  2250. "Concat",
  2251. symbolic_helper._unsqueeze_helper(g, n, [0]),
  2252. symbolic_helper._unsqueeze_helper(g, m, [0]),
  2253. axis_i=0,
  2254. )
  2255. tensor = zeros(g, shape, dtype, layout, device)
  2256. return g.op("EyeLike", tensor)
  2257. else:
  2258. raise NotImplementedError("Unknown aten::eye signature")
  2259. def slice(g, self, *args):
  2260. if len(args) == 4:
  2261. # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor
  2262. dim, start, end, step = args
  2263. step = symbolic_helper._parse_arg(step, "i")
  2264. if step != 1:
  2265. raise RuntimeError("step!=1 is currently not supported")
  2266. is_start_none = (
  2267. start.node().kind() == "prim::Constant"
  2268. and start.type().kind() == "NoneType"
  2269. )
  2270. is_end_none = (
  2271. end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType"
  2272. )
  2273. is_start_onnx_const = start.node().kind() == "onnx::Constant"
  2274. is_end_onnx_const = end.node().kind() == "onnx::Constant"
  2275. if (
  2276. ((not is_start_none) and (not is_start_onnx_const))
  2277. or ((not is_end_none) and (not is_end_onnx_const))
  2278. or dim.node().kind() != "onnx::Constant"
  2279. ):
  2280. if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
  2281. raise RuntimeError(
  2282. "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice "
  2283. "is a deprecated experimental op. Please use statically allocated "
  2284. "variables or export to a higher opset version."
  2285. )
  2286. else:
  2287. start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0])
  2288. end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0])
  2289. dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0])
  2290. return g.op(
  2291. "DynamicSlice",
  2292. self,
  2293. start_unsqueezed,
  2294. end_unsqueezed,
  2295. dim_unsqueezed,
  2296. )
  2297. else:
  2298. start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i")
  2299. end = (
  2300. 9223372036854775807
  2301. if is_end_none
  2302. else symbolic_helper._parse_arg(end, "i")
  2303. )
  2304. dim = symbolic_helper._parse_arg(dim, "i")
  2305. return symbolic_helper._slice_helper(
  2306. g, self, axes=[dim], starts=[start], ends=[end]
  2307. )
  2308. elif len(args) == 3:
  2309. # aten::slice(t[] l, int start, int end, int step) -> t[]
  2310. start, end, step = args
  2311. dim = 0
  2312. is_start_none = (
  2313. start.node().kind() == "prim::Constant"
  2314. and start.type().kind() == "NoneType"
  2315. )
  2316. is_end_none = (
  2317. end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType"
  2318. )
  2319. start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i")
  2320. end = (
  2321. 9223372036854775807 if is_end_none else symbolic_helper._parse_arg(end, "i")
  2322. )
  2323. return symbolic_helper._slice_helper(
  2324. g, self, axes=[dim], starts=[start], ends=[end]
  2325. )
  2326. else:
  2327. raise NotImplementedError("Unknown aten::slice signature")
  2328. @symbolic_helper.parse_args("v", "f", "f")
  2329. def hardtanh(g, self, min_val, max_val):
  2330. return op_with_optional_float_cast(
  2331. g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12
  2332. )
  2333. @symbolic_helper.parse_args("v")
  2334. def hardswish(g, self):
  2335. hs = hardsigmoid(g, self)
  2336. return g.op("Mul", self, hs)
  2337. # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp
  2338. @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0)
  2339. @symbolic_helper.parse_args("v")
  2340. def hardsigmoid(g, self):
  2341. # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid.
  2342. # See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html
  2343. return g.op("HardSigmoid", self, alpha_f=1 / 6)
  2344. @symbolic_helper.parse_args("v")
  2345. def tanhshrink(g, self):
  2346. return g.op("Sub", self, tanh(g, self))
  2347. @symbolic_helper.parse_args("v", "f")
  2348. def hardshrink(g, self, lambd):
  2349. lambd_op = g.op("Constant", value_t=torch.FloatTensor([lambd]))
  2350. cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op)))
  2351. return g.op("Where", cond, self, g.op("Constant", value_t=torch.FloatTensor([0])))
  2352. @symbolic_helper.parse_args("v", "f")
  2353. def softshrink(g, self, lambd):
  2354. lambd_op = g.op("Constant", value_t=torch.FloatTensor([lambd]))
  2355. gt_cond = gt(g, self, lambd_op)
  2356. gt_out = g.op(
  2357. "Where",
  2358. gt_cond,
  2359. sub(g, self, lambd_op),
  2360. g.op("Constant", value_t=torch.FloatTensor([0])),
  2361. )
  2362. lt_cond = lt(g, self, neg(g, lambd_op))
  2363. lt_out = g.op(
  2364. "Where",
  2365. lt_cond,
  2366. add(g, self, lambd_op),
  2367. g.op("Constant", value_t=torch.FloatTensor([0])),
  2368. )
  2369. return add(g, gt_out, lt_out)
  2370. def alias(g, self):
  2371. return self
  2372. @symbolic_helper.parse_args("v", "i")
  2373. def unsqueeze(g, self, dim):
  2374. # Handle negative dim
  2375. if dim < 0:
  2376. rank = symbolic_helper._get_tensor_rank(self)
  2377. if rank is not None:
  2378. warnings.warn(
  2379. "ONNX export unsqueeze with negative axis "
  2380. + str(dim)
  2381. + " might cause the onnx model to be incorrect. "
  2382. + "Negative axis is not supported in ONNX. "
  2383. + "Axis is converted to "
  2384. + str(dim + rank + 1)
  2385. + " based on input shape at export time. "
  2386. + "Passing an tensor of different rank in execution will be incorrect."
  2387. )
  2388. dim = dim + rank + 1
  2389. else:
  2390. return symbolic_helper._unimplemented(
  2391. "unsqueeze", "negative axis with unknown input rank"
  2392. )
  2393. return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim])
  2394. @symbolic_helper.parse_args("v", "i", "i", "none")
  2395. def sort(g, self, dim, decending, out=None):
  2396. if out is not None:
  2397. symbolic_helper._unimplemented(
  2398. "Sort", "Out parameter is not supported for sort"
  2399. )
  2400. self_sizes = symbolic_helper._get_tensor_sizes(self)
  2401. try:
  2402. dim_size = self_sizes[dim]
  2403. except Exception:
  2404. dim_size = None
  2405. if dim_size is None:
  2406. return symbolic_helper._unimplemented("Sort", "input size not accessible")
  2407. return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2)
  2408. def numel(g, self):
  2409. shape = g.op("Shape", self)
  2410. return g.op("ReduceProd", shape, keepdims_i=0)
  2411. @symbolic_helper.parse_args("v", "i", "i", "i", "i", "none")
  2412. def topk(g, self, k, dim, largest, sorted, out=None):
  2413. if out is not None:
  2414. symbolic_helper._unimplemented(
  2415. "TopK", "Out parameter is not supported for topk"
  2416. )
  2417. if not largest:
  2418. symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported")
  2419. return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2)
  2420. def to(g, self, *args):
  2421. def is_aten_to_device_only(args):
  2422. if len(args) == 4:
  2423. # aten::to(Tensor, Device, bool, bool, memory_format)
  2424. return (
  2425. args[0].node().kind() == "prim::device"
  2426. or args[0].type().isSubtypeOf(_C.ListType.ofInts())
  2427. or isinstance(args[0].type(), _C.DeviceObjType)
  2428. )
  2429. elif len(args) == 5:
  2430. # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
  2431. # When dtype is None, this is a aten::to(device) call
  2432. dtype = symbolic_helper._get_const(args[1], "i", "dtype")
  2433. return dtype is None
  2434. elif len(args) in (6, 7):
  2435. # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor
  2436. # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor
  2437. # When dtype is None, this is a aten::to(device) call
  2438. dtype = symbolic_helper._get_const(args[0], "i", "dtype")
  2439. return dtype is None
  2440. return False
  2441. # ONNX doesn't have a concept of a device, so we ignore device-only casts
  2442. if is_aten_to_device_only(args):
  2443. return self
  2444. if len(args) == 4:
  2445. # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=<Tensor>]()
  2446. # In this case, the constant value is a tensor not int,
  2447. # so symbolic_helper._maybe_get_const(args[0], 'i') would not work.
  2448. dtype = args[0]
  2449. if (
  2450. symbolic_helper._is_value(args[0])
  2451. and args[0].node().kind() == "onnx::Constant"
  2452. ):
  2453. tval = args[0].node()["value"]
  2454. if isinstance(tval, torch.Tensor):
  2455. if len(tval.shape) == 0:
  2456. tval = tval.item()
  2457. dtype = int(tval)
  2458. else:
  2459. dtype = tval
  2460. if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor):
  2461. # aten::to(Tensor, Tensor, bool, bool, memory_format)
  2462. dtype = args[0].type().scalarType()
  2463. return g.op("Cast", self, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
  2464. else:
  2465. # aten::to(Tensor, ScalarType, bool, bool, memory_format)
  2466. # memory_format is ignored
  2467. return g.op("Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype])
  2468. elif len(args) == 5:
  2469. # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
  2470. dtype = symbolic_helper._get_const(args[1], "i", "dtype")
  2471. # memory_format is ignored
  2472. return g.op("Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype])
  2473. elif len(args) == 6:
  2474. # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor
  2475. dtype = symbolic_helper._get_const(args[0], "i", "dtype")
  2476. # Layout, device and memory_format are ignored
  2477. return g.op("Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype])
  2478. elif len(args) == 7:
  2479. # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor
  2480. dtype = symbolic_helper._get_const(args[0], "i", "dtype")
  2481. # Layout, device and memory_format are ignored
  2482. return g.op("Cast", self, to_i=symbolic_helper.scalar_type_to_onnx[dtype])
  2483. else:
  2484. return symbolic_helper._onnx_unsupported("Unknown aten::to signature")
  2485. def repeat(g, self, repeats):
  2486. dtype = symbolic_helper.ScalarType.INT64
  2487. shape_ = ones_like(g, repeats, dtype)
  2488. self = g.op("Expand", self, shape_)
  2489. return g.op("Tile", self, repeats)
  2490. def repeat_interleave(g, self, repeats, dim=None, output_size=None):
  2491. input = self
  2492. # if dim is None flatten
  2493. # By default, use the flattened input array, and return a flat output array
  2494. if symbolic_helper._is_none(dim):
  2495. input = symbolic_helper._reshape_helper(
  2496. g, self, g.op("Constant", value_t=torch.tensor([-1]))
  2497. )
  2498. dim = 0
  2499. else:
  2500. dim = symbolic_helper._maybe_get_scalar(dim)
  2501. repeats_dim = symbolic_helper._get_tensor_rank(repeats)
  2502. repeats_sizes = symbolic_helper._get_tensor_sizes(repeats)
  2503. input_sizes = symbolic_helper._get_tensor_sizes(input)
  2504. if repeats_dim is None:
  2505. raise RuntimeError(
  2506. "Unsupported: ONNX export of repeat_interleave for unknown repeats rank."
  2507. )
  2508. if repeats_sizes is None:
  2509. raise RuntimeError(
  2510. "Unsupported: ONNX export of repeat_interleave for unknown repeats size."
  2511. )
  2512. if input_sizes is None:
  2513. raise RuntimeError(
  2514. "Unsupported: ONNX export of repeat_interleave for unknown input size."
  2515. )
  2516. input_sizes_temp = input_sizes.copy()
  2517. for idx, input_size in enumerate(input_sizes):
  2518. if input_size is None:
  2519. input_sizes[idx], input_sizes_temp[idx] = 0, -1
  2520. # Cases where repeats is an int or single value tensor
  2521. if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
  2522. if not symbolic_helper._is_tensor(repeats):
  2523. repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
  2524. if input_sizes[dim] == 0:
  2525. return symbolic_helper._onnx_opset_unsupported_detailed(
  2526. "repeat_interleave",
  2527. 9,
  2528. 13,
  2529. "Unsupported along dimension with unknown input size",
  2530. )
  2531. else:
  2532. reps = input_sizes[dim]
  2533. repeats = expand(
  2534. g, repeats, g.op("Constant", value_t=torch.tensor([reps])), None
  2535. )
  2536. # Cases where repeats is a 1 dim Tensor
  2537. elif repeats_dim == 1:
  2538. if input_sizes[dim] == 0:
  2539. return symbolic_helper._onnx_opset_unsupported_detailed(
  2540. "repeat_interleave",
  2541. 9,
  2542. 13,
  2543. "Unsupported along dimension with unknown input size",
  2544. )
  2545. if repeats_sizes[0] is None:
  2546. return symbolic_helper._onnx_opset_unsupported_detailed(
  2547. "repeat_interleave", 9, 13, "Unsupported for cases with dynamic repeats"
  2548. )
  2549. assert (
  2550. repeats_sizes[0] == input_sizes[dim]
  2551. ), "repeats must have the same size as input along dim"
  2552. reps = repeats_sizes[0]
  2553. else:
  2554. raise RuntimeError("repeats must be 0-dim or 1-dim tensor")
  2555. final_splits = list()
  2556. r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0)
  2557. i_splits = symbolic_helper._repeat_interleave_split_helper(g, input, reps, dim)
  2558. input_sizes[dim], input_sizes_temp[dim] = -1, 1
  2559. for idx, r_split in enumerate(r_splits):
  2560. i_split = unsqueeze(g, i_splits[idx], dim + 1)
  2561. r_concat = [
  2562. g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])),
  2563. r_split,
  2564. g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])),
  2565. ]
  2566. r_concat = g.op("Concat", *r_concat, axis_i=0)
  2567. i_split = expand(g, i_split, r_concat, None)
  2568. i_split = symbolic_helper._reshape_helper(
  2569. g,
  2570. i_split,
  2571. g.op("Constant", value_t=torch.LongTensor(input_sizes)),
  2572. allowzero=0,
  2573. )
  2574. final_splits.append(i_split)
  2575. return g.op("Concat", *final_splits, axis_i=dim)
  2576. @symbolic_helper.parse_args("v", "i")
  2577. def pixel_shuffle(g, self, upscale_factor):
  2578. dims = symbolic_helper._get_tensor_sizes(self)
  2579. if len(dims) != 4:
  2580. return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input")
  2581. if any(i is None for i in dims[1:]):
  2582. after_view = symbolic_helper._reshape_helper(
  2583. g,
  2584. symbolic_helper._unsqueeze_helper(g, self, [2, 3]),
  2585. g.op(
  2586. "Constant",
  2587. value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]),
  2588. ),
  2589. allowzero=0,
  2590. )
  2591. after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
  2592. # For dynamic input shapes, two reshapes are performed
  2593. reshape_h = symbolic_helper._reshape_helper(
  2594. g,
  2595. after_transpose,
  2596. g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])),
  2597. allowzero=0,
  2598. )
  2599. reshape_w = symbolic_helper._reshape_helper(
  2600. g,
  2601. reshape_h,
  2602. g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])),
  2603. allowzero=0,
  2604. )
  2605. return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5])
  2606. else:
  2607. output_channel = dims[1] // upscale_factor // upscale_factor
  2608. after_view = symbolic_helper._reshape_helper(
  2609. g,
  2610. self,
  2611. g.op(
  2612. "Constant",
  2613. value_t=torch.tensor(
  2614. [
  2615. -1,
  2616. output_channel,
  2617. upscale_factor,
  2618. upscale_factor,
  2619. dims[2],
  2620. dims[3],
  2621. ]
  2622. ),
  2623. ),
  2624. allowzero=0,
  2625. )
  2626. after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
  2627. return symbolic_helper._reshape_helper(
  2628. g,
  2629. after_transpose,
  2630. g.op(
  2631. "Constant",
  2632. value_t=torch.tensor(
  2633. [
  2634. -1,
  2635. output_channel,
  2636. dims[2] * upscale_factor,
  2637. dims[3] * upscale_factor,
  2638. ]
  2639. ),
  2640. ),
  2641. allowzero=0,
  2642. )
  2643. @symbolic_helper.parse_args("v", "i")
  2644. def pixel_unshuffle(g, self, downscale_factor):
  2645. dims = symbolic_helper._get_tensor_sizes(self)
  2646. if len(dims) != 4:
  2647. return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input")
  2648. if any(i is None for i in dims[1:]):
  2649. # For dynamic input shapes, two reshapes are performed
  2650. reshape_h = symbolic_helper._reshape_helper(
  2651. g,
  2652. symbolic_helper._unsqueeze_helper(g, self, [3]),
  2653. g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])),
  2654. allowzero=0,
  2655. )
  2656. reshape_w = symbolic_helper._reshape_helper(
  2657. g,
  2658. reshape_h,
  2659. g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])),
  2660. allowzero=0,
  2661. )
  2662. after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4])
  2663. final_reshape = symbolic_helper._reshape_helper(
  2664. g,
  2665. after_transpose,
  2666. g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])),
  2667. allowzero=0,
  2668. )
  2669. return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3])
  2670. else:
  2671. output_channel = dims[1] * downscale_factor * downscale_factor
  2672. after_view = symbolic_helper._reshape_helper(
  2673. g,
  2674. self,
  2675. g.op(
  2676. "Constant",
  2677. value_t=torch.tensor(
  2678. [
  2679. -1,
  2680. dims[1],
  2681. dims[2] // downscale_factor,
  2682. downscale_factor,
  2683. dims[3] // downscale_factor,
  2684. downscale_factor,
  2685. ]
  2686. ),
  2687. ),
  2688. allowzero=0,
  2689. )
  2690. after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4])
  2691. return symbolic_helper._reshape_helper(
  2692. g,
  2693. after_transpose,
  2694. g.op(
  2695. "Constant",
  2696. value_t=torch.tensor(
  2697. [
  2698. -1,
  2699. output_channel,
  2700. dims[2] // downscale_factor,
  2701. dims[3] // downscale_factor,
  2702. ]
  2703. ),
  2704. ),
  2705. allowzero=0,
  2706. )
  2707. def _generic_rnn(
  2708. g,
  2709. variant,
  2710. input,
  2711. initial_states,
  2712. all_weights,
  2713. has_biases,
  2714. num_layers,
  2715. dropout,
  2716. train,
  2717. bidirectional,
  2718. batch_first=None,
  2719. batch_sizes=None,
  2720. ):
  2721. warnings.warn(
  2722. "Exporting a model to ONNX with a batch_size other than 1, "
  2723. + "with a variable length with "
  2724. + variant
  2725. + " can cause an error "
  2726. + "when running the ONNX model with a different batch size. "
  2727. + "Make sure to save the model with a batch size of 1, "
  2728. + "or define the initial states (h0/c0) as inputs of the model. "
  2729. )
  2730. onnxActivations = [
  2731. "Relu",
  2732. "Tanh",
  2733. "Sigmoid",
  2734. "Affine",
  2735. "LeakyRelu",
  2736. "ThresholdedRelu",
  2737. "ScaledTanh",
  2738. "HardSigmoid",
  2739. "Elu",
  2740. "Softsign",
  2741. "Softplus",
  2742. ]
  2743. variantToOnnxActivationMap = dict(
  2744. zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations)
  2745. )
  2746. weights_per_layer = 4 if has_biases else 2
  2747. # this means that projections are used inside LSTM, so need to tell user that it's not supported
  2748. if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * (
  2749. 1 + bidirectional
  2750. ):
  2751. return symbolic_helper._unimplemented("LSTM", "LSTMs with projections")
  2752. assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional)
  2753. layer_weights = [
  2754. all_weights[i : i + weights_per_layer]
  2755. for i in range(0, len(all_weights), weights_per_layer)
  2756. ]
  2757. if batch_first:
  2758. # batch, seq, feat -> seq, batch, feat
  2759. input = g.op("Transpose", input, perm_i=[1, 0, 2])
  2760. if dropout and train:
  2761. return symbolic_helper._unimplemented(
  2762. "RNN/GRU/LSTM", "dropout in training mode"
  2763. )
  2764. if variant.startswith("RNN"):
  2765. nonlinearity = variantToOnnxActivationMap[variant[4:].lower()]
  2766. variant = "RNN"
  2767. w_hh = all_weights[1]
  2768. hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1)
  2769. if hidden_size is None:
  2770. return symbolic_helper._unimplemented("RNN/GRU/LSTM", "unknown hidden size")
  2771. unidirectional = not bidirectional
  2772. prev_output = input
  2773. h_outs = []
  2774. if variant == "RNN" or variant == "GRU":
  2775. h0 = initial_states
  2776. elif variant == "LSTM":
  2777. h0, c0 = initial_states
  2778. c_outs = []
  2779. sequence_lens = unused(g) if batch_sizes is None else batch_sizes
  2780. if variant == "GRU":
  2781. # pytorch is reset, input, hidden
  2782. # onnx is input, reset, hidden
  2783. reform_permutation = [(1, 2), (0, 1), (2, 3)]
  2784. elif variant == "LSTM":
  2785. # pytorch is input, forget, cell, output.
  2786. # onnx is input, output, forget, cell.
  2787. reform_permutation = [(0, 1), (3, 4), (1, 3)]
  2788. def reform_weights(g, w, n, intervals):
  2789. slices = [
  2790. symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n])
  2791. for x, y in intervals
  2792. ]
  2793. return g.op("Concat", *slices, axis_i=0)
  2794. def transform_weights_no_bias(layer_index):
  2795. weights = layer_weights[layer_index]
  2796. if variant == "RNN":
  2797. weight_ih, weight_hh = weights
  2798. elif variant == "GRU" or variant == "LSTM":
  2799. weight_ih, weight_hh = [
  2800. reform_weights(g, w, hidden_size, reform_permutation) for w in weights
  2801. ]
  2802. return tuple(
  2803. symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh)
  2804. )
  2805. def transform_weights(layer_index):
  2806. weights = layer_weights[layer_index]
  2807. if variant == "RNN":
  2808. weight_ih, weight_hh, bias_ih, bias_hh = weights
  2809. elif variant == "GRU" or variant == "LSTM":
  2810. weight_ih, weight_hh, bias_ih, bias_hh = [
  2811. reform_weights(g, w, hidden_size, reform_permutation) for w in weights
  2812. ]
  2813. bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0)
  2814. return tuple(
  2815. symbolic_helper._unsqueeze_helper(g, x, [0])
  2816. for x in (weight_ih, weight_hh, bias_concat)
  2817. )
  2818. def retrieve_state(x, start, end):
  2819. return (
  2820. x
  2821. if num_layers == 1
  2822. else symbolic_helper._slice_helper(
  2823. g, x, axes=[0], starts=[start], ends=[end]
  2824. )
  2825. )
  2826. for i in range(num_layers):
  2827. if unidirectional:
  2828. if weights_per_layer == 4:
  2829. weight_ih, weight_hh, bias_concat = transform_weights(i)
  2830. else:
  2831. weight_ih, weight_hh = transform_weights_no_bias(i)
  2832. bias_concat = unused(g)
  2833. state_indices = i, i + 1
  2834. else:
  2835. if weights_per_layer == 4:
  2836. weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
  2837. weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
  2838. bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0)
  2839. else:
  2840. weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i)
  2841. weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1)
  2842. bias_concat = unused(g)
  2843. weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0)
  2844. weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0)
  2845. state_indices = 2 * i, 2 * i + 2
  2846. inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
  2847. inputs.append(retrieve_state(h0, *state_indices))
  2848. if variant == "LSTM":
  2849. inputs.append(retrieve_state(c0, *state_indices))
  2850. extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"}
  2851. if variant == "RNN":
  2852. if bidirectional:
  2853. activation = [nonlinearity, nonlinearity]
  2854. else:
  2855. activation = [nonlinearity]
  2856. prev_output, h_out = g.op(
  2857. "RNN",
  2858. *inputs,
  2859. outputs=2,
  2860. hidden_size_i=hidden_size,
  2861. activations_s=activation,
  2862. **extra_kwargs,
  2863. )
  2864. elif variant == "GRU":
  2865. prev_output, h_out = g.op(
  2866. "GRU",
  2867. *inputs,
  2868. outputs=2,
  2869. hidden_size_i=hidden_size,
  2870. linear_before_reset_i=1,
  2871. **extra_kwargs,
  2872. )
  2873. elif variant == "LSTM":
  2874. prev_output, h_out, c_out = g.op(
  2875. "LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs
  2876. )
  2877. if bidirectional:
  2878. # The ONNX RNN/GRU/LSTM produce an output of dimensions
  2879. # seq_len, num_directions, batch, hidden_size
  2880. # We have to convert to match pytorch's expected
  2881. # seq_len, batch, num_directions * hidden_size
  2882. # by first moving num_directions before hidden_size with
  2883. # Transpose, and then combining it with hidden_size
  2884. # with Reshape.
  2885. prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3])
  2886. prev_output = symbolic_helper._reshape_helper(
  2887. g,
  2888. prev_output,
  2889. g.op("Constant", value_t=torch.LongTensor([0, 0, -1])),
  2890. allowzero=0,
  2891. )
  2892. else:
  2893. prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1])
  2894. h_outs.append(h_out)
  2895. if variant == "LSTM":
  2896. c_outs.append(c_out)
  2897. if batch_first:
  2898. # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size
  2899. prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2])
  2900. h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0)
  2901. if variant == "RNN" or variant == "GRU":
  2902. return prev_output, h_outs
  2903. elif variant == "LSTM":
  2904. c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0)
  2905. return prev_output, h_outs, c_outs
  2906. @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
  2907. def _lstm_full(
  2908. g,
  2909. input,
  2910. hidden_v,
  2911. weight_v,
  2912. has_biases,
  2913. num_layers,
  2914. dropout,
  2915. train,
  2916. bidirectional,
  2917. batch_first,
  2918. ):
  2919. hidden, weight = symbolic_helper._unpack_list(
  2920. hidden_v
  2921. ), symbolic_helper._unpack_list(weight_v)
  2922. return _generic_rnn(
  2923. g,
  2924. "LSTM",
  2925. input,
  2926. hidden,
  2927. weight,
  2928. has_biases,
  2929. num_layers,
  2930. dropout,
  2931. train,
  2932. bidirectional,
  2933. batch_first,
  2934. )
  2935. @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i")
  2936. def _lstm_packed(
  2937. g,
  2938. input,
  2939. batch_sizes,
  2940. hidden_v,
  2941. weight_v,
  2942. has_biases,
  2943. num_layers,
  2944. dropout,
  2945. train,
  2946. bidirectional,
  2947. ):
  2948. hidden, weight = symbolic_helper._unpack_list(
  2949. hidden_v
  2950. ), symbolic_helper._unpack_list(weight_v)
  2951. return _generic_rnn(
  2952. g,
  2953. "LSTM",
  2954. input,
  2955. hidden,
  2956. weight,
  2957. has_biases,
  2958. num_layers,
  2959. dropout,
  2960. train,
  2961. bidirectional,
  2962. batch_sizes=batch_sizes,
  2963. )
  2964. def lstm(g, *args):
  2965. if symbolic_helper._is_tensor_list(args[3]):
  2966. return _lstm_packed(g, *args)
  2967. else:
  2968. return _lstm_full(g, *args)
  2969. def lstm_cell(g, self, hidden, w_ih, w_hh, b_ih, b_hh):
  2970. input = symbolic_helper._unsqueeze_helper(g, self, [0])
  2971. hidden = symbolic_helper._unpack_list(hidden)
  2972. hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden]
  2973. weight = (
  2974. (w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh)
  2975. )
  2976. has_biases = True if symbolic_helper._is_tensor(b_ih) else False
  2977. _, h_outs, c_outs = _generic_rnn(
  2978. g,
  2979. "LSTM",
  2980. input,
  2981. hidden,
  2982. weight,
  2983. has_biases,
  2984. num_layers=1,
  2985. dropout=0,
  2986. train=0,
  2987. bidirectional=False,
  2988. batch_first=False,
  2989. )
  2990. return symbolic_helper._squeeze_helper(
  2991. g, h_outs, [0]
  2992. ), symbolic_helper._squeeze_helper(g, c_outs, [0])
  2993. def _one_hidden_rnn(kind):
  2994. @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
  2995. def _rnn_full(
  2996. g,
  2997. input,
  2998. hidden,
  2999. weight_v,
  3000. has_biases,
  3001. num_layers,
  3002. dropout,
  3003. train,
  3004. bidirectional,
  3005. batch_first,
  3006. ):
  3007. weight = symbolic_helper._unpack_list(weight_v)
  3008. return _generic_rnn(
  3009. g,
  3010. kind,
  3011. input,
  3012. hidden,
  3013. weight,
  3014. has_biases,
  3015. num_layers,
  3016. dropout,
  3017. train,
  3018. bidirectional,
  3019. batch_first,
  3020. )
  3021. @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i")
  3022. def _rnn_packed(
  3023. g,
  3024. input,
  3025. batch_sizes,
  3026. hidden,
  3027. weight_v,
  3028. has_biases,
  3029. num_layers,
  3030. dropout,
  3031. train,
  3032. bidirectional,
  3033. ):
  3034. weight = symbolic_helper._unpack_list(weight_v)
  3035. return _generic_rnn(
  3036. g,
  3037. kind,
  3038. input,
  3039. hidden,
  3040. weight,
  3041. has_biases,
  3042. num_layers,
  3043. dropout,
  3044. train,
  3045. bidirectional,
  3046. batch_sizes=batch_sizes,
  3047. )
  3048. def symbolic(g, *args):
  3049. if symbolic_helper._is_tensor_list(args[3]):
  3050. return _rnn_packed(g, *args)
  3051. else:
  3052. return _rnn_full(g, *args)
  3053. return symbolic
  3054. gru = _one_hidden_rnn("GRU")
  3055. rnn_tanh = _one_hidden_rnn("RNN_TANH")
  3056. rnn_relu = _one_hidden_rnn("RNN_RELU")
  3057. @symbolic_helper.parse_args("v", "i")
  3058. def _dim_arange(g, like, dim):
  3059. like_shape = g.op("Shape", like)
  3060. stop = g.op(
  3061. "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
  3062. )
  3063. if symbolic_helper.is_caffe2_aten_fallback():
  3064. return g.op("_caffe2::Range", stop)
  3065. else:
  3066. # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
  3067. return arange(g, stop, 4, None, None, None)
  3068. def detach(g, input):
  3069. # Erase aten::detach nodes because ONNX is inference only
  3070. return input
  3071. @symbolic_helper.parse_args("v", "i")
  3072. def contiguous(g, input, memory_format):
  3073. if memory_format > 2: # allower values are any, preserve and contiguous_format
  3074. raise RuntimeError("onnx memory_format support is not implemented")
  3075. return input
  3076. @symbolic_helper.parse_args("v", "v", "i")
  3077. def _pack_padded_sequence(g, input, lengths, batch_first):
  3078. # Currently there is no PackPadded operator in ONNX. We rely on an
  3079. # optimization pass to remove this later. It is an error if all
  3080. # PackPadded operators cannot be optimized out.
  3081. if batch_first:
  3082. input = g.op("Transpose", input, perm_i=[1, 0, 2])
  3083. if not lengths.type().isSubtypeOf(torch._C.TensorType.get()):
  3084. raise RuntimeError("Lengths must be a Tensor for ONNX export")
  3085. # We know it's a TensorType so this check is now safe.
  3086. # It's really only necessary because those operators expand to something that
  3087. # only works with int32 types in Caffe2...
  3088. if lengths.type().scalarType() != "Int":
  3089. lengths = _cast_Int(g, lengths, False) # type: ignore[name-defined]
  3090. return g.op("prim::PackPadded", input, lengths, outputs=2)
  3091. @symbolic_helper.parse_args("v", "v", "i", "t", "v")
  3092. def _pad_packed_sequence(
  3093. g, data, batch_sizes, batch_first, padding_value, total_length
  3094. ):
  3095. # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence
  3096. # It is only useful/used when training using data_parallel model, so
  3097. # It shouldn't be relevant for ONNX anyway
  3098. data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2)
  3099. if batch_first:
  3100. data = g.op("Transpose", data, perm_i=[1, 0, 2])
  3101. return data, lengths
  3102. def randn(g, shapes, dtype, *options):
  3103. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  3104. if dtype is None:
  3105. dtype = symbolic_helper.ScalarType.FLOAT
  3106. shape = symbolic_helper._maybe_get_const(shapes, "is")
  3107. if symbolic_helper._is_value(shape):
  3108. shape_const = g.op(
  3109. "ConstantOfShape",
  3110. shapes,
  3111. value_t=torch.tensor(
  3112. [0], dtype=symbolic_helper.scalar_type_to_pytorch_type[6]
  3113. ),
  3114. )
  3115. return g.op(
  3116. "RandomNormalLike",
  3117. shape_const,
  3118. dtype_i=symbolic_helper.scalar_type_to_onnx[dtype],
  3119. )
  3120. return g.op("RandomNormal", shape_i=shape)
  3121. def rand(g, shapes, dtype, *options):
  3122. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  3123. if dtype is None:
  3124. dtype = symbolic_helper.ScalarType.FLOAT
  3125. shape = symbolic_helper._maybe_get_const(shapes, "is")
  3126. if symbolic_helper._is_value(shape):
  3127. shape_const = g.op(
  3128. "ConstantOfShape",
  3129. shapes,
  3130. value_t=torch.tensor(
  3131. [0], dtype=symbolic_helper.scalar_type_to_pytorch_type[6]
  3132. ),
  3133. )
  3134. return g.op(
  3135. "RandomUniformLike",
  3136. shape_const,
  3137. dtype_i=symbolic_helper.scalar_type_to_onnx[dtype],
  3138. )
  3139. return g.op("RandomUniform", shape_i=shape)
  3140. def randn_like(
  3141. g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None
  3142. ):
  3143. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  3144. if dtype is None:
  3145. dtype = symbolic_helper.ScalarType.FLOAT
  3146. return g.op(
  3147. "RandomNormalLike", self, dtype_i=symbolic_helper.scalar_type_to_onnx[dtype]
  3148. )
  3149. def rand_like(
  3150. g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None
  3151. ):
  3152. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  3153. if dtype is None:
  3154. dtype = symbolic_helper.ScalarType.FLOAT
  3155. return g.op(
  3156. "RandomUniformLike", self, dtype_i=symbolic_helper.scalar_type_to_onnx[dtype]
  3157. )
  3158. @symbolic_helper.parse_args("v", "f", "f", "i", "none")
  3159. def rrelu(g, input, lower, upper, training, generator):
  3160. p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower)
  3161. return g.op("PRelu", input, p)
  3162. def bernoulli(g, input, generator=None, out=None):
  3163. if out is not None:
  3164. symbolic_helper._unimplemented(
  3165. "Bernoulli", "out parameter is not supported for bernoulli"
  3166. )
  3167. if generator is not None and not symbolic_helper._is_none(generator):
  3168. symbolic_helper._unimplemented(
  3169. "Bernoulli", "generator is not supported for bernoulli"
  3170. )
  3171. dtype = symbolic_helper._try_get_scalar_type(input)
  3172. if dtype is None:
  3173. return symbolic_helper._unimplemented("Bernoulli", "input dtype not accessible")
  3174. p = g.op(
  3175. "RandomUniformLike",
  3176. input,
  3177. high_f=1.0,
  3178. low_f=0.0,
  3179. dtype_i=symbolic_helper.cast_pytorch_to_onnx[dtype],
  3180. )
  3181. output = g.op("Less", p, input)
  3182. return g.op("Cast", output, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
  3183. @symbolic_helper.parse_args("v")
  3184. def log_sigmoid(g, input):
  3185. p = g.op("Sigmoid", input)
  3186. return g.op("Log", p)
  3187. @symbolic_helper.parse_args("v")
  3188. def erf(g, input):
  3189. return g.op("Erf", input)
  3190. @symbolic_helper.quantized_args(True, False, False)
  3191. @symbolic_helper.parse_args("v", "i", "i")
  3192. def flatten(g, input, start_dim, end_dim):
  3193. dim = symbolic_helper._get_tensor_rank(input)
  3194. if dim is None:
  3195. return symbolic_helper._unimplemented(
  3196. "dim",
  3197. "ONNX and PyTorch use different strategies to split the input. "
  3198. "Input rank must be known at export time.",
  3199. )
  3200. # TODO: remove this as onnx opset 11 spec allows negative axes
  3201. if end_dim < 0:
  3202. end_dim = dim + end_dim
  3203. # use ONNX's Flatten operator for cases where the output shape is 2D
  3204. if start_dim == 1 and end_dim == dim - 1:
  3205. return g.op("Flatten", input, axis_i=start_dim)
  3206. if start_dim == 0 and end_dim == dim - 2:
  3207. return g.op("Flatten", input, axis_i=end_dim + 1)
  3208. return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim)
  3209. @symbolic_helper.parse_args("v")
  3210. def nonzero(g, input):
  3211. """Emitted from `torch.nonzero(x, as_tuple=False)`"""
  3212. return t(g, g.op("NonZero", input))
  3213. # Emitted from `torch.nonzero(x, as_tuple=True)`
  3214. def nonzero_numpy(g, input, _outputs=None):
  3215. return unbind(g, nonzero(g, input), 1, _outputs=_outputs)
  3216. @symbolic_helper.parse_args("v")
  3217. def isnan(g, input):
  3218. output = g.op("IsNaN", input)
  3219. return output
  3220. def _any(g, *args):
  3221. # aten::any(Tensor self)
  3222. if len(args) == 1:
  3223. input = args[0]
  3224. dim, keepdim = None, 0
  3225. # aten::any(Tensor self, int dim, bool keepdim)
  3226. else:
  3227. input, dim, keepdim = args
  3228. dim = [symbolic_helper._parse_arg(dim, "i")]
  3229. keepdim = symbolic_helper._parse_arg(keepdim, "i")
  3230. input = _cast_Long(g, input, False) # type: ignore[name-defined]
  3231. input_sum = symbolic_helper._reducesum_helper(
  3232. g, input, axes_i=dim, keepdims_i=keepdim
  3233. )
  3234. return gt(g, input_sum, g.op("Constant", value_t=torch.LongTensor([0])))
  3235. def _all(g, *args):
  3236. input = g.op("Not", args[0])
  3237. # aten::all(Tensor self)
  3238. if len(args) == 1:
  3239. return g.op("Not", _any(g, input))
  3240. # aten::all(Tensor self, int dim, bool keepdim)
  3241. else:
  3242. return g.op("Not", _any(g, input, args[1], args[2]))
  3243. @symbolic_helper.parse_args("v", "i", "i", "i")
  3244. def narrow(g, input, dim, start, length):
  3245. return symbolic_helper._slice_helper(
  3246. g, input, axes=[dim], starts=[start], ends=[start + length]
  3247. )
  3248. def argmax(g, input, dim, keepdim):
  3249. if symbolic_helper._is_none(dim):
  3250. flattened = symbolic_helper._reshape_helper(
  3251. g, input, g.op("Constant", value_t=torch.tensor([-1]))
  3252. )
  3253. return g.op("ArgMax", flattened, axis_i=0, keepdims_i=False)
  3254. else:
  3255. dim = symbolic_helper._parse_arg(dim, "i")
  3256. keepdim = symbolic_helper._parse_arg(keepdim, "i")
  3257. return g.op("ArgMax", input, axis_i=dim, keepdims_i=keepdim)
  3258. def argmin(g, input, dim, keepdim):
  3259. if symbolic_helper._is_none(dim):
  3260. flattened = symbolic_helper._reshape_helper(
  3261. g, input, g.op("Constant", value_t=torch.tensor([-1]))
  3262. )
  3263. return g.op("ArgMin", flattened, axis_i=0, keepdims_i=False)
  3264. else:
  3265. dim = symbolic_helper._parse_arg(dim, "i")
  3266. keepdim = symbolic_helper._parse_arg(keepdim, "i")
  3267. return g.op("ArgMin", input, axis_i=dim, keepdims_i=keepdim)
  3268. @symbolic_helper.parse_args("v", "i", "v", "v")
  3269. def scatter(g, self, dim, index, src):
  3270. src_type = src.type().scalarType()
  3271. src = symbolic_helper._maybe_get_scalar(src)
  3272. if symbolic_helper._is_value(src):
  3273. return g.op("Scatter", self, index, src, axis_i=dim)
  3274. else:
  3275. # Check if scalar "src" has same type as self (PyTorch allows different
  3276. # type for scalar src (but not when src is tensor)). If not, insert Cast node.
  3277. if self.type().scalarType() != src_type:
  3278. src = g.op(
  3279. "Cast",
  3280. src,
  3281. to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
  3282. )
  3283. return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim)
  3284. @symbolic_helper.parse_args("v", "i", "v", "v")
  3285. def scatter_add(g, self, dim, index, src):
  3286. dtype = symbolic_helper._try_get_scalar_type(self)
  3287. if dtype is None:
  3288. return symbolic_helper._unimplemented(
  3289. "scatter_add", "input dtype not accessible"
  3290. )
  3291. dtype = symbolic_helper.scalar_type_to_onnx.index(
  3292. symbolic_helper.cast_pytorch_to_onnx[dtype]
  3293. )
  3294. dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype]
  3295. sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False)
  3296. if sizes:
  3297. to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=dtype))
  3298. else:
  3299. dtype = symbolic_helper.scalar_type_to_pytorch_type.index(dtype)
  3300. to_add = zeros_like(g, self, dtype)
  3301. to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src)
  3302. return add(g, self, to_add)
  3303. def log2(g, self):
  3304. _ln2 = 0.693147180559945309
  3305. return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln2])))
  3306. def is_floating_point(g, self):
  3307. if symbolic_helper._is_fp(self):
  3308. return g.op("Constant", value_t=torch.BoolTensor([1]))
  3309. return g.op("Constant", value_t=torch.BoolTensor([0]))
  3310. def __is_(g, self, other):
  3311. if symbolic_helper._is_none(other):
  3312. if symbolic_helper._is_none(self):
  3313. return g.op("Constant", value_t=torch.BoolTensor([1]))
  3314. return g.op("Constant", value_t=torch.BoolTensor([0]))
  3315. return eq(g, self, other)
  3316. @wrap_logical_op_with_negation
  3317. def __isnot_(g, self, other):
  3318. return __is_(g, self, other)
  3319. def one_hot(g, self, num_classes):
  3320. values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
  3321. # onnxruntime supports limited type combinations for OneHot.
  3322. if num_classes.type().scalarType() in ("Byte", "Char", "Int", "Short"):
  3323. num_classes = g.op(
  3324. "Cast", num_classes, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"]
  3325. )
  3326. return g.op("OneHot", self, num_classes, values, axis_i=-1)
  3327. @symbolic_helper.parse_args("v", "i", "v", "v")
  3328. def gather(g, self, dim, index, sparse_grad=False):
  3329. if symbolic_helper._maybe_get_const(sparse_grad, "i"):
  3330. return symbolic_helper._unimplemented("gather", "sparse_grad == True")
  3331. # NOTE: This workaround is needed since GatherElement is only supported
  3332. # since opset 11, and Gather in ONNX is not the same as torch.gather.
  3333. dtype = self.type().scalarType()
  3334. values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
  3335. depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim])))
  3336. index = g.op(
  3337. "Cast",
  3338. g.op("OneHot", index, depth, values, axis_i=dim),
  3339. to_i=symbolic_helper.cast_pytorch_to_onnx[dtype],
  3340. )
  3341. mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index)
  3342. return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0)
  3343. @symbolic_helper.parse_args("v", "is", "i", "i")
  3344. def _var_mean(g, input, dim, correction, keepdim):
  3345. if dim is None:
  3346. mean = g.op("ReduceMean", input, keepdims_i=0)
  3347. t_mean = mean
  3348. num_elements = numel(g, input)
  3349. else:
  3350. mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim)
  3351. t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1)
  3352. redudced_dims = g.op("Shape", input)
  3353. # dim could contain one or multiple dimensions
  3354. redudced_dims = g.op(
  3355. "Gather",
  3356. redudced_dims,
  3357. g.op("Constant", value_t=torch.tensor(dim)),
  3358. axis_i=0,
  3359. )
  3360. num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
  3361. sub_v = g.op("Sub", input, t_mean)
  3362. sqr_sub = g.op("Mul", sub_v, sub_v)
  3363. keepdim_mean = 0 if dim is None else keepdim
  3364. var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean)
  3365. # Correct bias in calculating variance, by dividing it over (N - correction) instead on N
  3366. if correction is None:
  3367. correction = 1
  3368. if correction != 0:
  3369. num_elements = g.op(
  3370. "Cast", num_elements, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"]
  3371. )
  3372. one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float))
  3373. mul = g.op("Mul", var, num_elements)
  3374. var = g.op("Div", mul, g.op("Sub", num_elements, one))
  3375. return var, mean
  3376. def std(g, input, *args):
  3377. var, _ = var_mean(g, input, *args)
  3378. return g.op("Sqrt", var)
  3379. def var(g, input, *args):
  3380. var, _ = var_mean(g, input, *args)
  3381. return var
  3382. # var_mean (and all variance-related functions) has multiple signatures, so need to manually figure
  3383. # out the correct arguments:
  3384. # aten::var_mean(Tensor self, bool unbiased)
  3385. # aten::var_mean(Tensor self, int[1] dim, bool unbiased, bool keepdim=False)
  3386. # aten::var_mean(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False)
  3387. def var_mean(g, input, *args):
  3388. if len(args) == 1:
  3389. return _var_mean(g, input, None, args[0], None)
  3390. else:
  3391. return _var_mean(g, input, *args)
  3392. def std_mean(g, input, *args):
  3393. var, mean = var_mean(g, input, *args)
  3394. return g.op("Sqrt", var), mean
  3395. @symbolic_helper.parse_args("v", "is", "i")
  3396. def logsumexp(g, input, dim, keepdim):
  3397. return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim)
  3398. def arange(g, *args):
  3399. if symbolic_helper.is_caffe2_aten_fallback():
  3400. return g.at("arange", *args)
  3401. def _get_arange_dtype(dtype):
  3402. dtype = symbolic_helper._maybe_get_const(dtype, "i")
  3403. return dtype
  3404. def _float_step_convert(range_tensor):
  3405. if symbolic_helper._is_fp(range_tensor):
  3406. range_tensor = g.op(
  3407. "Cast",
  3408. g.op("Ceil", range_tensor),
  3409. to_i=symbolic_helper.scalar_type_to_onnx[4],
  3410. )
  3411. return range_tensor
  3412. if len(args) == 2 or len(args) == 5:
  3413. if len(args) == 2:
  3414. # aten::arange(Scalar end, Tensor out)
  3415. dtype = None
  3416. else:
  3417. # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
  3418. dtype = _get_arange_dtype(args[1])
  3419. dtype, end, start, step = symbolic_helper._arange_cast_helper(
  3420. g, end=args[0], dtype=dtype
  3421. )
  3422. end = symbolic_helper._unsqueeze_helper(g, end, [0])
  3423. range_tensor = _float_step_convert(end)
  3424. arange_tensor = symbolic_helper._squeeze_helper(
  3425. g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1]
  3426. )
  3427. return g.op(
  3428. "Cast", arange_tensor, to_i=symbolic_helper.scalar_type_to_onnx[dtype]
  3429. )
  3430. elif len(args) == 4 or len(args) == 7:
  3431. if len(args) == 4:
  3432. # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
  3433. dtype = None
  3434. else:
  3435. # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
  3436. dtype = _get_arange_dtype(args[3])
  3437. dtype, end, start, step = symbolic_helper._arange_cast_helper(
  3438. g, start=args[0], end=args[1], step=args[2], dtype=dtype
  3439. )
  3440. step = symbolic_helper._unsqueeze_helper(g, step, [0])
  3441. end = symbolic_helper._unsqueeze_helper(g, end, [0])
  3442. start = symbolic_helper._unsqueeze_helper(g, start, [0])
  3443. range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step))
  3444. arange_tensor = symbolic_helper._squeeze_helper(
  3445. g, nonzero(g, ones(g, range_tensor, None, None, None)), [1]
  3446. )
  3447. arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start)
  3448. return g.op(
  3449. "Cast", arange_tensor, to_i=symbolic_helper.scalar_type_to_onnx[dtype]
  3450. )
  3451. elif len(args) == 6:
  3452. # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
  3453. dtype = _get_arange_dtype(args[2])
  3454. dtype, end, start, step = symbolic_helper._arange_cast_helper(
  3455. g, start=args[0], end=args[1], dtype=dtype
  3456. )
  3457. end = symbolic_helper._unsqueeze_helper(g, end, [0])
  3458. start = symbolic_helper._unsqueeze_helper(g, start, [0])
  3459. range_tensor = _float_step_convert(g.op("Sub", end, start))
  3460. arange_tensor = g.op(
  3461. "Add",
  3462. symbolic_helper._squeeze_helper(
  3463. g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1]
  3464. ),
  3465. start,
  3466. )
  3467. return g.op(
  3468. "Cast", arange_tensor, to_i=symbolic_helper.scalar_type_to_onnx[dtype]
  3469. )
  3470. else:
  3471. raise NotImplementedError(
  3472. "Unknown aten::arange signature taking " + str(len(args)) + " arguments."
  3473. )
  3474. def linspace(g, start, end, steps, dtype, layout, device, pin_memory):
  3475. range_tensor = symbolic_helper._arange_helper(g, steps, None)
  3476. step = div(
  3477. g,
  3478. sub(g, end, start),
  3479. sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))),
  3480. )
  3481. return add(g, mul(g, range_tensor, step), start)
  3482. def lift(g, self):
  3483. # at::lift() is a no-op from the perspective of tracing for onnx
  3484. return self
  3485. def masked_fill(g, self, mask, value):
  3486. mask = _cast_Bool(g, mask, False) # type: ignore[name-defined]
  3487. value = symbolic_helper._maybe_get_scalar(value)
  3488. return g.op("Where", mask, symbolic_helper._if_scalar_type_as(g, value, self), self)
  3489. def index(g, self, index):
  3490. if symbolic_helper.is_caffe2_aten_fallback():
  3491. return g.at("index", self, index, overload_name="Tensor")
  3492. if symbolic_helper._is_packed_list(index):
  3493. indices = symbolic_helper._unpack_list(index)
  3494. else:
  3495. indices = [index]
  3496. def try_mask_to_index(index):
  3497. if not symbolic_helper._is_none(index) and (
  3498. index.type().scalarType() == "Byte" or index.type().scalarType() == "Bool"
  3499. ):
  3500. if GLOBALS.export_onnx_opset_version < 9:
  3501. raise RuntimeError(
  3502. "Exporting masked indices are only supported after ONNX opset 9."
  3503. )
  3504. warnings.warn(
  3505. "Exporting aten::index operator with indices of type Byte. "
  3506. "Only 1-D indices are supported. In any other case, "
  3507. "this will produce an incorrect ONNX graph."
  3508. )
  3509. index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1])
  3510. return index
  3511. indices = [try_mask_to_index(idx) for idx in indices]
  3512. if len(indices) == 1:
  3513. return symbolic_helper._select_helper(
  3514. g, self, 0, indices[0], apply_reshape=False
  3515. )
  3516. else:
  3517. # Multiple tensors as indices. Each tensor could either be
  3518. # 1. prim::Constant()
  3519. # representing ":" in python indexing. E.g. tensor[:, :]
  3520. # 2. prim::Constant[value=...] or tensor output
  3521. # representing advanced indexing. E.g. tensor[[0, 1], [2, 0]].
  3522. # For more info on advanced indexing,
  3523. # check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
  3524. # Consider a general case of
  3525. # t: [x_1, y_1, y_2, ..., x_m, ..., y_n]
  3526. # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":".
  3527. # Same results can be achieved through transposing t into
  3528. # t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n]
  3529. # and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t
  3530. # and process the tensor indices.
  3531. # t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n]
  3532. # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j))
  3533. # After gather, reshape and transpose back.
  3534. adv_idx_indices = [
  3535. i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx)
  3536. ]
  3537. if len(adv_idx_indices) == 0:
  3538. return self
  3539. elif len(adv_idx_indices) == 1:
  3540. return index_select(
  3541. g, self, adv_idx_indices[0], indices[adv_idx_indices[0]]
  3542. )
  3543. else:
  3544. rank = symbolic_helper._get_tensor_rank(self)
  3545. if rank is None:
  3546. raise NotImplementedError(
  3547. "Unsupported aten::index operator of advanced indexing on tensor of unknown rank, "
  3548. + "try turning on shape and type propagate during export: "
  3549. + "torch.onnx._export(..., propagate=True)."
  3550. )
  3551. # TODO: If indexing is supported natively in ONNX in future opsets,
  3552. # update the warning to recommend exporting with higher opset version.
  3553. warnings.warn(
  3554. "Exporting aten::index operator of advanced indexing in opset "
  3555. + str(GLOBALS.export_onnx_opset_version)
  3556. + " is achieved by combination of multiple ONNX operators, "
  3557. + "including Reshape, Transpose, Concat, and Gather. "
  3558. + "If indices include negative values, the exported graph will produce incorrect results."
  3559. )
  3560. adv_idx_count = len(adv_idx_indices)
  3561. shape_tensor = _shape_as_tensor(g, self)
  3562. dim_tensor_list = [
  3563. g.op(
  3564. "Gather",
  3565. shape_tensor,
  3566. g.op("Constant", value_t=torch.LongTensor([dim])),
  3567. axis_i=0,
  3568. )
  3569. for dim in range(rank)
  3570. ]
  3571. self = g.op(
  3572. "Transpose",
  3573. self,
  3574. perm_i=adv_idx_indices
  3575. + [i for i in range(rank) if i not in adv_idx_indices],
  3576. )
  3577. self = g.op("Flatten", self, axis_i=adv_idx_count)
  3578. # Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well.
  3579. cum_adv_index = indices[adv_idx_indices[-1]]
  3580. multiplier = dim_tensor_list[adv_idx_indices[-1]]
  3581. for i in range(adv_idx_count - 2, -1, -1):
  3582. adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier)
  3583. cum_adv_index = g.op("Add", cum_adv_index, adv_index)
  3584. multiplier = g.op(
  3585. "Mul", multiplier, dim_tensor_list[adv_idx_indices[i]]
  3586. )
  3587. # perform gather
  3588. self = index_select(g, self, 0, cum_adv_index)
  3589. cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index)
  3590. # check if all advanced indices are consecutive.
  3591. # Refer to https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
  3592. # to understand how the subarray position is decided.
  3593. if adv_idx_indices == list(
  3594. range(adv_idx_indices[0], adv_idx_indices[-1] + 1)
  3595. ):
  3596. # unfold regular index axes
  3597. folded_adv_idx_shape_list = [
  3598. g.op("Constant", value_t=torch.LongTensor([-1]))
  3599. ] + [
  3600. dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices
  3601. ]
  3602. folded_adv_idx_shape = g.op(
  3603. "Concat", *folded_adv_idx_shape_list, axis_i=0
  3604. )
  3605. self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape)
  3606. # Transpose folded advanced indexed axis to its original location.
  3607. adv_idx_permute = (
  3608. list(range(1, adv_idx_indices[0] + 1))
  3609. + [0]
  3610. + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1))
  3611. )
  3612. self = g.op("Transpose", self, perm_i=adv_idx_permute)
  3613. # unfold advanced index axes
  3614. final_shape_list = (
  3615. [dim_tensor_list[i] for i in range(adv_idx_indices[0])]
  3616. + [cum_adv_index_shape_tensor]
  3617. + [
  3618. dim_tensor_list[i]
  3619. for i in range(adv_idx_indices[0], rank)
  3620. if i not in adv_idx_indices
  3621. ]
  3622. )
  3623. final_shape = g.op("Concat", *final_shape_list, axis_i=0)
  3624. else:
  3625. final_shape = g.op(
  3626. "Concat",
  3627. cum_adv_index_shape_tensor,
  3628. *[
  3629. dim_tensor_list[i]
  3630. for i in range(rank)
  3631. if i not in adv_idx_indices
  3632. ],
  3633. axis_i=0,
  3634. )
  3635. return symbolic_helper._reshape_helper(g, self, final_shape)
  3636. @symbolic_helper.parse_args("v", "v", "is", "i", "v")
  3637. def linalg_norm(g, self, ord, dim, keepdim, dtype):
  3638. # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html
  3639. ord_value = None
  3640. if dim is None:
  3641. if symbolic_helper._is_none(ord):
  3642. self = symbolic_helper._reshape_helper(g, self, [-1])
  3643. ord = g.op("Constant", value_t=torch.LongTensor([2]))
  3644. self_dim = symbolic_helper._get_tensor_rank(self)
  3645. if self_dim is None:
  3646. return symbolic_helper._unimplemented(
  3647. "dim", "Input rank must be known at export time."
  3648. )
  3649. if self_dim == 1:
  3650. ord_value = symbolic_helper._parse_arg(ord, "f")
  3651. else:
  3652. dim = [0, 1]
  3653. else:
  3654. if len(dim) == 1:
  3655. if symbolic_helper._is_none(ord):
  3656. ord = g.op("Constant", value_t=torch.LongTensor([2]))
  3657. ord_value = symbolic_helper._parse_arg(ord, "f")
  3658. if ord_value:
  3659. return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype)
  3660. return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype)
  3661. @symbolic_helper.parse_args("v", "f", "is", "i", "v")
  3662. def linalg_vector_norm(g, self, ord, dim, keepdim, dtype):
  3663. # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html
  3664. if dim is None:
  3665. self = symbolic_helper._reshape_helper(g, self, [-1])
  3666. keepdim = None
  3667. if ord == math.inf:
  3668. result = g.op("ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim)
  3669. elif ord == -math.inf:
  3670. result = g.op("ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim)
  3671. elif ord == 0:
  3672. return symbolic_helper._onnx_opset_unsupported_detailed(
  3673. "linalg_vector_norm", 9, 11, "ord=0 not supported"
  3674. )
  3675. else:
  3676. ord_op = g.op("Constant", value_t=torch.FloatTensor([ord]))
  3677. result = symbolic_helper._reducesum_helper(
  3678. g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim
  3679. )
  3680. result = g.op(
  3681. "Pow",
  3682. result,
  3683. g.op("Div", g.op("Constant", value_t=torch.FloatTensor([1])), ord_op),
  3684. )
  3685. return result
  3686. @symbolic_helper.parse_args("v", "v", "is", "i", "v")
  3687. def linalg_matrix_norm(g, self, ord, dim, keepdim, dtype):
  3688. # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html
  3689. ord_value = symbolic_helper._parse_arg(ord, "s")
  3690. if ord_value == "fro":
  3691. return frobenius_norm(g, self, dim, keepdim)
  3692. elif ord_value == "nuc":
  3693. return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc")
  3694. else:
  3695. ord_value = symbolic_helper._parse_arg(ord, "f")
  3696. if ord_value is None:
  3697. return frobenius_norm(g, self, dim, keepdim)
  3698. if ord_value == 2 or ord_value == -2:
  3699. # ord = 2/-2 unimplemented due to lack of operators
  3700. # used to calculate singular values
  3701. return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2")
  3702. # Wrap the dim vector to handle neagtive dim values
  3703. self_dim = symbolic_helper._get_tensor_rank(self)
  3704. if self_dim is None:
  3705. return symbolic_helper._unimplemented(
  3706. "linalg.matrix_norm", "Input rank must be known at export time."
  3707. )
  3708. # Common implementation for cases with
  3709. # ord = 1/-1 and ord = inf/-inf
  3710. if dim[0] < 0:
  3711. dim[0] += self_dim
  3712. if dim[1] < 0:
  3713. dim[1] += self_dim
  3714. if ord_value == math.inf or ord_value == -math.inf:
  3715. dim[0], dim[1] = dim[1], dim[0]
  3716. if dim[1] > dim[0] and not keepdim:
  3717. dim[1] -= 1
  3718. sum = symbolic_helper._reducesum_helper(
  3719. g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim
  3720. )
  3721. if ord_value > 0:
  3722. result, indices = max(
  3723. g,
  3724. sum,
  3725. dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
  3726. keepdim=keepdim,
  3727. )
  3728. else:
  3729. result, indices = min(
  3730. g,
  3731. sum,
  3732. dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])),
  3733. keepdim=keepdim,
  3734. )
  3735. return result
  3736. @symbolic_helper.parse_args("v", "v", "i")
  3737. def linalg_cross(g, input, other, dim=-1):
  3738. return cross(g, input, other, dim)
  3739. @symbolic_helper.parse_args("v", "is", "i")
  3740. def frobenius_norm(g, self, dim=None, keepdim=False):
  3741. sqr = g.op("Mul", self, self)
  3742. sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim)
  3743. return g.op("Sqrt", sumsqr)
  3744. @symbolic_helper.parse_args("v", "i", "b", "v")
  3745. def multinomial(g, input, num_samples, replacement=False, generator=None):
  3746. if generator is not None and not symbolic_helper._is_none(generator):
  3747. symbolic_helper._unimplemented(
  3748. "Multinomial", "generator is not supported for multinomial"
  3749. )
  3750. if not replacement and num_samples > 1:
  3751. symbolic_helper._unimplemented(
  3752. "Multinomial",
  3753. "replacement=False when num_samples > 1 is not supported for multinomial",
  3754. )
  3755. log_input = log(g, input)
  3756. return g.op(
  3757. "Multinomial",
  3758. log_input,
  3759. dtype_i=symbolic_helper.cast_pytorch_to_onnx["Long"],
  3760. sample_size_i=num_samples,
  3761. )
  3762. def baddbmm(g, self, batch1, batch2, beta, alpha):
  3763. dtype = self.type().scalarType()
  3764. batch_mul = matmul(g, batch1, batch2)
  3765. mul_a = mul(
  3766. g,
  3767. batch_mul,
  3768. g.op("Cast", alpha, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]),
  3769. )
  3770. mul_b = mul(
  3771. g, self, g.op("Cast", beta, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
  3772. )
  3773. return add(g, mul_a, mul_b)
  3774. @symbolic_helper.parse_args("v", "s")
  3775. def meshgrid(g, tensor_list, indexing: Optional[str] = None):
  3776. if indexing is None:
  3777. indexing = "ij"
  3778. elif indexing not in {"ij", "xy"}:
  3779. raise ValueError(f"Unsupported indexing: {indexing}")
  3780. if indexing == "xy":
  3781. tensor_list[0], tensor_list[1] = tensor_list[1], tensor_list[0]
  3782. tensors = [
  3783. symbolic_helper._reshape_helper(
  3784. g, t, g.op("Constant", value_t=torch.LongTensor([-1]))
  3785. )
  3786. for t in symbolic_helper._unpack_list(tensor_list)
  3787. ]
  3788. tensors_shape = [g.op("Shape", t) for t in tensors]
  3789. out_shape = g.op("Concat", *tensors_shape, axis_i=0)
  3790. out = []
  3791. for i, t in enumerate(tensors):
  3792. shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len(
  3793. tensors
  3794. )
  3795. shape_i[i] = tensors_shape[i]
  3796. t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0))
  3797. out.append(g.op("Expand", t_reshaped, out_shape))
  3798. if indexing == "xy":
  3799. out[0], out[1] = out[1], out[0]
  3800. return g.op("prim::ListConstruct", *out)
  3801. def remainder(g, input, other):
  3802. div = _floor_divide(g, input, other)
  3803. quo = g.op("Mul", div, other)
  3804. return g.op("Sub", input, quo)
  3805. @symbolic_helper.parse_args("v", "s")
  3806. def gelu(g, self: torch._C.Value, approximate: str = "none"):
  3807. if approximate == "tanh":
  3808. kBeta = math.sqrt(2 / math.pi)
  3809. kKappa = 0.044715
  3810. beta = torch.tensor(kBeta, dtype=torch.double)
  3811. kappa = torch.tensor(kKappa, dtype=torch.double)
  3812. one = torch.tensor(1.0, dtype=torch.double)
  3813. half = torch.tensor(0.5, dtype=torch.double)
  3814. self_cube = mul(g, self, mul(g, self, self))
  3815. inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube)))
  3816. return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner))))
  3817. else:
  3818. _sqrt2 = 1.4142135623730951
  3819. erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double)))
  3820. erf_plusone = add(
  3821. g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double))
  3822. )
  3823. return mul(
  3824. g,
  3825. mul(g, self, erf_plusone),
  3826. g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)),
  3827. )
  3828. @symbolic_helper.parse_args("v", "i", "v", "v", "f", "i")
  3829. def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
  3830. if symbolic_helper.is_caffe2_aten_fallback():
  3831. return g.at(
  3832. "group_norm",
  3833. input,
  3834. weight,
  3835. bias,
  3836. num_groups_i=num_groups,
  3837. eps_f=eps,
  3838. cudnn_enabled_i=cudnn_enabled,
  3839. )
  3840. channel_size = symbolic_helper._get_tensor_dim_size(input, 1)
  3841. if channel_size is not None:
  3842. assert channel_size % num_groups == 0
  3843. input_rank = symbolic_helper._get_tensor_rank(input)
  3844. if input_rank is None:
  3845. return symbolic_helper._unimplemented("group_norm", "unknown input rank")
  3846. # 0 in the shape list keeps dimension value unchanged.
  3847. shape = [0, num_groups, -1]
  3848. input_reshaped = symbolic_helper._reshape_helper(
  3849. g, input, g.op("Constant", value_t=torch.LongTensor(shape))
  3850. )
  3851. # C is always divisible by num_groups
  3852. # Due to shape difference. we need to apply weight and bias after
  3853. # instance norm computation and reshape
  3854. weight_ = g.op(
  3855. "Constant",
  3856. value_t=torch.tensor([1.0] * num_groups).type(
  3857. "torch." + input.type().scalarType() + "Tensor"
  3858. ),
  3859. )
  3860. bias_ = g.op(
  3861. "Constant",
  3862. value_t=torch.tensor([0.0] * num_groups).type(
  3863. "torch." + input.type().scalarType() + "Tensor"
  3864. ),
  3865. )
  3866. norm_reshaped = g.op(
  3867. "InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps
  3868. )
  3869. norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input))
  3870. if weight is None or weight.node().mustBeNone():
  3871. weight_value = torch.tensor([1.0]).type(
  3872. "torch." + input.type().scalarType() + "Tensor"
  3873. )
  3874. weight = g.op("Constant", value_t=weight_value)
  3875. if bias is None or bias.node().mustBeNone():
  3876. bias_value = torch.tensor([0.0]).type(
  3877. "torch." + input.type().scalarType() + "Tensor"
  3878. )
  3879. bias = g.op("Constant", value_t=bias_value)
  3880. # Norm has shape [N, C, *] so we reshape weight and bias to [C, *]
  3881. axes = list(range(1, input_rank - 1))
  3882. return add(
  3883. g,
  3884. mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)),
  3885. symbolic_helper._unsqueeze_helper(g, bias, axes),
  3886. )
  3887. @symbolic_helper.parse_args("v", "v", "i")
  3888. def _weight_norm(g, weight_v, weight_g, dim):
  3889. rank = symbolic_helper._get_tensor_rank(weight_v)
  3890. if rank is not None:
  3891. # W = g * ((v) / ||v||)
  3892. # Compute norm_except_dim for l2 norm. dim = None means over all dims
  3893. # torch's weight_norm module sets dim = -1 if it's None.
  3894. # This conflicts the logic for negative axes to access dims backwards
  3895. # TODO: Might need a fix in torch group_norm module
  3896. axes = list(range(rank))
  3897. if dim is not None:
  3898. if dim < -1:
  3899. dim += rank
  3900. if dim != -1:
  3901. axes.remove(dim)
  3902. norm_v = norm(g, weight_v, 2, axes, 1)
  3903. div = g.op("Div", weight_v, norm_v)
  3904. return g.op("Mul", div, weight_g)
  3905. elif symbolic_helper.is_caffe2_aten_fallback():
  3906. return g.at("_weight_norm", weight_v, weight_g, dim_i=dim)
  3907. else:
  3908. raise RuntimeError(
  3909. "Unsupported: ONNX export of _weight_norm for tensor " "of unknown rank."
  3910. )
  3911. def dim(g, self):
  3912. """Implement the dim functionality available for a pytorch tensor in ONNX"""
  3913. # ONNX does not support dim directly in this opset so we can use 2 ops to get the info
  3914. shape = g.op("Shape", self)
  3915. return g.op("Size", shape)
  3916. def __getitem_(g, self, i):
  3917. return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i)
  3918. def item(g, self):
  3919. return self
  3920. def take(g, self, index):
  3921. self_flattened = symbolic_helper._reshape_helper(
  3922. g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
  3923. )
  3924. out = index_select(g, self_flattened, 0, index)
  3925. out = reshape_as(g, out, index)
  3926. return out
  3927. def _kl_div_log_target_impl(g, input, target):
  3928. diff_ = sub(g, target, input)
  3929. exp_ = exp(g, target)
  3930. output = mul(g, exp_, diff_)
  3931. return output
  3932. def _kl_div_non_log_target_impl(g, input, target):
  3933. log_ = log(g, target)
  3934. diff_ = sub(g, log_, input)
  3935. output_pos = mul(g, target, diff_)
  3936. zeros_ = zeros_like(g, output_pos)
  3937. mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0)))
  3938. output = where(g, mask_, output_pos, zeros_)
  3939. return output
  3940. @symbolic_helper.parse_args("v", "v", "i", "b")
  3941. def kl_div(g, input, target, reduction, log_target):
  3942. if log_target:
  3943. output = _kl_div_log_target_impl(g, input, target)
  3944. else:
  3945. output = _kl_div_non_log_target_impl(g, input, target)
  3946. if reduction == 0:
  3947. return output
  3948. elif reduction == 1:
  3949. return g.op("ReduceMean", output, keepdims_i=0)
  3950. elif reduction == 2:
  3951. return symbolic_helper._reducesum_helper(g, output, keepdims_i=0)
  3952. else:
  3953. return symbolic_helper._onnx_unsupported(
  3954. "kl_div with reduction other than none, mean, or sum."
  3955. )
  3956. @symbolic_helper.parse_args("v", "v", "is", "i")
  3957. def as_strided(g, self, sizes, strides, offset=None):
  3958. sizes = symbolic_helper._maybe_get_const(sizes, "is")
  3959. rank = len(strides)
  3960. self_1d = symbolic_helper._reshape_helper(
  3961. g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
  3962. )
  3963. ind: Optional[torch.Tensor]
  3964. if not symbolic_helper._is_value(sizes):
  3965. ind = torch.tensor([0], dtype=torch.long)
  3966. for i, (size, stride) in enumerate(zip(sizes, strides)):
  3967. r_size = [1] * rank
  3968. r_size[i] = -1
  3969. ind = ind + torch.arange(size).view(r_size) * stride
  3970. if offset:
  3971. ind = ind + offset
  3972. return g.op("Gather", self_1d, g.op("Constant", value_t=ind))
  3973. else:
  3974. ind = None
  3975. for i, stride in enumerate(strides):
  3976. r_size = [1] * rank
  3977. r_size[i] = -1
  3978. size = select(
  3979. g,
  3980. sizes,
  3981. g.op("Constant", value_t=torch.tensor([0])),
  3982. g.op("Constant", value_t=torch.tensor(i)),
  3983. )
  3984. tmp_ind = symbolic_helper._reshape_helper(
  3985. g,
  3986. arange(g, size, 4, None, None, None),
  3987. g.op("Constant", value_t=torch.tensor(r_size)),
  3988. )
  3989. tmp_ind = g.op(
  3990. "Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride]))
  3991. )
  3992. if ind is None:
  3993. ind = tmp_ind
  3994. else:
  3995. ind = g.op("Add", ind, tmp_ind)
  3996. if offset:
  3997. ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset])))
  3998. return g.op("Gather", self_1d, ind)
  3999. def __derive_index(g, index, start, step):
  4000. return g.op("Add", start, g.op("Mul", index, step))
  4001. # Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp
  4002. # if (step > 0 && lo < hi) {
  4003. # push(stack, 1 + (hi - 1 - lo) / step);
  4004. # } else if (step < 0 && lo > hi) {
  4005. # push(stack, 1 + (lo - 1 - hi) / (0 - step));
  4006. # } else {
  4007. # push(stack, 0);
  4008. # }
  4009. def __range_length(g, lo, hi, step):
  4010. sub = g.op("Sub", hi, lo)
  4011. div = g.op("Ceil", true_divide(g, sub, step))
  4012. return g.op("Cast", div, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"])
  4013. def linear(g, input, weight, bias):
  4014. rank = symbolic_helper._get_tensor_rank(input)
  4015. weight = t(g, weight)
  4016. if rank == 2 and not bias.node().mustBeNone():
  4017. alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
  4018. beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
  4019. output = addmm(g, bias, input, weight, alpha, beta)
  4020. else:
  4021. output = matmul(g, input, weight)
  4022. if not bias.node().mustBeNone():
  4023. output = add(g, bias, output)
  4024. return output
  4025. @symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v")
  4026. def hann_window(
  4027. g,
  4028. window_length,
  4029. periodic=True,
  4030. dtype=None,
  4031. layout=None,
  4032. device=None,
  4033. pin_memory=None,
  4034. requires_grad=False,
  4035. ):
  4036. if dtype is None:
  4037. dtype = torch.get_default_dtype()
  4038. if not dtype or not dtype.is_floating_point:
  4039. dtype = torch.float
  4040. dtype = symbolic_helper.scalar_type_to_pytorch_type.index(dtype)
  4041. n_array = arange(g, window_length, 4, None, None, None)
  4042. output = g.op("Cast", n_array, to_i=symbolic_helper.cast_pytorch_to_onnx["Float"])
  4043. output = mul(
  4044. g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output
  4045. )
  4046. if periodic is False:
  4047. window_length = sub(
  4048. g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int))
  4049. )
  4050. output = div(g, output, window_length)
  4051. output = g.op(
  4052. "Cast",
  4053. square(g, sin(g, output)),
  4054. to_i=symbolic_helper.scalar_type_to_onnx[dtype],
  4055. )
  4056. return output
  4057. def mv(g, self, vec):
  4058. return matmul(g, self, vec)
  4059. def dot(g, self, other):
  4060. return matmul(g, self, other)
  4061. @symbolic_helper.parse_args("v", "v")
  4062. def fill(g, self, value):
  4063. dtype = self.type().scalarType()
  4064. if dtype is None:
  4065. dtype = symbolic_helper.ScalarType.FLOAT
  4066. else:
  4067. dtype = symbolic_helper.scalar_type_to_onnx.index(
  4068. symbolic_helper.cast_pytorch_to_onnx[dtype]
  4069. )
  4070. return full_like(g, self, value, dtype)
  4071. def index_add(g, self, dim, index, other, alpha=None):
  4072. warnings.warn(
  4073. "Warning: ONNX export does not support duplicated values in 'index' field, "
  4074. + "this will cause the ONNX model to be incorrect."
  4075. )
  4076. # ONNX does not support "alpha" argument, unlike aten index_add
  4077. # See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context
  4078. if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1:
  4079. return symbolic_helper._unimplemented("index_add", "alpha != 1")
  4080. dim = symbolic_helper._maybe_get_const(dim, "i")
  4081. if dim is None:
  4082. raise NotImplementedError(
  4083. "ONNX export does NOT support exporting 'index_add_()' function with "
  4084. + "unknown 'dim' value."
  4085. )
  4086. self_dim_rank = symbolic_helper._get_tensor_rank(self)
  4087. other_dim_rank = symbolic_helper._get_tensor_rank(other)
  4088. if self_dim_rank is None or other_dim_rank is None:
  4089. raise NotImplementedError(
  4090. "ONNX export does NOT support exporting 'index_add_()' function while "
  4091. + "the rank of self tensor or tensor to be added is unknown."
  4092. )
  4093. if other_dim_rank != self_dim_rank:
  4094. delta = self_dim_rank - other_dim_rank
  4095. for i in range(delta):
  4096. other = symbolic_helper._unsqueeze_helper(
  4097. g, other, [symbolic_helper._get_tensor_rank(other)]
  4098. )
  4099. other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim)
  4100. self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim)
  4101. if (other_dim_size is not None) and (self_dim_size is not None):
  4102. if other_dim_size > self_dim_size:
  4103. raise NotImplementedError(
  4104. "ONNX export does NOT support exporting 'index_add_()' function with "
  4105. + "duplicated values in 'index' parameter yet."
  4106. )
  4107. # Construct a new shape. It's almost as same as self except the size of the 'dim'
  4108. # dimension is 1, so that we can expand other dimensions as expected.
  4109. new_shape_axes = list(range(self_dim_rank))
  4110. new_shape_starts = [0 for i in range(self_dim_rank)]
  4111. new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)]
  4112. new_shape = symbolic_helper._slice_helper(
  4113. g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends
  4114. )
  4115. other = expand_as(g, other, new_shape)
  4116. for i in range(dim):
  4117. index = symbolic_helper._unsqueeze_helper(g, index, [0])
  4118. for i in range(self_dim_rank - dim - 1):
  4119. index = symbolic_helper._unsqueeze_helper(
  4120. g, index, [symbolic_helper._get_tensor_rank(index)]
  4121. )
  4122. return scatter_add(g, self, dim, expand_as(g, index, other), other)
  4123. @symbolic_helper.parse_args("v", "is", "is")
  4124. def roll(g, self, shifts, dims):
  4125. assert len(shifts) == len(dims)
  4126. result = self
  4127. for i in range(len(shifts)):
  4128. shapes = []
  4129. shape = symbolic_helper._slice_helper(
  4130. g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize]
  4131. )
  4132. shapes.append(shape)
  4133. shape = symbolic_helper._slice_helper(
  4134. g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]]
  4135. )
  4136. shapes.append(shape)
  4137. result = g.op("Concat", *shapes, axis_i=dims[i])
  4138. return result
  4139. @symbolic_helper.parse_args("v", "v", "i")
  4140. def cross(g, input, other, dim=None):
  4141. dim = symbolic_helper._get_dim_for_cross(input, dim)
  4142. # If we have two tensors such that
  4143. # A = [a, b, c], B = [d, e, f], we permute the tensor such that we have
  4144. # After first roll,
  4145. # A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e)
  4146. roll_x_1 = roll(g, input, [2], [dim])
  4147. roll_y_1 = roll(g, other, [1], [dim])
  4148. # After second roll,
  4149. # A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d)
  4150. roll_x_2 = roll(g, input, [1], [dim])
  4151. roll_y_2 = roll(g, other, [2], [dim])
  4152. # cross product is calculated as
  4153. # result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)]
  4154. return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2))
  4155. def cdist(g, x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
  4156. # X1.shape = (B * P * D), X2.shape = (B * R * D)
  4157. # In order to respect numpy style broadcasting as demonstrated in
  4158. # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
  4159. # we unsqueeze both input tensors
  4160. # Currently we ignore the 'compute_mode' variable as we use default to
  4161. # using matrix multiplication to calculate the euclidean distance
  4162. rank = symbolic_helper._get_tensor_rank(x1)
  4163. broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1])
  4164. broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2])
  4165. return pairwise_distance(
  4166. g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False
  4167. )
  4168. def broadcast_tensors(g, self):
  4169. all_tensors = symbolic_helper._unpack_list(self)
  4170. t_with_final_shape = zeros_like(g, all_tensors[0])
  4171. # Add operator supports multidirectional broadcasting. So we leverage this function
  4172. # to infer the final shape generated by the broadcast.
  4173. for t in all_tensors:
  4174. t_with_final_shape = add(g, t_with_final_shape, t)
  4175. t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors]
  4176. return g.op("prim::ListConstruct", *t_list)
  4177. class Prim:
  4178. domain = "prim"
  4179. @staticmethod
  4180. def ConstantSplit(g, self, split_size, dim):
  4181. size = symbolic_helper._get_tensor_dim_size(self, dim)
  4182. if size is None:
  4183. return symbolic_helper._unimplemented(
  4184. "prim::ConstantSplit", "unknown dimension size"
  4185. )
  4186. splits = [split_size] * (size // split_size)
  4187. leftover = size % split_size
  4188. if leftover:
  4189. splits.append(leftover)
  4190. return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits))
  4191. # TODO: It would be better to export this as a chunk directly, as this is
  4192. # less sensitive to changes in input size.
  4193. # TODO: Once we have proper scoping, stop reimplementing chunk, delete this
  4194. # method, and use the desugared version
  4195. @staticmethod
  4196. def ConstantChunk(g, self, chunks, dim):
  4197. dim_size = symbolic_helper._get_tensor_dim_size(self, dim)
  4198. if dim_size is None:
  4199. return symbolic_helper._unimplemented(
  4200. "prim::ConstantChunk", "unknown dimension size"
  4201. )
  4202. split_size = (dim_size + chunks - 1) // chunks
  4203. return Prim.ConstantSplit(g, self, split_size, dim)
  4204. @staticmethod
  4205. def shape(g, self):
  4206. return g.op("Shape", self)
  4207. @staticmethod
  4208. def max(g, self, other):
  4209. return op_with_optional_float_cast(g, "Max", self, other, opset_before=12)
  4210. @staticmethod
  4211. def min(g, self, other=None):
  4212. if not other:
  4213. if symbolic_helper._is_packed_list(self):
  4214. self = stack(g, self, g.op("Constant", value_t=torch.tensor([0])))
  4215. return min(g, self)
  4216. return min(g, self, other)
  4217. @staticmethod
  4218. def data(g, self):
  4219. return self
  4220. @staticmethod
  4221. def ListConstruct(g, *inputs, **kwargs):
  4222. return None
  4223. @staticmethod
  4224. def ListUnpack(g, *inputs, **kwargs) -> Optional[List[_C.Value]]:
  4225. if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct":
  4226. # Cancel the previous node if it is ListConstruct by returning its inputs
  4227. # TODO(justinchuby): Use a public method in the helper module
  4228. return symbolic_helper._unpack_list(inputs[0])
  4229. return None
  4230. @staticmethod
  4231. def TupleConstruct(g, *inputs, **kwargs):
  4232. return None
  4233. @staticmethod
  4234. def Uninitialized(g, *inputs, **kwargs):
  4235. return None
  4236. # exists to refine the type of the Value
  4237. # if x is an optional Tensor, unchecked_cast will cast
  4238. # x to Tensor, so the rest of the graph knows that x is a Tensor
  4239. # this doesn't do anything in runtime and is a noop in ONNX
  4240. @staticmethod
  4241. def unchecked_cast(g, self):
  4242. return self
  4243. @staticmethod
  4244. def dtype(g, self):
  4245. dtype = symbolic_helper._try_get_scalar_type(self)
  4246. if dtype is None:
  4247. dtype = "Float"
  4248. dtype = symbolic_helper.scalar_type_to_onnx.index(
  4249. symbolic_helper.cast_pytorch_to_onnx[dtype]
  4250. )
  4251. return g.op("Constant", value_t=torch.tensor(dtype))
  4252. # tolist is currently supported only for 1D input tensors.
  4253. # dim_val and elem_ty_val represent dimension and type annotations
  4254. # that need to match dimension and type of the input tensor.
  4255. @staticmethod
  4256. def tolist(g, input, dim_val, elem_ty_val):
  4257. dim = symbolic_helper._maybe_get_const(dim_val, "i")
  4258. if dim > 1:
  4259. return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1")
  4260. return input
  4261. # -----------------------------------------------------------------------------
  4262. # Symbolic functions that need extra context
  4263. # -----------------------------------------------------------------------------
  4264. @staticmethod
  4265. def device(ctx: torch.onnx.SymbolicContext, g, *inputs, **kwargs):
  4266. n = ctx.cur_node
  4267. if n.output().type().kind() == "DeviceObjType":
  4268. return None
  4269. return symbolic_helper._unimplemented(
  4270. "prim::device", "output type is not `DeviceObjType`."
  4271. )
  4272. @staticmethod
  4273. def Loop(ctx: torch.onnx.SymbolicContext, g, *inputs, **attrs):
  4274. n = ctx.cur_node
  4275. env = ctx.env
  4276. params_dict = ctx.params_dict
  4277. operator_export_type = GLOBALS.operator_export_type
  4278. opset_version = GLOBALS.export_onnx_opset_version
  4279. new_op_outputs = g.op("Loop", *inputs, outputs=n.outputsSize())
  4280. new_node = (
  4281. new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node()
  4282. )
  4283. for b in n.blocks():
  4284. new_block = new_node.addBlock()
  4285. # Copy input metadata to subblock
  4286. #
  4287. # prim::Loop(iter, cond, input_1, ..., input_n)
  4288. # block0(iter, input_1, ..., input_n)
  4289. #
  4290. # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`.
  4291. for i, b_in in enumerate(b.inputs()):
  4292. if i == 0 and i < len(inputs):
  4293. b_in.setType(inputs[i].type())
  4294. # For optional block inputs, they may switch between None not-None inside
  4295. # the loop body, so if the loop input is not optional, the block input may
  4296. # still need to be optional.
  4297. if (
  4298. i > 0
  4299. and (i + 1) < len(inputs)
  4300. and not isinstance(b_in.type(), _C.OptionalType)
  4301. ):
  4302. b_in.setType(inputs[i + 1].type())
  4303. torch._C._jit_pass_onnx_block(
  4304. b, new_block, operator_export_type, env, False # type:ignore[arg-type]
  4305. )
  4306. new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
  4307. new_node, opset_version
  4308. )
  4309. # Run shape type inference for Loop after subblock is converted.
  4310. if GLOBALS.onnx_shape_inference:
  4311. torch._C._jit_pass_onnx_node_shape_type_inference(
  4312. new_node, params_dict, opset_version
  4313. )
  4314. return new_op_outputs
  4315. @staticmethod
  4316. def If(ctx: torch.onnx.SymbolicContext, g, *inputs, **attrs):
  4317. n = ctx.cur_node
  4318. block = ctx.onnx_block
  4319. env = ctx.env
  4320. params_dict = ctx.params_dict
  4321. operator_export_type = GLOBALS.operator_export_type
  4322. opset_version = GLOBALS.export_onnx_opset_version
  4323. static_if = inputs[0].node().kind() == "onnx::Constant"
  4324. if static_if:
  4325. # Fold static if
  4326. #
  4327. # The torch IR
  4328. # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu),
  4329. # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ...
  4330. # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
  4331. # %21 : Long(device=cpu) = aten::eq(%20, %64)
  4332. # %22 : Long(device=cpu) = prim::If(%21)
  4333. # block0():
  4334. # %23 : Long(device=cpu) = aten::is_floating_point(%input.1)
  4335. # -> (%23)
  4336. # block1():
  4337. # -> (%65)
  4338. # %input.53 : Tensor, %weight : Tensor = prim::If(%22)
  4339. # block0():
  4340. # -> (%embedding_matrix.1, %input.1)
  4341. # block1():
  4342. # -> (%input.1, %embedding_matrix.1)
  4343. # %26 : int[] = aten::size(%input.53)
  4344. #
  4345. # The converted ONNX graph
  4346. # %10 : Bool(device=cpu) = onnx::Constant[value={0}]()
  4347. # %14 : Bool(device=cpu) = onnx::Equal(%13, %8)
  4348. # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  4349. # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1)
  4350. input_flag = inputs[0].node()["value"].tolist()
  4351. const_value = (
  4352. all(input_flag) if isinstance(input_flag, list) else bool(input_flag)
  4353. )
  4354. block_idx = 0 if const_value else 1
  4355. current_b = list(n.blocks())[block_idx]
  4356. env = torch._C._jit_pass_onnx_block(
  4357. current_b,
  4358. block,
  4359. operator_export_type, # type:ignore[arg-type]
  4360. env, # type:ignore[arg-type]
  4361. True,
  4362. )
  4363. if_output_list = list(n.outputs())
  4364. current_b_list = list(current_b.outputs())
  4365. final_b_list = []
  4366. for idx in range(len(if_output_list)):
  4367. if current_b_list[idx] not in env:
  4368. raise RuntimeError(
  4369. "The sub block ATen output {}"
  4370. " is not in env.".format(current_b_list[idx])
  4371. ) # type:ignore[operator]
  4372. onnx_b = env[current_b_list[idx]]
  4373. final_b_list.append(onnx_b)
  4374. return final_b_list
  4375. else:
  4376. new_op_outputs = g.op("If", *inputs, outputs=n.outputsSize())
  4377. new_node = (
  4378. new_op_outputs[0].node()
  4379. if n.outputsSize() > 1
  4380. else new_op_outputs.node()
  4381. )
  4382. for b in n.blocks():
  4383. new_block = new_node.addBlock()
  4384. torch._C._jit_pass_onnx_block(
  4385. b,
  4386. new_block,
  4387. operator_export_type, # type:ignore[arg-type]
  4388. env,
  4389. False,
  4390. )
  4391. new_op_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node(
  4392. new_node, opset_version
  4393. )
  4394. # Run shape type inference for If after subblock is converted.
  4395. if GLOBALS.onnx_shape_inference:
  4396. torch._C._jit_pass_onnx_node_shape_type_inference(
  4397. new_node, params_dict, opset_version
  4398. )
  4399. return new_op_outputs
  4400. @staticmethod
  4401. def Constant(ctx: torch.onnx.SymbolicContext, g, *inputs, **attrs):
  4402. n = ctx.cur_node
  4403. if n.mustBeNone():
  4404. return None
  4405. # This must go before checking for string values, because some device constants
  4406. # have string values, but we want to keep them as unconverted Device types so
  4407. # that eq() can work on them.
  4408. if isinstance(n.output().type(), _C.DeviceObjType):
  4409. return None
  4410. if n.kindOf("value") == "t":
  4411. return g.op("Constant", value_t=n["value"])
  4412. if n.kindOf("value") == "s":
  4413. return g.op("Constant", value_s=n["value"])
  4414. elif n.output().type().isSubtypeOf(
  4415. _C.ListType.ofInts()
  4416. ) or n.output().type().isSubtypeOf(_C.ListType.ofFloats()):
  4417. return g.op("Constant", value_t=torch.tensor(n["value"]))
  4418. else:
  4419. raise RuntimeError(
  4420. "Unsupported prim::Constant kind: `{}`. Send a bug report.".format(
  4421. n.kindOf("value")
  4422. )
  4423. )
  4424. class Onnx:
  4425. domain = "onnx"
  4426. # -----------------------------------------------------------------------------
  4427. # Symbolic functions that need extra context
  4428. # -----------------------------------------------------------------------------
  4429. @staticmethod
  4430. def Placeholder(ctx: torch.onnx.SymbolicContext, g, *inputs, **attrs):
  4431. n = ctx.cur_node
  4432. block = ctx.onnx_block
  4433. env = ctx.env
  4434. return torch._C._jit_onnx_convert_pattern_from_subblock(block, n, env)