Skip to content
Snippets Groups Projects

obs spaces simplified to boxes in init and reset, superfluous methods deleted;...

Merged Maddila Siva Sri Prasanna requested to merge reimpl-env into main
18 files
+ 375
1134
Compare changes
  • Side-by-side
  • Inline
Files
18
"""Module to implement useful functions like box_flatten
and box_flatten_obs. These will be used in the
NonCategoricalFlatten wrapper, as defined in `wrappers.py`.
"""
from functools import singledispatch
import numpy as np
import gymnasium
from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiDiscrete,
MultiBinary,
)
# Custom Flattening logic
@singledispatch
def box_flatten(space: gymnasium.Space, dtype=None) -> Box:
"""
SingleDispatch function to recursively fold a
space into a Box: Currently supports only Box,
Discrete, MultiDiscrete and Dict spaces.
"""
raise NotImplementedError(f"Unknown/Unsupported space: {space=}")
@box_flatten.register(Dict)
def flatten_dict(space: Dict, dtype=None) -> Box:
"""
Function to fold a Dict space into a Box.
"""
# Recursively flatten all the spaces in the keys
list_boxed_spaces = [box_flatten(subsp) for subsp in space.values()]
# return the Box with the concatenated shapes
return Box(
low=np.concatenate([subsp.low for subsp in list_boxed_spaces]),
high=np.concatenate([subsp.high for subsp in list_boxed_spaces]),
dtype=np.int32,
)
@box_flatten.register(Discrete)
def flatten_discrete(space: Discrete, dtype=None) -> Box:
"""Flatten a discrete box, but not as a categorical
space. We consider Discrete as a Box of ints.
"""
return Box(
low=0, high=space.n - 1, dtype=space.dtype if not dtype else dtype
)
@box_flatten.register(MultiDiscrete)
def flatten_multidiscrete(space: MultiDiscrete, dtype=None) -> Box:
"""Flatten a MultiDiscrete box, but not as a categorical
space. We consider MultiDiscrete as a multi-dimensional
Box of ints.
"""
return Box(
low=np.zeros_like(space.nvec).flatten(),
high=np.array(space.nvec).flatten() - 1,
dtype=space.dtype if not dtype else dtype,
)
@box_flatten.register(Box)
def flatten_box(space: Box, dtype=None) -> Box:
"""Flatten of a box should just return the box if it is 1D,
else the equivalent of running np.flatten on its samples."""
return Box(
low=np.array(space.low).flatten(),
high=np.array(space.high).flatten(),
dtype=space.dtype if not dtype else dtype,
)
@box_flatten.register(MultiBinary)
def flatten_multibinary(space: MultiBinary, dtype=None) -> Box:
"""Convert a Binary to a Box"""
return Box(
low=0,
high=1,
shape=np.array(space.shape).flatten(),
dtype=space.dtype if not dtype else dtype,
)
def box_flatten_obs(obs, dtype=None) -> np.array:
"""
Box flattens an observation recursively.
"""
if isinstance(obs, dict):
# Recursively fold.
return np.concatenate(
[box_flatten_obs(val).flatten() for val in obs.values()]
)
dtype = np.int32 if not dtype else dtype # Default to integer type
if isinstance(obs, (np.ndarray, list)) or np.isscalar(obs):
# Return the flattened version
return np.array(obs, dtype=dtype).flatten()
raise NotImplementedError(f"Unknown observation sent: {obs}")
# Custom unflattening logic
def sizeof_space(space: gymnasium.Space):
"""gets the number of elements in the space"""
if isinstance(space, Dict):
return np.sum([sizeof_space(subsp) for subsp in space.values()])
return np.prod(np.array(space.shape).flatten().astype(np.int64))
def parse_obs(obs: np.array, space: gymnasium.Space):
"""Parse the observation into the original space."""
if isinstance(space, (Box, Discrete, MultiDiscrete, MultiBinary)):
return obs.reshape(space.shape).astype(space.dtype)
if not isinstance(space, Dict):
raise NotImplementedError
# Now to implement parsing for dictionaries
begin, end = 0, 0
new_obs = dict.fromkeys(space.keys())
for key, subsp in space.items():
end = begin + sizeof_space(subsp) # Define the chunk to parse
new_obs[key] = parse_obs(obs[begin:end], subsp)
begin = end # update the begin position for next subspace
return new_obs
Loading