AI_Zig/src/main.zig
2026-02-03 14:55:50 +01:00

134 lines
4.3 KiB
Zig

const std = @import("std");
const World = @import("env.zig").World;
const env = @import("env.zig"); // Per accedere a costanti come GRID_SIZE
const Network = @import("modular_network.zig").Network;
// --- IPERPARAMETRI ---
const GAMMA: f32 = 0.9;
const LR: f32 = 0.005; // Basso perché output lineare
const EPSILON_START: f32 = 1.0;
const EPSILON_END: f32 = 0.05;
const DECAY_RATE: f32 = 0.0001; // Decadimento più lento dato che abbiamo tanti step
// Helper per trovare max e argmax
fn maxVal(slice: []const f32) f32 {
var m: f32 = -1.0e20; // Numero molto basso
for (slice) |v| if (v > m) {
m = v;
};
return m;
}
fn argmax(slice: []const f32) usize {
var m: f32 = -1.0e20;
var idx: usize = 0;
for (slice, 0..) |v, i| {
if (v > m) {
m = v;
idx = i;
}
}
return idx;
}
// Export aggiornato per Multi-Formica
fn exportAntJSON(world: *World, file_path: []const u8, episode: usize, epsilon: f32) !void {
const file = try std.fs.cwd().createFile(file_path, .{});
defer file.close();
// Buffer grande per contenere le coordinate di 20 formiche
var buffer: [65536]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&buffer);
const allocator = fba.allocator();
// Costruiamo la lista delle formiche in JSON: [[x,y], [x,y], ...]
var ants_json = std.ArrayList(u8){};
defer ants_json.deinit(allocator);
try ants_json.appendSlice(allocator, "[");
for (world.ants, 0..) |ant, i| {
if (i > 0) try ants_json.appendSlice(allocator, ",");
try std.fmt.format(ants_json.writer(allocator), "[{d},{d}]", .{ ant.x, ant.y });
}
try ants_json.appendSlice(allocator, "]");
// Scriviamo il JSON completo
// QUI RISOLVIAMO L'ERRORE: Usiamo 'epsilon' nella stringa
const json = try std.fmt.allocPrint(allocator, "{{\n \"grid_size\": {d},\n \"food\": [{d}, {d}],\n \"ants\": {s},\n \"episode\": {d},\n \"epsilon\": {d:.3}\n}}", .{ env.GRID_SIZE, world.food_x, world.food_y, ants_json.items, episode, epsilon });
try file.writeAll(json);
}
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
const allocator = gpa.allocator();
defer _ = gpa.deinit();
var world = World.init(12345);
var net = Network.init(allocator);
defer net.deinit();
try net.addLayer(15, 40, 111, true);
try net.addLayer(40, 4, 222, false);
var prng = std.Random.DefaultPrng.init(999);
const random = prng.random();
std.debug.print("--- HIVE MIND TRAINING START ---\n", .{});
std.debug.print("Mappa: {d}x{d} | Formiche: {d}\n", .{ env.GRID_SIZE, env.GRID_SIZE, env.NUM_ANTS });
var global_step: usize = 0;
var epsilon: f32 = EPSILON_START;
while (true) {
world.evaporatePheromones();
for (0..env.NUM_ANTS) |i| {
const current_obs = try world.getAntObservation(allocator, i);
defer allocator.free(current_obs);
var action: usize = 0;
const q_values = net.forward(current_obs);
if (random.float(f32) < epsilon) {
action = random.intRangeAtMost(usize, 0, 3);
} else {
action = argmax(q_values);
}
const result = world.stepAnt(i, action);
const reward = result[0];
var target_val = reward;
const next_obs = try world.getAntObservation(allocator, i);
defer allocator.free(next_obs);
const next_q_values = net.forward(next_obs);
target_val += GAMMA * maxVal(next_q_values);
var target_vector = try allocator.alloc(f32, 4);
defer allocator.free(target_vector);
for (0..4) |j| target_vector[j] = q_values[j];
target_vector[action] = target_val;
_ = try net.train(current_obs, target_vector, LR);
}
global_step += 1;
if (epsilon > EPSILON_END) {
epsilon -= DECAY_RATE;
}
if (global_step % 10 == 0) {
try exportAntJSON(&world, "ant_state.json", global_step, epsilon);
if (global_step % 100 == 0) {
std.debug.print("Step: {d} | Epsilon: {d:.3} | Cibo: [{d},{d}]\r", .{ global_step, epsilon, world.food_x, world.food_y });
}
std.Thread.sleep(100 * 1_000_000);
}
}
}