Skip to content

Implementation in TDMPC-2 #9

@youwyu

Description

@youwyu

Hello, a very smart work. Could you please share the implementation in TDMPC-2?

This is the original MLP.

class NormedLinear(nn.Linear):
	def __init__(self, *args, dropout=0., act=None, **kwargs):
		super().__init__(*args, **kwargs)
		self.ln = nn.LayerNorm(self.out_features)
		if act is None:
			act = nn.Mish(inplace=False)
		self.act = act
		self.dropout = nn.Dropout(dropout, inplace=False) if dropout else None

	def forward(self, x):
		x = super().forward(x)
		if self.dropout:
			x = self.dropout(x)
		return self.act(self.ln(x))

This is my implementation, but I'm not sure whether it's correct.

class NormedResidualLinear(nn.Linear):
	def __init__(self, *args, dropout=0., act=None, **kwargs):
		super().__init__(*args, **kwargs)
		self.ln = nn.LayerNorm(self.out_features)
		self.act = nn.ReLU(inplace=False)
		self.dropout = nn.Dropout(dropout, inplace=False) if dropout else None

	def forward(self, x):
		x = super().forward(x)
		res = x
		if self.dropout:
			x = self.dropout(x)
		return res + self.act(self.ln(x))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions