package logx import ( "context" "log/slog" "os" "time" ) type Logger struct { handlers []Handler ctxExtractor ContextExtractor } type Handler interface { slog.Handler Close() error } type Option func(*Logger) type ContextExtractor func(context.Context) []slog.Attr func New(opts ...Option) *Logger { l := &Logger{ ctxExtractor: defaultContextExtractor, } for _, opt := range opts { opt(l) } return l } func (l *Logger) WithContext(ctx context.Context) *Logger { newLogger := &Logger{ handlers: make([]Handler, len(l.handlers)), ctxExtractor: l.ctxExtractor, } attrs := l.ctxExtractor(ctx) for i, h := range l.handlers { newHandler := h.WithAttrs(attrs) if handlerWithClose, ok := newHandler.(Handler); ok { newLogger.handlers[i] = handlerWithClose } else { newLogger.handlers[i] = &handlerWrapper{newHandler} } } return newLogger } type handlerWrapper struct { slog.Handler } func (h *handlerWrapper) Close() error { return nil } func (l *Logger) Debug(msg string, args ...any) { l.log(slog.LevelDebug, msg, args...) } func (l *Logger) Info(msg string, args ...any) { l.log(slog.LevelInfo, msg, args...) } func (l *Logger) Warn(msg string, args ...any) { l.log(slog.LevelWarn, msg, args...) } func (l *Logger) Error(msg string, args ...any) { l.log(slog.LevelError, msg, args...) } func (l *Logger) Fatal(msg string, args ...any) { l.log(slog.LevelError, msg, args...) os.Exit(1) } func (l *Logger) log(level slog.Level, msg string, args ...any) { for _, h := range l.handlers { if !h.Enabled(context.Background(), level) { continue } r := slog.NewRecord(time.Now(), level, msg, 0) r.Add(args...) _ = h.Handle(context.Background(), r) } } func (l *Logger) Close() error { for _, h := range l.handlers { if err := h.Close(); err != nil { return err } } return nil } func WithContextExtractor(extractor ContextExtractor) Option { return func(l *Logger) { l.ctxExtractor = extractor } } func defaultContextExtractor(ctx context.Context) []slog.Attr { return []slog.Attr{ slog.String("trace_id", getTraceID(ctx)), } } func getTraceID(ctx context.Context) string { traceID := ctx.Value("trace_id") return traceID.(string) }