AI_Zig/src/env.zig

179 lines
4.9 KiB
Zig

const std = @import("std");
const Allocator = std.mem.Allocator;
pub const GRID_SIZE = 8;
pub const TILE_EMPTY = 0;
pub const TILE_WALL = 1;
pub const TILE_FOOD = 2;
pub const TILE_ANT = 3;
pub const ACTION_UP = 0;
pub const ACTION_DOWN = 1;
pub const ACTION_LEFT = 2;
pub const ACTION_RIGHT = 3;
pub const Ant = struct { x: usize, y: usize, alive: bool };
pub const World = struct {
grid: [GRID_SIZE][GRID_SIZE]u8,
visited: [GRID_SIZE][GRID_SIZE]bool,
ant_x: usize,
ant_y: usize,
food_x: usize,
food_y: usize,
steps: usize,
max_steps: usize,
prng: std.Random.DefaultPrng,
pub fn init(seed: u64) World {
var w = World{
.grid = undefined,
.visited = undefined,
.ant_x = 0,
.ant_y = 0,
.food_x = 0,
.food_y = 0,
.steps = 0,
.max_steps = 100,
.prng = std.Random.DefaultPrng.init(seed),
};
w.reset();
return w;
}
pub fn reset(self: *World) void {
const random = self.prng.random();
for (0..GRID_SIZE) |y| {
for (0..GRID_SIZE) |x| {
self.visited[y][x] = false;
if (x == 0 or y == 0 or x == GRID_SIZE - 1 or y == GRID_SIZE - 1) {
self.grid[y][x] = TILE_WALL;
} else {
self.grid[y][x] = TILE_EMPTY;
}
}
}
self.ant_x = random.intRangeAtMost(usize, 1, GRID_SIZE - 2);
self.ant_y = random.intRangeAtMost(usize, 1, GRID_SIZE - 2);
self.visited[self.ant_y][self.ant_x] = true;
while (true) {
self.food_x = random.intRangeAtMost(usize, 1, GRID_SIZE - 2);
self.food_y = random.intRangeAtMost(usize, 1, GRID_SIZE - 2);
if (self.food_x != self.ant_x or self.food_y != self.ant_y) break;
}
self.steps = 0;
self.updateGrid();
}
fn updateGrid(self: *World) void {
for (1..GRID_SIZE - 1) |y| {
for (1..GRID_SIZE - 1) |x| {
self.grid[y][x] = TILE_EMPTY;
}
}
self.grid[self.food_y][self.food_x] = TILE_FOOD;
self.grid[self.ant_y][self.ant_x] = TILE_ANT;
}
pub fn step(self: *World, action: usize) struct { f32, bool } {
self.steps += 1;
var new_x = self.ant_x;
var new_y = self.ant_y;
if (action == ACTION_UP) new_y -= 1;
if (action == ACTION_DOWN) new_y += 1;
if (action == ACTION_LEFT) new_x -= 1;
if (action == ACTION_RIGHT) new_x += 1;
const tile = self.grid[new_y][new_x];
if (tile == TILE_WALL) {
return .{ -10.0, false };
}
var move_reward: f32 = -0.1;
if (self.visited[new_y][new_x]) {
move_reward -= 0.5;
} else {
move_reward += 0.2;
}
self.ant_x = new_x;
self.ant_y = new_y;
self.visited[new_y][new_x] = true;
self.updateGrid();
if (new_x == self.food_x and new_y == self.food_y) {
return .{ 100.0, true };
}
if (self.steps >= self.max_steps) {
return .{ -10.0, true };
}
return .{ move_reward, false };
}
pub fn getObservation(self: *World, allocator: Allocator) ![]f32 {
var obs = try allocator.alloc(f32, 10);
var idx: usize = 0;
const ax = @as(i32, @intCast(self.ant_x));
const ay = @as(i32, @intCast(self.ant_y));
var dy: i32 = -1;
while (dy <= 1) : (dy += 1) {
var dx: i32 = -1;
while (dx <= 1) : (dx += 1) {
const py_i = ay + dy;
const px_i = ax + dx;
var val: f32 = 0.0;
if (py_i >= 0 and py_i < GRID_SIZE and px_i >= 0 and px_i < GRID_SIZE) {
const py = @as(usize, @intCast(py_i));
const px = @as(usize, @intCast(px_i));
const content = self.grid[py][px];
if (content == TILE_WALL) {
val = -1.0;
} else if (content == TILE_FOOD) {
val = 1.0;
} else if (self.visited[py][px]) {
val = -0.5;
} else {
val = 0.0;
}
} else {
val = -1.0;
}
obs[idx] = val;
idx += 1;
}
}
obs[9] = self.getScent(self.ant_x, self.ant_y);
return obs;
}
fn getScent(self: *World, x: usize, y: usize) f32 {
const dx = if (x > self.food_x) x - self.food_x else self.food_x - x;
const dy = if (y > self.food_y) y - self.food_y else self.food_y - y;
const dist = dx + dy;
if (dist == 0) return 1.0;
return 1.0 / (@as(f32, @floatFromInt(dist)) + 1.0);
}
};