diff --git a/apps/agent/internal/pipeline/runner.go b/apps/agent/internal/pipeline/runner.go index 86c1615..0db0b24 100644 --- a/apps/agent/internal/pipeline/runner.go +++ b/apps/agent/internal/pipeline/runner.go @@ -138,16 +138,23 @@ func (r *Runner) Run(ctx context.Context, req *backupv1.RunBackup) (completed *b return nil, fmt.Errorf("pipeline: no driver registered for db_type=%s", driverKey) } - // Unwrap the DEK once. The plaintext DEK never leaves this function. - dek, err := r.dekResolver.Unwrap(ctx, req.EncryptedDek) - if err != nil { - return nil, fmt.Errorf("pipeline: unwrap DEK: %w", err) - } - defer wipe(dek) + // Resolve the encryption stage. Jobs with encryption_enabled=false + // arrive with EncryptedDek=nil; in that case we wire the compressed + // stream straight to the uploader without ever materialising a + // plaintext DEK or instantiating an encryptor. + encryptEnabled := len(req.EncryptedDek) > 0 + var encryptor *Encryptor + if encryptEnabled { + dek, err := r.dekResolver.Unwrap(ctx, req.EncryptedDek) + if err != nil { + return nil, fmt.Errorf("pipeline: unwrap DEK: %w", err) + } + defer wipe(dek) - encryptor, err := NewEncryptor(dek) - if err != nil { - return nil, fmt.Errorf("pipeline: build encryptor: %w", err) + encryptor, err = NewEncryptor(dek) + if err != nil { + return nil, fmt.Errorf("pipeline: build encryptor: %w", err) + } } // Smoke-validate the driver before we burn upload time on a dead db. @@ -291,9 +298,20 @@ func (r *Runner) Run(ctx context.Context, req *backupv1.RunBackup) (completed *b errs <- nil }() - // Stage 3 — encrypt. + // Stage 3 — encrypt (skipped when the job has encryption disabled; + // in that case the compressed bytes are passed through unchanged). go func() { defer encryptedPW.Close() + if encryptor == nil { + if _, err := io.Copy(encryptedPW, compressedPR); err != nil { + _ = encryptedPW.CloseWithError(err) + _ = compressedPR.CloseWithError(err) + errs <- fmt.Errorf("encrypt: passthrough copy: %w", err) + return + } + errs <- nil + return + } if _, err := encryptor.Stream(compressedPR, encryptedPW); err != nil { _ = encryptedPW.CloseWithError(err) _ = compressedPR.CloseWithError(err) diff --git a/apps/agent/internal/pipeline/runner_test.go b/apps/agent/internal/pipeline/runner_test.go index 2724fd9..0459785 100644 --- a/apps/agent/internal/pipeline/runner_test.go +++ b/apps/agent/internal/pipeline/runner_test.go @@ -263,3 +263,44 @@ func TestRunner_DEKWrongLength(t *testing.T) { _, err := runner.Run(context.Background(), req) require.Error(t, err) } + +// TestRunner_HappyPath_EncryptionDisabled verifies that a RunBackup +// arriving without a DEK (encryption_enabled=false on the job) skips +// the encrypt stage entirely and uploads the compressed bytes as-is. +func TestRunner_HappyPath_EncryptionDisabled(t *testing.T) { + plaintext := append([]byte(PgDumpMagic), make([]byte, 1<<10)...) + _, err := rand.Read(plaintext[len(PgDumpMagic):]) + require.NoError(t, err) + + driver := &fakeDriver{name: "pg_dump", payload: plaintext, version: "PostgreSQL 16.2"} + job := &backupv1.BackupJobSpec{Id: "j", TargetId: "t"} + target := &backupv1.Target{Id: "t", Type: backupv1.DbType_POSTGRESQL, Connection: &backupv1.ConnectionConfig{Host: "x"}} + lookups := &simpleLookups{job: job, target: target} + + var received bytes.Buffer + srv := startFakeS3(t, &received) + defer srv.Close() + + runner := NewRunner( + map[string]Driver{"postgresql": driver}, + NewUploaderWithClient(srv.Client()), + WithTargetLookup(lookups), + WithJobLookup(lookups), + ) + req := &backupv1.RunBackup{ + JobId: "j", RunId: "r", + // No EncryptedDek — encryption disabled. + UploadCreds: &backupv1.S3UploadCreds{PresignedPutUrl: srv.URL + "/r.enc", FinalS3Key: "k"}, + } + completed, err := runner.Run(context.Background(), req) + require.NoError(t, err) + require.Empty(t, completed.EncryptedDek, "no DEK should be reported back when encryption is disabled") + + // The uploaded blob is the raw zstd stream — decompress directly. + zr, err := zstd.NewReader(&received) + require.NoError(t, err) + defer zr.Close() + round, err := io.ReadAll(zr) + require.NoError(t, err) + require.Equal(t, plaintext, round) +}