@@ -67,14 +67,56 @@ NEML2TensorCompute::NEML2TensorCompute(const InputParameters & params)
6767 NEML2Utils ::assertNEML2Enabled ();
6868
6969#ifdef NEML2_ENABLED
70- for (const auto & [swift_input_name , neml2_input_name ] :
71- getParam < TensorInputBufferName , std ::string > ("swift_inputs ", "neml2_inputs "))
70+ const auto inputs = getParam < TensorInputBufferName , std ::string > ("swift_inputs ", "neml2_inputs ");
71+ std ::map < neml2 ::LabeledAxisAccessor , TensorInputBufferName > lookup_swift_name ;
72+ const auto model_inputs = _model .consumed_items ();
73+
74+ // current inputs
75+ for (const auto & [swift_input_name , neml2_input_name ] : inputs )
7276 {
73- const auto * input_buffer = & getInputBufferByName < > (swift_input_name );
74- _input_mapping .emplace_back (
75- input_buffer , neml2 ::LabeledAxisAccessor (NEML2Utils ::parseVariableName (neml2_input_name )));
77+ const auto neml2_input =
78+ neml2 ::LabeledAxisAccessor (NEML2Utils ::parseVariableName (neml2_input_name ));
79+
80+ // populate reverse lookup map
81+ if (lookup_swift_name .find (neml2_input ) != lookup_swift_name .end ())
82+ mooseError ("Repeated NEML2 input " , neml2_input_name );
83+ lookup_swift_name [neml2_input ] = swift_input_name ;
84+
85+ // the user should only specify current neml2 axis
86+ if (!neml2_input .is_state () && !neml2_input .is_force ())
87+ mooseError ("Specify only current forces or states as inputs. Old forces and states are "
88+ "automatically coupled when needed." );
89+
90+ // add input if the model requires it
91+ if (model_inputs .count (neml2_input ))
92+ {
93+ const auto * input_buffer = & getInputBufferByName < > (swift_input_name );
94+ const auto type = _model .input_variable (neml2_input ).type ();
95+ _input_mapping .emplace_back (input_buffer , type , neml2_input );
96+ }
7697 }
7798
99+ // old state inputs
100+ for (const auto & neml2_input : model_inputs )
101+ if (neml2_input .is_old_state ())
102+ {
103+ // check if we couple the current state
104+ auto it = lookup_swift_name .find (neml2_input .current ());
105+ if (it == lookup_swift_name .end ())
106+ mooseError ("The model requires " ,
107+ neml2_input ,
108+ " but no tensor buffer is assigned to " ,
109+ neml2_input .current (),
110+ "." );
111+ const auto & swift_input_name = it -> second ;
112+
113+ const auto * old_states = & getBufferOldByName < > (swift_input_name , 1 );
114+ // we also get the current state here just to step zero, when no old state exists!
115+ const auto * input_buffer = & getInputBufferByName < > (swift_input_name );
116+ const auto type = _model .input_variable (neml2_input ).type ();
117+ _old_input_mapping .emplace_back (old_states , input_buffer , type , neml2_input );
118+ }
119+
78120 for (const auto & [neml2_output_name , swift_output_name ] :
79121 getParam < std ::string , TensorInputBufferName > ("neml2_outputs" , "swift_outputs" ))
80122 {
@@ -99,20 +141,30 @@ NEML2TensorCompute::computeBuffer()
99141{
100142#ifdef NEML2_ENABLED
101143 neml2 ::ValueMap in ;
102- for (const auto & [ tensor_ptr , label ] : _input_mapping )
144+ auto insert_tensor = [ & in , this ] (const auto & tensor , auto type , const auto & label )
103145 {
104146 // convert tensors on the fly at runtime
105- auto sizes = tensor_ptr -> sizes ();
106- mooseInfoRepeated (name (), " sizes size " , sizes .size (), " is " , Moose ::stringify (sizes ));
107- if (sizes .size () == _dim )
108- in [label ] = neml2 ::Scalar (* tensor_ptr );
109- else if (sizes .size () == _dim + 1 )
110- in [label ] = neml2 ::Vec (* tensor_ptr , _domain .getShape ());
111- else if (sizes .size () == _dim + 3 )
112- in [label ] = neml2 ::R2 (* tensor_ptr , _domain .getShape ());
147+ auto sizes = tensor .sizes ();
148+ if (sizes .size () == _dim && type == neml2 ::TensorType ::kScalar )
149+ in [label ] = neml2 ::Scalar (tensor );
150+ else if (sizes .size () == _dim + 1 && type == neml2 ::TensorType ::kVec )
151+ in [label ] = neml2 ::Vec (tensor , _domain .getShape ());
152+ else if (sizes .size () == _dim + 3 && type == neml2 ::TensorType ::kR2 )
153+ in [label ] = neml2 ::R2 (tensor , _domain .getShape ());
113154 else
114- mooseError ("Unsupported tensor dimension" );
115- }
155+ mooseError ("Unsupported/mismatching tensor dimension" );
156+ };
157+
158+ // insert current state
159+ for (const auto & [current_state , type , label ] : _input_mapping )
160+ insert_tensor (* current_state , type , label );
161+
162+ // insert old state
163+ for (const auto & [old_states , current_state , type , label ] : _old_input_mapping )
164+ if (old_states -> empty ())
165+ insert_tensor (* current_state , type , label );
166+ else
167+ insert_tensor ((* old_states )[0 ], type , label );
116168
117169 auto out = _model .value (in );
118170
0 commit comments