@@ -74,9 +74,10 @@ impl<'a> Tokenizer<'a> {
7474 }
7575}
7676
77- #[ derive( Debug , Clone , PartialEq , Eq ) ]
77+ #[ derive( Clone , PartialEq , Eq ) ]
7878enum WorkingToken < T : MatrixNumber > {
7979 Type ( Type < T > ) ,
80+ Function ( Identifier ) ,
8081 UnaryOp ( char ) ,
8182 BinaryOp ( char ) ,
8283 LeftBracket ,
@@ -87,6 +88,7 @@ impl<T: MatrixNumber> Display for WorkingToken<T> {
8788 fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
8889 match self {
8990 WorkingToken :: Type ( _) => write ! ( f, "value token" ) ,
91+ WorkingToken :: Function ( _) => write ! ( f, "function token" ) ,
9092 WorkingToken :: UnaryOp ( op) => write ! ( f, "unary operator \" {op}\" " ) ,
9193 WorkingToken :: BinaryOp ( op) => write ! ( f, "binary operator \" {op}\" " ) ,
9294 WorkingToken :: LeftBracket => write ! ( f, "( bracket" ) ,
@@ -114,23 +116,35 @@ fn binary_op<T: MatrixNumber>(left: Type<T>, right: Type<T>, op: char) -> anyhow
114116 ( Type :: Scalar ( l) , Type :: Matrix ( r) ) => Type :: from_matrix_result ( r. checked_mul_scl ( & l) ) ,
115117 } ,
116118 '/' => match ( left, right) {
117- ( Type :: Scalar ( l) , Type :: Scalar ( r) ) => if !r. is_zero ( ) {
118- Type :: from_scalar_option ( l. checked_div ( & r) )
119- } else {
120- bail ! ( "Division by zero!" )
121- } ,
122- ( Type :: Matrix ( _) , Type :: Matrix ( _) ) => bail ! ( "WTF dividing by matrix? You should use the `inv` function (not implemented yet, wait for it...)" ) ,
123- ( Type :: Matrix ( _) , Type :: Scalar ( _) ) => bail ! ( "Diving matrix by scalar is not supported yet..." ) ,
124- ( Type :: Scalar ( _) , Type :: Matrix ( _) ) => bail ! ( "Diving scalar by matrix does not make sense!" ) ,
119+ ( Type :: Scalar ( l) , Type :: Scalar ( r) ) => {
120+ if !r. is_zero ( ) {
121+ Type :: from_scalar_option ( l. checked_div ( & r) )
122+ } else {
123+ bail ! ( "Division by zero!" )
124+ }
125+ }
126+ ( Type :: Matrix ( _) , Type :: Matrix ( _) ) => {
127+ bail ! ( "WTF dividing by matrix? You should use the `inverse` function instead!" )
128+ }
129+ ( Type :: Matrix ( _) , Type :: Scalar ( _) ) => {
130+ bail ! ( "Diving matrix by scalar is not supported yet..." )
131+ }
132+ ( Type :: Scalar ( _) , Type :: Matrix ( _) ) => {
133+ bail ! ( "Diving scalar by matrix does not make sense!" )
134+ }
125135 } ,
126- '^' => if let Type :: Scalar ( exp) = right {
127- let exp = exp. to_usize ( ) . context ( "Exponent should be a nonnegative integer." ) ?;
128- match left {
129- Type :: Scalar ( base) => Type :: from_scalar_option ( checked_pow ( base, exp) ) ,
130- Type :: Matrix ( base) => Type :: from_matrix_result ( base. checked_pow ( exp) ) ,
136+ '^' => {
137+ if let Type :: Scalar ( exp) = right {
138+ let exp = exp
139+ . to_usize ( )
140+ . context ( "Exponent should be a nonnegative integer." ) ?;
141+ match left {
142+ Type :: Scalar ( base) => Type :: from_scalar_option ( checked_pow ( base, exp) ) ,
143+ Type :: Matrix ( base) => Type :: from_matrix_result ( base. checked_pow ( exp) ) ,
144+ }
145+ } else {
146+ bail ! ( "Exponent cannot be a matrix!" ) ;
131147 }
132- } else {
133- bail ! ( "Exponent cannot be a matrix!" ) ;
134148 }
135149 _ => unimplemented ! ( ) ,
136150 }
@@ -155,7 +169,7 @@ fn unary_op<T: MatrixNumber>(arg: Type<T>, op: char) -> anyhow::Result<Type<T>>
155169<unary_op> ::= "+" | "-"
156170<binary_op> ::= "+" | "-" | "*" | "/"
157171<expr> ::= <integer> | <identifier> | <expr> <binary_op> <expr>
158- | "(" <expr> ")" | <unary_op> <expr>
172+ | "(" <expr> ")" | <unary_op> <expr> | <identifier> "(" <expr> ")"
159173 */
160174pub fn parse_expression < T : MatrixNumber > (
161175 raw : & str ,
@@ -185,6 +199,7 @@ pub fn parse_expression<T: MatrixNumber>(
185199 None | Some ( WorkingToken :: LeftBracket )
186200 | Some ( WorkingToken :: BinaryOp ( _) )
187201 | Some ( WorkingToken :: UnaryOp ( _) )
202+ | Some ( WorkingToken :: Function ( _) )
188203 ) ,
189204 Token :: Operator ( _) => matches ! (
190205 previous,
@@ -221,15 +236,18 @@ pub fn parse_expression<T: MatrixNumber>(
221236 outputs. back ( )
222237 }
223238 Token :: Identifier ( id) => {
224- outputs. push_back ( WorkingToken :: Type (
225- env. get ( id)
226- . context ( format ! (
227- "Undefined identifier! Object \" {}\" is unknown." ,
228- id. to_string( )
229- ) ) ?
230- . clone ( ) ,
231- ) ) ;
232- outputs. back ( )
239+ if let Some ( value) = env. get_value ( id) {
240+ outputs. push_back ( WorkingToken :: Type ( value. clone ( ) ) ) ;
241+ outputs. back ( )
242+ } else if env. get_function ( id) . is_some ( ) {
243+ operators. push_front ( WorkingToken :: Function ( id. clone ( ) ) ) ;
244+ operators. front ( )
245+ } else {
246+ bail ! (
247+ "Undefined identifier! Object \" {}\" is unknown." ,
248+ id. to_string( )
249+ )
250+ }
233251 }
234252 Token :: LeftBracket => {
235253 operators. push_front ( WorkingToken :: LeftBracket ) ;
@@ -248,10 +266,11 @@ pub fn parse_expression<T: MatrixNumber>(
248266 bail ! ( "Mismatched brackets!" ) ;
249267 }
250268 if let Some ( op) = operators. pop_front ( ) {
251- if matches ! ( op, WorkingToken :: UnaryOp ( _) ) {
252- outputs. push_back ( op) ;
253- } else {
254- operators. push_front ( op) ;
269+ match op {
270+ WorkingToken :: UnaryOp ( _) | WorkingToken :: Function ( _) => {
271+ outputs. push_back ( op)
272+ }
273+ _ => operators. push_front ( op) ,
255274 }
256275 }
257276 Some ( & WorkingToken :: RightBracket )
@@ -312,6 +331,10 @@ pub fn parse_expression<T: MatrixNumber>(
312331 let arg = val_stack. pop_front ( ) . context ( "Invalid expression!" ) ?;
313332 val_stack. push_front ( unary_op ( arg, op) ?) ;
314333 }
334+ WorkingToken :: Function ( id) => {
335+ let arg = val_stack. pop_front ( ) . context ( "Invalid expression!" ) ?;
336+ val_stack. push_front ( env. get_function ( & id) . unwrap ( ) ( arg) ?) ;
337+ }
315338 _ => unreachable ! ( ) ,
316339 }
317340 }
@@ -546,7 +569,8 @@ mod tests {
546569 }
547570
548571 assert_eq ! (
549- * env. get( & Identifier :: new( "b" . to_string( ) ) . unwrap( ) ) . unwrap( ) ,
572+ * env. get_value( & Identifier :: new( "b" . to_string( ) ) . unwrap( ) )
573+ . unwrap( ) ,
550574 Type :: <i64 >:: Scalar ( 89 )
551575 ) ;
552576 }
@@ -561,8 +585,105 @@ mod tests {
561585 exec ( "a = $ ^ $" ) ;
562586
563587 assert_eq ! (
564- * env. get( & Identifier :: new( "a" . to_string( ) ) . unwrap( ) ) . unwrap( ) ,
588+ * env. get_value( & Identifier :: new( "a" . to_string( ) ) . unwrap( ) )
589+ . unwrap( ) ,
565590 Type :: <i64 >:: Scalar ( 256 )
566591 ) ;
567592 }
593+
594+ #[ test]
595+ fn test_expression_functions ( ) {
596+ let mut env = Environment :: new ( ) ;
597+
598+ let a = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
599+ let at = im ! [ 1 , 4 ; 2 , 5 ; 3 , 6 ] ;
600+ let b = im ! [ 1 , 2 ; 3 , 4 ] ;
601+
602+ env. insert ( Identifier :: new ( "A" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( a) ) ;
603+ env. insert (
604+ Identifier :: new ( "B" . to_string ( ) ) . unwrap ( ) ,
605+ Type :: Matrix ( b. clone ( ) ) ,
606+ ) ;
607+
608+ assert_eq ! (
609+ parse_expression( "transpose(A)" , & env) . unwrap( ) ,
610+ Type :: Matrix ( at)
611+ ) ;
612+ assert_eq ! (
613+ parse_expression( "identity(4)" , & env) . unwrap( ) ,
614+ Type :: Matrix ( Matrix :: identity( 4 ) )
615+ ) ;
616+ assert_eq ! (
617+ parse_expression( "inverse(B)" , & env) . unwrap( ) ,
618+ Type :: Matrix ( b. inverse( ) . unwrap( ) . result)
619+ ) ;
620+ }
621+
622+ #[ test]
623+ fn test_nested_functions ( ) {
624+ let mut env = Environment :: new ( ) ;
625+
626+ let a = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
627+ let att = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
628+
629+ env. insert ( Identifier :: new ( "A" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( a) ) ;
630+
631+ assert_eq ! (
632+ parse_expression( "transpose(transpose(A))" , & env) . unwrap( ) ,
633+ Type :: Matrix ( att)
634+ )
635+ }
636+
637+ #[ test]
638+ fn test_expr_with_function ( ) {
639+ let mut env = Environment :: new ( ) ;
640+
641+ let a = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
642+ let b = im ! [ 1 , 2 ; 3 , 4 ] ;
643+
644+ env. insert ( Identifier :: new ( "A" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( a) ) ;
645+ env. insert (
646+ Identifier :: new ( "B" . to_string ( ) ) . unwrap ( ) ,
647+ Type :: Matrix ( b. clone ( ) ) ,
648+ ) ;
649+
650+ assert_eq ! (
651+ parse_expression( "transpose(A) * B" , & env) . unwrap( ) ,
652+ Type :: Matrix ( im![ 13 , 18 ; 17 , 24 ; 21 , 30 ] )
653+ ) ;
654+ }
655+
656+ #[ test]
657+ fn test_expr_in_function ( ) {
658+ let mut env = Environment :: new ( ) ;
659+
660+ let a = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
661+ let i = Matrix :: identity ( 2 ) ;
662+ let at = im ! [ 1 , 4 ; 2 , 5 ; 3 , 6 ] ;
663+
664+ env. insert ( Identifier :: new ( "A" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( a) ) ;
665+ env. insert ( Identifier :: new ( "I" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( i) ) ;
666+
667+ assert_eq ! (
668+ parse_expression( "transpose(I * A)" , & env) . unwrap( ) ,
669+ Type :: Matrix ( at)
670+ ) ;
671+ }
672+
673+ #[ test]
674+ fn test_complex_nested_function_with_expr ( ) {
675+ let mut env = Environment :: new ( ) ;
676+
677+ let a = im ! [ 1 , 2 , 3 ; 4 , 5 , 6 ] ;
678+
679+ env. insert ( Identifier :: new ( "A" . to_string ( ) ) . unwrap ( ) , Type :: Matrix ( a) ) ;
680+
681+ assert_eq ! (
682+ parse_expression(
683+ "transpose(transpose(identity(2137 - 2135 + 1 - 1 + (42 - 420) * 0) * A) + transpose(identity(2) * A))" ,
684+ & env
685+ ) . unwrap( ) ,
686+ Type :: Matrix ( im![ 2 , 4 , 6 ; 8 , 10 , 12 ] )
687+ ) ;
688+ }
568689}
0 commit comments