diff --git a/internal/stackitprovider/helper.go b/internal/stackitprovider/helper.go index 99bc7fb..a824463 100644 --- a/internal/stackitprovider/helper.go +++ b/internal/stackitprovider/helper.go @@ -7,6 +7,7 @@ import ( stackitdnsclient "github.com/stackitcloud/stackit-sdk-go/services/dns/v1api" "go.uber.org/zap" "sigs.k8s.io/external-dns/endpoint" + "sigs.k8s.io/external-dns/provider" ) // findBestMatchingZone finds the best matching zone for a given record set name. The criteria are @@ -49,18 +50,16 @@ func findRRSet( return nil, false } -// appendDotIfNotExists appends a dot to the end of a string if it doesn't already end with a dot. -func appendDotIfNotExists(s string) string { - if !strings.HasSuffix(s, ".") { - return s + "." - } - - return s -} - // modifyChange modifies a change to ensure it is valid for this stackitprovider. func modifyChange(change *endpoint.Endpoint) { - change.DNSName = appendDotIfNotExists(change.DNSName) + change.DNSName = provider.EnsureTrailingDot(change.DNSName) + + switch change.RecordType { + case endpoint.RecordTypeCNAME, endpoint.RecordTypeMX, endpoint.RecordTypeSRV, endpoint.RecordTypeNS: + for i := range change.Targets { + change.Targets[i] = provider.EnsureTrailingDot(change.Targets[i]) + } + } if change.RecordTTL == 0 { change.RecordTTL = 300 @@ -73,7 +72,7 @@ func getStackitRecordSetPayload(change *endpoint.Endpoint) stackitdnsclient.Crea for i := range change.Targets { content := change.Targets[i] - if change.RecordType == txtRecord { + if change.RecordType == endpoint.RecordTypeTXT { content = formatTXTContent(content) } @@ -96,7 +95,7 @@ func getStackitPartialUpdateRecordSetPayload(change *endpoint.Endpoint) stackitd for i := range change.Targets { content := change.Targets[i] - if change.RecordType == txtRecord { + if change.RecordType == endpoint.RecordTypeTXT { content = formatTXTContent(content) } diff --git a/internal/stackitprovider/helper_test.go b/internal/stackitprovider/helper_test.go index f23672b..93c2f5d 100644 --- a/internal/stackitprovider/helper_test.go +++ b/internal/stackitprovider/helper_test.go @@ -10,29 +10,6 @@ import ( "sigs.k8s.io/external-dns/endpoint" ) -func TestAppendDotIfNotExists(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - s string - want string - }{ - {"No dot at end", "test", "test."}, - {"Dot at end", "test.", "test."}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - if got := appendDotIfNotExists(tt.s); got != tt.want { - t.Errorf("appendDotIfNotExists() = %v, want %v", got, tt.want) - } - }) - } -} - func TestModifyChange(t *testing.T) { t.Parallel() @@ -58,6 +35,43 @@ func TestModifyChange(t *testing.T) { if endpointWithoutTTL.RecordTTL != 300 { t.Errorf("modifyChange() did not set default RecordTTL = %v, want 300", endpointWithoutTTL.RecordTTL) } + + // Hostname-valued record types get a trailing dot appended to every target, + // leaving already-dotted targets untouched. + dottedTargetTests := []struct { + recordType string + targets endpoint.Targets + want endpoint.Targets + }{ + {"CNAME", endpoint.Targets{"foo.mydomain.com", "bar.mydomain.com."}, endpoint.Targets{"foo.mydomain.com.", "bar.mydomain.com."}}, + {"MX", endpoint.Targets{"10 mail.mydomain.com"}, endpoint.Targets{"10 mail.mydomain.com."}}, + {"SRV", endpoint.Targets{"0 5 5060 sip.mydomain.com"}, endpoint.Targets{"0 5 5060 sip.mydomain.com."}}, + {"NS", endpoint.Targets{"ns1.mydomain.com", "ns2.mydomain.com."}, endpoint.Targets{"ns1.mydomain.com.", "ns2.mydomain.com."}}, + } + for _, tt := range dottedTargetTests { + ep := &endpoint.Endpoint{ + DNSName: "test.mydomain.com", + RecordType: tt.recordType, + Targets: tt.targets, + } + modifyChange(ep) + for i := range tt.want { + if ep.Targets[i] != tt.want[i] { + t.Errorf("modifyChange() %s target[%d] = %v, want %v", tt.recordType, i, ep.Targets[i], tt.want[i]) + } + } + } + + // A records keep their IP targets untouched. + nonHostnameEndpoint := &endpoint.Endpoint{ + DNSName: "test.mydomain.com", + RecordType: "A", + Targets: endpoint.Targets{"1.2.3.4"}, + } + modifyChange(nonHostnameEndpoint) + if nonHostnameEndpoint.Targets[0] != "1.2.3.4" { + t.Errorf("modifyChange() changed A target = %v, want 1.2.3.4", nonHostnameEndpoint.Targets[0]) + } } func TestGetStackitRRSetRecordPost(t *testing.T) { diff --git a/internal/stackitprovider/records.go b/internal/stackitprovider/records.go index 51d640c..ba0a4c5 100644 --- a/internal/stackitprovider/records.go +++ b/internal/stackitprovider/records.go @@ -8,8 +8,6 @@ import ( "sigs.k8s.io/external-dns/provider" ) -const txtRecord = "TXT" - // Records returns resource records. func (d *StackitDNSProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { zones, err := d.zoneFetcherClient.zones(ctx) @@ -117,7 +115,7 @@ func endpointsFromRecords(name, recordType string, ttl endpoint.TTL, records []s rec := &records[i] content := rec.Content - if recordType == txtRecord { + if recordType == endpoint.RecordTypeTXT { content = unformatTXTContent(content) }