@@ -672,6 +672,108 @@ def forward(self, a):
672672def ViewNegativeStaticModule_basic (module , tu : TestUtils ):
673673 module .forward (tu .rand (1 , 128 ))
674674
675+ class ViewSizeDimFollowedByExpandedOnesModule (torch .nn .Module ):
676+ def __init__ (self ):
677+ super ().__init__ ()
678+
679+ @export
680+ @annotate_args ([
681+ None ,
682+ ([- 1 ], torch .float32 , True ),
683+ ])
684+
685+ def forward (self , a ):
686+ return a .view (a .size (0 ), 1 , 1 , 1 )
687+
688+ @register_test_case (module_factory = lambda : ViewSizeDimFollowedByExpandedOnesModule ())
689+ def ViewSizeDimFollowedByExpandedOnesModule_basic (module , tu : TestUtils ):
690+ module .forward (tu .rand (128 ))
691+
692+ class ViewSizeDimFollowedByCollapsedOnesModule (torch .nn .Module ):
693+ def __init__ (self ):
694+ super ().__init__ ()
695+
696+ @export
697+ @annotate_args ([
698+ None ,
699+ ([- 1 , 1 , 1 , 1 ], torch .float32 , True ),
700+ ])
701+
702+ def forward (self , a ):
703+ return a .view (a .size (0 ))
704+
705+ @register_test_case (module_factory = lambda : ViewSizeDimFollowedByCollapsedOnesModule ())
706+ def ViewSizeDimFollowedByCollapsedOnesModule_basic (module , tu : TestUtils ):
707+ module .forward (tu .rand (128 , 1 , 1 , 1 ))
708+
709+ class ViewSizeDimLedByExpandedOnesModule (torch .nn .Module ):
710+ def __init__ (self ):
711+ super ().__init__ ()
712+
713+ @export
714+ @annotate_args ([
715+ None ,
716+ ([- 1 ], torch .float32 , True ),
717+ ])
718+
719+ def forward (self , a ):
720+ return a .view (1 , 1 , 1 , a .size (0 ))
721+
722+ @register_test_case (module_factory = lambda : ViewSizeDimLedByExpandedOnesModule ())
723+ def ViewSizeDimLedByExpandedOnesModule_basic (module , tu : TestUtils ):
724+ module .forward (tu .rand (128 ))
725+
726+ class ViewSizeDimLedByCollapsedOnesModule (torch .nn .Module ):
727+ def __init__ (self ):
728+ super ().__init__ ()
729+
730+ @export
731+ @annotate_args ([
732+ None ,
733+ ([1 , 1 , 1 , - 1 ], torch .float32 , True ),
734+ ])
735+
736+ def forward (self , a ):
737+ return a .view (a .size (3 ))
738+
739+ @register_test_case (module_factory = lambda : ViewSizeDimLedByCollapsedOnesModule ())
740+ def ViewSizeDimLedByCollapsedOnesModule_basic (module , tu : TestUtils ):
741+ module .forward (tu .rand (1 , 1 , 1 , 128 ))
742+
743+ class ViewSizeDimLedAndFollowedByExpandedOnesModule (torch .nn .Module ):
744+ def __init__ (self ):
745+ super ().__init__ ()
746+
747+ @export
748+ @annotate_args ([
749+ None ,
750+ ([- 1 ], torch .float32 , True ),
751+ ])
752+
753+ def forward (self , a ):
754+ return a .view (1 , 1 , 1 , a .size (0 ), 1 , 1 , 1 )
755+
756+ @register_test_case (module_factory = lambda : ViewSizeDimLedAndFollowedByExpandedOnesModule ())
757+ def ViewSizeDimLedAndFollowedByExpandedOnesModule_basic (module , tu : TestUtils ):
758+ module .forward (tu .rand (128 ))
759+
760+ class ViewSizeDimLedAndFollowedByCollapsedOnesModule (torch .nn .Module ):
761+ def __init__ (self ):
762+ super ().__init__ ()
763+
764+ @export
765+ @annotate_args ([
766+ None ,
767+ ([1 , 1 , 1 , - 1 , 1 , 1 , 1 ], torch .float32 , True ),
768+ ])
769+
770+ def forward (self , a ):
771+ return a .view (a .size (3 ))
772+
773+ @register_test_case (module_factory = lambda : ViewSizeDimLedAndFollowedByCollapsedOnesModule ())
774+ def ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic (module , tu : TestUtils ):
775+ module .forward (tu .rand (1 , 1 , 1 , 128 , 1 , 1 , 1 ))
776+
675777# ==============================================================================
676778
677779class ReshapeAliasExpandModule (torch .nn .Module ):
@@ -710,4 +812,4 @@ def forward(self, a):
710812
711813@register_test_case (module_factory = lambda : ReshapeAliasCollapseModule ())
712814def ReshapeAliasCollapseModule_basic (module , tu : TestUtils ):
713- module .forward (tu .rand (2 , 4 ))
815+ module .forward (tu .rand (2 , 4 ))
0 commit comments