diff --git a/app/controlplane/cmd/main.go b/app/controlplane/cmd/main.go index aad2e6eb0..3a369b83e 100644 --- a/app/controlplane/cmd/main.go +++ b/app/controlplane/cmd/main.go @@ -17,7 +17,6 @@ package main import ( "context" - "fmt" "math/rand" _ "net/http/pprof" "os" @@ -25,7 +24,6 @@ import ( "buf.build/go/protovalidate" "github.com/getsentry/sentry-go" - "github.com/nats-io/nats.go" flag "github.com/spf13/pflag" conf "github.com/chainloop-dev/chainloop/app/controlplane/internal/conf/controlplane/config/v1" @@ -35,6 +33,7 @@ import ( "github.com/chainloop-dev/chainloop/app/controlplane/plugins/sdk/v1" "github.com/chainloop-dev/chainloop/pkg/credentials" "github.com/chainloop-dev/chainloop/pkg/credentials/manager" + "github.com/chainloop-dev/chainloop/pkg/natsconn" "github.com/chainloop-dev/chainloop/pkg/servicelogger" "github.com/go-kratos/kratos/v2" @@ -145,7 +144,7 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - app, cleanup, err := wireApp(&bc, credsWriter, logger, availablePlugins) + app, cleanup, err := wireApp(ctx, &bc, credsWriter, logger, availablePlugins) if err != nil { panic(err) } @@ -215,29 +214,23 @@ type app struct { apiTokenStaleRevoker *biz.APITokenStaleRevoker } -// Connection to nats is optional, if not configured, pubsub will be disabled -func newNatsConnection(c *conf.Bootstrap_NatsServer) (*nats.Conn, error) { +// newNatsConfig converts the proto config to a plain natsconn.Config. +func newNatsConfig(c *conf.Bootstrap_NatsServer) *natsconn.Config { uri := c.GetUri() if uri == "" { - return nil, nil + return nil } - var opts []nats.Option - if c.GetAuthentication() != nil { - switch c.GetAuthentication().(type) { - case *conf.Bootstrap_NatsServer_Token: - opts = append(opts, nats.Token(c.GetToken())) - default: - return nil, fmt.Errorf("unsupported nats authentication type: %T", c.GetAuthentication()) - } + cfg := &natsconn.Config{ + URI: uri, + Name: "chainloop-controlplane", } - nc, err := nats.Connect(uri, opts...) - if err != nil { - return nil, fmt.Errorf("failed to connect to nats: %w", err) + if c.GetToken() != "" { + cfg.Token = c.GetToken() } - return nc, nil + return cfg } func filterSensitiveArgs(_ log.Level, keyvals ...interface{}) bool { diff --git a/app/controlplane/cmd/wire.go b/app/controlplane/cmd/wire.go index f72db8db9..cada02b2b 100644 --- a/app/controlplane/cmd/wire.go +++ b/app/controlplane/cmd/wire.go @@ -1,5 +1,5 @@ // -// Copyright 2024-2025 The Chainloop Authors. +// Copyright 2024-2026 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ package main import ( + "context" "time" conf "github.com/chainloop-dev/chainloop/app/controlplane/internal/conf/controlplane/config/v1" @@ -38,13 +39,13 @@ import ( "github.com/chainloop-dev/chainloop/pkg/blobmanager/loader" "github.com/chainloop-dev/chainloop/pkg/cache" "github.com/chainloop-dev/chainloop/pkg/credentials" + "github.com/chainloop-dev/chainloop/pkg/natsconn" "github.com/go-kratos/kratos/v2/log" "github.com/golang-jwt/jwt/v4" "github.com/google/wire" - "github.com/nats-io/nats.go" ) -func wireApp(*conf.Bootstrap, credentials.ReaderWriter, log.Logger, sdk.AvailablePlugins) (*app, func(), error) { +func wireApp(context.Context, *conf.Bootstrap, credentials.ReaderWriter, log.Logger, sdk.AvailablePlugins) (*app, func(), error) { panic( wire.Build( wire.Bind(new(credentials.Reader), new(credentials.ReaderWriter)), @@ -65,7 +66,8 @@ func wireApp(*conf.Bootstrap, credentials.ReaderWriter, log.Logger, sdk.Availabl newProtoValidator, newDataConf, newPolicyProviderConfig, - newNatsConnection, + newNatsConfig, + natsconn.New, cacheProviderSet, auditor.NewAuditLogPublisher, newCASServerOptions, @@ -141,37 +143,40 @@ var cacheProviderSet = wire.NewSet( newPolicyEvalBundleCache, ) -func newClaimsCache(conn *nats.Conn, logger log.Logger) (cache.Cache[*jwt.MapClaims], error) { +func newClaimsCache(ctx context.Context, rc *natsconn.ReloadableConnection, logger log.Logger) (cache.Cache[*jwt.MapClaims], error) { l := log.NewHelper(logger) backend := "memory" opts := []cache.Option{cache.WithTTL(10 * time.Second), cache.WithLogger(&kratosLogAdapter{h: l}), cache.WithDescription("Cache for JWT claims")} - if conn != nil { + if rc != nil { backend = "nats" - opts = append(opts, cache.WithNATS(conn, "chainloop-jwt-claims")) + opts = append(opts, cache.WithNATS(rc.Conn, "chainloop-jwt-claims")) + opts = append(opts, cache.WithReconnect(rc.Subscribe(ctx))) } l.Infow("msg", "cache initialized", "bucket", "chainloop-jwt-claims", "backend", backend, "ttl", "10s") return cache.New[*jwt.MapClaims](opts...) } -func newMembershipsCache(conn *nats.Conn, logger log.Logger) (cache.Cache[*entities.Membership], error) { +func newMembershipsCache(ctx context.Context, rc *natsconn.ReloadableConnection, logger log.Logger) (cache.Cache[*entities.Membership], error) { l := log.NewHelper(logger) backend := "memory" opts := []cache.Option{cache.WithTTL(time.Second), cache.WithLogger(&kratosLogAdapter{h: l}), cache.WithDescription("Cache for org memberships")} - if conn != nil { + if rc != nil { backend = "nats" - opts = append(opts, cache.WithNATS(conn, "chainloop-memberships")) + opts = append(opts, cache.WithNATS(rc.Conn, "chainloop-memberships")) + opts = append(opts, cache.WithReconnect(rc.Subscribe(ctx))) } l.Infow("msg", "cache initialized", "bucket", "chainloop-memberships", "backend", backend, "ttl", "1s") return cache.New[*entities.Membership](opts...) } -func newPolicyEvalBundleCache(conn *nats.Conn, logger log.Logger) (cache.Cache[[]byte], error) { +func newPolicyEvalBundleCache(ctx context.Context, rc *natsconn.ReloadableConnection, logger log.Logger) (cache.Cache[[]byte], error) { l := log.NewHelper(logger) backend := "memory" opts := []cache.Option{cache.WithTTL(24 * time.Hour), cache.WithLogger(&kratosLogAdapter{h: l}), cache.WithDescription("Cache for policy evaluation bundles from CAS")} - if conn != nil { + if rc != nil { backend = "nats" - opts = append(opts, cache.WithNATS(conn, "chainloop-policy-eval-bundles")) + opts = append(opts, cache.WithNATS(rc.Conn, "chainloop-policy-eval-bundles")) + opts = append(opts, cache.WithReconnect(rc.Subscribe(ctx))) } l.Infow("msg", "cache initialized", "bucket", "chainloop-policy-eval-bundles", "backend", backend, "ttl", "24h") return cache.New[[]byte](opts...) diff --git a/app/controlplane/cmd/wire_gen.go b/app/controlplane/cmd/wire_gen.go index 8836763cc..271b72ce9 100644 --- a/app/controlplane/cmd/wire_gen.go +++ b/app/controlplane/cmd/wire_gen.go @@ -7,6 +7,7 @@ package main import ( + "context" "github.com/chainloop-dev/chainloop/app/controlplane/internal/conf/controlplane/config/v1" "github.com/chainloop-dev/chainloop/app/controlplane/internal/dispatcher" "github.com/chainloop-dev/chainloop/app/controlplane/internal/server" @@ -22,10 +23,10 @@ import ( "github.com/chainloop-dev/chainloop/pkg/blobmanager/loader" "github.com/chainloop-dev/chainloop/pkg/cache" "github.com/chainloop-dev/chainloop/pkg/credentials" + "github.com/chainloop-dev/chainloop/pkg/natsconn" "github.com/go-kratos/kratos/v2/log" "github.com/golang-jwt/jwt/v4" "github.com/google/wire" - "github.com/nats-io/nats.go" "time" ) @@ -35,7 +36,7 @@ import ( // Injectors from wire.go: -func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, logger log.Logger, availablePlugins sdk.AvailablePlugins) (*app, func(), error) { +func wireApp(contextContext context.Context, bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, logger log.Logger, availablePlugins sdk.AvailablePlugins) (*app, func(), error) { config := authzConfig() casbinEnforcer, err := authz.NewCasbinEnforcer(config) if err != nil { @@ -60,19 +61,22 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l bootstrap_CASServer := bootstrap.CasServer casServerDefaultOpts := newCASServerOptions(bootstrap_CASServer) bootstrap_NatsServer := bootstrap.NatsServer - conn, err := newNatsConnection(bootstrap_NatsServer) + natsconnConfig := newNatsConfig(bootstrap_NatsServer) + reloadableConnection, cleanup2, err := natsconn.New(natsconnConfig, logger) if err != nil { cleanup() return nil, nil, err } - auditLogPublisher, err := auditor.NewAuditLogPublisher(conn, logger) + auditLogPublisher, err := auditor.NewAuditLogPublisher(contextContext, reloadableConnection, logger) if err != nil { + cleanup2() cleanup() return nil, nil, err } auditorUseCase := biz.NewAuditorUseCase(auditLogPublisher, logger) casBackendUseCase, err := biz.NewCASBackendUseCase(casBackendRepo, readerWriter, providers, casServerDefaultOpts, auditorUseCase, logger) if err != nil { + cleanup2() cleanup() return nil, nil, err } @@ -107,6 +111,7 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l robotAccountUseCase := biz.NewRootAccountUseCase(robotAccountRepo, workflowRepo, auth, logger) casCredentialsUseCase, err := biz.NewCASCredentialsUseCase(auth) if err != nil { + cleanup2() cleanup() return nil, nil, err } @@ -116,17 +121,20 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l referrerSharedIndex := bootstrap.ReferrerSharedIndex referrerSharedIndexConfig, err := biz.NewIndexConfig(referrerSharedIndex) if err != nil { + cleanup2() cleanup() return nil, nil, err } referrerUseCase, err := biz.NewReferrerUseCase(referrerRepo, workflowRepo, membershipUseCase, referrerSharedIndexConfig, logger) if err != nil { + cleanup2() cleanup() return nil, nil, err } apiTokenJWTConfig := newJWTConfig(auth) apiTokenUseCase, err := biz.NewAPITokenUseCase(apiTokenRepo, apiTokenJWTConfig, authzUseCase, organizationUseCase, auditorUseCase, logger) if err != nil { + cleanup2() cleanup() return nil, nil, err } @@ -136,24 +144,28 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l v4 := newPolicyProviderConfig(v3) registry, err := policies.NewRegistry(logger, v4...) if err != nil { + cleanup2() cleanup() return nil, nil, err } workflowContractUseCase := biz.NewWorkflowContractUseCase(workflowContractRepo, registry, auditorUseCase, logger) workflowUseCase := biz.NewWorkflowUsecase(workflowRepo, projectsRepo, workflowContractUseCase, auditorUseCase, membershipUseCase, organizationRepo, logger) - cache, err := newMembershipsCache(conn, logger) + cache, err := newMembershipsCache(contextContext, reloadableConnection, logger) if err != nil { + cleanup2() cleanup() return nil, nil, err } - cacheCache, err := newClaimsCache(conn, logger) + cacheCache, err := newClaimsCache(contextContext, reloadableConnection, logger) if err != nil { + cleanup2() cleanup() return nil, nil, err } orgInvitationRepo := data.NewOrgInvitation(dataData, logger) orgInvitationUseCase, err := biz.NewOrgInvitationUseCase(orgInvitationRepo, membershipRepo, userRepo, auditorUseCase, groupRepo, projectsRepo, logger) if err != nil { + cleanup2() cleanup() return nil, nil, err } @@ -164,6 +176,7 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l confServer := bootstrap.Server authService, err := service.NewAuthService(userUseCase, organizationUseCase, membershipUseCase, orgInvitationUseCase, auth, confServer, auditorUseCase, v5...) if err != nil { + cleanup2() cleanup() return nil, nil, err } @@ -171,18 +184,21 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l workflowRunRepo := data.NewWorkflowRunRepo(dataData, logger) signingUseCase, err := biz.NewChainloopSigningUseCase(bootstrap, logger) if err != nil { + cleanup2() cleanup() return nil, nil, err } workflowRunUseCase, err := biz.NewWorkflowRunUseCase(workflowRunRepo, workflowRepo, signingUseCase, auditorUseCase, logger) if err != nil { + cleanup2() cleanup() return nil, nil, err } casMappingRepo := data.NewCASMappingRepo(dataData, casBackendRepo, logger) casMappingUseCase := biz.NewCASMappingUseCase(casMappingRepo, membershipUseCase, logger) - cache2, err := newPolicyEvalBundleCache(conn, logger) + cache2, err := newPolicyEvalBundleCache(contextContext, reloadableConnection, logger) if err != nil { + cleanup2() cleanup() return nil, nil, err } @@ -204,6 +220,7 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l orgMetricsRepo := data.NewOrgMetricsRepo(dataData, logger) orgMetricsUseCase, err := biz.NewOrgMetricsUseCase(orgMetricsRepo, organizationRepo, workflowUseCase, logger) if err != nil { + cleanup2() cleanup() return nil, nil, err } @@ -242,6 +259,7 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l casBackendService := service.NewCASBackendService(casBackendUseCase, providers, v5...) casRedirectService, err := service.NewCASRedirectService(casMappingUseCase, casCredentialsUseCase, bootstrap_CASServer, v5...) if err != nil { + cleanup2() cleanup() return nil, nil, err } @@ -251,6 +269,7 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l attestationStateRepo := data.NewAttestationStateRepo(dataData, logger) attestationStateUseCase, err := biz.NewAttestationStateUseCase(attestationStateRepo, workflowRunRepo) if err != nil { + cleanup2() cleanup() return nil, nil, err } @@ -269,6 +288,7 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l federatedAuthentication := bootstrap.FederatedAuthentication validator, err := newProtoValidator() if err != nil { + cleanup2() cleanup() return nil, nil, err } @@ -318,21 +338,25 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l } grpcServer, err := server.NewGRPCServer(opts) if err != nil { + cleanup2() cleanup() return nil, nil, err } httpServer, err := server.NewHTTPServer(opts, grpcServer) if err != nil { + cleanup2() cleanup() return nil, nil, err } httpMetricsServer, err := server.NewHTTPMetricsServer(opts) if err != nil { + cleanup2() cleanup() return nil, nil, err } httpProfilerServer, err := server.NewHTTPProfilerServer(opts) if err != nil { + cleanup2() cleanup() return nil, nil, err } @@ -341,6 +365,7 @@ func wireApp(bootstrap *conf.Bootstrap, readerWriter credentials.ReaderWriter, l apiTokenStaleRevoker := biz.NewAPITokenStaleRevoker(organizationRepo, apiTokenRepo, apiTokenUseCase, logger) mainApp := newApp(logger, grpcServer, httpServer, httpMetricsServer, httpProfilerServer, workflowRunExpirerUseCase, availablePlugins, userAccessSyncerUseCase, casBackendChecker, apiTokenStaleRevoker, bootstrap) return mainApp, func() { + cleanup2() cleanup() }, nil } @@ -409,37 +434,40 @@ var cacheProviderSet = wire.NewSet( newPolicyEvalBundleCache, ) -func newClaimsCache(conn *nats.Conn, logger log.Logger) (cache.Cache[*jwt.MapClaims], error) { +func newClaimsCache(ctx context.Context, rc *natsconn.ReloadableConnection, logger log.Logger) (cache.Cache[*jwt.MapClaims], error) { l := log.NewHelper(logger) backend := "memory" opts := []cache.Option{cache.WithTTL(10 * time.Second), cache.WithLogger(&kratosLogAdapter{h: l}), cache.WithDescription("Cache for JWT claims")} - if conn != nil { + if rc != nil { backend = "nats" - opts = append(opts, cache.WithNATS(conn, "chainloop-jwt-claims")) + opts = append(opts, cache.WithNATS(rc.Conn, "chainloop-jwt-claims")) + opts = append(opts, cache.WithReconnect(rc.Subscribe(ctx))) } l.Infow("msg", "cache initialized", "bucket", "chainloop-jwt-claims", "backend", backend, "ttl", "10s") return cache.New[*jwt.MapClaims](opts...) } -func newMembershipsCache(conn *nats.Conn, logger log.Logger) (cache.Cache[*entities.Membership], error) { +func newMembershipsCache(ctx context.Context, rc *natsconn.ReloadableConnection, logger log.Logger) (cache.Cache[*entities.Membership], error) { l := log.NewHelper(logger) backend := "memory" opts := []cache.Option{cache.WithTTL(time.Second), cache.WithLogger(&kratosLogAdapter{h: l}), cache.WithDescription("Cache for org memberships")} - if conn != nil { + if rc != nil { backend = "nats" - opts = append(opts, cache.WithNATS(conn, "chainloop-memberships")) + opts = append(opts, cache.WithNATS(rc.Conn, "chainloop-memberships")) + opts = append(opts, cache.WithReconnect(rc.Subscribe(ctx))) } l.Infow("msg", "cache initialized", "bucket", "chainloop-memberships", "backend", backend, "ttl", "1s") return cache.New[*entities.Membership](opts...) } -func newPolicyEvalBundleCache(conn *nats.Conn, logger log.Logger) (cache.Cache[[]byte], error) { +func newPolicyEvalBundleCache(ctx context.Context, rc *natsconn.ReloadableConnection, logger log.Logger) (cache.Cache[[]byte], error) { l := log.NewHelper(logger) backend := "memory" opts := []cache.Option{cache.WithTTL(24 * time.Hour), cache.WithLogger(&kratosLogAdapter{h: l}), cache.WithDescription("Cache for policy evaluation bundles from CAS")} - if conn != nil { + if rc != nil { backend = "nats" - opts = append(opts, cache.WithNATS(conn, "chainloop-policy-eval-bundles")) + opts = append(opts, cache.WithNATS(rc.Conn, "chainloop-policy-eval-bundles")) + opts = append(opts, cache.WithReconnect(rc.Subscribe(ctx))) } l.Infow("msg", "cache initialized", "bucket", "chainloop-policy-eval-bundles", "backend", backend, "ttl", "24h") return cache.New[[]byte](opts...) diff --git a/app/controlplane/pkg/auditor/nats.go b/app/controlplane/pkg/auditor/nats.go index b4568a226..67f776f76 100644 --- a/app/controlplane/pkg/auditor/nats.go +++ b/app/controlplane/pkg/auditor/nats.go @@ -1,5 +1,5 @@ // -// Copyright 2024 The Chainloop Authors. +// Copyright 2024-2026 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,8 +22,8 @@ import ( "strings" "time" + "github.com/chainloop-dev/chainloop/pkg/natsconn" "github.com/go-kratos/kratos/v2/log" - "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" ) @@ -37,20 +37,32 @@ const ( ) type AuditLogPublisher struct { - conn *nats.Conn + rc *natsconn.ReloadableConnection logger *log.Helper } -func NewAuditLogPublisher(conn *nats.Conn, logger log.Logger) (*AuditLogPublisher, error) { +func NewAuditLogPublisher(ctx context.Context, rc *natsconn.ReloadableConnection, logger log.Logger) (*AuditLogPublisher, error) { l := log.NewHelper(log.With(logger, "component", "natsAuditLogPublisher")) - if conn == nil { + if rc == nil { l.Infow("msg", "NATS connection not set, audit log publisher disabled") return nil, nil } - js, err := jetstream.New(conn) + p := &AuditLogPublisher{rc: rc, logger: l} + + if err := p.initJetStream(); err != nil { + return nil, err + } + + go p.watchReconnect(rc.Subscribe(ctx)) + + return p, nil +} + +func (p *AuditLogPublisher) initJetStream() error { + js, err := jetstream.New(p.rc.Conn) if err != nil { - return nil, fmt.Errorf("failed to create jetstream context: %w", err) + return fmt.Errorf("creating jetstream context: %w", err) } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -60,26 +72,33 @@ func NewAuditLogPublisher(conn *nats.Conn, logger log.Logger) (*AuditLogPublishe Name: streamName, Subjects: []string{subjectName}, }); err != nil { - return nil, fmt.Errorf("failed to create stream: %w", err) + return fmt.Errorf("creating stream: %w", err) } - l.Infow("msg", "Stream Created or Updated", "name", streamName, "subject", subjectName) + p.logger.Infow("msg", "stream created or updated", "name", streamName, "subject", subjectName) + + return nil +} - return &AuditLogPublisher{conn, l}, nil +func (p *AuditLogPublisher) watchReconnect(ch <-chan struct{}) { + for range ch { + p.logger.Infow("msg", "NATS reconnected, reinitializing JetStream") + if err := p.initJetStream(); err != nil { + p.logger.Errorw("msg", "failed to reinitialize JetStream after reconnect", "error", err) + } + } } -func (n *AuditLogPublisher) Publish(data *EventPayload) error { - // If the connection is nil, we don't want to publish anything - if n == nil || n.conn == nil { +func (p *AuditLogPublisher) Publish(data *EventPayload) error { + if p == nil || p.rc == nil { return nil } jsonPayload, err := json.Marshal(data) if err != nil { - return fmt.Errorf("failed to marshal event payload: %w", err) + return fmt.Errorf("marshaling event payload: %w", err) } - // Send the event to the specific subject based on the event type "audit.." specificSubject := fmt.Sprintf("%s.%s.%s", baseSubjectName, strings.ToLower(string(data.Data.TargetType)), strings.ToLower(data.Data.ActionType)) - return n.conn.Publish(specificSubject, jsonPayload) + return p.rc.Publish(specificSubject, jsonPayload) } diff --git a/app/controlplane/pkg/biz/testhelpers/database.go b/app/controlplane/pkg/biz/testhelpers/database.go index d2fab4382..e446f1d03 100644 --- a/app/controlplane/pkg/biz/testhelpers/database.go +++ b/app/controlplane/pkg/biz/testhelpers/database.go @@ -1,5 +1,5 @@ // -// Copyright 2024-2025 The Chainloop Authors. +// Copyright 2024-2026 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -146,7 +146,7 @@ func NewTestingUseCases(t *testing.T, opts ...NewTestingUCOpt) *TestingUseCases db := NewTestDatabase(t) log := log.NewStdLogger(io.Discard) - testData, _, err := WireTestData(db, t, log, newArgs.credsReaderWriter, &robotaccount.Builder{}, &conf.Auth{ + testData, _, err := WireTestData(t.Context(), db, t, log, newArgs.credsReaderWriter, &robotaccount.Builder{}, &conf.Auth{ GeneratedJwsHmacSecret: "test", CasRobotAccountPrivateKeyPath: "./testdata/test-key.ec.pem", }, &conf.Bootstrap{}, newArgs.onboardingConfiguration, newArgs.integrations, newArgs.providers) diff --git a/app/controlplane/pkg/biz/testhelpers/wire.go b/app/controlplane/pkg/biz/testhelpers/wire.go index 2d124363b..9f8168620 100644 --- a/app/controlplane/pkg/biz/testhelpers/wire.go +++ b/app/controlplane/pkg/biz/testhelpers/wire.go @@ -1,5 +1,5 @@ // -// Copyright 2024-2025 The Chainloop Authors. +// Copyright 2024-2026 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ package testhelpers import ( + "context" "testing" conf "github.com/chainloop-dev/chainloop/app/controlplane/internal/conf/controlplane/config/v1" @@ -35,13 +36,13 @@ import ( robotaccount "github.com/chainloop-dev/chainloop/internal/robotaccount/cas" backends "github.com/chainloop-dev/chainloop/pkg/blobmanager" "github.com/chainloop-dev/chainloop/pkg/credentials" + "github.com/chainloop-dev/chainloop/pkg/natsconn" "github.com/go-kratos/kratos/v2/log" "github.com/google/wire" - "github.com/nats-io/nats.go" ) // wireTestData init testing data -func WireTestData(*TestDatabase, *testing.T, log.Logger, credentials.ReaderWriter, *robotaccount.Builder, *conf.Auth, *conf.Bootstrap, []*config.OnboardingSpec, sdk.AvailablePlugins, backends.Providers) (*TestingUseCases, func(), error) { +func WireTestData(context.Context, *TestDatabase, *testing.T, log.Logger, credentials.ReaderWriter, *robotaccount.Builder, *conf.Auth, *conf.Bootstrap, []*config.OnboardingSpec, sdk.AvailablePlugins, backends.Providers) (*TestingUseCases, func(), error) { panic( wire.Build( data.ProviderSet, @@ -55,7 +56,7 @@ func WireTestData(*TestDatabase, *testing.T, log.Logger, credentials.ReaderWrite NewPolicyProviderConfig, policies.NewRegistry, authz.NewCasbinEnforcer, - newNatsConnection, + newNatsReloadableConnection, auditor.NewAuditLogPublisher, NewCASBackendConfig, NewCASServerOptions, @@ -87,9 +88,9 @@ func newJWTConfig(conf *conf.Auth) *biz.APITokenJWTConfig { } } -// Connection to nats is optional, if not configured, pubsub will be disabled -func newNatsConnection() (*nats.Conn, error) { - return nil, nil +// newNatsReloadableConnection returns nil in tests (NATS is not available). +func newNatsReloadableConnection() *natsconn.ReloadableConnection { + return nil } func newAuthAllowList(conf *conf.Bootstrap) *pkgConf.AllowList { diff --git a/app/controlplane/pkg/biz/testhelpers/wire_gen.go b/app/controlplane/pkg/biz/testhelpers/wire_gen.go index 997c5d4a9..09e4546d6 100644 --- a/app/controlplane/pkg/biz/testhelpers/wire_gen.go +++ b/app/controlplane/pkg/biz/testhelpers/wire_gen.go @@ -7,6 +7,7 @@ package testhelpers import ( + "context" "github.com/chainloop-dev/chainloop/app/controlplane/internal/conf/controlplane/config/v1" "github.com/chainloop-dev/chainloop/app/controlplane/pkg/auditor" "github.com/chainloop-dev/chainloop/app/controlplane/pkg/authz" @@ -18,8 +19,8 @@ import ( "github.com/chainloop-dev/chainloop/internal/robotaccount/cas" "github.com/chainloop-dev/chainloop/pkg/blobmanager" "github.com/chainloop-dev/chainloop/pkg/credentials" + "github.com/chainloop-dev/chainloop/pkg/natsconn" "github.com/go-kratos/kratos/v2/log" - "github.com/nats-io/nats.go" "testing" ) @@ -30,7 +31,7 @@ import ( // Injectors from wire.go: // wireTestData init testing data -func WireTestData(testDatabase *TestDatabase, t *testing.T, logger log.Logger, readerWriter credentials.ReaderWriter, builder *robotaccount.Builder, auth *conf.Auth, bootstrap *conf.Bootstrap, arg []*v1.OnboardingSpec, availablePlugins sdk.AvailablePlugins, providers backend.Providers) (*TestingUseCases, func(), error) { +func WireTestData(contextContext context.Context, testDatabase *TestDatabase, t *testing.T, logger log.Logger, readerWriter credentials.ReaderWriter, builder *robotaccount.Builder, auth *conf.Auth, bootstrap *conf.Bootstrap, arg []*v1.OnboardingSpec, availablePlugins sdk.AvailablePlugins, providers backend.Providers) (*TestingUseCases, func(), error) { confData := NewConfData(testDatabase, t) databaseConfig := NewDataConfig(confData) dataData, cleanup, err := data.NewData(databaseConfig, logger) @@ -43,12 +44,8 @@ func WireTestData(testDatabase *TestDatabase, t *testing.T, logger log.Logger, r casBackendRepo := data.NewCASBackendRepo(dataData, logger) bootstrap_CASServer := NewCASBackendConfig() casServerDefaultOpts := NewCASServerOptions(bootstrap_CASServer) - conn, err := newNatsConnection() - if err != nil { - cleanup() - return nil, nil, err - } - auditLogPublisher, err := auditor.NewAuditLogPublisher(conn, logger) + reloadableConnection := newNatsReloadableConnection() + auditLogPublisher, err := auditor.NewAuditLogPublisher(contextContext, reloadableConnection, logger) if err != nil { cleanup() return nil, nil, err @@ -229,9 +226,9 @@ func newJWTConfig(conf2 *conf.Auth) *biz.APITokenJWTConfig { } } -// Connection to nats is optional, if not configured, pubsub will be disabled -func newNatsConnection() (*nats.Conn, error) { - return nil, nil +// newNatsReloadableConnection returns nil in tests (NATS is not available). +func newNatsReloadableConnection() *natsconn.ReloadableConnection { + return nil } func newAuthAllowList(conf2 *conf.Bootstrap) *v1.AllowList { diff --git a/pkg/natsconn/natsconn.go b/pkg/natsconn/natsconn.go new file mode 100644 index 000000000..f030bfd2f --- /dev/null +++ b/pkg/natsconn/natsconn.go @@ -0,0 +1,150 @@ +// +// Copyright 2026 The Chainloop Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package natsconn + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/go-kratos/kratos/v2/log" + "github.com/nats-io/nats.go" +) + +// Config holds the connection parameters for NATS. +// Decoupled from protobuf config so this package can be imported externally. +type Config struct { + URI string + Token string + Name string +} + +// ReloadableConnection wraps a NATS connection and provides reconnection +// notifications via a pub/sub fan-out to subscribers. +type ReloadableConnection struct { + *nats.Conn + mu sync.RWMutex + subscribers []chan struct{} + logger *log.Helper +} + +// New creates a ReloadableConnection with automatic reconnection handling. +// Returns (nil, cleanup, nil) when cfg is nil or URI is empty (NATS is optional). +// The cleanup function drains the NATS connection on shutdown. +func New(cfg *Config, logger log.Logger) (*ReloadableConnection, func(), error) { + noop := func() {} + if cfg == nil || cfg.URI == "" { + return nil, noop, nil + } + + l := log.NewHelper(log.With(logger, "component", "natsconn")) + rc := &ReloadableConnection{logger: l} + + opts := []nats.Option{ + nats.MaxReconnects(-1), + nats.ReconnectWait(2 * time.Second), + nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { + l.Warnw("msg", "NATS disconnected", "error", err) + }), + nats.ReconnectHandler(func(nc *nats.Conn) { + l.Infow("msg", "NATS reconnected", "url", nc.ConnectedUrl()) + rc.Broadcast() + }), + } + + if cfg.Name != "" { + opts = append(opts, nats.Name(cfg.Name)) + } + + if cfg.Token != "" { + opts = append(opts, nats.Token(cfg.Token)) + } + + nc, err := nats.Connect(cfg.URI, opts...) + if err != nil { + return nil, noop, fmt.Errorf("connecting to NATS: %w", err) + } + + rc.Conn = nc + l.Infow("msg", "NATS connected", "url", nc.ConnectedUrl()) + + cleanup := func() { + l.Infow("msg", "draining NATS connection") + if err := nc.Drain(); err != nil { + l.Warnw("msg", "failed to drain NATS connection", "error", err) + } + } + + return rc, cleanup, nil +} + +// Subscribe registers for reconnection notifications. The returned channel +// receives a signal each time the NATS connection is re-established. +// The subscription is automatically removed when ctx is cancelled. +// Nil-receiver safe: returns a closed channel. +func (rc *ReloadableConnection) Subscribe(ctx context.Context) <-chan struct{} { + if rc == nil { + ch := make(chan struct{}) + close(ch) + return ch + } + + ch := make(chan struct{}, 1) + + rc.mu.Lock() + rc.subscribers = append(rc.subscribers, ch) + rc.mu.Unlock() + + go func() { + <-ctx.Done() + rc.unsubscribe(ch) + }() + + return ch +} + +func (rc *ReloadableConnection) unsubscribe(ch chan struct{}) { + rc.mu.Lock() + defer rc.mu.Unlock() + + for i, s := range rc.subscribers { + if s == ch { + rc.subscribers = append(rc.subscribers[:i], rc.subscribers[i+1:]...) + close(ch) + return + } + } +} + +// Broadcast notifies all subscribers of a reconnection event. +// Non-blocking: if a subscriber's channel is full, the signal is dropped. +// Nil-receiver safe. +func (rc *ReloadableConnection) Broadcast() { + if rc == nil { + return + } + + rc.mu.RLock() + defer rc.mu.RUnlock() + + for _, ch := range rc.subscribers { + select { + case ch <- struct{}{}: + default: + } + } +} diff --git a/pkg/natsconn/natsconn_test.go b/pkg/natsconn/natsconn_test.go new file mode 100644 index 000000000..1a26d80c9 --- /dev/null +++ b/pkg/natsconn/natsconn_test.go @@ -0,0 +1,178 @@ +// +// Copyright 2026 The Chainloop Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package natsconn + +import ( + "context" + "testing" + "time" + + "github.com/go-kratos/kratos/v2/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + tests := []struct { + name string + cfg *Config + wantNil bool + wantErr bool + }{ + { + name: "nil config returns nil", + cfg: nil, + wantNil: true, + }, + { + name: "empty URI returns nil", + cfg: &Config{URI: ""}, + wantNil: true, + }, + { + name: "invalid URI returns error", + cfg: &Config{URI: "nats://invalid:99999"}, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + rc, cleanup, err := New(tc.cfg, log.DefaultLogger) + if tc.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, cleanup) + defer cleanup() + if tc.wantNil { + assert.Nil(t, rc) + } + }) + } +} + +func TestSubscribeAndBroadcast(t *testing.T) { + // Create a ReloadableConnection without an actual NATS conn — + // Subscribe/Broadcast only manage channels, they don't use the conn. + rc := &ReloadableConnection{ + logger: log.NewHelper(log.DefaultLogger), + } + + ch := rc.Subscribe(t.Context()) + require.NotNil(t, ch) + + // Broadcast should send a signal + rc.Broadcast() + + select { + case <-ch: + // received signal — pass + case <-time.After(time.Second): + require.Fail(t, "expected reconnect signal, got timeout") + } +} + +func TestBroadcastMultipleSubscribers(t *testing.T) { + rc := &ReloadableConnection{ + logger: log.NewHelper(log.DefaultLogger), + } + + ch1 := rc.Subscribe(t.Context()) + ch2 := rc.Subscribe(t.Context()) + ch3 := rc.Subscribe(t.Context()) + + rc.Broadcast() + + for i, ch := range []<-chan struct{}{ch1, ch2, ch3} { + select { + case <-ch: + // received — pass + case <-time.After(time.Second): + require.Failf(t, "subscriber did not receive signal", "subscriber %d", i) + } + } +} + +func TestBroadcastNonBlocking(t *testing.T) { + rc := &ReloadableConnection{ + logger: log.NewHelper(log.DefaultLogger), + } + + ch := rc.Subscribe(t.Context()) + + // Fill the buffered channel + rc.Broadcast() + // Second broadcast should not block even though channel is full + rc.Broadcast() + + // Only one signal should be in the channel + select { + case <-ch: + case <-time.After(time.Second): + require.Fail(t, "expected signal") + } + + // Channel should be empty now + select { + case <-ch: + require.Fail(t, "expected no second signal in channel") + default: + // pass + } +} + +func TestSubscribeContextCancellation(t *testing.T) { + rc := &ReloadableConnection{ + logger: log.NewHelper(log.DefaultLogger), + } + + ctx, cancel := context.WithCancel(t.Context()) + ch := rc.Subscribe(ctx) + + // Cancel context — should unsubscribe and close channel + cancel() + + // Wait for the goroutine to process the cancellation + time.Sleep(50 * time.Millisecond) + + // Channel should be closed + select { + case _, ok := <-ch: + assert.False(t, ok, "channel should be closed after context cancellation") + case <-time.After(time.Second): + require.Fail(t, "channel was not closed after context cancellation") + } + + // Verify subscriber was removed + rc.mu.RLock() + assert.Empty(t, rc.subscribers) + rc.mu.RUnlock() +} + +func TestNilReceiverSafety(t *testing.T) { + var rc *ReloadableConnection + + // These should not panic + assert.NotPanics(t, func() { rc.Broadcast() }) + assert.NotPanics(t, func() { + ch := rc.Subscribe(context.Background()) + // nil receiver returns a closed channel + _, ok := <-ch + assert.False(t, ok) + }) +}