diff --git a/cmd/generate-bindings/solana/anchor-go/generator/idl_validate.go b/cmd/generate-bindings/solana/anchor-go/generator/idl_validate.go new file mode 100644 index 00000000..35f41e69 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/idl_validate.go @@ -0,0 +1,215 @@ +package generator + +import ( + "fmt" + "strings" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/tools" +) + +// ValidateIDLDerivedIdentifiers checks that names from the IDL produce valid Go identifiers +// after the same transforms used by the Jennifer-based generator. Call this before Generate(). +func ValidateIDLDerivedIdentifiers(i *idl.Idl) error { + if i == nil { + return fmt.Errorf("idl is nil") + } + for ai, acc := range i.Accounts { + ctx := fmt.Sprintf("accounts[%d](name=%q)", ai, acc.Name) + if err := validatePascalIdent(ctx, acc.Name); err != nil { + return err + } + disc := FormatAccountDiscriminatorName(acc.Name) + if err := validateRawIdent(ctx+".discriminatorVar", acc.Name, disc); err != nil { + return err + } + } + for ei, ev := range i.Events { + ctx := fmt.Sprintf("events[%d](name=%q)", ei, ev.Name) + if err := validatePascalIdent(ctx, ev.Name); err != nil { + return err + } + disc := FormatEventDiscriminatorName(ev.Name) + if err := validateRawIdent(ctx+".discriminatorVar", ev.Name, disc); err != nil { + return err + } + } + for ci, co := range i.Constants { + if co.Name == "" { + continue + } + ctx := fmt.Sprintf("constants[%d]", ci) + if err := validateRawIdent(ctx, co.Name, co.Name); err != nil { + return err + } + } + for ixIdx, ix := range i.Instructions { + ctx := fmt.Sprintf("instructions[%d](name=%q)", ixIdx, ix.Name) + if err := validatePascalIdent(ctx, ix.Name); err != nil { + return err + } + disc := FormatInstructionDiscriminatorName(ix.Name) + if err := validateRawIdent(ctx+".discriminatorVar", ix.Name, disc); err != nil { + return err + } + fn := newInstructionFuncName(ix.Name) + if err := validateRawIdent(ctx+".constructor", ix.Name, fn); err != nil { + return err + } + typeName := instructionStructTypeName(ix.Name) + if err := validateRawIdent(ctx+".instructionStructType", ix.Name, typeName); err != nil { + return err + } + for _, arg := range ix.Args { + argCtx := ctx + ".args(name=" + quoteIDL(arg.Name) + ")" + if err := validatePascalIdent(argCtx, arg.Name); err != nil { + return err + } + param := formatParamName(arg.Name) + if err := validateRawIdent(argCtx+".builderParam", arg.Name, param); err != nil { + return err + } + } + for ai, accItem := range ix.Accounts { + switch acc := accItem.(type) { + case *idl.IdlInstructionAccount: + acCtx := fmt.Sprintf("%s.accounts[%d](name=%q)", ctx, ai, acc.Name) + if err := validatePascalIdent(acCtx, acc.Name); err != nil { + return err + } + fieldBase := tools.ToCamelUpper(acc.Name) + if err := validateRawIdent(acCtx+".accountField", acc.Name, fieldBase); err != nil { + return err + } + if acc.Writable { + if err := validateRawIdent(acCtx+".writableFlag", acc.Name, fieldBase+"Writable"); err != nil { + return err + } + } + if acc.Signer { + if err := validateRawIdent(acCtx+".signerFlag", acc.Name, fieldBase+"Signer"); err != nil { + return err + } + } + if acc.Optional { + if err := validateRawIdent(acCtx+".optionalFlag", acc.Name, fieldBase+"Optional"); err != nil { + return err + } + } + param := formatAccountNameParam(acc.Name) + if err := validateRawIdent(acCtx+".builderParam", acc.Name, param); err != nil { + return err + } + case *idl.IdlInstructionAccounts: + return fmt.Errorf("%s.accounts[%d]: composite account groups are not supported", ctx, ai) + default: + return fmt.Errorf("%s.accounts[%d]: unknown account item type %T", ctx, ai, accItem) + } + } + } + for ti, def := range i.Types { + ctx := fmt.Sprintf("types[%d](name=%q)", ti, def.Name) + if err := validatePascalIdent(ctx, def.Name); err != nil { + return err + } + if err := validateTypeDefTy(ctx, def.Name, def.Ty); err != nil { + return err + } + } + return nil +} + +func instructionStructTypeName(instructionName string) string { + lower := strings.ToLower(instructionName) + if strings.HasSuffix(lower, "instruction") { + return tools.ToCamelUpper(instructionName) + } + return tools.ToCamelUpper(instructionName) + "Instruction" +} + +func quoteIDL(s string) string { + return fmt.Sprintf("%q", s) +} + +func validateTypeDefTy(ctx, typeName string, ty idl.IdlTypeDefTy) error { + if ty == nil { + return fmt.Errorf("%s: type definition has nil type body", ctx) + } + switch vv := ty.(type) { + case *idl.IdlTypeDefTyStruct: + fields := vv.Fields + if fields == nil { + return nil + } + switch f := fields.(type) { + case idl.IdlDefinedFieldsNamed: + for fi, field := range f { + fctx := fmt.Sprintf("%s.fields[%d](name=%q)", ctx, fi, field.Name) + if err := validatePascalIdent(fctx, field.Name); err != nil { + return err + } + } + case idl.IdlDefinedFieldsTuple: + _ = f + } + case *idl.IdlTypeDefTyEnum: + enumExported := tools.ToCamelUpper(typeName) + if vv.Variants.IsAllSimple() { + for vi, variant := range vv.Variants { + vctx := fmt.Sprintf("%s.variants[%d](name=%q)", ctx, vi, variant.Name) + if err := validatePascalIdent(vctx, variant.Name); err != nil { + return err + } + combo := formatSimpleEnumVariantName(variant.Name, enumExported) + if err := validateRawIdent(vctx+".simpleEnumConst", variant.Name, combo); err != nil { + return err + } + } + } else { + for vi, variant := range vv.Variants { + vctx := fmt.Sprintf("%s.variants[%d](name=%q)", ctx, vi, variant.Name) + if err := validatePascalIdent(vctx, variant.Name); err != nil { + return err + } + vt := formatComplexEnumVariantTypeName(enumExported, variant.Name) + if err := validateRawIdent(vctx+".complexVariantType", variant.Name, vt); err != nil { + return err + } + if !variant.Fields.IsSome() { + continue + } + switch df := variant.Fields.Unwrap().(type) { + case idl.IdlDefinedFieldsNamed: + for fi, field := range df { + fctx := fmt.Sprintf("%s.fields[%d](name=%q)", vctx, fi, field.Name) + if err := validatePascalIdent(fctx, field.Name); err != nil { + return err + } + } + case idl.IdlDefinedFieldsTuple: + } + } + } + default: + return fmt.Errorf("%s: unsupported IDL type definition shape %T", ctx, ty) + } + return nil +} + +func validatePascalIdent(context, raw string) error { + ident := tools.ToCamelUpper(raw) + return validateRawIdent(context, raw, ident) +} + +func validateRawIdent(context, idlSource, goIdent string) error { + if goIdent == "" { + return fmt.Errorf("%s: empty Go identifier derived from IDL name %q", context, idlSource) + } + if !tools.IsValidIdent(goIdent) { + return fmt.Errorf("%s: IDL name %q yields invalid Go identifier %q (must be a valid Go identifier for generated bindings)", context, idlSource, goIdent) + } + if tools.IsReservedKeyword(goIdent) { + return fmt.Errorf("%s: IDL name %q yields Go reserved keyword %q", context, idlSource, goIdent) + } + return nil +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/idl_validate_test.go b/cmd/generate-bindings/solana/anchor-go/generator/idl_validate_test.go new file mode 100644 index 00000000..dda3a799 --- /dev/null +++ b/cmd/generate-bindings/solana/anchor-go/generator/idl_validate_test.go @@ -0,0 +1,49 @@ +package generator + +import ( + "testing" + + "github.com/gagliardetto/anchor-go/idl" + "github.com/gagliardetto/anchor-go/idl/idltype" + "github.com/gagliardetto/solana-go" + "github.com/stretchr/testify/require" +) + +func testProgramID(t *testing.T) *solana.PublicKey { + t.Helper() + pk, err := solana.PublicKeyFromBase58("ECL8142j2YQAvs9R9geSsRnkVH2wLEi7soJCRyJ74cfL") + require.NoError(t, err) + return &pk +} + +func minimalInstruction(name string) idl.IdlInstruction { + return idl.IdlInstruction{ + Name: name, + Discriminator: idl.IdlDiscriminator{175, 175, 109, 31, 13, 152, 155, 237}, + Accounts: []idl.IdlInstructionAccountItem{}, + Args: []idl.IdlField{}, + } +} + +func TestValidateIDLDerivedIdentifiers_valid(t *testing.T) { + i := &idl.Idl{ + Address: testProgramID(t), + Instructions: []idl.IdlInstruction{minimalInstruction("initialize")}, + } + require.NoError(t, ValidateIDLDerivedIdentifiers(i)) +} + +func TestValidateIDLDerivedIdentifiers_invalidConstantName(t *testing.T) { + i := &idl.Idl{ + Address: testProgramID(t), + Instructions: []idl.IdlInstruction{minimalInstruction("initialize")}, + Constants: []idl.IdlConst{{ + Name: "123bad", + Ty: &idltype.U8{}, + Value: "1", + }}, + } + err := ValidateIDLDerivedIdentifiers(i) + require.Error(t, err) + require.Contains(t, err.Error(), "123bad") +} diff --git a/cmd/generate-bindings/solana/anchor-go/generator/tools.go b/cmd/generate-bindings/solana/anchor-go/generator/tools.go index 6488326b..91dd956a 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/tools.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/tools.go @@ -2,6 +2,7 @@ package generator import ( + "bytes" "os" "path" @@ -22,18 +23,12 @@ const ( ) func WriteFile(outDir string, assetFileName string, file *File) error { - // Save Go assets: assetFilepath := path.Join(outDir, assetFileName) - - // Create file Golang file: - goFile, err := os.Create(assetFilepath) - if err != nil { - panic(err) + var buf bytes.Buffer + if err := file.Render(&buf); err != nil { + return err } - defer goFile.Close() - - // Write generated Golang to file: - return file.Render(goFile) + return os.WriteFile(assetFilepath, buf.Bytes(), 0o644) } func DoGroup(f func(*Group)) *Statement { diff --git a/cmd/generate-bindings/solana/bindgen.go b/cmd/generate-bindings/solana/bindgen.go index 9cb3b3fe..30445272 100644 --- a/cmd/generate-bindings/solana/bindgen.go +++ b/cmd/generate-bindings/solana/bindgen.go @@ -1,11 +1,13 @@ package solana import ( + "bytes" "fmt" "go/token" "log/slog" "os" "path" + "strings" "github.com/gagliardetto/anchor-go/idl" "github.com/gagliardetto/anchor-go/tools" @@ -36,11 +38,6 @@ func GenerateBindings( "outputDir", outputDir, "pathToIdl", pathToIdl, ) - options := generator.GeneratorOptions{ - OutputDir: outputDir, - Package: programName, - ProgramName: programName, - } parsedIdl, err := idl.ParseFromFilepath(pathToIdl) if err != nil { return fmt.Errorf("failed to parse IDL: %w", err) @@ -51,13 +48,11 @@ func GenerateBindings( if err := parsedIdl.Validate(); err != nil { return fmt.Errorf("invalid IDL: %w", err) } - if parsedIdl.Address != nil && !parsedIdl.Address.IsZero() { - // If the IDL has an address, use it as the program ID: - slog.Info("Using IDL address as program ID", "address", parsedIdl.Address.String()) - options.ProgramId = parsedIdl.Address - } else { + if parsedIdl.Address == nil || parsedIdl.Address.IsZero() { return fmt.Errorf("address is empty in idl file: %s", pathToIdl) } + slog.Info("Using IDL address as program ID", "address", parsedIdl.Address.String()) + parsedIdl.Metadata.Name = bin.ToSnakeForSighash(parsedIdl.Metadata.Name) // check that the name is not a reserved keyword: if parsedIdl.Metadata.Name != "" { @@ -75,6 +70,21 @@ func GenerateBindings( } } + packageName, err := normalizeGoPackageName(programName) + if err != nil { + return err + } + if err := generator.ValidateIDLDerivedIdentifiers(parsedIdl); err != nil { + return fmt.Errorf("IDL contains names that cannot be mapped to valid Go identifiers: %w", err) + } + + options := generator.GeneratorOptions{ + OutputDir: outputDir, + Package: packageName, + ProgramName: programName, + ProgramId: parsedIdl.Address, + } + slog.Info("Parsed IDL successfully", "version", parsedIdl.Metadata.Version, "name", parsedIdl.Metadata.Name, @@ -95,28 +105,21 @@ func GenerateBindings( } for _, file := range generatedFiles.Files { - { - // Save assets: - assetFilename := file.Name - assetFilepath := path.Join(options.OutputDir, assetFilename) + assetFilename := file.Name + assetFilepath := path.Join(options.OutputDir, assetFilename) - // Create file: - goFile, err := os.Create(assetFilepath) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } + var buf bytes.Buffer + if err := file.File.Render(&buf); err != nil { + return fmt.Errorf("failed to render generated file %q: %w", assetFilename, err) + } - slog.Info("Writing file", - "filepath", assetFilepath, - "name", file.Name, - "modPath", options.ModPath, - ) - err = file.File.Render(goFile) - if err != nil { - goFile.Close() - return fmt.Errorf("failed to render file: %w", err) - } - goFile.Close() + slog.Info("Writing file", + "filepath", assetFilepath, + "name", file.Name, + "modPath", options.ModPath, + ) + if err := os.WriteFile(assetFilepath, buf.Bytes(), 0o644); err != nil { + return fmt.Errorf("failed to write file %q: %w", assetFilepath, err) } } slog.Info("Generation completed successfully", @@ -127,3 +130,26 @@ func GenerateBindings( ) return nil } + +// normalizeGoPackageName maps a contract filename stem or program label to a valid Go package name. +func normalizeGoPackageName(name string) (string, error) { + if strings.TrimSpace(name) == "" { + return "", fmt.Errorf("contract/program name for Go package is empty") + } + var b strings.Builder + for _, r := range strings.ToLower(name) { + if r == '-' { + b.WriteByte('_') + } else { + b.WriteRune(r) + } + } + out := b.String() + if !tools.IsValidIdent(out) { + return "", fmt.Errorf("invalid Go package name after normalization (from contract/program name %q): %q is not a valid Go identifier; use only letters, digits, and underscores, and do not start with a digit", name, out) + } + if tools.IsReservedKeyword(out) { + return "", fmt.Errorf("invalid Go package name: normalized name %q is a Go reserved keyword (from contract/program name %q)", out, name) + } + return out, nil +}