Skip to content

Commit 7634c49

Browse files
committed
➖ remove PL
Signed-off-by: peter szemraj <peterszemraj@gmail.com>
1 parent 146d124 commit 7634c49

File tree

1 file changed

+94
-6
lines changed

1 file changed

+94
-6
lines changed

samba_pytorch/utils.py

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,21 @@
66

77
"""Utility functions for training and inference."""
88

9+
import csv
10+
import os
911
import pickle
1012
import sys
1113
import warnings
1214
from contextlib import contextmanager
15+
from dataclasses import dataclass, field
1316
from functools import partial
1417
from io import BytesIO
1518
from pathlib import Path
16-
from types import MethodType
17-
from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union
19+
from typing import Any, Dict, List, Mapping, Optional, TypeVar, Union
1820

1921
import torch
2022
import torch.nn as nn
2123
import torch.utils._device
22-
from lightning.fabric.loggers import CSVLogger
2324
from torch.serialization import normalize_storage_type
2425

2526

@@ -471,10 +472,99 @@ def __exit__(self, type, value, traceback):
471472
self.zipfile.write_end_of_file()
472473

473474

475+
class FileSystem:
476+
"""Basic filesystem interface that can be extended for different storage backends."""
477+
478+
def open(self, path: str, mode: str, newline: str = "") -> Any:
479+
return open(path, mode, newline=newline)
480+
481+
def exists(self, path: str) -> bool:
482+
return os.path.exists(path)
483+
484+
def makedirs(self, path: str, exist_ok: bool = True) -> None:
485+
os.makedirs(path, exist_ok=exist_ok)
486+
487+
488+
@dataclass
489+
class Experiment:
490+
"""Handles the actual logging of metrics to CSV."""
491+
492+
metrics_file_path: str
493+
metrics: List[Dict[str, Any]] = field(default_factory=list)
494+
_fs: FileSystem = field(default_factory=FileSystem)
495+
496+
def log_metrics(self, metrics: Dict[str, Any]) -> None:
497+
"""Log a dictionary of metrics."""
498+
self.metrics.append(metrics)
499+
500+
def save(self) -> None:
501+
"""Save metrics to CSV file."""
502+
if not self.metrics:
503+
return
504+
505+
# Ensure directory exists
506+
dir_path = os.path.dirname(self.metrics_file_path)
507+
if dir_path:
508+
self._fs.makedirs(dir_path)
509+
510+
# Get all possible keys from all metric dictionaries
511+
keys = sorted({k for m in self.metrics for k in m})
512+
513+
with self._fs.open(self.metrics_file_path, "w", newline="") as f:
514+
writer = csv.DictWriter(f, fieldnames=keys)
515+
writer.writeheader()
516+
writer.writerows(self.metrics)
517+
518+
519+
class CSVLogger:
520+
"""Base CSV Logger class that can be extended for specific logging needs."""
521+
522+
def __init__(
523+
self,
524+
metrics_file_path: str,
525+
fs: Optional[FileSystem] = None,
526+
experiment_class: Optional[type] = None,
527+
):
528+
self._fs = fs or FileSystem()
529+
self.metrics_file_path = metrics_file_path
530+
exp_class = experiment_class or Experiment
531+
self.experiment = exp_class(metrics_file_path=metrics_file_path, _fs=self._fs)
532+
533+
def log_metrics(self, metrics: Dict[str, Any]) -> None:
534+
"""Log metrics to the experiment."""
535+
self.experiment.log_metrics(metrics)
536+
537+
def save(self) -> None:
538+
"""Save the logged metrics to CSV."""
539+
self.experiment.save()
540+
541+
542+
class StepCSVLogger(CSVLogger):
543+
"""CSV Logger that includes step numbers in metrics."""
544+
545+
def __init__(
546+
self,
547+
metrics_file_path: str,
548+
initial_step: int = 0,
549+
fs: Optional[FileSystem] = None,
550+
):
551+
super().__init__(metrics_file_path, fs)
552+
self.current_step = initial_step
553+
554+
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
555+
"""Log metrics with step number."""
556+
if step is not None:
557+
self.current_step = step
558+
559+
metrics_with_step = {"step": self.current_step, **metrics}
560+
super().log_metrics(metrics_with_step)
561+
self.current_step += 1
562+
563+
474564
T = TypeVar("T")
475565

476566

477-
def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T:
567+
def step_csv_logger(*args: Any, cls=CSVLogger, **kwargs: Any) -> T:
478568
logger = cls(*args, **kwargs)
479569

480570
def merge_by(dicts, key):
@@ -499,8 +589,6 @@ def save(self) -> None:
499589
writer.writeheader()
500590
writer.writerows(metrics)
501591

502-
logger.experiment.save = MethodType(save, logger.experiment)
503-
504592
return logger
505593

506594

0 commit comments

Comments
 (0)