"""Config."""
from __future__ import annotations
import random
from typing import Annotated, Any, Awaitable, Callable, Literal, NewType, Optional
from annotated_types import Ge
from pydantic import BaseModel, Field, FilePath, NewPath, model_validator
PositiveInt = Annotated[int, Ge(0)]
Milliseconds = NewType("Milliseconds", PositiveInt)
Seconds = NewType("Seconds", PositiveInt)
[docs]
class RetryConfig(BaseModel):
"""Retry configuration for the service."""
max_attempts: PositiveInt = 3
wait_exp_max: PositiveInt = 10
wait_exp_min: PositiveInt = 4
wait_exp_mul: PositiveInt = 1
[docs]
class RateLimiterConfig(BaseModel):
"""Retry configuration for the service."""
max_rate: PositiveInt = 10
time_period: Seconds = Seconds(60)
[docs]
class CacheConfig(BaseModel):
use: bool = Field(default=False, description="Whether to cache requests.")
cache_type: Literal["json", "pickle", "remote"] = Field(
default="json", description="The type of cache to use."
)
path: FilePath | NewPath = Field(
default="cache.json",
description="The path of the file to use for the cache. Defaults to 'cache.json'. Unused for remote cache.",
)
write_interval: PositiveInt = Field(
default=20 * 60,
description="The interval to write the cache in seconds. Defaults to 20 minutes.",
)
write_periodically: bool = Field(
default=True,
description="Whether to write the cache to disk periodically. Defaults to True.",
)
save_state: Optional[Callable[[dict], Awaitable[None]]] = Field(
description="A function to save the cache state. Only used for remote cache.",
default=None,
)
load_state: Optional[Callable[[], Awaitable[Any]]] = Field(
description="A function to load the cache state. Only used for remote cache.",
default=None,
)
[docs]
@model_validator(mode="after")
def validate(self) -> CacheConfig: # type: ignore
if self.cache_type == "remote" and not self.save_state and not self.load_state:
raise ValueError(
"Remote cache requires save_state and load_state functions."
)
if self.cache_type == "json" and str(self.path).split(".")[1] not in (
"json",
"jsonl",
"json.gz",
):
raise ValueError("JSON cache requires a .json file.")
if self.cache_type == "pickle" and str(self.path).split(".")[1] not in (
"pkl",
"pickle",
):
raise ValueError("Pickle cache requires a .pkl file.")
return self
[docs]
class DelayConfig(BaseModel):
"""Delay configuration for the service."""
amount: Milliseconds = Field(
default=Milliseconds(0),
description="The total amount of delay in milliseconds.",
)
type: Literal["constant", "random"] = Field(
default="random",
description="The type of delay. Either constant or random. Defaults to random.",
)
[docs]
def get(self):
if self.type == "constant":
return self.amount / 1000
return random.randint(0, self.amount) / 1000
[docs]
class ServiceConfig(BaseModel):
"""Global configuration for the service."""
retry: RetryConfig = Field(
default_factory=RetryConfig, description="The retry configuration."
)
deduplication: bool = Field(
default=True, description="Whether to deduplicate requests."
)
max_concurrency: PositiveInt = Field(
default=10, description="The maximum number of concurrent requests."
)
limiter: RateLimiterConfig | None = Field(
description="The rate limiter configuration", default=None
)
cache: CacheConfig = Field(
description="The cache configuration", default_factory=CacheConfig
)
delay: DelayConfig = Field(
description="The delay configuration", default_factory=DelayConfig
)
[docs]
class ProxyConfig(BaseModel):
"""Proxy configuration for the service."""
host: str = Field(description="The proxy host.")
port: int = Field(description="The proxy port.")
username: Optional[str] = Field(description="The proxy username.", default=None)
password: Optional[str] = Field(description="The proxy password.", default=None)
[docs]
@classmethod
def from_url(cls, url: str) -> ProxyConfig:
if "://" in url:
url = url.split("://")[1]
if "@" in url:
auth, url = url.split("@")
username, password = auth.split(":")
else:
username = None
password = None
host, port = url.split(":")
return cls(host=host, port=int(port), username=username, password=password)
@property
def url(self) -> str:
if self.username and self.password:
return f"http://{self.username}:{self.password}@{self.host}:{self.port}"
return f"http://{self.host}:{self.port}"
[docs]
class PlaywrightConfig(BaseModel):
browser: Literal["chromium", "firefox", "webkit"] = Field(
description="The browser to use.", default="chromium"
)
headless: bool = Field(description="Whether to run in headless mode.", default=True)
slow_mo: PositiveInt = Field(
description="The slow motion delay in milliseconds.", default=0
)
device: Optional[dict[str, Any]] = Field(
description="The devices to use.", default=None
)