179 lines
4.9 KiB
Zig
179 lines
4.9 KiB
Zig
const std = @import("std");
|
|
const Allocator = std.mem.Allocator;
|
|
|
|
pub const GRID_SIZE = 8;
|
|
pub const TILE_EMPTY = 0;
|
|
pub const TILE_WALL = 1;
|
|
pub const TILE_FOOD = 2;
|
|
pub const TILE_ANT = 3;
|
|
|
|
pub const ACTION_UP = 0;
|
|
pub const ACTION_DOWN = 1;
|
|
pub const ACTION_LEFT = 2;
|
|
pub const ACTION_RIGHT = 3;
|
|
|
|
pub const Ant = struct { x: usize, y: usize, alive: bool };
|
|
|
|
pub const World = struct {
|
|
grid: [GRID_SIZE][GRID_SIZE]u8,
|
|
visited: [GRID_SIZE][GRID_SIZE]bool,
|
|
ant_x: usize,
|
|
ant_y: usize,
|
|
food_x: usize,
|
|
food_y: usize,
|
|
steps: usize,
|
|
max_steps: usize,
|
|
prng: std.Random.DefaultPrng,
|
|
|
|
pub fn init(seed: u64) World {
|
|
var w = World{
|
|
.grid = undefined,
|
|
.visited = undefined,
|
|
.ant_x = 0,
|
|
.ant_y = 0,
|
|
.food_x = 0,
|
|
.food_y = 0,
|
|
.steps = 0,
|
|
.max_steps = 100,
|
|
.prng = std.Random.DefaultPrng.init(seed),
|
|
};
|
|
w.reset();
|
|
return w;
|
|
}
|
|
|
|
pub fn reset(self: *World) void {
|
|
const random = self.prng.random();
|
|
|
|
for (0..GRID_SIZE) |y| {
|
|
for (0..GRID_SIZE) |x| {
|
|
self.visited[y][x] = false;
|
|
|
|
if (x == 0 or y == 0 or x == GRID_SIZE - 1 or y == GRID_SIZE - 1) {
|
|
self.grid[y][x] = TILE_WALL;
|
|
} else {
|
|
self.grid[y][x] = TILE_EMPTY;
|
|
}
|
|
}
|
|
}
|
|
|
|
self.ant_x = random.intRangeAtMost(usize, 1, GRID_SIZE - 2);
|
|
self.ant_y = random.intRangeAtMost(usize, 1, GRID_SIZE - 2);
|
|
|
|
self.visited[self.ant_y][self.ant_x] = true;
|
|
|
|
while (true) {
|
|
self.food_x = random.intRangeAtMost(usize, 1, GRID_SIZE - 2);
|
|
self.food_y = random.intRangeAtMost(usize, 1, GRID_SIZE - 2);
|
|
if (self.food_x != self.ant_x or self.food_y != self.ant_y) break;
|
|
}
|
|
|
|
self.steps = 0;
|
|
self.updateGrid();
|
|
}
|
|
|
|
fn updateGrid(self: *World) void {
|
|
for (1..GRID_SIZE - 1) |y| {
|
|
for (1..GRID_SIZE - 1) |x| {
|
|
self.grid[y][x] = TILE_EMPTY;
|
|
}
|
|
}
|
|
self.grid[self.food_y][self.food_x] = TILE_FOOD;
|
|
self.grid[self.ant_y][self.ant_x] = TILE_ANT;
|
|
}
|
|
|
|
pub fn step(self: *World, action: usize) struct { f32, bool } {
|
|
self.steps += 1;
|
|
|
|
var new_x = self.ant_x;
|
|
var new_y = self.ant_y;
|
|
|
|
if (action == ACTION_UP) new_y -= 1;
|
|
if (action == ACTION_DOWN) new_y += 1;
|
|
if (action == ACTION_LEFT) new_x -= 1;
|
|
if (action == ACTION_RIGHT) new_x += 1;
|
|
|
|
const tile = self.grid[new_y][new_x];
|
|
|
|
if (tile == TILE_WALL) {
|
|
return .{ -10.0, false };
|
|
}
|
|
|
|
var move_reward: f32 = -0.1;
|
|
|
|
if (self.visited[new_y][new_x]) {
|
|
move_reward -= 0.5;
|
|
} else {
|
|
move_reward += 0.2;
|
|
}
|
|
|
|
self.ant_x = new_x;
|
|
self.ant_y = new_y;
|
|
self.visited[new_y][new_x] = true;
|
|
|
|
self.updateGrid();
|
|
|
|
if (new_x == self.food_x and new_y == self.food_y) {
|
|
return .{ 100.0, true };
|
|
}
|
|
|
|
if (self.steps >= self.max_steps) {
|
|
return .{ -10.0, true };
|
|
}
|
|
|
|
return .{ move_reward, false };
|
|
}
|
|
|
|
pub fn getObservation(self: *World, allocator: Allocator) ![]f32 {
|
|
var obs = try allocator.alloc(f32, 10);
|
|
var idx: usize = 0;
|
|
|
|
const ax = @as(i32, @intCast(self.ant_x));
|
|
const ay = @as(i32, @intCast(self.ant_y));
|
|
|
|
var dy: i32 = -1;
|
|
while (dy <= 1) : (dy += 1) {
|
|
var dx: i32 = -1;
|
|
while (dx <= 1) : (dx += 1) {
|
|
const py_i = ay + dy;
|
|
const px_i = ax + dx;
|
|
|
|
var val: f32 = 0.0;
|
|
|
|
if (py_i >= 0 and py_i < GRID_SIZE and px_i >= 0 and px_i < GRID_SIZE) {
|
|
const py = @as(usize, @intCast(py_i));
|
|
const px = @as(usize, @intCast(px_i));
|
|
const content = self.grid[py][px];
|
|
|
|
if (content == TILE_WALL) {
|
|
val = -1.0;
|
|
} else if (content == TILE_FOOD) {
|
|
val = 1.0;
|
|
} else if (self.visited[py][px]) {
|
|
val = -0.5;
|
|
} else {
|
|
val = 0.0;
|
|
}
|
|
} else {
|
|
val = -1.0;
|
|
}
|
|
|
|
obs[idx] = val;
|
|
idx += 1;
|
|
}
|
|
}
|
|
obs[9] = self.getScent(self.ant_x, self.ant_y);
|
|
|
|
return obs;
|
|
}
|
|
|
|
fn getScent(self: *World, x: usize, y: usize) f32 {
|
|
const dx = if (x > self.food_x) x - self.food_x else self.food_x - x;
|
|
const dy = if (y > self.food_y) y - self.food_y else self.food_y - y;
|
|
const dist = dx + dy;
|
|
|
|
if (dist == 0) return 1.0;
|
|
|
|
return 1.0 / (@as(f32, @floatFromInt(dist)) + 1.0);
|
|
}
|
|
};
|