Benchmark for all methods and environments listed on registers.py
Parameters: |
-
benchmark_methods
(List[Method] )
–
list of benchmark methods to run.
|
src/benchmark/benchmark.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123 | def benchmark(benchmark_methods: List[Method]) -> None:
"""Benchmark for all methods and environments listed on registers.py
Args:
benchmark_methods (List[Method]): list of benchmark methods to run.
"""
file_name = "benchmark_results_"
file_name += "-".join([str(method.__name__) for method in benchmark_methods])
benchmark_results = []
for environments in tqdm(benchmark_environments, desc="Benchmark Environments"):
for name, info in environments.items():
path, random_reward = info.values()
try:
environment = gym.make(name)
except VersionNotFound:
logging.error("benchmark: Version for environment does not exist")
except NameNotFound:
logging.error("benchmark: Environment name does not exist")
except Exception:
logging.error("benchmark: Generic error raised, probably dependency related")
try:
dataloader = create_dataloader(path)
except FileNotFoundError:
logging.error("benchmark: HuggingFace path is not valid")
continue
for method in tqdm(benchmark_methods, desc=f"Methods for environment: {name}"):
try:
metrics = benchmark_method(
method,
environment,
dataloader,
dataloader.dataset.average_reward,
random_reward
)
except Exception as exception:
logging.error(
"benchmark: Method %s did raise an exception during training",
method.__method_name__
)
logging.error(exception)
continue
benchmark_results.append([
name,
method.__method_name__,
*metrics.values()
])
table = tabulate(
benchmark_results,
headers=["Environment", "Method", "AER", "Performance"],
tablefmt="github"
)
with open(f"./{file_name}.md", "w", encoding="utf-8") as _file:
_file.write(table)
|