package context
import (
"context"
"github.com/aertje/semaphore/semaphore"
)
type prioritizedKey struct{}
var key = prioritizedKey{}
func PrioritizedFromContext(ctx context.Context) (*semaphore.Prioritized, bool) {
s, ok := ctx.Value(key).(*semaphore.Prioritized)
return s, ok
}
func WithPrioritized(ctx context.Context, s *semaphore.Prioritized) context.Context {
return context.WithValue(ctx, key, s)
}
package queue
import "container/heap"
type item[T any] struct {
value T
priority int
index int
}
type Q[T any] []*item[T]
// Len implements heap.Interface.
func (q Q[T]) Len() int {
return len(q)
}
// Less implements heap.Interface.
func (q Q[T]) Less(i, j int) bool {
return q[i].priority < q[j].priority
}
// Swap implements heap.Interface, do not use this method directly.
func (q Q[T]) Swap(i, j int) {
q[i], q[j] = q[j], q[i]
q[i].index = i
q[j].index = j
}
// Push implements heap.Interface, do not use this method directly.
func (q *Q[T]) Push(x any) {
n := len(*q)
item := x.(*item[T])
item.index = n
*q = append(*q, item)
}
// Pop implements heap.Interface, do not use this method directly.
func (q *Q[T]) Pop() any {
old := *q
n := len(old)
item := old[n-1]
old[n-1] = nil // don't stop the GC from reclaiming the item eventually
item.index = -1 // for safety
*q = old[0 : n-1]
return item
}
func (q *Q[T]) PushItem(priority int, value T) {
item := &item[T]{value: value, priority: priority}
heap.Push(q, item)
}
func (q *Q[T]) PopItem() T {
if q.Len() == 0 {
var zero T
return zero
}
return heap.Pop(q).(*item[T]).value
}
package semaphore
import (
"container/heap"
"context"
"runtime"
"sync"
"github.com/aertje/semaphore/queue"
)
type entry struct {
waitChan chan<- struct{}
cancelChan <-chan struct{}
}
type Prioritized struct {
maxConcurrency int
concurrency int
lock sync.Mutex
entries *queue.Q[entry]
}
type Option func(*Prioritized)
func WithMaxConcurrency(maxConcurrency int) Option {
return func(p *Prioritized) {
p.maxConcurrency = maxConcurrency
}
}
func NewPrioritized(opts ...Option) *Prioritized {
s := &Prioritized{
maxConcurrency: runtime.GOMAXPROCS(0),
entries: new(queue.Q[entry]),
}
heap.Init(s.entries)
for _, opt := range opts {
opt(s)
}
return s
}
func (s *Prioritized) assessEntries() {
s.lock.Lock()
defer s.lock.Unlock()
for {
if s.concurrency >= s.maxConcurrency {
return
}
if s.entries.Len() == 0 {
return
}
entry := s.entries.PopItem()
select {
case <-entry.cancelChan:
continue
default:
entry.waitChan <- struct{}{}
close(entry.waitChan)
s.concurrency++
}
}
}
func (s *Prioritized) acquireInternal(ctx context.Context, priority int, force bool) error {
waitChan := make(chan struct{})
cancelChan := make(chan struct{})
entry := entry{
waitChan: waitChan,
cancelChan: cancelChan,
}
s.lock.Lock()
s.entries.PushItem(priority, entry)
s.lock.Unlock()
go func() {
s.assessEntries()
}()
select {
case <-ctx.Done():
close(cancelChan)
return ctx.Err()
case <-waitChan:
return nil
}
}
func (s *Prioritized) AcquireContext(ctx context.Context, priority int) error {
return s.acquireInternal(ctx, priority, false)
}
func (s *Prioritized) Acquire(priority int) {
_ = s.acquireInternal(context.Background(), priority, false)
}
func (s *Prioritized) ForceAcquire() {
s.lock.Lock()
defer s.lock.Unlock()
s.concurrency++
}
func (s *Prioritized) Release() {
s.lock.Lock()
defer s.lock.Unlock()
s.concurrency--
go func() {
s.assessEntries()
}()
}