@@ -12,29 +12,32 @@ import (
 
 	"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
 	"gitlab.com/gitlab-org/gitaly/v15/internal/command"
+	"gitlab.com/gitlab-org/gitaly/v15/internal/git/repository"
 	"gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/service"
 	"gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/transaction"
 	"gitlab.com/gitlab-org/gitaly/v15/internal/featureflag"
 	"gitlab.com/gitlab-org/gitaly/v15/internal/safe"
 	"gitlab.com/gitlab-org/gitaly/v15/internal/structerr"
+	"gitlab.com/gitlab-org/gitaly/v15/internal/tempdir"
 	"gitlab.com/gitlab-org/gitaly/v15/internal/transaction/txinfo"
 	"gitlab.com/gitlab-org/gitaly/v15/internal/transaction/voting"
 	"gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb"
 	"gitlab.com/gitlab-org/gitaly/v15/streamio"
 )
 
+// RestoreCustomHooks sets the git hooks for a repository. The hooks are sent in
+// a tar archive containing a `custom_hooks` directory. This directory is
+// ultimately extracted to the repository.
 func (s *server) RestoreCustomHooks(stream gitalypb.RepositoryService_RestoreCustomHooksServer) error {
-	if featureflag.TransactionalRestoreCustomHooks.IsEnabled(stream.Context()) {
-		return s.restoreCustomHooksWithVoting(stream)
-	}
+	ctx := stream.Context()
 
 	firstRequest, err := stream.Recv()
 	if err != nil {
 		return structerr.NewInternal("first request failed %w", err)
 	}
 
-	repository := firstRequest.GetRepository()
-	if err := service.ValidateRepository(repository); err != nil {
+	repo := firstRequest.GetRepository()
+	if err := service.ValidateRepository(repo); err != nil {
 		return structerr.NewInvalidArgument("%w", err)
 	}
 
@@ -49,127 +52,138 @@ func (s *server) RestoreCustomHooks(stream gitalypb.RepositoryService_RestoreCus
 		return request.GetData(), err
 	})
 
-	repoPath, err := s.locator.GetPath(repository)
-	if err != nil {
-		return structerr.NewInternal("getting repo path failed %w", err)
-	}
+	if featureflag.TransactionalRestoreCustomHooks.IsEnabled(ctx) {
+		if err := s.restoreCustomHooks(ctx, reader, repo); err != nil {
+			return structerr.NewInternal("setting custom hooks: %w", err)
+		}
 
-	cmdArgs := []string{
-		"-xf",
-		"-",
-		"-C",
-		repoPath,
-		customHooksDir,
+		return stream.SendAndClose(&gitalypb.RestoreCustomHooksResponse{})
 	}
 
-	ctx := stream.Context()
-	cmd, err := command.New(ctx, append([]string{"tar"}, cmdArgs...), command.WithStdin(reader))
+	repoPath, err := s.locator.GetPath(repo)
 	if err != nil {
-		return structerr.NewInternal("Could not untar custom hooks tar %w", err)
+		return structerr.NewInternal("getting repo path failed %w", err)
 	}
 
-	if err := cmd.Wait(); err != nil {
-		return structerr.NewInternal("cmd wait failed: %w", err)
+	if err := extractHooks(ctx, reader, repoPath); err != nil {
+		return structerr.NewInternal("extracting hooks: %w", err)
 	}
 
 	return stream.SendAndClose(&gitalypb.RestoreCustomHooksResponse{})
 }
 
-func (s *server) restoreCustomHooksWithVoting(stream gitalypb.RepositoryService_RestoreCustomHooksServer) error {
-	firstRequest, err := stream.Recv()
+// restoreCustomHooks transactionally and atomically sets the provided custom
+// hooks for the specified repository.
+func (s *server) restoreCustomHooks(ctx context.Context, tar io.Reader, repo repository.GitRepo) (returnedErr error) {
+	repoPath, err := s.locator.GetRepoPath(repo)
 	if err != nil {
-		return structerr.NewInternal("first request failed %w", err)
+		return fmt.Errorf("getting repo path: %w", err)
 	}
 
-	ctx := stream.Context()
+	// The `custom_hooks` directory in the repository is locked to prevent
+	// concurrent modification of hooks.
+	hooksLock, err := safe.NewLockingDirectory(repoPath, customHooksDir)
+	if err != nil {
+		return fmt.Errorf("creating hooks lock: %w", err)
+	}
 
-	repository := firstRequest.GetRepository()
-	if err := service.ValidateRepository(repository); err != nil {
-		return structerr.NewInvalidArgument("%w", err)
+	if err := hooksLock.Lock(); err != nil {
+		return fmt.Errorf("locking hooks: %w", err)
 	}
+	defer func() {
+		// If the `.lock` file is not removed from the `custom_hooks` directory,
+		// future modifications to the repository's hooks will be prevented. If
+		// this occurs, the `.lock` file will have to be manually removed.
+		if err := hooksLock.Unlock(); err != nil {
+			ctxlogrus.Extract(ctx).WithError(err).Warn("failed to unlock hooks")
+		}
+	}()
 
-	repoPath, err := s.locator.GetRepoPath(repository)
+	// Create a temporary directory to write the new hooks to and also
+	// temporarily store the current repository hooks. This enables "atomic"
+	// directory swapping by acting as an intermediary storage location between
+	// moves.
+	tmpDir, err := tempdir.NewWithPrefix(ctx, repo.GetStorageName(), "hooks-", s.locator)
 	if err != nil {
-		return structerr.NewInternal("RestoreCustomHooks: getting repo path failed %w", err)
+		return fmt.Errorf("creating temp directory: %w", err)
+	}
+
+	if err := extractHooks(ctx, tar, tmpDir.Path()); err != nil {
+		return fmt.Errorf("extracting hooks: %w", err)
 	}
 
-	customHooksPath := filepath.Join(repoPath, customHooksDir)
+	tempHooksPath := filepath.Join(tmpDir.Path(), customHooksDir)
 
-	if err = os.MkdirAll(customHooksPath, os.ModePerm); err != nil {
-		return structerr.NewInternal("making custom hooks directory %w", err)
+	// No hooks will be extracted if the tar archive is empty. If this happens
+	// it means the repository should be set with an empty `custom_hooks`
+	// directory. Create `custom_hooks` in the temporary directory so that any
+	// existing repository hooks will be replaced with this empty directory.
+	if err := os.Mkdir(tempHooksPath, os.ModePerm); err != nil && !errors.Is(err, fs.ErrExist) {
+		return fmt.Errorf("making temp hooks directory: %w", err)
 	}
 
-	lockDir, err := safe.NewLockingDirectory(repoPath, customHooksDir)
+	preparedVote, err := newDirectoryVote(tempHooksPath)
 	if err != nil {
-		return structerr.NewInternal("RestoreCustomHooks: creating locking directory: %w", err)
+		return fmt.Errorf("generating prepared vote: %w", err)
 	}
 
-	if err := lockDir.Lock(); err != nil {
-		return structerr.NewInternal("locking directory failed: %w", err)
+	// Cast prepared vote with hash of the extracted archive in the temporary
+	// `custom_hooks` directory.
+	if err := voteCustomHooks(ctx, s.txManager, preparedVote, voting.Prepared); err != nil {
+		return fmt.Errorf("casting prepared vote: %w", err)
 	}
 
+	repoHooksPath := filepath.Join(repoPath, customHooksDir)
+	prevHooksPath := filepath.Join(tmpDir.Path(), "previous_hooks")
+
+	// If the `custom_hooks` directory exists in the repository, move the
+	// current hooks to `previous_hooks` in the temporary directory.
+	if err := os.Rename(repoHooksPath, prevHooksPath); err != nil && !errors.Is(err, fs.ErrNotExist) {
+		return fmt.Errorf("moving current hooks to temp: %w", err)
+	}
+
+	// If an error is returned after this point, the previous hooks need to be
+	// restored. The repository will be left in an altered state if this fails.
 	defer func() {
-		if !lockDir.IsLocked() {
+		if returnedErr == nil {
 			return
 		}
 
-		if err := lockDir.Unlock(); err != nil {
-			ctxlogrus.Extract(ctx).WithError(err).Warn("could not unlock directory")
+		if err := os.RemoveAll(repoHooksPath); err != nil {
+			ctxlogrus.Extract(ctx).WithError(err).Warn("failed reverting to previous hooks")
+			return
 		}
-	}()
-
-	preparedVote := voting.NewVoteHash()
-	if err := voteCustomHooks(ctx, s.txManager, &preparedVote, voting.Prepared); err != nil {
-		return structerr.NewInternal("casting prepared vote: %w", err)
-	}
 
-	reader := streamio.NewReader(func() ([]byte, error) {
-		if firstRequest != nil {
-			data := firstRequest.GetData()
-			firstRequest = nil
-			return data, nil
+		// If the `previous_hooks` directory does not exist, then there are no
+		// hooks to roll back to.
+		if err := os.Rename(prevHooksPath, repoHooksPath); err != nil && !errors.Is(err, fs.ErrNotExist) {
+			ctxlogrus.Extract(ctx).WithError(err).Warn("failed reverting to previous hooks")
 		}
+	}()
 
-		request, err := stream.Recv()
-		return request.GetData(), err
-	})
-
-	cmdArgs := []string{
-		"-xf",
-		"-",
-		"-C",
-		repoPath,
-		customHooksDir,
-	}
-
-	cmd, err := command.New(ctx, append([]string{"tar"}, cmdArgs...), command.WithStdin(reader))
-	if err != nil {
-		return structerr.NewInternal("Could not untar custom hooks tar %w", err)
-	}
-
-	if err := cmd.Wait(); err != nil {
-		return structerr.NewInternal("cmd wait failed: %w", err)
+	// Move `custom_hooks` from the temporary directory to the repository.
+	if err := os.Rename(tempHooksPath, repoHooksPath); err != nil {
+		return fmt.Errorf("moving new hooks to repo: %w", err)
 	}
 
-	committedVote, err := newDirectoryVote(customHooksPath)
+	committedVote, err := newDirectoryVote(repoHooksPath)
 	if err != nil {
-		return structerr.NewInternal("generating committed vote: %w", err)
+		return fmt.Errorf("generating committed vote: %w", err)
 	}
 
+	// Cast committed vote with hash of the extracted archive in the repository
+	// `custom_hooks` directory.
 	if err := voteCustomHooks(ctx, s.txManager, committedVote, voting.Committed); err != nil {
-		return structerr.NewInternal("casting committed vote: %w", err)
+		return fmt.Errorf("casting committed vote: %w", err)
 	}
 
-	if err := lockDir.Unlock(); err != nil {
-		return structerr.NewInternal("committing lock dir %w", err)
-	}
-
-	return stream.SendAndClose(&gitalypb.RestoreCustomHooksResponse{})
+	return nil
 }
 
 // newDirectoryVote creates a voting.VoteHash by walking the specified path and
 // generating a hash based on file name, permissions, and data.
 func newDirectoryVote(basePath string) (*voting.VoteHash, error) {
+	parentDir := filepath.Dir(basePath)
 	voteHash := voting.NewVoteHash()
 
 	if err := filepath.WalkDir(basePath, func(path string, entry fs.DirEntry, err error) error {
@@ -177,11 +191,13 @@ func newDirectoryVote(basePath string) (*voting.VoteHash, error) {
 			return err
 		}
 
-		// Write file name to hash. Since `WalkDir()` output is deterministic
-		// based on lexical order, the path does not need to be included with
-		// the name written to the hash. Any change to the entry's path will
-		// result in a different hash due to the change in walked order.
-		_, _ = voteHash.Write([]byte(entry.Name()))
+		relPath, err := filepath.Rel(parentDir, path)
+		if err != nil {
+			return fmt.Errorf("getting relative path: %w", err)
+		}
+
+		// Write file relative path to hash.
+		_, _ = voteHash.Write([]byte(relPath))
 
 		info, err := entry.Info()
 		if err != nil {
@@ -216,6 +232,8 @@ func newDirectoryVote(basePath string) (*voting.VoteHash, error) {
 	return &voteHash, nil
 }
 
+// voteCustomHooks casts a vote symbolic of the custom hooks received. If there
+// is no transaction voting is skipped.
 func voteCustomHooks(
 	ctx context.Context,
 	txManager transaction.Manager,
@@ -240,3 +258,26 @@ func voteCustomHooks(
 
 	return nil
 }
+
+// extractHooks unpacks a tar file containing custom hooks into a `custom_hooks`
+// directory at the specified path.
+func extractHooks(ctx context.Context, reader io.Reader, path string) error {
+	cmdArgs := []string{
+		"-xf",
+		"-",
+		"-C",
+		path,
+		customHooksDir,
+	}
+
+	cmd, err := command.New(ctx, append([]string{"tar"}, cmdArgs...), command.WithStdin(reader))
+	if err != nil {
+		return fmt.Errorf("executing tar command: %w", err)
+	}
+
+	if err := cmd.Wait(); err != nil {
+		return fmt.Errorf("waiting for tar command completion: %w", err)
+	}
+
+	return nil
+}
