266 lines
6.8 KiB
Go
266 lines
6.8 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/config"
|
|
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
|
|
"github.com/aws/aws-sdk-go-v2/service/ec2"
|
|
"github.com/aws/aws-sdk-go-v2/service/route53"
|
|
r53 "github.com/aws/aws-sdk-go-v2/service/route53/types"
|
|
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
|
|
sm "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types"
|
|
)
|
|
|
|
type AWSClient struct {
|
|
cfg aws.Config
|
|
}
|
|
|
|
type SecretListEntry struct {
|
|
Name string
|
|
Description string
|
|
}
|
|
|
|
type optionFunc func(*AWSClient)
|
|
|
|
func NewAWSClient(optionFuncs ...optionFunc) *AWSClient {
|
|
awsClient := AWSClient{}
|
|
for _, option := range optionFuncs {
|
|
option(&awsClient)
|
|
}
|
|
return &awsClient
|
|
}
|
|
|
|
func SetConfig(cfg aws.Config) optionFunc {
|
|
return func(c *AWSClient) {
|
|
c.cfg = cfg
|
|
}
|
|
}
|
|
|
|
/* Create config for AWS clients. If region is unset, attempt to set it via IMDS. */
|
|
func CreateAWSConfig() (aws.Config, error) {
|
|
cfg, err := config.LoadDefaultConfig(context.TODO())
|
|
if err != nil {
|
|
return cfg, fmt.Errorf("unable to load default AWS config, %w", err)
|
|
}
|
|
|
|
// try to set the region using IMDS if it's not already set
|
|
if cfg.Region == "" {
|
|
resp, err := imds.NewFromConfig(cfg).GetRegion(context.TODO(), &imds.GetRegionInput{})
|
|
if err != nil {
|
|
return cfg, fmt.Errorf("unable to get AWS region from IMDS, %w", err)
|
|
}
|
|
cfg, _ = config.LoadDefaultConfig(context.TODO(), config.WithRegion(resp.Region))
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
/* Get the running instance id */
|
|
func (c AWSClient) GetInstanceID() (string, error) {
|
|
id, err := imds.NewFromConfig(c.cfg).GetInstanceIdentityDocument(context.TODO(), &imds.GetInstanceIdentityDocumentInput{})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return id.InstanceID, nil
|
|
}
|
|
|
|
// getZoneID returns the zone id given domain
|
|
func (c AWSClient) GetZoneID(domain string) (string, error) {
|
|
r53Client := route53.NewFromConfig(c.cfg)
|
|
output, err := r53Client.ListHostedZonesByName(context.TODO(), &route53.ListHostedZonesByNameInput{
|
|
DNSName: aws.String(domain),
|
|
})
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if len(output.HostedZones) == 0 {
|
|
return "", errors.New("no matching zone")
|
|
}
|
|
|
|
return *output.HostedZones[0].Id, nil
|
|
}
|
|
|
|
// GetMetadata returns the specified path from instance metadata
|
|
func (c AWSClient) GetMetadata(path string) (string, error) {
|
|
imdsClient := imds.NewFromConfig(c.cfg)
|
|
metadata, err := imdsClient.GetMetadata(context.TODO(), &imds.GetMetadataInput{
|
|
Path: path,
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
buf := new(strings.Builder)
|
|
_, err = io.Copy(buf, metadata.Content)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return buf.String(), nil
|
|
}
|
|
|
|
func (c AWSClient) StopInstance(instance_id string, hibernate bool) (string, error) {
|
|
si_out, err := ec2.NewFromConfig(c.cfg).StopInstances(context.TODO(), &ec2.StopInstancesInput{
|
|
InstanceIds: []string{instance_id},
|
|
Hibernate: aws.Bool(hibernate),
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(si_out.StoppingInstances[0].CurrentState.Name), nil
|
|
}
|
|
|
|
func (c AWSClient) GetEC2Tag(instance_id string, tag_name string) (string, error) {
|
|
result, err := ec2.NewFromConfig(c.cfg).DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{
|
|
InstanceIds: []string{instance_id},
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
for _, reservation := range result.Reservations {
|
|
for _, instance := range reservation.Instances {
|
|
for _, tag := range instance.Tags {
|
|
// Check if the tag key matches the one you're interested in
|
|
if *tag.Key == tag_name {
|
|
return *tag.Value, nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func (c AWSClient) GetEC2PublicIP(instance_id string) (string, error) {
|
|
var publicIP string
|
|
result, err := ec2.NewFromConfig(c.cfg).DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{
|
|
InstanceIds: []string{instance_id},
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
for _, reservation := range result.Reservations {
|
|
for _, instance := range reservation.Instances {
|
|
if instance.PublicIpAddress != nil {
|
|
publicIP = *instance.PublicIpAddress
|
|
}
|
|
}
|
|
}
|
|
if publicIP == "" {
|
|
err = errors.New("no public ip address is associated with the instance")
|
|
}
|
|
return publicIP, err
|
|
}
|
|
|
|
func (c AWSClient) UpdateA(hostname, ipaddress, zoneID string) error {
|
|
|
|
change := r53.Change{
|
|
Action: "UPSERT",
|
|
ResourceRecordSet: &r53.ResourceRecordSet{
|
|
Name: aws.String(hostname),
|
|
Type: "A",
|
|
TTL: aws.Int64(60),
|
|
ResourceRecords: []r53.ResourceRecord{
|
|
r53.ResourceRecord{Value: aws.String(ipaddress)},
|
|
},
|
|
},
|
|
}
|
|
|
|
_, err := route53.NewFromConfig(c.cfg).ChangeResourceRecordSets(context.TODO(), &route53.ChangeResourceRecordSetsInput{
|
|
HostedZoneId: aws.String(zoneID),
|
|
ChangeBatch: &r53.ChangeBatch{
|
|
Changes: []r53.Change{change},
|
|
},
|
|
})
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
// ListSecrets lists all secrets with tags matching the tags map, returning a sorted list
|
|
// of secrets names and descriptions from the response.
|
|
func (c AWSClient) ListSecrets(tags map[string]string) ([]SecretListEntry, error) {
|
|
client := secretsmanager.NewFromConfig(c.cfg)
|
|
|
|
// construct filters from tags map
|
|
var filters []sm.Filter
|
|
// filters := make([]sm.Filter, 0, len(tags)*2)
|
|
for key, value := range tags {
|
|
filters = append(filters, sm.Filter{
|
|
Key: sm.FilterNameStringTypeTagKey,
|
|
Values: []string{key},
|
|
})
|
|
if value != "" {
|
|
filters = append(filters, sm.Filter{
|
|
Key: sm.FilterNameStringTypeTagValue,
|
|
Values: []string{value},
|
|
})
|
|
}
|
|
}
|
|
|
|
var secretList []SecretListEntry
|
|
|
|
// Set the initial pagination token to an empty string
|
|
var nextToken *string
|
|
|
|
// Loop until there are no more pages
|
|
for {
|
|
// Call ListSecrets with the pagination token
|
|
input := &secretsmanager.ListSecretsInput{
|
|
Filters: filters,
|
|
NextToken: nextToken,
|
|
}
|
|
resp, err := client.ListSecrets(context.TODO(), input)
|
|
if err != nil {
|
|
return secretList, fmt.Errorf("unable to list secrets, %w", err)
|
|
}
|
|
|
|
// Process the secrets in the current page
|
|
for _, secret := range resp.SecretList {
|
|
entry := SecretListEntry{
|
|
Name: *secret.Name,
|
|
}
|
|
if secret.Description != nil {
|
|
entry.Description = *secret.Description
|
|
}
|
|
secretList = append(secretList, entry)
|
|
}
|
|
|
|
// Check if there are more pages
|
|
if resp.NextToken == nil {
|
|
break
|
|
}
|
|
// Set the pagination token for the next iteration
|
|
nextToken = resp.NextToken
|
|
}
|
|
|
|
// sort secretList alphabetically by name
|
|
sort.Slice(secretList, func(i, j int) bool {
|
|
return secretList[i].Name < secretList[j].Name
|
|
})
|
|
|
|
return secretList, nil
|
|
}
|
|
|
|
// GetSecretValue retrieves the secret value for the given secret name.
|
|
func (c AWSClient) GetSecretValue(secretName string) (string, error) {
|
|
client := secretsmanager.NewFromConfig(c.cfg)
|
|
|
|
input := &secretsmanager.GetSecretValueInput{
|
|
SecretId: &secretName,
|
|
}
|
|
|
|
resp, err := client.GetSecretValue(context.TODO(), input)
|
|
if err != nil {
|
|
return "", fmt.Errorf("unable to get secret value, %w", err)
|
|
}
|
|
|
|
return *resp.SecretString, nil
|
|
}
|