package client import ( "context" "errors" "fmt" "io" "sort" "strings" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/route53" r53 "github.com/aws/aws-sdk-go-v2/service/route53/types" "github.com/aws/aws-sdk-go-v2/service/secretsmanager" sm "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types" ) type AWSClient struct { cfg aws.Config } type SecretListEntry struct { Name string Description string } type optionFunc func(*AWSClient) func NewAWSClient(optionFuncs ...optionFunc) *AWSClient { awsClient := AWSClient{} for _, option := range optionFuncs { option(&awsClient) } return &awsClient } func SetConfig(cfg aws.Config) optionFunc { return func(c *AWSClient) { c.cfg = cfg } } /* Create config for AWS clients. If region is unset, attempt to set it via IMDS. */ func CreateAWSConfig() (aws.Config, error) { cfg, err := config.LoadDefaultConfig(context.TODO()) if err != nil { return cfg, fmt.Errorf("unable to load default AWS config, %w", err) } // try to set the region using IMDS if it's not already set if cfg.Region == "" { resp, err := imds.NewFromConfig(cfg).GetRegion(context.TODO(), &imds.GetRegionInput{}) if err != nil { return cfg, fmt.Errorf("unable to get AWS region from IMDS, %w", err) } cfg, _ = config.LoadDefaultConfig(context.TODO(), config.WithRegion(resp.Region)) } return cfg, nil } /* Get the running instance id */ func (c AWSClient) GetInstanceID() (string, error) { id, err := imds.NewFromConfig(c.cfg).GetInstanceIdentityDocument(context.TODO(), &imds.GetInstanceIdentityDocumentInput{}) if err != nil { return "", err } return id.InstanceID, nil } // getZoneID returns the zone id given domain func (c AWSClient) GetZoneID(domain string) (string, error) { r53Client := route53.NewFromConfig(c.cfg) output, err := r53Client.ListHostedZonesByName(context.TODO(), &route53.ListHostedZonesByNameInput{ DNSName: aws.String(domain), }) if err != nil { return "", err } if len(output.HostedZones) == 0 { return "", errors.New("no matching zone") } return *output.HostedZones[0].Id, nil } // GetMetadata returns the specified path from instance metadata func (c AWSClient) GetMetadata(path string) (string, error) { imdsClient := imds.NewFromConfig(c.cfg) metadata, err := imdsClient.GetMetadata(context.TODO(), &imds.GetMetadataInput{ Path: path, }) if err != nil { return "", err } buf := new(strings.Builder) _, err = io.Copy(buf, metadata.Content) if err != nil { return "", err } return buf.String(), nil } func (c AWSClient) StopInstance(instance_id string, hibernate bool) (string, error) { si_out, err := ec2.NewFromConfig(c.cfg).StopInstances(context.TODO(), &ec2.StopInstancesInput{ InstanceIds: []string{instance_id}, Hibernate: aws.Bool(hibernate), }) if err != nil { return "", err } return string(si_out.StoppingInstances[0].CurrentState.Name), nil } func (c AWSClient) GetEC2Tag(instance_id string, tag_name string) (string, error) { result, err := ec2.NewFromConfig(c.cfg).DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{ InstanceIds: []string{instance_id}, }) if err != nil { return "", err } for _, reservation := range result.Reservations { for _, instance := range reservation.Instances { for _, tag := range instance.Tags { // Check if the tag key matches the one you're interested in if *tag.Key == tag_name { return *tag.Value, nil } } } } return "", nil } func (c AWSClient) GetEC2PublicIP(instance_id string) (string, error) { var publicIP string result, err := ec2.NewFromConfig(c.cfg).DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{ InstanceIds: []string{instance_id}, }) if err != nil { return "", err } for _, reservation := range result.Reservations { for _, instance := range reservation.Instances { if instance.PublicIpAddress != nil { publicIP = *instance.PublicIpAddress } } } if publicIP == "" { err = errors.New("no public ip address is associated with the instance") } return publicIP, err } func (c AWSClient) UpdateA(hostname, ipaddress, zoneID string) error { change := r53.Change{ Action: "UPSERT", ResourceRecordSet: &r53.ResourceRecordSet{ Name: aws.String(hostname), Type: "A", TTL: aws.Int64(60), ResourceRecords: []r53.ResourceRecord{ r53.ResourceRecord{Value: aws.String(ipaddress)}, }, }, } _, err := route53.NewFromConfig(c.cfg).ChangeResourceRecordSets(context.TODO(), &route53.ChangeResourceRecordSetsInput{ HostedZoneId: aws.String(zoneID), ChangeBatch: &r53.ChangeBatch{ Changes: []r53.Change{change}, }, }) return err } // ListSecrets lists all secrets with tags matching the tags map, returning a sorted list // of secrets names and descriptions from the response. func (c AWSClient) ListSecrets(tags map[string]string) ([]SecretListEntry, error) { client := secretsmanager.NewFromConfig(c.cfg) // construct filters from tags map var filters []sm.Filter // filters := make([]sm.Filter, 0, len(tags)*2) for key, value := range tags { filters = append(filters, sm.Filter{ Key: sm.FilterNameStringTypeTagKey, Values: []string{key}, }) if value != "" { filters = append(filters, sm.Filter{ Key: sm.FilterNameStringTypeTagValue, Values: []string{value}, }) } } var secretList []SecretListEntry // Set the initial pagination token to an empty string var nextToken *string // Loop until there are no more pages for { // Call ListSecrets with the pagination token input := &secretsmanager.ListSecretsInput{ Filters: filters, NextToken: nextToken, } resp, err := client.ListSecrets(context.TODO(), input) if err != nil { return secretList, fmt.Errorf("unable to list secrets, %w", err) } // Process the secrets in the current page for _, secret := range resp.SecretList { entry := SecretListEntry{ Name: *secret.Name, } if secret.Description != nil { entry.Description = *secret.Description } secretList = append(secretList, entry) } // Check if there are more pages if resp.NextToken == nil { break } // Set the pagination token for the next iteration nextToken = resp.NextToken } // sort secretList alphabetically by name sort.Slice(secretList, func(i, j int) bool { return secretList[i].Name < secretList[j].Name }) return secretList, nil } // GetSecretValue retrieves the secret value for the given secret name. func (c AWSClient) GetSecretValue(secretName string) (string, error) { client := secretsmanager.NewFromConfig(c.cfg) input := &secretsmanager.GetSecretValueInput{ SecretId: &secretName, } resp, err := client.GetSecretValue(context.TODO(), input) if err != nil { return "", fmt.Errorf("unable to get secret value, %w", err) } return *resp.SecretString, nil }