diff --git a/ast/if.go b/ast/if.go index e63c90c1..5bfb6765 100644 --- a/ast/if.go +++ b/ast/if.go @@ -186,6 +186,9 @@ func (i *IfStatement) Parse( body.ParseBlock(unit, contractNode, fnNode, statementCtx.Block()) break } + // Edge case for single statement if + // Eg: if (a) b; + body.parseStatements(unit, contractNode, fnNode, statementCtx.GetChild(0)) i.Body = body } diff --git a/ast/parameter.go b/ast/parameter.go index 3827fe72..e3428ab5 100644 --- a/ast/parameter.go +++ b/ast/parameter.go @@ -438,5 +438,5 @@ func (p *Parameter) getStorageLocationFromCtx(ctx *parser.ParameterDeclarationCo } } - return ast_pb.StorageLocation_MEMORY + return ast_pb.StorageLocation_DEFAULT } diff --git a/ast/source_unit.go b/ast/source_unit.go index 9aff89b6..2ff5c8b7 100644 --- a/ast/source_unit.go +++ b/ast/source_unit.go @@ -2,10 +2,11 @@ package ast import ( "fmt" - "github.com/goccy/go-json" "path/filepath" "regexp" + "github.com/goccy/go-json" + v3 "github.com/cncf/xds/go/xds/type/v3" ast_pb "github.com/unpackdev/protos/dist/go/ast" "github.com/unpackdev/solgo" diff --git a/ast/state_variable.go b/ast/state_variable.go index 40f2f3d1..6ec062ae 100644 --- a/ast/state_variable.go +++ b/ast/state_variable.go @@ -22,6 +22,7 @@ type StateVariableDeclaration struct { Visibility ast_pb.Visibility `json:"visibility"` // Visibility of the state variable declaration StorageLocation ast_pb.StorageLocation `json:"storage_location"` // Storage location of the state variable declaration StateMutability ast_pb.Mutability `json:"mutability"` // State mutability of the state variable declaration + Override bool `json:"is_override"` // Indicates if the state variable is an override TypeName *TypeName `json:"type_name"` // Type name of the state variable InitialValue Node[NodeType] `json:"initial_value"` // Initial value of the state variable } @@ -207,6 +208,8 @@ func (v *StateVariableDeclaration) Parse( v.Constant = constantCtx != nil } + v.Override = ctx.GetOverrideSpecifierSet() + typeName := NewTypeName(v.ASTBuilder) typeName.Parse(unit, nil, v.Id, ctx.GetType_()) @@ -270,6 +273,8 @@ func (v *StateVariableDeclaration) ParseGlobal( v.Constant = constantCtx != nil } + v.Override = ctx.GetOverrideSpecifierSet() + typeName := NewTypeName(v.ASTBuilder) typeName.Parse(nil, nil, v.Id, ctx.GetType_()) v.TypeName = typeName diff --git a/ast/storage.go b/ast/storage.go index 1de38773..7ad82da4 100644 --- a/ast/storage.go +++ b/ast/storage.go @@ -57,7 +57,6 @@ func (t *TypeName) StorageSize() (int64, bool) { // Add cases for other node types like struct, enum, etc., as needed. default: panic(fmt.Sprintf("Unhandled node type @ StorageSize: %s", t.NodeType)) - return 0, false // Type not recognized or not handled yet. } } diff --git a/printer/ast_printer/and_operation.go b/printer/ast_printer/and_operation.go new file mode 100644 index 00000000..44ed2b8e --- /dev/null +++ b/printer/ast_printer/and_operation.go @@ -0,0 +1,19 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printAndOperation(node *ast.AndOperation, sb *strings.Builder, depth int) bool { + expressions := []string{} + success := true + for _, exp := range node.GetExpressions() { + s, ok := Print(exp) + success = ok && success + expressions = append(expressions, s) + } + writeSeperatedList(sb, " && ", expressions) + return success +} diff --git a/printer/ast_printer/assignment.go b/printer/ast_printer/assignment.go new file mode 100644 index 00000000..8537aed1 --- /dev/null +++ b/printer/ast_printer/assignment.go @@ -0,0 +1,65 @@ +package ast_printer + +import ( + "strings" + + ast_pb "github.com/unpackdev/protos/dist/go/ast" + "github.com/unpackdev/solgo/ast" +) + +func getAssignOperatorString(op ast_pb.Operator) string { + switch op { + case ast_pb.Operator_EQUAL: + return "=" + case ast_pb.Operator_PLUS_EQUAL: + return "+=" + case ast_pb.Operator_MINUS_EQUAL: + return "-=" + case ast_pb.Operator_MUL_EQUAL: + return "*=" + case ast_pb.Operator_DIV_EQUAL: + return "/=" + case ast_pb.Operator_MOD_EQUAL: + return "%=" + case ast_pb.Operator_AND_EQUAL: + return "&=" + case ast_pb.Operator_OR_EQUAL: + return "|=" + case ast_pb.Operator_XOR_EQUAL: + return "^=" + case ast_pb.Operator_SHIFT_LEFT_EQUAL: + return "<<=" + case ast_pb.Operator_SHIFT_RIGHT_EQUAL: + return ">>=" + case ast_pb.Operator_BIT_AND_EQUAL: + return "&=" + case ast_pb.Operator_BIT_OR_EQUAL: + return "|=" + case ast_pb.Operator_BIT_XOR_EQUAL: + return "^=" + case ast_pb.Operator_POW_EQUAL: + return "**=" + default: + return "" + } +} + +func printAssignment(node *ast.Assignment, sb *strings.Builder, depth int) bool { + success := true + if node.Expression != nil { + return PrintRecursive(node.Expression, sb, depth) + } + if node.LeftExpression == nil || node.RightExpression == nil { + return false + } + op := getAssignOperatorString(node.Operator) + if op == "" { + success = false + } + success = PrintRecursive(node.LeftExpression, sb, depth) && success + sb.WriteString(" ") + sb.WriteString(op) + sb.WriteString(" ") + success = PrintRecursive(node.RightExpression, sb, depth) && success + return success +} diff --git a/printer/ast_printer/ast_printer.go b/printer/ast_printer/ast_printer.go new file mode 100644 index 00000000..bd05e85b --- /dev/null +++ b/printer/ast_printer/ast_printer.go @@ -0,0 +1,195 @@ +package ast_printer + +import ( + "strings" + + ast_pb "github.com/unpackdev/protos/dist/go/ast" + "github.com/unpackdev/solgo/ast" + "go.uber.org/zap" +) + +const INDENT_SIZE = 2 + +// Print is a function that prints the AST nodes to source code +func Print(node ast.Node[ast.NodeType]) (string, bool) { + sb := strings.Builder{} + success := PrintRecursive(node, &sb, 0) + return sb.String(), success +} + +// PrintRecursive is a function that prints the AST nodes to source code recursively +func PrintRecursive(node ast.Node[ast.NodeType], sb *strings.Builder, depth int) bool { + if node == nil { + zap.S().Error("Node is nil") + return false + } + switch node := node.(type) { + case *ast.AndOperation: + return printAndOperation(node, sb, depth) + case *ast.BodyNode: + return printBody(node, sb, depth) + case *ast.Conditional: + return printConditional(node, sb, depth) + case *ast.Constructor: + return printConstructor(node, sb, depth) + case *ast.Pragma: + return printPragma(node, sb, depth) + case *ast.Contract: + return printContract(node, sb, depth) + case *ast.Function: + return printFunction(node, sb, depth) + case *ast.Parameter: + return printParameter(node, sb, depth) + case *ast.Assignment: + return printAssignment(node, sb, depth) + case *ast.TypeName: + return printTypeName(node, sb, depth) + case *ast.BinaryOperation: + return printBinaryOperation(node, sb, depth) + case *ast.StateVariableDeclaration: + return printStateVariableDeclaration(node, sb, depth) + case *ast.Emit: + return printEmit(node, sb, depth) + case *ast.ForStatement: + return printFor(node, sb, depth) + case *ast.PrimaryExpression: + return printPrimaryExpression(node, sb, depth) + case *ast.FunctionCall: + return printFunctionCall(node, sb, depth) + case *ast.Import: + return printImport(node, sb, depth) + case *ast.MemberAccessExpression: + return printMemberAccessExpression(node, sb, depth) + case *ast.VariableDeclaration: + return printVariableDeclaration(node, sb, depth) + case *ast.Declaration: + return printDeclaration(node, sb, depth) + case *ast.UnaryPrefix: + return printUnaryPrefix(node, sb, depth) + case *ast.UnarySuffix: + return printUnarySuffix(node, sb, depth) + case *ast.IndexAccess: + return printIndexAccess(node, sb, depth) + case *ast.ReturnStatement: + return printReturn(node, sb, depth) + case *ast.TupleExpression: + return printTupleExpression(node, sb, depth) + case *ast.StructDefinition: + return printStructDefinition(node, sb, depth) + case *ast.IfStatement: + return printIfStatement(node, sb, depth) + case *ast.EnumDefinition: + return printEnumDefinition(node, sb, depth) + case *ast.ModifierDefinition: + return printModifierDefinition(node, sb, depth) + case *ast.EventDefinition: + return printEventDefinition(node, sb, depth) + case *ast.ErrorDefinition: + return printErrorDefinition(node, sb, depth) + case *ast.PayableConversion: + return printPayableConversion(node, sb, depth) + case *ast.RevertStatement: + return printRevertStatement(node, sb, depth) + case *ast.ContinueStatement: + return printContinueStatement(node, sb, depth) + case *ast.InlineArray: + return printInlineArray(node, sb, depth) + default: + if node.GetType() == ast_pb.NodeType_SOURCE_UNIT { + return printSourceUnit(node, sb, depth) + } + zap.S().Errorf("Unknown node type: %T\n", node) + return false + } +} + +func writeSeperatedStrings(sb *strings.Builder, seperator string, s ...string) { + count := 0 + for _, item := range s { + // Skip empty strings + if item == "" { + continue + } + + if count > 0 { + sb.WriteString(seperator) + sb.WriteString(item) + } else { + sb.WriteString(item) + } + count++ + } +} + +func writeSeperatedList(sb *strings.Builder, seperator string, s []string) { + count := 0 + for _, item := range s { + // Skip empty strings + if item == "" { + continue + } + + if count > 0 { + sb.WriteString(seperator) + sb.WriteString(item) + } else { + sb.WriteString(item) + } + count++ + } +} + +func writeStrings(sb *strings.Builder, s ...string) { + for _, item := range s { + sb.WriteString(item) + } +} + +func indentString(s string, depth int) string { + return strings.Repeat(" ", depth*INDENT_SIZE) + s +} + +func getStorageLocationString(storage ast_pb.StorageLocation) string { + switch storage { + case ast_pb.StorageLocation_DEFAULT: + return "" + case ast_pb.StorageLocation_MEMORY: + return "memory" + case ast_pb.StorageLocation_STORAGE: + return "storage" + case ast_pb.StorageLocation_CALLDATA: + return "calldata" + default: + return "" + } +} + +func getVisibilityString(visibility ast_pb.Visibility) string { + switch visibility { + case ast_pb.Visibility_INTERNAL: + return "internal" + case ast_pb.Visibility_PUBLIC: + return "public" + case ast_pb.Visibility_EXTERNAL: + return "external" + case ast_pb.Visibility_PRIVATE: + return "private" + default: + return "" + } +} + +func getStateMutabilityString(mut ast_pb.Mutability) string { + switch mut { + case ast_pb.Mutability_PURE: + return "pure" + case ast_pb.Mutability_VIEW: + return "view" + case ast_pb.Mutability_NONPAYABLE: + return "" + case ast_pb.Mutability_PAYABLE: + return "payable" + default: + return "" + } +} diff --git a/printer/ast_printer/binary.go b/printer/ast_printer/binary.go new file mode 100644 index 00000000..e0475817 --- /dev/null +++ b/printer/ast_printer/binary.go @@ -0,0 +1,57 @@ +package ast_printer + +import ( + "strings" + + ast_pb "github.com/unpackdev/protos/dist/go/ast" + "github.com/unpackdev/solgo/ast" +) + +func getBinaryOperatorString(op ast_pb.Operator) string { + switch op { + case ast_pb.Operator_ADDITION: + return "+" + case ast_pb.Operator_SUBTRACTION: + return "-" + case ast_pb.Operator_MULTIPLICATION: + return "*" + case ast_pb.Operator_DIVISION: + return "/" + case ast_pb.Operator_MODULO: + return "%" + case ast_pb.Operator_EQUAL: + return "==" + case ast_pb.Operator_NOT_EQUAL: + return "!=" + case ast_pb.Operator_GREATER_THAN: + return ">" + case ast_pb.Operator_GREATER_THAN_OR_EQUAL: + return ">=" + case ast_pb.Operator_LESS_THAN: + return "<" + case ast_pb.Operator_LESS_THAN_OR_EQUAL: + return "<=" + case ast_pb.Operator_OR: + return "||" + // not sure where is the AND operator in ast_pb? + default: + return "" + } +} + +func printBinaryOperation(node *ast.BinaryOperation, sb *strings.Builder, depth int) bool { + ok := true + if node.LeftExpression == nil || node.RightExpression == nil { + return false + } + op := getBinaryOperatorString(node.Operator) + if op == "" { + ok = false + } + ok = PrintRecursive(node.LeftExpression, sb, depth) && ok + sb.WriteString(" ") + sb.WriteString(op) + sb.WriteString(" ") + ok = PrintRecursive(node.RightExpression, sb, depth) && ok + return ok +} diff --git a/printer/ast_printer/body.go b/printer/ast_printer/body.go new file mode 100644 index 00000000..f8f587ef --- /dev/null +++ b/printer/ast_printer/body.go @@ -0,0 +1,43 @@ +package ast_printer + +import ( + "strings" + + ast_pb "github.com/unpackdev/protos/dist/go/ast" + "github.com/unpackdev/solgo/ast" +) + +var blocks []ast_pb.NodeType = []ast_pb.NodeType{ + ast_pb.NodeType_IF_STATEMENT, + ast_pb.NodeType_FOR_STATEMENT, + ast_pb.NodeType_WHILE_STATEMENT, + ast_pb.NodeType_DO_WHILE_STATEMENT, + ast_pb.NodeType_FUNCTION_DEFINITION, + ast_pb.NodeType_MODIFIER_DEFINITION, + ast_pb.NodeType_CONTRACT_DEFINITION, + ast_pb.NodeType_STRUCT_DEFINITION, + ast_pb.NodeType_ENUM_DEFINITION, +} + +func isBlock(nodeType ast_pb.NodeType) bool { + for _, block := range blocks { + if block == nodeType { + return true + } + } + return false +} + +func printBody(node *ast.BodyNode, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("{\n") + for _, stmt := range node.GetStatements() { + sb.WriteString(indentString("", depth+1)) + success = PrintRecursive(stmt, sb, depth+1) && success + if !isBlock(stmt.GetType()) { + writeStrings(sb, ";\n") + } + } + sb.WriteString(indentString("}\n", depth)) + return success +} diff --git a/printer/ast_printer/conditional.go b/printer/ast_printer/conditional.go new file mode 100644 index 00000000..eb5a25b2 --- /dev/null +++ b/printer/ast_printer/conditional.go @@ -0,0 +1,22 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" + "go.uber.org/zap" +) + +func printConditional(node *ast.Conditional, sb *strings.Builder, depth int) bool { + success := true + if len(node.GetExpressions()) < 3 { + zap.S().Error("Conditional node must have at least 3 expressions") + return false + } + success = PrintRecursive(node.GetExpressions()[0], sb, depth) && success + sb.WriteString(" ? ") + success = PrintRecursive(node.GetExpressions()[1], sb, depth) && success + sb.WriteString(" : ") + success = PrintRecursive(node.GetExpressions()[2], sb, depth) && success + return success +} diff --git a/printer/ast_printer/constructor.go b/printer/ast_printer/constructor.go new file mode 100644 index 00000000..d44d642a --- /dev/null +++ b/printer/ast_printer/constructor.go @@ -0,0 +1,16 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printConstructor(node *ast.Constructor, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("constructor(") + success = printParameterList(node.GetParameters(), sb, depth) && success + sb.WriteString(") ") + success = PrintRecursive(node.GetBody(), sb, depth) && success + return success +} diff --git a/printer/ast_printer/continue.go b/printer/ast_printer/continue.go new file mode 100644 index 00000000..3e6fa60a --- /dev/null +++ b/printer/ast_printer/continue.go @@ -0,0 +1,12 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printContinueStatement(node *ast.ContinueStatement, sb *strings.Builder, depth int) bool { + sb.WriteString("continue") + return true +} diff --git a/printer/ast_printer/contract.go b/printer/ast_printer/contract.go new file mode 100644 index 00000000..9e3706bb --- /dev/null +++ b/printer/ast_printer/contract.go @@ -0,0 +1,29 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printContract(node *ast.Contract, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("contract ") + sb.WriteString(node.GetName()) + baseContracts := []string{} + for _, base := range node.GetBaseContracts() { + baseContracts = append(baseContracts, base.BaseName.GetName()) + } + if len(baseContracts) > 0 { + sb.WriteString(" is ") + writeSeperatedStrings(sb, ", ", baseContracts...) + } + + sb.WriteString(" {\n") + for _, child := range node.GetNodes() { + sb.WriteString(indentString("", depth+1)) + success = PrintRecursive(child, sb, depth+1) && success + } + sb.WriteString("}\n") + return success +} diff --git a/printer/ast_printer/declaration.go b/printer/ast_printer/declaration.go new file mode 100644 index 00000000..60cd8f84 --- /dev/null +++ b/printer/ast_printer/declaration.go @@ -0,0 +1,18 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printDeclaration(node *ast.Declaration, sb *strings.Builder, depth int) bool { + success := true + typeName, ok := Print(node.GetTypeName()) + success = ok && success + ident := node.GetName() + storage := getStorageLocationString(node.GetStorageLocation()) + writeSeperatedStrings(sb, " ", typeName, storage, ident) + + return success +} diff --git a/printer/ast_printer/emit.go b/printer/ast_printer/emit.go new file mode 100644 index 00000000..ffddcde1 --- /dev/null +++ b/printer/ast_printer/emit.go @@ -0,0 +1,23 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printEmit(node *ast.Emit, sb *strings.Builder, depth int) bool { + success := true + args := []string{} + for _, arg := range node.GetArguments() { + s, ok := Print(arg) + success = ok && success + args = append(args, s) + } + sb.WriteString("emit ") + success = PrintRecursive(node.GetExpression(), sb, depth) && success + sb.WriteString("(") + writeSeperatedList(sb, ", ", args) + sb.WriteString(")") + return success +} diff --git a/printer/ast_printer/enum.go b/printer/ast_printer/enum.go new file mode 100644 index 00000000..6078b332 --- /dev/null +++ b/printer/ast_printer/enum.go @@ -0,0 +1,23 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printEnumDefinition(node *ast.EnumDefinition, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("enum ") + sb.WriteString(node.GetName()) + sb.WriteString(" {") + members := []string{} + for _, member := range node.GetMembers() { + s, ok := Print(member) + success = ok && success + members = append(members, s) + } + writeSeperatedList(sb, ", ", members) + sb.WriteString("}\n") + return success +} diff --git a/printer/ast_printer/error.go b/printer/ast_printer/error.go new file mode 100644 index 00000000..9e08345c --- /dev/null +++ b/printer/ast_printer/error.go @@ -0,0 +1,18 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printErrorDefinition(node *ast.ErrorDefinition, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("error ") + sb.WriteString(node.GetName()) + sb.WriteString("(") + success = printParameterList(node.GetParameters(), sb, depth) && success + sb.WriteString(")") + sb.WriteString(";\n") + return success +} diff --git a/printer/ast_printer/event.go b/printer/ast_printer/event.go new file mode 100644 index 00000000..70b8d694 --- /dev/null +++ b/printer/ast_printer/event.go @@ -0,0 +1,18 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printEventDefinition(node *ast.EventDefinition, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("event ") + sb.WriteString(node.GetName()) + sb.WriteString("(") + printParameterList(node.GetParameters(), sb, depth) + sb.WriteString(")") + sb.WriteString(";\n") + return success +} diff --git a/printer/ast_printer/for.go b/printer/ast_printer/for.go new file mode 100644 index 00000000..788bfcd8 --- /dev/null +++ b/printer/ast_printer/for.go @@ -0,0 +1,28 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printFor(node *ast.ForStatement, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("for (") + if node.GetInitialiser() != nil { + success = PrintRecursive(node.GetInitialiser(), sb, depth) && success + } + sb.WriteString("; ") + if node.GetCondition() != nil { + success = PrintRecursive(node.GetCondition(), sb, depth) && success + } + sb.WriteString("; ") + if node.GetClosure() != nil { + success = PrintRecursive(node.GetClosure(), sb, depth) && success + } + sb.WriteString(") ") + if node.GetBody() != nil { + success = PrintRecursive(node.GetBody(), sb, depth) && success + } + return success +} diff --git a/printer/ast_printer/function.go b/printer/ast_printer/function.go new file mode 100644 index 00000000..d651220f --- /dev/null +++ b/printer/ast_printer/function.go @@ -0,0 +1,36 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printFunction(node *ast.Function, sb *strings.Builder, depth int) bool { + success := true + visibility := getVisibilityString(node.GetVisibility()) + funcName := node.GetName() + mutability := getStateMutabilityString(node.GetStateMutability()) + virtual := "" + if node.IsVirtual() { + virtual = "virtual " + } + + writeStrings(sb, "function ", funcName, "(") + printParameterList(node.GetParameters(), sb, depth) + writeSeperatedStrings(sb, " ", ")", visibility, virtual, mutability) + + paramBuilder := strings.Builder{} + printParameterList(node.GetReturnParameters(), ¶mBuilder, depth) + + if paramBuilder.String() != "" { + writeStrings(sb, " returns (", paramBuilder.String(), ")") + } + sb.WriteString(" ") + + if node.GetBody() != nil { + success = PrintRecursive(node.GetBody(), sb, depth) && success + } + + return success +} diff --git a/printer/ast_printer/function_call.go b/printer/ast_printer/function_call.go new file mode 100644 index 00000000..ef147900 --- /dev/null +++ b/printer/ast_printer/function_call.go @@ -0,0 +1,26 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printFunctionCall(node *ast.FunctionCall, sb *strings.Builder, depth int) bool { + success := true + if node.GetExpression() != nil { + success = PrintRecursive(node.GetExpression(), sb, depth) && success + } + args := []string{} + if node.GetArguments() != nil { + for _, arg := range node.GetArguments() { + s, ok := Print(arg) + success = ok && success + args = append(args, s) + } + } + sb.WriteString("(") + writeSeperatedList(sb, ", ", args) + sb.WriteString(")") + return success +} diff --git a/printer/ast_printer/if.go b/printer/ast_printer/if.go new file mode 100644 index 00000000..8fb8502e --- /dev/null +++ b/printer/ast_printer/if.go @@ -0,0 +1,20 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printIfStatement(node *ast.IfStatement, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("if (") + if node.GetCondition() != nil { + success = PrintRecursive(node.GetCondition(), sb, depth) && success + } + sb.WriteString(") ") + if node.GetBody() != nil { + success = PrintRecursive(node.GetBody(), sb, depth) && success + } + return success +} diff --git a/printer/ast_printer/imports.go b/printer/ast_printer/imports.go new file mode 100644 index 00000000..dd84f3b5 --- /dev/null +++ b/printer/ast_printer/imports.go @@ -0,0 +1,25 @@ +package ast_printer + +import ( + "fmt" + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printImport(node *ast.Import, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("import ") + file := fmt.Sprintf("'%s'", node.GetFile()) + if node.UnitAlias != "" { + writeStrings(sb, file, " as ", node.UnitAlias) + } else if len(node.UnitAliases) > 0 { + sb.WriteString("{") + writeSeperatedList(sb, ", ", node.UnitAliases) + writeStrings(sb, "} from ", file) + } else { + writeStrings(sb, file) + } + writeStrings(sb, ";") + return success +} diff --git a/printer/ast_printer/index_access.go b/printer/ast_printer/index_access.go new file mode 100644 index 00000000..66d88c38 --- /dev/null +++ b/printer/ast_printer/index_access.go @@ -0,0 +1,16 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printIndexAccess(node *ast.IndexAccess, sb *strings.Builder, depth int) bool { + success := true + success = PrintRecursive(node.GetBaseExpression(), sb, depth) && success + sb.WriteString("[") + success = PrintRecursive(node.GetIndexExpression(), sb, depth) && success + sb.WriteString("]") + return success +} diff --git a/printer/ast_printer/inline_array.go b/printer/ast_printer/inline_array.go new file mode 100644 index 00000000..4bea71b2 --- /dev/null +++ b/printer/ast_printer/inline_array.go @@ -0,0 +1,28 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" + "go.uber.org/zap" +) + +func printInlineArray(node *ast.InlineArray, sb *strings.Builder, depth int) bool { + success := true + if len(node.GetExpressions()) < 3 { + zap.S().Error("Conditional node must have at least 3 expressions") + return false + } + sb.WriteString("[") + items := []string{} + for _, item := range node.GetExpressions() { + s, ok := Print(item) + if !ok { + success = false + } + items = append(items, s) + } + writeSeperatedList(sb, ", ", items) + sb.WriteString("]") + return success +} diff --git a/printer/ast_printer/member_access.go b/printer/ast_printer/member_access.go new file mode 100644 index 00000000..10a699c3 --- /dev/null +++ b/printer/ast_printer/member_access.go @@ -0,0 +1,15 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printMemberAccessExpression(node *ast.MemberAccessExpression, sb *strings.Builder, depth int) bool { + success := true + success = PrintRecursive(node.GetExpression(), sb, depth) && success + sb.WriteString(".") + sb.WriteString(node.GetMemberName()) + return success +} diff --git a/printer/ast_printer/modifier.go b/printer/ast_printer/modifier.go new file mode 100644 index 00000000..e0faa0cd --- /dev/null +++ b/printer/ast_printer/modifier.go @@ -0,0 +1,30 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printModifierDefinition(node *ast.ModifierDefinition, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("modifier ") + sb.WriteString(node.GetName()) + sb.WriteString("(") + printParameterList(node.GetParameters(), sb, depth) + sb.WriteString(") ") + visibility := getVisibilityString(node.GetVisibility()) + virtual := "" + if node.IsVirtual() { + virtual = "virtual" + } + writeSeperatedStrings(sb, " ", visibility, virtual) + + if node.GetBody() == nil { + sb.WriteString(";\n") + } else { + sb.WriteString(" ") + success = PrintRecursive(node.GetBody(), sb, depth) && success + } + return success +} diff --git a/printer/ast_printer/parameter.go b/printer/ast_printer/parameter.go new file mode 100644 index 00000000..08120c9c --- /dev/null +++ b/printer/ast_printer/parameter.go @@ -0,0 +1,21 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printParameter(node *ast.Parameter, sb *strings.Builder, depth int) bool { + success := true + typeName := "" + ok := true + if node.GetTypeName() != nil { + typeName, ok = Print(node.GetTypeName()) + } + success = ok && success + ident := node.GetName() + storage := getStorageLocationString(node.GetStorageLocation()) + writeSeperatedStrings(sb, " ", typeName, storage, ident) + return success +} diff --git a/printer/ast_printer/parameter_list.go b/printer/ast_printer/parameter_list.go new file mode 100644 index 00000000..8b5723f9 --- /dev/null +++ b/printer/ast_printer/parameter_list.go @@ -0,0 +1,20 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +// parameterList does not have the correct Node interface, we handle it separately +func printParameterList(node *ast.ParameterList, sb *strings.Builder, depth int) bool { + success := true + params := []string{} + for _, param := range node.GetParameters() { + s, ok := Print(param) + success = ok && success + params = append(params, s) + } + writeSeperatedList(sb, ", ", params) + return success +} diff --git a/printer/ast_printer/payable_conversion.go b/printer/ast_printer/payable_conversion.go new file mode 100644 index 00000000..e17a4e39 --- /dev/null +++ b/printer/ast_printer/payable_conversion.go @@ -0,0 +1,21 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printPayableConversion(node *ast.PayableConversion, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("payable(") + args := []string{} + for _, arg := range node.GetArguments() { + s, ok := Print(arg) + success = ok && success + args = append(args, s) + } + sb.WriteString(strings.Join(args, ", ")) + sb.WriteString(")") + return success +} diff --git a/printer/ast_printer/pragma.go b/printer/ast_printer/pragma.go new file mode 100644 index 00000000..cc4a41f9 --- /dev/null +++ b/printer/ast_printer/pragma.go @@ -0,0 +1,13 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printPragma(node *ast.Pragma, sb *strings.Builder, depth int) bool { + sb.WriteString(node.GetText()) + sb.WriteString("\n") + return true +} diff --git a/printer/ast_printer/primary_expression.go b/printer/ast_printer/primary_expression.go new file mode 100644 index 00000000..a84231d9 --- /dev/null +++ b/printer/ast_printer/primary_expression.go @@ -0,0 +1,25 @@ +package ast_printer + +import ( + "strings" + + ast_pb "github.com/unpackdev/protos/dist/go/ast" + "github.com/unpackdev/solgo/ast" +) + +func printPrimaryExpression(node *ast.PrimaryExpression, sb *strings.Builder, depth int) bool { + s := "" + if node.GetKind() == ast_pb.NodeType_UNICODE_STRING_LITERAL { + s = "\"" + node.GetValue() + "\"" + sb.WriteString(s) + return true + } + + if node.GetValue() == "" { + s = node.GetName() + } else { + s = node.GetValue() + } + sb.WriteString(s) + return true +} diff --git a/printer/ast_printer/return.go b/printer/ast_printer/return.go new file mode 100644 index 00000000..52f675ab --- /dev/null +++ b/printer/ast_printer/return.go @@ -0,0 +1,17 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printReturn(node *ast.ReturnStatement, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("return") + if node.GetExpression() != nil { + sb.WriteString(" ") + success = PrintRecursive(node.GetExpression(), sb, depth) && success + } + return success +} diff --git a/printer/ast_printer/revert.go b/printer/ast_printer/revert.go new file mode 100644 index 00000000..f4977461 --- /dev/null +++ b/printer/ast_printer/revert.go @@ -0,0 +1,26 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printRevertStatement(node *ast.RevertStatement, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("revert") + if node.GetExpression() != nil { + sb.WriteString(" ") + success = PrintRecursive(node.GetExpression(), sb, depth) && success + } + args := []string{} + for _, arg := range node.GetArguments() { + s, ok := Print(arg) + success = success && ok + args = append(args, s) + } + sb.WriteString("(") + writeSeperatedList(sb, ", ", args) + sb.WriteString(")") + return success +} diff --git a/printer/ast_printer/source_unit.go b/printer/ast_printer/source_unit.go new file mode 100644 index 00000000..0a5525cf --- /dev/null +++ b/printer/ast_printer/source_unit.go @@ -0,0 +1,18 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printSourceUnit(node ast.Node[ast.NodeType], sb *strings.Builder, depth int) bool { + success := true + for _, child := range node.GetNodes() { + s, ok := Print(child.(ast.Node[ast.NodeType])) + success = ok && success + sb.WriteString(s) + sb.WriteString("\n") + } + return success +} diff --git a/printer/ast_printer/state_variable.go b/printer/ast_printer/state_variable.go new file mode 100644 index 00000000..6abc8bab --- /dev/null +++ b/printer/ast_printer/state_variable.go @@ -0,0 +1,31 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printStateVariableDeclaration(node *ast.StateVariableDeclaration, sb *strings.Builder, depth int) bool { + success := true + typeName, ok := Print(node.GetTypeName()) + success = ok && success + ident := node.GetName() + storage := getStorageLocationString(node.GetStorageLocation()) + visibility := getVisibilityString(node.GetVisibility()) + override := "" + if typeName == "addresspayable" { + typeName = "address" + storage = "payable" + } + if node.Override { + override = "override" + } + writeSeperatedStrings(sb, " ", typeName, visibility, storage, override, ident) + if node.GetInitialValue() != nil { + sb.WriteString(" = ") + success = PrintRecursive(node.GetInitialValue(), sb, depth) && success + } + sb.WriteString(";\n") + return success +} diff --git a/printer/ast_printer/struct.go b/printer/ast_printer/struct.go new file mode 100644 index 00000000..e230b35f --- /dev/null +++ b/printer/ast_printer/struct.go @@ -0,0 +1,21 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printStructDefinition(node *ast.StructDefinition, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("struct ") + sb.WriteString(node.GetName()) + sb.WriteString(" {\n") + for _, member := range node.GetMembers() { + sb.WriteString(indentString("", depth+1)) + success = PrintRecursive(member, sb, depth) && success + sb.WriteString(";\n") + } + sb.WriteString(indentString("}\n", depth)) + return success +} diff --git a/printer/ast_printer/tuple.go b/printer/ast_printer/tuple.go new file mode 100644 index 00000000..b7678b45 --- /dev/null +++ b/printer/ast_printer/tuple.go @@ -0,0 +1,21 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printTupleExpression(node *ast.TupleExpression, sb *strings.Builder, depth int) bool { + success := true + sb.WriteString("(") + components := []string{} + for _, component := range node.GetComponents() { + s, ok := Print(component) + success = ok && success + components = append(components, s) + } + sb.WriteString(strings.Join(components, ", ")) + sb.WriteString(")") + return success +} diff --git a/printer/ast_printer/type_name.go b/printer/ast_printer/type_name.go new file mode 100644 index 00000000..172316ec --- /dev/null +++ b/printer/ast_printer/type_name.go @@ -0,0 +1,41 @@ +package ast_printer + +import ( + "fmt" + "strings" + + ast_pb "github.com/unpackdev/protos/dist/go/ast" + "github.com/unpackdev/solgo/ast" +) + +func printTypeName(node *ast.TypeName, sb *strings.Builder, depth int) bool { + success := true + if node.ValueType != nil { + keyType, ok := Print(node.KeyType) + if !ok { + success = false + } + valueType, ok := Print(node.ValueType) + if !ok { + success = false + } + typeName := fmt.Sprintf("mapping(%s => %s)", keyType, valueType) + sb.WriteString(typeName) + } else { + if node.NodeType == ast_pb.NodeType_USER_DEFINED_PATH_NAME { + + ref := node.GetTree().GetById(node.GetReferencedDeclaration()) + if enumType, ok := ref.(*ast.EnumDefinition); ok { + enumName := enumType.GetName() + sb.WriteString(enumName) + } + if structType, ok := ref.(*ast.StructDefinition); ok { + structName := structType.GetName() + sb.WriteString(structName) + } + } else { + sb.WriteString(node.GetName()) + } + } + return success +} diff --git a/printer/ast_printer/unary_prefix.go b/printer/ast_printer/unary_prefix.go new file mode 100644 index 00000000..58f7d1db --- /dev/null +++ b/printer/ast_printer/unary_prefix.go @@ -0,0 +1,13 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printUnaryPrefix(node *ast.UnaryPrefix, sb *strings.Builder, depth int) bool { + sb.WriteString(getUnaryOperatorString(node.GetOperator())) + success := PrintRecursive(node.GetExpression(), sb, depth) + return success +} diff --git a/printer/ast_printer/unary_suffix.go b/printer/ast_printer/unary_suffix.go new file mode 100644 index 00000000..427bc927 --- /dev/null +++ b/printer/ast_printer/unary_suffix.go @@ -0,0 +1,31 @@ +package ast_printer + +import ( + "strings" + + ast_pb "github.com/unpackdev/protos/dist/go/ast" + "github.com/unpackdev/solgo/ast" +) + +func getUnaryOperatorString(op ast_pb.Operator) string { + switch op { + case ast_pb.Operator_NOT: + return "!" + case ast_pb.Operator_BIT_NOT: + return "~" + case ast_pb.Operator_SUBTRACT: + return "-" + case ast_pb.Operator_INCREMENT: + return "++" + case ast_pb.Operator_DECREMENT: + return "--" + default: + return "" + } +} + +func printUnarySuffix(node *ast.UnarySuffix, sb *strings.Builder, depth int) bool { + success := PrintRecursive(node.GetExpression(), sb, depth) + sb.WriteString(getUnaryOperatorString(node.GetOperator())) + return success +} diff --git a/printer/ast_printer/variable.go b/printer/ast_printer/variable.go new file mode 100644 index 00000000..b66547da --- /dev/null +++ b/printer/ast_printer/variable.go @@ -0,0 +1,28 @@ +package ast_printer + +import ( + "strings" + + "github.com/unpackdev/solgo/ast" +) + +func printVariableDeclaration(node *ast.VariableDeclaration, sb *strings.Builder, depth int) bool { + success := true + isTuple := len(node.GetDeclarations()) > 1 + if isTuple { + decls := []string{} + for _, decl := range node.GetDeclarations() { + s, ok := Print(decl) + success = ok && success + decls = append(decls, s) + } + writeSeperatedList(sb, ", ", decls) + } else { + PrintRecursive(node.GetDeclarations()[0], sb, depth) + } + if node.GetInitialValue() != nil { + sb.WriteString(" = ") + PrintRecursive(node.GetInitialValue(), sb, depth) + } + return success +}