aws-mgmt-go/pkg/client/aws.go
2024-12-12 20:22:11 -05:00

266 lines
6.8 KiB
Go

package client
import (
"context"
"errors"
"fmt"
"io"
"sort"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/route53"
r53 "github.com/aws/aws-sdk-go-v2/service/route53/types"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
sm "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types"
)
type AWSClient struct {
cfg aws.Config
}
type SecretListEntry struct {
Name string
Description string
}
type optionFunc func(*AWSClient)
func NewAWSClient(optionFuncs ...optionFunc) *AWSClient {
awsClient := AWSClient{}
for _, option := range optionFuncs {
option(&awsClient)
}
return &awsClient
}
func SetConfig(cfg aws.Config) optionFunc {
return func(c *AWSClient) {
c.cfg = cfg
}
}
/* Create config for AWS clients. If region is unset, attempt to set it via IMDS. */
func CreateAWSConfig() (aws.Config, error) {
cfg, err := config.LoadDefaultConfig(context.TODO())
if err != nil {
return cfg, fmt.Errorf("unable to load default AWS config, %w", err)
}
// try to set the region using IMDS if it's not already set
if cfg.Region == "" {
resp, err := imds.NewFromConfig(cfg).GetRegion(context.TODO(), &imds.GetRegionInput{})
if err != nil {
return cfg, fmt.Errorf("unable to get AWS region from IMDS, %w", err)
}
cfg, _ = config.LoadDefaultConfig(context.TODO(), config.WithRegion(resp.Region))
}
return cfg, nil
}
/* Get the running instance id */
func (c AWSClient) GetInstanceID() (string, error) {
id, err := imds.NewFromConfig(c.cfg).GetInstanceIdentityDocument(context.TODO(), &imds.GetInstanceIdentityDocumentInput{})
if err != nil {
return "", err
}
return id.InstanceID, nil
}
// getZoneID returns the zone id given domain
func (c AWSClient) GetZoneID(domain string) (string, error) {
r53Client := route53.NewFromConfig(c.cfg)
output, err := r53Client.ListHostedZonesByName(context.TODO(), &route53.ListHostedZonesByNameInput{
DNSName: aws.String(domain),
})
if err != nil {
return "", err
}
if len(output.HostedZones) == 0 {
return "", errors.New("no matching zone")
}
return *output.HostedZones[0].Id, nil
}
// GetMetadata returns the specified path from instance metadata
func (c AWSClient) GetMetadata(path string) (string, error) {
imdsClient := imds.NewFromConfig(c.cfg)
metadata, err := imdsClient.GetMetadata(context.TODO(), &imds.GetMetadataInput{
Path: path,
})
if err != nil {
return "", err
}
buf := new(strings.Builder)
_, err = io.Copy(buf, metadata.Content)
if err != nil {
return "", err
}
return buf.String(), nil
}
func (c AWSClient) StopInstance(instance_id string, hibernate bool) (string, error) {
si_out, err := ec2.NewFromConfig(c.cfg).StopInstances(context.TODO(), &ec2.StopInstancesInput{
InstanceIds: []string{instance_id},
Hibernate: aws.Bool(hibernate),
})
if err != nil {
return "", err
}
return string(si_out.StoppingInstances[0].CurrentState.Name), nil
}
func (c AWSClient) GetEC2Tag(instance_id string, tag_name string) (string, error) {
result, err := ec2.NewFromConfig(c.cfg).DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{
InstanceIds: []string{instance_id},
})
if err != nil {
return "", err
}
for _, reservation := range result.Reservations {
for _, instance := range reservation.Instances {
for _, tag := range instance.Tags {
// Check if the tag key matches the one you're interested in
if *tag.Key == tag_name {
return *tag.Value, nil
}
}
}
}
return "", nil
}
func (c AWSClient) GetEC2PublicIP(instance_id string) (string, error) {
var publicIP string
result, err := ec2.NewFromConfig(c.cfg).DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{
InstanceIds: []string{instance_id},
})
if err != nil {
return "", err
}
for _, reservation := range result.Reservations {
for _, instance := range reservation.Instances {
if instance.PublicIpAddress != nil {
publicIP = *instance.PublicIpAddress
}
}
}
if publicIP == "" {
err = errors.New("no public ip address is associated with the instance")
}
return publicIP, err
}
func (c AWSClient) UpdateA(hostname, ipaddress, zoneID string) error {
change := r53.Change{
Action: "UPSERT",
ResourceRecordSet: &r53.ResourceRecordSet{
Name: aws.String(hostname),
Type: "A",
TTL: aws.Int64(60),
ResourceRecords: []r53.ResourceRecord{
r53.ResourceRecord{Value: aws.String(ipaddress)},
},
},
}
_, err := route53.NewFromConfig(c.cfg).ChangeResourceRecordSets(context.TODO(), &route53.ChangeResourceRecordSetsInput{
HostedZoneId: aws.String(zoneID),
ChangeBatch: &r53.ChangeBatch{
Changes: []r53.Change{change},
},
})
return err
}
// ListSecrets lists all secrets with tags matching the tags map, returning a sorted list
// of secrets names and descriptions from the response.
func (c AWSClient) ListSecrets(tags map[string]string) ([]SecretListEntry, error) {
client := secretsmanager.NewFromConfig(c.cfg)
// construct filters from tags map
var filters []sm.Filter
// filters := make([]sm.Filter, 0, len(tags)*2)
for key, value := range tags {
filters = append(filters, sm.Filter{
Key: sm.FilterNameStringTypeTagKey,
Values: []string{key},
})
if value != "" {
filters = append(filters, sm.Filter{
Key: sm.FilterNameStringTypeTagValue,
Values: []string{value},
})
}
}
var secretList []SecretListEntry
// Set the initial pagination token to an empty string
var nextToken *string
// Loop until there are no more pages
for {
// Call ListSecrets with the pagination token
input := &secretsmanager.ListSecretsInput{
Filters: filters,
NextToken: nextToken,
}
resp, err := client.ListSecrets(context.TODO(), input)
if err != nil {
return secretList, fmt.Errorf("unable to list secrets, %w", err)
}
// Process the secrets in the current page
for _, secret := range resp.SecretList {
entry := SecretListEntry{
Name: *secret.Name,
}
if secret.Description != nil {
entry.Description = *secret.Description
}
secretList = append(secretList, entry)
}
// Check if there are more pages
if resp.NextToken == nil {
break
}
// Set the pagination token for the next iteration
nextToken = resp.NextToken
}
// sort secretList alphabetically by name
sort.Slice(secretList, func(i, j int) bool {
return secretList[i].Name < secretList[j].Name
})
return secretList, nil
}
// GetSecretValue retrieves the secret value for the given secret name.
func (c AWSClient) GetSecretValue(secretName string) (string, error) {
client := secretsmanager.NewFromConfig(c.cfg)
input := &secretsmanager.GetSecretValueInput{
SecretId: &secretName,
}
resp, err := client.GetSecretValue(context.TODO(), input)
if err != nil {
return "", fmt.Errorf("unable to get secret value, %w", err)
}
return *resp.SecretString, nil
}