6
6
7
7
"""Utility functions for training and inference."""
8
8
9
+ import csv
10
+ import os
9
11
import pickle
10
12
import sys
11
13
import warnings
12
14
from contextlib import contextmanager
15
+ from dataclasses import dataclass , field
13
16
from functools import partial
14
17
from io import BytesIO
15
18
from pathlib import Path
16
- from types import MethodType
17
- from typing import Any , Dict , List , Mapping , Optional , Type , TypeVar , Union
19
+ from typing import Any , Dict , List , Mapping , Optional , TypeVar , Union
18
20
19
21
import torch
20
22
import torch .nn as nn
21
23
import torch .utils ._device
22
- from lightning .fabric .loggers import CSVLogger
23
24
from torch .serialization import normalize_storage_type
24
25
25
26
@@ -471,10 +472,99 @@ def __exit__(self, type, value, traceback):
471
472
self .zipfile .write_end_of_file ()
472
473
473
474
475
+ class FileSystem :
476
+ """Basic filesystem interface that can be extended for different storage backends."""
477
+
478
+ def open (self , path : str , mode : str , newline : str = "" ) -> Any :
479
+ return open (path , mode , newline = newline )
480
+
481
+ def exists (self , path : str ) -> bool :
482
+ return os .path .exists (path )
483
+
484
+ def makedirs (self , path : str , exist_ok : bool = True ) -> None :
485
+ os .makedirs (path , exist_ok = exist_ok )
486
+
487
+
488
+ @dataclass
489
+ class Experiment :
490
+ """Handles the actual logging of metrics to CSV."""
491
+
492
+ metrics_file_path : str
493
+ metrics : List [Dict [str , Any ]] = field (default_factory = list )
494
+ _fs : FileSystem = field (default_factory = FileSystem )
495
+
496
+ def log_metrics (self , metrics : Dict [str , Any ]) -> None :
497
+ """Log a dictionary of metrics."""
498
+ self .metrics .append (metrics )
499
+
500
+ def save (self ) -> None :
501
+ """Save metrics to CSV file."""
502
+ if not self .metrics :
503
+ return
504
+
505
+ # Ensure directory exists
506
+ dir_path = os .path .dirname (self .metrics_file_path )
507
+ if dir_path :
508
+ self ._fs .makedirs (dir_path )
509
+
510
+ # Get all possible keys from all metric dictionaries
511
+ keys = sorted ({k for m in self .metrics for k in m })
512
+
513
+ with self ._fs .open (self .metrics_file_path , "w" , newline = "" ) as f :
514
+ writer = csv .DictWriter (f , fieldnames = keys )
515
+ writer .writeheader ()
516
+ writer .writerows (self .metrics )
517
+
518
+
519
+ class CSVLogger :
520
+ """Base CSV Logger class that can be extended for specific logging needs."""
521
+
522
+ def __init__ (
523
+ self ,
524
+ metrics_file_path : str ,
525
+ fs : Optional [FileSystem ] = None ,
526
+ experiment_class : Optional [type ] = None ,
527
+ ):
528
+ self ._fs = fs or FileSystem ()
529
+ self .metrics_file_path = metrics_file_path
530
+ exp_class = experiment_class or Experiment
531
+ self .experiment = exp_class (metrics_file_path = metrics_file_path , _fs = self ._fs )
532
+
533
+ def log_metrics (self , metrics : Dict [str , Any ]) -> None :
534
+ """Log metrics to the experiment."""
535
+ self .experiment .log_metrics (metrics )
536
+
537
+ def save (self ) -> None :
538
+ """Save the logged metrics to CSV."""
539
+ self .experiment .save ()
540
+
541
+
542
+ class StepCSVLogger (CSVLogger ):
543
+ """CSV Logger that includes step numbers in metrics."""
544
+
545
+ def __init__ (
546
+ self ,
547
+ metrics_file_path : str ,
548
+ initial_step : int = 0 ,
549
+ fs : Optional [FileSystem ] = None ,
550
+ ):
551
+ super ().__init__ (metrics_file_path , fs )
552
+ self .current_step = initial_step
553
+
554
+ def log_metrics (self , metrics : Dict [str , Any ], step : Optional [int ] = None ) -> None :
555
+ """Log metrics with step number."""
556
+ if step is not None :
557
+ self .current_step = step
558
+
559
+ metrics_with_step = {"step" : self .current_step , ** metrics }
560
+ super ().log_metrics (metrics_with_step )
561
+ self .current_step += 1
562
+
563
+
474
564
T = TypeVar ("T" )
475
565
476
566
477
- def step_csv_logger (* args : Any , cls : Type [ T ] = CSVLogger , ** kwargs : Any ) -> T :
567
+ def step_csv_logger (* args : Any , cls = CSVLogger , ** kwargs : Any ) -> T :
478
568
logger = cls (* args , ** kwargs )
479
569
480
570
def merge_by (dicts , key ):
@@ -499,8 +589,6 @@ def save(self) -> None:
499
589
writer .writeheader ()
500
590
writer .writerows (metrics )
501
591
502
- logger .experiment .save = MethodType (save , logger .experiment )
503
-
504
592
return logger
505
593
506
594
0 commit comments