스푸79 기록 보관소

OpenAI LunarLander-v2 UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. 본문

AI

OpenAI LunarLander-v2 UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow.

스푸79 2024. 8. 15. 08:30

 

OpenAI gymnasium에서 lunarlander를 계속 진행했다.

동작 과정에서 큰 문제는 없었는데

훈련할 때마다 ndarrays로 tensor를 만들면 속도가 심각하게 느려진다는 경고 메시지가 눈에 거슬렸다.

 

UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. 

 

 

변환하기 전 코드는 아래와 같았다.

    states = torch.tensor(batch[0], dtype=torch.float32)   
    actions = torch.tensor(batch[1], dtype=torch.int64).unsqueeze(1) 
    rewards = torch.tensor(batch[2], dtype=torch.float32)
    next_states = torch.tensor(batch[3], dtype=torch.float32)
    dones = torch.tensor(batch[4], dtype=torch.float32)

 

여기서 tensor로 변환되는 batch가 list(zip(*transitions))으로 인해 ndarrays 형태로 변환되기 때문인 것으로 보인다.

    transitions = memory.sample(BATCH_SIZE)
    batch = list(zip(*transitions))

 

변환한 후 코드는 아래와 같다.

    states = torch.tensor(np.array(batch[0]), dtype=torch.float32)   
    #actions의 경우는 2차원 배열로 변환하는 이유는 actions의 형태가 0,1,2,3,4 값, tensor가 불필요
    actions = torch.tensor(np.array(batch[1]), dtype=torch.int64).unsqueeze(1) 
    rewards = torch.tensor(np.array(batch[2]), dtype=torch.float32)
    next_states = torch.tensor(np.array(batch[3]), dtype=torch.float32)
    dones = torch.tensor(np.array(batch[4]), dtype=torch.float32)

 

작성한 DQN 클래스에 save_model, load_model을 추가했다.

def save_model(policy_net, filename):
    torch.save(policy_net.state_dict(), filename)
    print(f"Model saved to {filename}")

def load_model(policy_net, filename):
    policy_net.load_state_dict(torch.load(filename))
    policy_net.eval()  # Set the model to evaluation mode
    print(f"Model loaded from {filename}")

 

이제 훈련시킨 후 해당 모델을 저장한 후

저장된 모델을 다시 불러와서 실행시키는 작업을 하기로 했다.

 

훈련은 총 500회를 진행한 후, model로 저장한 후 평가를 진행했다.

lunar lander의 파일 구성은 아래와 같이 총 4개이다. 전체 코드와 모델을 압축 파일로 첨부하도록 하겠다.

 

dqn_episode_500.pth > train 모델

dqn.py > Q-Network

lunar_lander_eval.py > train 모델을 기반으로 평가하는 코드

lunar_lander_train.py > Q-Network 기반으로 훈련하는 코드

 

훈련 초반 - 사정없이 추락하는 lunar lander의 모습

훈련 후반 - 안정적으로 착륙 지점을 찾아가는 모습

 

lunar_lander.zip
0.07MB