RLVE_Gym / server /RLVE_Gym_environment.py
ZhiyuanZeng's picture
misc
c4bedee
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
RLVE-Gym Environment Implementation.
"""
import os
from typing import Optional, Tuple
import random
from openenv_core.env_server.interfaces import Environment
from models import RlveGymState, RlveGymAction, RlveGymObservation
from server.Gym.environment import VerifiableEnvironment
from server.Gym.parameter_controller import ParameterController
from server.Gym.environments import identifier2environment
from server.Gym.parameter_controllers import identifier2controller
class RlveGymEnvironment(Environment):
"""
Wrap any verifiable environment from RLVE-Gym behind the OpenEnv ``Environment`` API.
"""
def __init__(
self,
environment_identifier: str = None,
difficulty: int = None,
answer_markers: Optional[Tuple[str, str]] = None,
initial_seed: int = None,
):
"""
Initialize the RLVE_Gym environment.
Args:
environment_identifier (str): The environment's identifier. Check server/Gym/environments/__init__.py for detailed usage.
difficulty (int): The difficulty of generated problems.
answer_markers (Tuple[str] of length 2): How the environment extracts the final answer from a model output.
initial_seed (int): The initial seed to use when generating the first problem. Whenever reset() is called, the seed will be incremented by 1.
"""
if environment_identifier is not None :
self.environment_identifier = environment_identifier
else :
self.environment_identifier = os.getenv("RLVEGYM_ENVIRONMENT_IDENTIFIER", default = "Multiplication")
if difficulty is not None :
self.difficulty = difficulty
else :
self.difficulty = int(os.getenv("RLVEGYM_DIFFICULTY", default = "0"))
if answer_markers is not None :
self.answer_markers = answer_markers
else :
self.answer_markers = (os.getenv("RLVEGYM_ANSWER_MARKER_START", default = r"<answer>"), os.getenv("RLVEGYM_ANSWER_MARKER_END", default = r"</answer>"))
if initial_seed is not None :
pass
else :
initial_seed = int(os.getenv("RLVEGYM_INITIAL_SEED", default = "0"))
self._state = RlveGymState(
seed=initial_seed,
problem_input=None,
num_samples=0,
sum_accuracy=0,
)
self.problem = None
def reset(self) -> RlveGymObservation:
"""
Reset the environment.
Returns:
problem_input (Optional[str]): The input of the problem; if it is None, it means that the problem generation has not been run or has failed.
verifier_result (Optional[dict]): Contains reward as the raw reward, accuracy as the 0/1 correctness, and format_score as the 0/1 format correctness; if it is None, it means that the verification has failed.
success (bool): True or False indicates whether the operation succeeded.
message (str): The explanation of success.
reward (Optional[float]): The value is verifier_result["reward"] when verifier_result is not None (otherwise, reward is also None).
"""
if (self.environment_identifier not in identifier2environment) or (
self.environment_identifier not in identifier2controller
):
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message="Invalid environment identifier.",
reward=None,
)
if not (isinstance(self.difficulty, int) and self.difficulty >= 0):
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message="Difficulty should be a non-negative integer.",
reward=None,
)
if not (isinstance(self._state.seed, int) and self._state.seed >= 0):
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message="Seed should be a non-negative integer.",
reward=None,
)
try:
problem: VerifiableEnvironment = identifier2environment[self.environment_identifier](
answer_markers=self.answer_markers
)
except Exception as e:
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message=f"Failed to initialize environment: {e}",
reward=None,
)
controller: ParameterController = identifier2controller[self.environment_identifier]()
for _ in range(self.difficulty):
controller.update()
random.seed(self._state.seed)
parameter = random.choice(controller.get_parameter_list())
if problem.generator(seed=self._state.seed, parameter=parameter):
self._state.problem_input = problem.prompt_generator()
self.problem = problem
else:
self._state.problem_input = None
self.problem = None
self._state.seed += 1
self._state.num_samples = self._state.sum_accuracy = 0
if self.problem is not None:
return RlveGymObservation(
problem_input=self._state.problem_input,
verifier_result=None,
success=True,
message="Problem generated successfully.",
reward=None,
)
else:
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message="Problem generation failed. Please try decreasing difficulty or changing seed.",
reward=None,
)
def step(self, action: RlveGymAction) -> RlveGymObservation: # type: ignore[override]
"""
Execute a step in the environment by verifying the model output.
Args:
action (RlveGymAction): Contains a single field:
- output (str): The model's output to get verified.
Returns:
problem_input (Optional[str]): The input of the problem; if it is None, it means that the problem generation has not been run or has failed.
verifier_result (Optional[dict]): Contains reward as the raw reward, accuracy as the 0/1 correctness, and format_score as the 0/1 format correctness; if it is None, it means that the verification has failed.
success (bool): True or False indicates whether the operation succeeded.
message (str): The explanation of success.
reward (Optional[float]): The value is verifier_result["reward"] when verifier_result is not None (otherwise, reward is also None).
"""
if self.problem is None:
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message="Problem not ready. Please reset the environment.",
reward=None,
)
try:
verifier_result = self.problem.verifier(action.output)
except Exception as e:
return RlveGymObservation(
problem_input=self._state.problem_input,
verifier_result=None,
success=False,
message=f"Verification failed with error: {e}",
reward=None,
)
self._state.num_samples += 1
self._state.sum_accuracy += verifier_result["accuracy"]
return RlveGymObservation(
problem_input=self._state.problem_input,
verifier_result=verifier_result,
success=True,
message="Verification completed.",
reward=verifier_result["reward"],
)
@property
def state(self) -> RlveGymState:
"""
Get the current environment state.
Returns:
seed (int): The seed to use when running reset().
problem_input (Optional[str]): The input of the problem; if it is None, it means that the problem generation has not been run, or it failed.
num_samples (int) and sum_accuracy (int): The statistics of the result of `step(action)` so far for the current problem (the number of outputs sent to the verifier and the number of correct ones).
"""
return self._state