134 lines
4.3 KiB
Zig
134 lines
4.3 KiB
Zig
const std = @import("std");
|
|
const World = @import("env.zig").World;
|
|
const env = @import("env.zig"); // Per accedere a costanti come GRID_SIZE
|
|
const Network = @import("modular_network.zig").Network;
|
|
|
|
// --- IPERPARAMETRI ---
|
|
const GAMMA: f32 = 0.9;
|
|
const LR: f32 = 0.005; // Basso perché output lineare
|
|
const EPSILON_START: f32 = 1.0;
|
|
const EPSILON_END: f32 = 0.05;
|
|
const DECAY_RATE: f32 = 0.0001; // Decadimento più lento dato che abbiamo tanti step
|
|
|
|
// Helper per trovare max e argmax
|
|
fn maxVal(slice: []const f32) f32 {
|
|
var m: f32 = -1.0e20; // Numero molto basso
|
|
for (slice) |v| if (v > m) {
|
|
m = v;
|
|
};
|
|
return m;
|
|
}
|
|
|
|
fn argmax(slice: []const f32) usize {
|
|
var m: f32 = -1.0e20;
|
|
var idx: usize = 0;
|
|
for (slice, 0..) |v, i| {
|
|
if (v > m) {
|
|
m = v;
|
|
idx = i;
|
|
}
|
|
}
|
|
return idx;
|
|
}
|
|
|
|
// Export aggiornato per Multi-Formica
|
|
fn exportAntJSON(world: *World, file_path: []const u8, episode: usize, epsilon: f32) !void {
|
|
const file = try std.fs.cwd().createFile(file_path, .{});
|
|
defer file.close();
|
|
|
|
// Buffer grande per contenere le coordinate di 20 formiche
|
|
var buffer: [65536]u8 = undefined;
|
|
var fba = std.heap.FixedBufferAllocator.init(&buffer);
|
|
const allocator = fba.allocator();
|
|
|
|
// Costruiamo la lista delle formiche in JSON: [[x,y], [x,y], ...]
|
|
var ants_json = std.ArrayList(u8){};
|
|
defer ants_json.deinit(allocator);
|
|
try ants_json.appendSlice(allocator, "[");
|
|
|
|
for (world.ants, 0..) |ant, i| {
|
|
if (i > 0) try ants_json.appendSlice(allocator, ",");
|
|
try std.fmt.format(ants_json.writer(allocator), "[{d},{d}]", .{ ant.x, ant.y });
|
|
}
|
|
try ants_json.appendSlice(allocator, "]");
|
|
|
|
// Scriviamo il JSON completo
|
|
// QUI RISOLVIAMO L'ERRORE: Usiamo 'epsilon' nella stringa
|
|
const json = try std.fmt.allocPrint(allocator, "{{\n \"grid_size\": {d},\n \"food\": [{d}, {d}],\n \"ants\": {s},\n \"episode\": {d},\n \"epsilon\": {d:.3}\n}}", .{ env.GRID_SIZE, world.food_x, world.food_y, ants_json.items, episode, epsilon });
|
|
|
|
try file.writeAll(json);
|
|
}
|
|
|
|
pub fn main() !void {
|
|
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
|
const allocator = gpa.allocator();
|
|
defer _ = gpa.deinit();
|
|
|
|
var world = World.init(12345);
|
|
var net = Network.init(allocator);
|
|
defer net.deinit();
|
|
|
|
try net.addLayer(15, 40, 111, true);
|
|
try net.addLayer(40, 4, 222, false);
|
|
|
|
var prng = std.Random.DefaultPrng.init(999);
|
|
const random = prng.random();
|
|
|
|
std.debug.print("--- HIVE MIND TRAINING START ---\n", .{});
|
|
std.debug.print("Mappa: {d}x{d} | Formiche: {d}\n", .{ env.GRID_SIZE, env.GRID_SIZE, env.NUM_ANTS });
|
|
|
|
var global_step: usize = 0;
|
|
var epsilon: f32 = EPSILON_START;
|
|
|
|
while (true) {
|
|
world.evaporatePheromones();
|
|
|
|
for (0..env.NUM_ANTS) |i| {
|
|
const current_obs = try world.getAntObservation(allocator, i);
|
|
defer allocator.free(current_obs);
|
|
|
|
var action: usize = 0;
|
|
const q_values = net.forward(current_obs);
|
|
|
|
if (random.float(f32) < epsilon) {
|
|
action = random.intRangeAtMost(usize, 0, 3);
|
|
} else {
|
|
action = argmax(q_values);
|
|
}
|
|
|
|
const result = world.stepAnt(i, action);
|
|
const reward = result[0];
|
|
|
|
var target_val = reward;
|
|
|
|
const next_obs = try world.getAntObservation(allocator, i);
|
|
defer allocator.free(next_obs);
|
|
const next_q_values = net.forward(next_obs);
|
|
target_val += GAMMA * maxVal(next_q_values);
|
|
|
|
var target_vector = try allocator.alloc(f32, 4);
|
|
defer allocator.free(target_vector);
|
|
for (0..4) |j| target_vector[j] = q_values[j];
|
|
target_vector[action] = target_val;
|
|
|
|
_ = try net.train(current_obs, target_vector, LR);
|
|
}
|
|
|
|
global_step += 1;
|
|
|
|
if (epsilon > EPSILON_END) {
|
|
epsilon -= DECAY_RATE;
|
|
}
|
|
|
|
if (global_step % 10 == 0) {
|
|
try exportAntJSON(&world, "ant_state.json", global_step, epsilon);
|
|
|
|
if (global_step % 100 == 0) {
|
|
std.debug.print("Step: {d} | Epsilon: {d:.3} | Cibo: [{d},{d}]\r", .{ global_step, epsilon, world.food_x, world.food_y });
|
|
}
|
|
|
|
std.Thread.sleep(100 * 1_000_000);
|
|
}
|
|
}
|
|
}
|