Benchmark

benchmark

create_dataloader

Creates dataloader based on the BaselineDataset.

Parameters:
  • path(str)

    HuggingFace path to dataset.

Returns:
  • dataloader( DataLoader ) –

    dataloader to use for training.

src/benchmark/benchmark.py
19
20
21
22
23
24
25
26
27
28
29
30
def create_dataloader(path: str) -> DataLoader:
    """Creates dataloader based on the BaselineDataset.

    Args:
        path(str): HuggingFace path to dataset.

    Returns:
        dataloader (DataLoader): dataloader to use for training.
    """
    dataset = BaselineDataset(path, source="huggingface")
    dataloader = DataLoader(dataset, batch_size=2048, shuffle=True)
    return dataloader

benchmark_method

Function for training a method and evaluating.

Parameters:
  • method (Method) –

    class for a method.

  • environment (Env) –

    environment to train the method.

  • dataloader (DataLoader) –

    dataloader to train the method.

  • teacher_reward (Number) –

    teacher reward to compute performance.

  • random_reward (Number) –

    random reward to compute performance.

Returns:
  • metrics( Metrics ) –

    resulting metrics for best checkpoint. aer (Dict[str, str]): average episodic reward. performance (Dict[str, str]) performance.

src/benchmark/benchmark.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def benchmark_method(
    method: Method,
    environment: gym.Env,
    dataloader: DataLoader,
    teacher_reward: Number,
    random_reward: Number
) -> Metrics:
    """Function for training a method and evaluating.

    Args:
        method (Method): class for a method.
        environment (Env): environment to train the method.
        dataloader (DataLoader): dataloader to train the method.
        teacher_reward (Number): teacher reward to compute performance.
        random_reward (Number): random reward to compute performance.

    Returns:
        metrics (Metrics): resulting metrics for best checkpoint.
            aer (Dict[str, str]): average episodic reward.
            performance (Dict[str, str]) performance.
    """
    policy: Method = method(environment, verbose=True, enjoy_criteria=100)
    metrics = policy.train(5000, train_dataset=dataloader) \
        .load() \
        ._enjoy(teacher_reward=teacher_reward, random_reward=random_reward)
    aer = f"{round(metrics['aer'], 4)} ± {round(metrics['aer_std'], 4)}"
    performance = f"{round(metrics['performance'], 4)} ± {round(metrics['performance_std'], 4)}"
    return {"aer": aer, "performance": performance}

benchmark

Benchmark for all methods and environments listed on registers.py

Parameters:
  • benchmark_methods (List[Method]) –

    list of benchmark methods to run.

src/benchmark/benchmark.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def benchmark(benchmark_methods: List[Method]) -> None:
    """Benchmark for all methods and environments listed on registers.py

    Args:
        benchmark_methods (List[Method]): list of benchmark methods to run.
    """
    file_name = "benchmark_results_"
    file_name += "-".join([str(method.__name__) for method in benchmark_methods])

    benchmark_results = []
    for environments in tqdm(benchmark_environments, desc="Benchmark Environments"):
        for name, info in environments.items():
            path, random_reward = info.values()

            try:
                environment = gym.make(name)
            except VersionNotFound:
                logging.error("benchmark: Version for environment does not exist")
            except NameNotFound:
                logging.error("benchmark: Environment name does not exist")
            except Exception:
                logging.error("benchmark: Generic error raised, probably dependency related")

            try:
                dataloader = create_dataloader(path)
            except FileNotFoundError:
                logging.error("benchmark: HuggingFace path is not valid")
                continue

            for method in tqdm(benchmark_methods, desc=f"Methods for environment: {name}"):
                try:
                    metrics = benchmark_method(
                        method,
                        environment,
                        dataloader,
                        dataloader.dataset.average_reward,
                        random_reward
                    )
                except Exception as exception:
                    logging.error(
                        "benchmark: Method %s did raise an exception during training",
                        method.__method_name__
                    )
                    logging.error(exception)
                    continue

                benchmark_results.append([
                    name,
                    method.__method_name__,
                    *metrics.values()
                ])

        table = tabulate(
            benchmark_results,
            headers=["Environment", "Method", "AER", "Performance"],
            tablefmt="github"
        )

        with open(f"./{file_name}.md", "w", encoding="utf-8") as _file:
            _file.write(table)

registers

get_methods

Get methods from string list.

Parameters:
  • names (List[str]) –

    list of method names.

Returns:
  • benchmark_methods( List[Method] ) –

    list of methods.

src/benchmark/registers.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def get_methods(names: List[str]) -> List[Method]:
    """Get methods from string list.

    Args:
        names (List[str]): list of method names.

    Returns:
        benchmark_methods (List[Method]): list of methods.
    """
    if len(names) == 1 and names[0] == "all":
        return benchmark_methods

    partial_benchmark_methods = []
    for name in names:
        partial_benchmark_methods.append(eval(name.upper()))
    return partial_benchmark_methods