# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import getpass
import json
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any
def _is_interactive() -> bool:
"""Return True when stdin is attached to a terminal."""
return hasattr(sys.stdin, "isatty") and sys.stdin.isatty()
def get_default_config_path() -> Path:
"""Get the default configuration file path."""
return Path.home() / ".neuralbench" / "config.json"
def prompt_user_for_path(
key: str, description: str, default_value: str | None = None
) -> str:
"""Prompt user for a path with optional default value."""
if default_value:
response = input(f"{description}\n[Default: {default_value}]: ").strip()
return response if response else default_value
else:
response = input(f"{description}: ").strip()
while not response:
print("This field is required. Please provide a path.")
response = input(f"{description}: ").strip()
return response
[docs]
def setup_config(config_path: Path | None = None) -> dict[str, Any]:
"""
Set up neuralbench configuration.
If config doesn't exist, prompt user for paths.
Args:
config_path: Path to config file. If None, uses default location.
Returns:
Dictionary with configuration values.
"""
if config_path is None:
print("\nNeuralbench Configuration Setup")
print("=" * 50)
response = (
input(
"\nConfiguration file not found.\n"
"Do you want to use the default location (~/.neuralbench/config.json)? [Y/n]: "
)
.strip()
.lower()
)
if response in ["n", "no"]:
custom_path = input("Enter custom configuration file path: ").strip()
config_path = Path(custom_path).expanduser()
else:
config_path = get_default_config_path()
# If config exists, load and return it
if config_path.exists():
with config_path.open() as f:
config = json.load(f)
print(f"\nLoaded configuration from: {config_path}")
return config
# Config doesn't exist, prompt for values
print(f"\nSetting up new configuration at: {config_path}")
print("-" * 50)
config = {}
# Get username and entity name
username = getpass.getuser()
config["USER"] = username
config["ENTITY_NAME"] = username
# Default project name
config["PROJECT_NAME"] = "neuralbench"
# SLURM defaults (can be overridden later in config.json)
config["SLURM_PARTITION"] = ""
config["SLURM_CONSTRAINT"] = ""
config["N_CPUS"] = 10
# Prompt for paths
print("\nPlease provide the following paths:")
print()
config["CACHE_DIR"] = prompt_user_for_path(
"CACHE_DIR",
"CACHE_DIR - Where to cache intermediate results from experiments",
)
config["SAVE_DIR"] = prompt_user_for_path(
"SAVE_DIR",
"SAVE_DIR - Where to save experiment results",
)
config["DATA_DIR"] = prompt_user_for_path(
"DATA_DIR",
"DATA_DIR - Where to download and store datasets",
)
# Create directories if they don't exist
for key in ["CACHE_DIR", "SAVE_DIR", "DATA_DIR"]:
path = Path(config[key])
path.mkdir(parents=True, exist_ok=True)
print(f"Created directory: {path}")
# Prompt for W&B host
print()
wandb_host = input(
"WANDB_HOST - Weights & Biases server URL\n"
"(leave empty to skip, e.g. https://wandb.ai/): "
).strip()
config["WANDB_HOST"] = wandb_host
# Save configuration
config_path.parent.mkdir(parents=True, exist_ok=True)
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print(f"\nConfiguration saved to: {config_path}")
print("=" * 50)
print()
return config
def _default_config() -> dict[str, Any]:
"""Return a minimal config with temporary paths for non-interactive use."""
import tempfile
base = Path(tempfile.gettempdir()) / "neuralbench"
username = getpass.getuser()
config = {
"USER": username,
"ENTITY_NAME": username,
"PROJECT_NAME": "neuralbench",
"CACHE_DIR": str(base / "cache"),
"SAVE_DIR": str(base / "save"),
"DATA_DIR": str(base / "data"),
"WANDB_HOST": "",
"SLURM_PARTITION": "",
"SLURM_CONSTRAINT": "",
"N_CPUS": 10,
}
for key in ["CACHE_DIR", "SAVE_DIR", "DATA_DIR"]:
Path(str(config[key])).mkdir(parents=True, exist_ok=True)
return config
[docs]
def load_config(config_path: Path | None = None) -> dict[str, Any]:
"""
Load neuralbench configuration.
The config file location is resolved in order:
1. Explicit *config_path* argument
2. ``NEURALBENCH_CONFIG`` environment variable
3. Default ``~/.neuralbench/config.json``
Args:
config_path: Path to config file. If None, checks env var then default.
Returns:
Dictionary with configuration values.
"""
if config_path is None:
env_path = os.environ.get("NEURALBENCH_CONFIG")
if env_path:
config_path = Path(env_path).expanduser()
else:
config_path = get_default_config_path()
if not config_path.exists():
if not _is_interactive():
return _default_config()
return setup_config(config_path)
with config_path.open() as f:
return json.load(f)
# Global config instance (will be initialized when module is imported)
_config: dict[str, Any] | None = None
[docs]
def get_config() -> dict[str, Any]:
"""Get the current configuration, loading it if necessary."""
global _config
if _config is None:
_config = load_config()
return _config
if TYPE_CHECKING:
USER: str
ENTITY_NAME: str
PROJECT_NAME: str
DATA_DIR: str
CACHE_DIR: str
SAVE_DIR: str
WANDB_HOST: str
SLURM_PARTITION: str
SLURM_CONSTRAINT: str
N_CPUS: int
def _initialize_module_vars() -> None:
"""Initialize module-level variables from config."""
global USER, ENTITY_NAME, PROJECT_NAME, DATA_DIR, CACHE_DIR, SAVE_DIR
global WANDB_HOST, SLURM_PARTITION, SLURM_CONSTRAINT, N_CPUS
config = get_config()
USER = config["USER"]
ENTITY_NAME = config["ENTITY_NAME"]
PROJECT_NAME = config["PROJECT_NAME"]
DATA_DIR = config["DATA_DIR"]
CACHE_DIR = config["CACHE_DIR"]
SAVE_DIR = config["SAVE_DIR"]
WANDB_HOST = config.get("WANDB_HOST", "")
SLURM_PARTITION = config.get("SLURM_PARTITION", "")
SLURM_CONSTRAINT = config.get("SLURM_CONSTRAINT", "")
N_CPUS = config.get("N_CPUS", 10)
# Lazy module-level config variables for YAML compatibility
# (``!!python/name:neuralbench.config_manager.DATA_DIR`` etc.).
# Values are resolved on first access via ``__getattr__`` so that
# importing this module does not trigger file I/O or interactive prompts.
_LAZY_CONFIG_KEYS = {
"USER",
"ENTITY_NAME",
"PROJECT_NAME",
"DATA_DIR",
"CACHE_DIR",
"SAVE_DIR",
"WANDB_HOST",
"SLURM_PARTITION",
"SLURM_CONSTRAINT",
"N_CPUS",
}
_initialized = False
def _ensure_initialized() -> None:
global _initialized
if not _initialized:
_initialize_module_vars()
_initialized = True
def __getattr__(name: str) -> Any:
if name in _LAZY_CONFIG_KEYS:
_ensure_initialized()
return globals()[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")