通过 postgres 的数据库行级别的锁机制全局控制所有的任务只在某个节点上执行

package main

import (
	"context"
	"errors"
	"log"
	"time"

	"github.com/jackc/pgconn"
	"golang.org/x/sync/errgroup"
	"gorm.io/driver/postgres"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"
	"gorm.io/gorm/logger"
)

const (
	globalJobLockGetPollingDuration = time.Second // 全局任务锁轮询时间

	ServiceDeploymentNameGlobalJobLock = "global-job-lock" // 全局任务锁
)

type ServiceDeployment struct {
	Name string `gorm:"primarykey"` // 服务名称

	CreatedAt int64 `gorm:"autoCreateTime:milli"` // 记录创建时间,单位 milli
	UpdatedAt int64 `gorm:"autoUpdateTime:milli"` // 记录更新时间,单位 milli
}

func main() {
	dsn := "host=localhost user=postgres password=secret dbname=test port=5432 sslmode=disable"
	db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
	if err != nil {
		log.Fatal(err)
	}
	if err := db.AutoMigrate(&ServiceDeployment{}); err != nil {
		log.Fatal(err)
	}

	var jobEnable bool

	eg, ctx := errgroup.WithContext(context.Background())
	eg.Go(func() error {
		ticker := time.NewTicker(globalJobLockGetPollingDuration)
		defer ticker.Stop()

		for {
			exist, err := GetGlobalJobLock(db)
			if err != nil {
				log.Println(err)
			}
			if exist {
				jobEnable = true
				break
			}

			select {
			case <-ctx.Done():
				return ctx.Err()
			case <-ticker.C:
			}
		}
		return nil
	})

	eg.Go(func() error {
		if err := stopTickerUntil(ctx, globalJobLockGetPollingDuration, func() bool {
			return jobEnable
		}); err != nil {
			return err
		}
		return RunJob(ctx)
	})

	if err := eg.Wait(); err != nil {
		log.Println(err)
	}
}

func FindOne(db *gorm.DB, out interface{}) (bool, error) {
	result := db.First(out)
	if err := result.Error; err != nil {
		if errors.Is(err, gorm.ErrRecordNotFound) {
			return false, nil
		}
		return false, err
	}
	return true, nil
}

func GetGlobalJobLock(db *gorm.DB) (bool, error) {
	deployment := ServiceDeployment{
		Name: ServiceDeploymentNameGlobalJobLock,
	}
	if err := db.FirstOrCreate(&deployment, ServiceDeployment{Name: ServiceDeploymentNameGlobalJobLock}).Error; err != nil {
		return false, err
	}

	tx := db.Session(&gorm.Session{Logger: logger.Discard}).Begin() // 不打印错误日志,拿到锁会占用一个数据库连接
	exist, err := FindOne(tx.Clauses(clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}).Where("name = ?", ServiceDeploymentNameGlobalJobLock), &deployment)
	if err != nil {
		if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "55P03" {
			err = nil
		}
		_ = tx.Rollback()
	}
	return exist, err
}

func stopTickerUntil(ctx context.Context, duration time.Duration, stopFunc func() bool) error {
	ticker := time.NewTicker(duration)
	defer ticker.Stop()

Loop:
	for {
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-ticker.C:
			if stopFunc() {
				break Loop
			}
		}
	}
	return nil
}

func RunJob(ctx context.Context) error {
	// run job
	return nil
}