import gymnasium
import numpy as np
import socket
from stable_baselines3 import PPO
class GodotEnv(gymnasium.Env):
def __init__(self):
super(GodotEnv, self).__init__()
HOST = '127.0.0.1'
PORT = 12345
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.s.bind((HOST, PORT))
self.s.listen()
print(f"Server läuft und wartet auf Verbindung auf {HOST}:{PORT}...")
self.conn, self.addr = self.s.accept()
print(f"Verbindung hergestellt von {self.addr}")
self.observation_space = gymnasium.spaces.Box(low=-1000, high=1000, shape=(129,), dtype=np.float32)
self.action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(22,), dtype=np.float32)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self.conn.sendall(b"reset\n")
return self.receive_data(), {}
def step(self, action):
action_str = ";".join(map(str, action)) + "\n"
self.conn.sendall(action_str.encode())
state = self.receive_data()
info = {}
reward = 0.0
done = False
truncated = False
return state, reward, done, truncated, info
def receive_data(self):
buffer = ""
while True:
try:
data = self.conn.recv(1024)
if not data:
break
buffer += data.decode("ascii", "replace")
if ";" in buffer:
messages = buffer.split(";")
for message in messages[:-1]:
#print(f"Erhaltene Daten: {message}")
try:
state = np.array([float(x) for x in message.split(',') if x], dtype=np.float32)
print(state, state.shape)
return state
except ValueError as e:
continue
buffer = messages[-1]
except socket.error as e:
break
def close(self):
self.conn.close()
#Haupt
env = GodotEnv()
from gymnasium.utils.env_checker import check_env
check_env(env)
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10)
model.save("ppo_godot_model")
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs)
print(action)
obs, rewards, done, info = env.step(action)
if done:
obs = env.reset()
env.close()
Я распечатываю форму np.array, которая возвращается функцией Recieve_data, и могу подтвердить, что большинство массивов имеют форму 129 (это моя желаемая форма для модели ). Но всегда есть фигура, которая ниже той, которая вызывает ошибку.
Я пробовал отсортировать фигуры
Я программирую модель PPO, которая получает данные из сценария Godot 4.3. он отправляет информацию о положении, вращении и т. д. [code]func send_data_to_python(): if is_connected: socket.put_data(get_character_state().to_ascii_buffer()) socket.poll()
func _physics_process(delta: float) -> void: if is_connected: send_data_to_python() [/code] это то, что я получил в моем скрипте Python, который подключен через TCP: [code]import gymnasium import numpy as np import socket from stable_baselines3 import PPO
class GodotEnv(gymnasium.Env): def __init__(self): super(GodotEnv, self).__init__() HOST = '127.0.0.1' PORT = 12345 self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.s.bind((HOST, PORT)) self.s.listen() print(f"Server läuft und wartet auf Verbindung auf {HOST}:{PORT}...") self.conn, self.addr = self.s.accept() print(f"Verbindung hergestellt von {self.addr}") self.observation_space = gymnasium.spaces.Box(low=-1000, high=1000, shape=(129,), dtype=np.float32) self.action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(22,), dtype=np.float32)
def receive_data(self): buffer = "" while True: try: data = self.conn.recv(1024) if not data: break buffer += data.decode("ascii", "replace") if ";" in buffer: messages = buffer.split(";") for message in messages[:-1]: #print(f"Erhaltene Daten: {message}") try: state = np.array([float(x) for x in message.split(',') if x], dtype=np.float32) print(state, state.shape) return state except ValueError as e: continue buffer = messages[-1] except socket.error as e: break
def close(self): self.conn.close()
#Haupt env = GodotEnv() from gymnasium.utils.env_checker import check_env check_env(env) model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10) model.save("ppo_godot_model") obs = env.reset() for i in range(1000): action, _states = model.predict(obs) print(action) obs, rewards, done, info = env.step(action) if done: obs = env.reset() env.close() [/code] Я распечатываю форму np.array, которая возвращается функцией Recieve_data, и могу подтвердить, что большинство массивов имеют форму 129 (это моя желаемая форма для модели ). Но всегда есть фигура, которая ниже той, которая вызывает ошибку. Я пробовал отсортировать фигуры