项目地址:https://github.com/emgoz/Neural-network-snake
(来源: Snakes, Neural Networks and Genetic Algorithms)
游戏界面如上图。其中红色是食物,其它颜色都是正在存活的蛇。我们的目的是培育出最聪明的蛇。
为了能够定义这样一只蛇,我们需要定义从环境中收集到的信息。我们可以把整个环境作为一个Bitmap输出。但是在这里,我们使用了蛇的视觉的方法。
(来源: Snakes, Neural Networks and Genetic Algorithms)
如上图所示,我们让蛇在120度的范围内去观察,在16个扇区内看自己到墙、自己、食物的距离,于是得到了48个变量作为输入。
(来源: Snakes, Neural Networks and Genetic Algorithms)
将48个变量作为输入,我们可以进一步构建16和16的两层神经网络,最后输出是两个结果。并且在其中使用了非线性的Logistic的激励函数。这两个结果的差表示蛇的偏移角度。
每条蛇的DNA由1091个byte构成,主要内容是神经网络的权重。
当一条蛇撞墙、装住自己,或者饿死后,会生成一只新的蛇。这条蛇的基因构建分为三步:选择、交叉、突变。
一条蛇的适应度函数有它的长度和健康度共同构成,健康度就是这条蛇有多么饥饿。一条一只抢不到食物的蛇的健康度会逐渐下降。
(来源: Snakes, Neural Networks and Genetic Algorithms)
当我们要创建一只新的蛇的时候,我们首先以适应度为权重,选择出两条蛇,然后基于他们的DNA进行后续的操作。
(来源: Snakes, Neural Networks and Genetic Algorithms)
如上图所示,我们会首先对DNA进行随机切片,然后使用相同的概率选择任何一个父母的DNA切片,从而构建成新的DNA。这里有意思的是,切片可能在一个Byte内发生,所以能够部分引发更大的变异。
(来源: Snakes, Neural Networks and Genetic Algorithms)
最后一步是在把父母基因拷贝到后代时的突变。这里的核心是如何决定突变的概率?越是初期,蛇的适应度越低,越应该更多的突变,以寻求更大的探索空间。因此突变的概率与当前蛇的最大适应度成反比。
整个项目的架构如上图。位于正中间的是我们的蛇,它从右上角获得了构建自己每个球形节点的能力,从右下角获得了神经网络判断能力,从下方获得了DNA的能力。
左侧中间的GameLoop控制了游戏的运行,从上方构建了世界,然后从右方把蛇导入。最后左下方的游戏主程序负责启动整个项目。
让我们分别查看这些核心的代码:
package main;
import gameEngine.GameLoop;
import helpers.KeyboardListener;
import javax.swing.JFrame;
public class MainWindow extends JFrame {
/**
* 游戏的主程序
*/
public static void main(String[] args) {
new MainWindow();
}
/**
* 基于JFrame创建界面
*/
public MainWindow() {
setDefaultCloseOperation(EXIT_ON_CLOSE);
setSize( 1000, 600);
setExtendedState(MAXIMIZED_BOTH);
setTitle("Neural Net Snake Genetic Algorithm");
KeyboardListener keyb = new KeyboardListener();
addKeyListener(keyb);
add(new GameLoop(keyb));
setVisible(true);
}
}
package gameEngine;
import java.awt.Color;
import java.awt.Graphics;
import java.util.LinkedList;
import java.util.concurrent.Semaphore;
public class World {
/**
* 构建游戏的墙、食物、全局时钟
*/
public int height, width;
public long clock;
public int maxNibbles = 20;
private Semaphore nibbleProtect = new Semaphore(1);
private LinkedList<PhysicalCircle> nibbles = new LinkedList<PhysicalCircle>();
public void newNibble(int n) { //创建食物
try {
nibbleProtect.acquire();
} catch (InterruptedException e) {
e.printStackTrace();
}
for (int i = 0; i < n; i++) {
if (nibbles.size() >= maxNibbles)
break;
PhysicalCircle nibble = new PhysicalCircle(0, 0, GameLoop.globalCircleRadius);
nibble.x = Math.random() * (width - 2 * nibble.rad) + nibble.rad;
nibble.y = Math.random() * (height - 2 * nibble.rad) + nibble.rad;
nibble.vx = 2 * (Math.random() - .5);
nibble.vy = 2 * (Math.random() - .5);
nibble.t = 0;
nibbles.add(nibble);
}
nibbleProtect.release();
}
public LinkedList<PhysicalCircle> getNibbles() {
return nibbles;
}
public int calcValue(PhysicalCircle p) {
return (int) (5 + (8d * Math.min(Math.exp(-(double) (p.t - 800) / 2000d), 1)));
}
public void update(int w, int h) {
this.width = w;
this.height = h;
for (PhysicalCircle p : nibbles) {
p.updatePosition();
p.collideWall(50, 50, w - 50, h - 50);
}
clock += GameLoop.UPDATEPERIOD;
}
public void draw(Graphics g) {
g.setColor(Color.RED);
for (PhysicalCircle nibble : nibbles) {
g.fillOval((int) (nibble.x - nibble.rad), (int) (nibble.y - nibble.rad), (int) (2 * nibble.rad + 1), (int) (2 * nibble.rad + 1));
}
}
public void removeNibbles(LinkedList<PhysicalCircle> rem) {
try {
nibbleProtect.acquire();
} catch (InterruptedException e) {
e.printStackTrace();
}
for (PhysicalCircle p : rem) {
nibbles.remove(p);
}
nibbleProtect.release();
}
public void reset() {
try {
nibbleProtect.acquire();
} catch (InterruptedException e) {
e.printStackTrace();
}
nibbles.clear();
nibbleProtect.release();
clock = 0;
}
}
package gameEngine;
import java.awt.Point;
public class PhysicalCircle {
/*
* 创建一个球形,拥有位置、速度、半径等特征
*/
public double x;
public double y;
public double vx = 0;
public double vy = 0;
public double rad;
public long t = 0;
/**
* @return point
*/
public Point toPoint() {
return new Point((int) x, (int) y);
}
/**
*
* @param x
* @param y
* @param rad 半径
*/
public PhysicalCircle(double x, double y, double rad) {
this.x = x;
this.y = y;
this.rad = rad;
}
/**
*
* @param p 初始化用的点
*/
public PhysicalCircle(Point p) {
this.x = p.x;
this.y = p.y;
}
/**
* 计算和墙壁的碰撞
*
* @param xmin left border
* @param ymin top border
* @param xmax right border
* @param ymax bottom border
*/
public void collideWall(double xmin, double ymin, double xmax, double ymax) {
if (x - rad < xmin) {
x = xmin + rad;
vx = -vx * .9;
}
if (x + rad > xmax) {
x = xmax - rad;
vx = -vx * .9;
}
if (y - rad < ymin) {
y = ymin + rad;
vy = -vy * .9;
}
if (y + rad > ymax) {
y = ymax - rad;
vy = -vy * .9;
}
}
/**
* 限制速度
*
* @param maxspeed 最高速度
* @param fadeout 衰减因子
*/
public void constrainSpeed(double maxspeed, double fadeout) {
vx *= fadeout;
vy *= fadeout;
if (Math.abs(vx) < .001)
vx = 0;
if (Math.abs(vy) < .001)
vy = 0;
if (vx > maxspeed)
vx = maxspeed;
if (vx < -maxspeed)
vx = -maxspeed;
if (vy > maxspeed)
vy = maxspeed;
if (vy < -maxspeed)
vy = -maxspeed;
}
/**
* 基于速度更新位置
*/
public void updatePosition() {
x += vx;
y += vy;
}
/**
* 判断球形的相互碰撞
*
* @param o 另一个球形
*/
public void collideStatic(PhysicalCircle o) {
if (this == o)
return;
double s = this.rad + o.rad;
double d = Math.sqrt((this.x - o.x) * (this.x - o.x) + (this.y - o.y) * (this.y - o.y));
double a = Math.atan2(this.y - o.y, this.x - o.x);
if (d < s) {
this.x = o.x + s * Math.cos(a);
this.y = o.y + s * Math.sin(a);
}
}
/**
* 判断和另一个球形的碰撞,并且弹开
*
* @param o 另一个球形
* @param speed 限速
*/
public void collideBouncy(PhysicalCircle o, double speed) {
if (this == o)
return;
double s = this.rad + o.rad;
double d = Math.sqrt((this.x - o.x) * (this.x - o.x) + (this.y - o.y) * (this.y - o.y));
double a = Math.atan2(this.y - o.y, this.x - o.x);
if (d < s) {
this.x = o.x + s * Math.cos(a);
this.y = o.y + s * Math.sin(a);
this.vx -= (o.x - this.x) * 2 / d * speed / 5;
this.vy -= (o.y - this.y) * 2 / d * speed / 5;
}
}
/**
* 让球形跟着跑
*
* @param o 另一个球
*/
public void followBouncy(PhysicalCircle o) {
if (this == o)
return;
double s = this.rad + o.rad;
double a = Math.atan2(this.y - o.y, this.x - o.x);
this.vx += (o.x + s * Math.cos(a) - this.x) / s / 32;
this.vy += (o.y + s * Math.sin(a) - this.y) / s / 32;
this.x += (o.x + s * Math.cos(a) - this.x) / s * 24 + o.vx * .24;
this.y += (o.y + s * Math.sin(a) - this.y) / s * 24 + o.vy * .24;
}
/**
* 让球形挨着跑
*
* @param o 另一个球
*/
public void followStatic(PhysicalCircle o) {
if (this == o)
return;
double s = this.rad + o.rad;
double a = Math.atan2(this.y - o.y, this.x - o.x);
this.x = (o.x + s * Math.cos(a));
this.y = (o.y + s * Math.sin(a));
}
/**
* 判断两个球是否足够接近
*
* @param o 另一个球
* @param thresholdDistance 最近距离
* @return 是否足够接近
*/
public boolean isColliding(PhysicalCircle o, double thresholdDistance) {
double d = Math.sqrt((this.x - o.x) * (this.x - o.x) + (this.y - o.y) * (this.y - o.y));
double s = this.rad + o.rad;
return d < s + thresholdDistance;
}
public double getAbsoluteVelocity() {
return Math.sqrt(vx * vx + vy * vy);
}
public double getDistanceTo(PhysicalCircle o) {
return Math.sqrt((this.x - o.x) * (this.x - o.x) + (this.y - o.y) * (this.y - o.y)) - (this.rad - o.rad) / 2;
}
public double getAngleTo(PhysicalCircle o) {
return Math.atan2(o.y - this.y, o.x - this.x);
}
}
package gameEngine;
import genetics.DNA;
import helpers.DoubleMath;
import java.awt.Color;
import java.awt.Graphics;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import neuralNetwork.NeuralNet;
import neuralNetwork.Stage;
public class Snake {
// 移动相关的常量
public static final double maximumForwardSpeed = 5;
public static final double maximumAngularSpeed = Math.PI / 32d;
public static final double wallCollisionThreshold = 4;
// 视野的常量:
public static final double maximumSightDistance = 600;
public static final double fieldOfView = Math.PI * 2 / 3;
// 神经网络的常量:
public static final int FOVDIVISIONS = 8;
public static final int FIRSTSTAGESIZE = FOVDIVISIONS * 2 * 3;
public static final int stageSizes[] = new int[] { FIRSTSTAGESIZE, 16, 16, 2 };
public static final boolean isNNSymmetric = false;
// 分数常量:
public static final double nibblebonus = 20;
public static final int healthbonus = 10; // 吃到食物后增加
public static final double healthdecrement = .02; // 固定时间降低的比例
// 彩蛋
public final boolean displayCuteEyes = false; // 小彩蛋,是否有眼睛
public final boolean snakeInertia = false;
// 蛇的基本变量
public ArrayList<PhysicalCircle> snakeSegments = new ArrayList<PhysicalCircle>(100);
public DNA dna;
public NeuralNet brainNet;
public double age = 0;
public double angle;
public double score;
public boolean isDead;
public float hue;
public double deathFade = 180;
public double health;
/**
* 用一个DNA初始化一条蛇
*
* @param dna
* 如果为null,随机生成DNA
* @param world
* 世界
*/
public Snake(DNA dna, World world) {
double x = Math.random() * (world.width - 2 * wallCollisionThreshold - 2 * GameLoop.globalCircleRadius) + wallCollisionThreshold
+ GameLoop.globalCircleRadius;
double y = Math.random() * (world.height - 2 * wallCollisionThreshold - 2 * GameLoop.globalCircleRadius) + wallCollisionThreshold
+ GameLoop.globalCircleRadius;
int dnalength = NeuralNet.calcNumberOfCoeffs(stageSizes, isNNSymmetric) + 1;
if (dna == null) {
this.dna = new DNA(true, dnalength);
} else {
this.dna = dna;
}
snakeSegments.clear();
for (int i = 0; i < 1; i++) {
snakeSegments.add(new PhysicalCircle(x, y, GameLoop.globalCircleRadius));
}
this.angle = Math.atan2(world.height / 2 - y, world.width / 2 - x);
// 初始化大脑的神经网络
brainNet = new NeuralNet(stageSizes);
reloadFromDNA();
score = 0;
deathFade = 180;
isDead = false;
health = healthbonus * 3 / 2;
age = 0;
}
/**
* 从DNA导入数据
*/
public void reloadFromDNA() {
if (isNNSymmetric)
brainNet.loadCoeffsSymmetrical(this.dna.data);
else
brainNet.loadCoeffs(this.dna.data);
this.hue = (float) this.dna.data[this.dna.data.length - 1] / 256f;
}
/**
* 移动、老化、碰撞
*
* @param world
* 游戏的世界
* @return 是否死掉
*/
public boolean update(World world) {
if (isDead) {
deathFade -= .6;
return true;
}
age += .1;
double slowdown = 49d / (48d + snakeSegments.size());
PhysicalCircle head = snakeSegments.get(0);
// 通过神经网络计算移动的角度偏移
double angleIncrement = brain(world);
angle += slowdown * angleIncrement;
angle = DoubleMath.doubleModulo(angle, Math.PI * 2);
// 碰墙
if (head.x - head.rad < wallCollisionThreshold) {
score /= 2;
isDead = true;
}
if (head.x + head.rad > world.width - wallCollisionThreshold) {
score /= 2;
isDead = true;
}
if (head.y - head.rad < wallCollisionThreshold) {
score /= 2;
isDead = true;
}
if (head.y + head.rad > world.height - wallCollisionThreshold) {
score /= 2;
isDead = true;
}
// 主要的移动
head.vx = maximumForwardSpeed * slowdown * Math.cos(angle);
head.vy = maximumForwardSpeed * slowdown * Math.sin(angle);
PhysicalCircle previous = head;
for (int i = 0; i < snakeSegments.size(); i++) {
PhysicalCircle c = snakeSegments.get(i);
if (snakeInertia){
c.followBouncy(previous);
} else {
c.followStatic(previous);
}
c.updatePosition();
for (int j = 0; j < i; j++) {
c.collideStatic(snakeSegments.get(j));
}
previous = c;
if (i > 1 && head.isColliding(c, 0)) {
isDead = true;
score /= 2;
break;
}
}
// 吃食物
LinkedList<PhysicalCircle> nibblesToRemove = new LinkedList<PhysicalCircle>();
int nibbleEatCount = 0;
for (PhysicalCircle nibble : world.getNibbles()) {
if (head.isColliding(nibble, -10)) {
score += world.calcValue(nibble);
snakeSegments.add(new PhysicalCircle(snakeSegments.get(snakeSegments.size() - 1).x, snakeSegments.get(snakeSegments.size() - 1).y, nibble.rad));
nibblesToRemove.add(nibble);
nibbleEatCount++;
}
}
score += nibbleEatCount * nibblebonus;
world.newNibble(nibbleEatCount);
world.removeNibbles(nibblesToRemove);
// 健康度更新
health += nibbleEatCount * healthbonus;
if (health > 3 * healthbonus) // 吃饱了
health = 3 * healthbonus;
health -= healthdecrement;
if (health <= 0) {
isDead = true;
score /= 2;
}
return !isDead;
}
/**
* 适应度函数
*
* @return 计算一条蛇的适应度
*/
public double getFitness() {
return score + health / 4;
}
/**
* 用于视野的结构体
*/
public class Thing {
public double distance = maximumSightDistance;
public int type = 0;
// 墙 = 0;
// 自身 = 1;
// 食物 = 2;
}
/**
* 大脑的计算
*
* @param world
* 游戏世界
* @return 角度的偏移量
*/
public double brain(World world) {
// c初始化
Thing input[] = new Thing[FOVDIVISIONS * 2];
for (int i = 0; i < FOVDIVISIONS * 2; i++)
input[i] = new Thing();
// 看食物
input = updateVisualInput(input, world.getNibbles(), 2);
// 看自己
input = updateVisualInput(input, snakeSegments, 1);
// 看墙
int step = (int) (maximumSightDistance * Math.sin(fieldOfView / (FOVDIVISIONS * 1d))) / 20;
LinkedList<PhysicalCircle> walls = new LinkedList<PhysicalCircle>();
for (int x = 0; x < world.width; x += step) {
walls.add(new PhysicalCircle(x, 0, 1));
walls.add(new PhysicalCircle(x, world.height, 1));
}
for (int y = 0; y < world.height; y += step) {
walls.add(new PhysicalCircle(0, y, 1));
walls.add(new PhysicalCircle(world.width, y, 1));
}
input = updateVisualInput(input, walls, 0);
// 转化为神经网络的输入
double stageA[] = new double[FIRSTSTAGESIZE]; // zeros initialized ;)
if (isNNSymmetric) {
for (int i = 0; i < FOVDIVISIONS; i++) {
stageA[input[i].type * FOVDIVISIONS + i] = Stage.signalMultiplier * (maximumSightDistance - input[i].distance) / maximumSightDistance;
stageA[FIRSTSTAGESIZE - 1 - (input[i + FOVDIVISIONS].type * FOVDIVISIONS + i)] = Stage.signalMultiplier
* (maximumSightDistance - input[i + FOVDIVISIONS].distance) / maximumSightDistance;
}
} else {
for (int i = 0; i < FOVDIVISIONS; i++) {
stageA[input[i].type * FOVDIVISIONS * 2 + i] = Stage.signalMultiplier * (maximumSightDistance - input[i].distance) / maximumSightDistance;
stageA[input[i + FOVDIVISIONS].type * FOVDIVISIONS * 2 + FOVDIVISIONS * 2 - 1 - i] = Stage.signalMultiplier
* (maximumSightDistance - input[i + FOVDIVISIONS].distance) / maximumSightDistance;
}
}
double output[] = brainNet.calc(stageA);
double delta = output[0] - output[1];
double angleIncrement = 10 * maximumAngularSpeed / Stage.signalMultiplier * delta;
if (angleIncrement > maximumAngularSpeed)
angleIncrement = maximumAngularSpeed;
if (angleIncrement < -maximumAngularSpeed)
angleIncrement = -maximumAngularSpeed;
return angleIncrement;
}
/**
* 把看到的东西转化为可以后续计算的数组
*
* @param input
* 蛇看到的东西
* @param objects
* 需要检查的物体
* @param type
* 类型: 0: 墙, 1: 自己, 2: 食物
* @return 更新后的数组
*/
private Thing[] updateVisualInput(Thing input[], List<PhysicalCircle> objects, int type) {
PhysicalCircle head = snakeSegments.get(0);
for (PhysicalCircle n : objects) {
if (head == n)
continue;
double a = DoubleMath.signedDoubleModulo(head.getAngleTo(n) - angle, Math.PI * 2);
double d = head.getDistanceTo(n);
if (a >= 0 && a < fieldOfView) {
if (d < input[(int) (a * FOVDIVISIONS / fieldOfView)].distance) {
input[(int) (a * FOVDIVISIONS / fieldOfView)].distance = d;
input[(int) (a * FOVDIVISIONS / fieldOfView)].type = type;
}
} else if (a <= 0 && -a < fieldOfView) {
if (d < input[(int) (-a * FOVDIVISIONS / fieldOfView) + FOVDIVISIONS].distance) {
input[(int) (-a * FOVDIVISIONS / fieldOfView) + FOVDIVISIONS].distance = d;
input[(int) (-a * FOVDIVISIONS / fieldOfView) + FOVDIVISIONS].type = type;
}
}
}
return input;
}
/**
* 把蛇画出来
*
* @param g
* 画的地方
*/
public void draw(Graphics g) {
// 蛇的蛇体
int alpha = (int) deathFade;
for (int i = 0; i < snakeSegments.size(); i++) {
Color c = new Color(Color.HSBtoRGB(hue, 1 - (float) i / ((float) snakeSegments.size() + 1f), 1));
g.setColor(new Color(c.getRed(), c.getGreen(), c.getBlue(), alpha));
PhysicalCircle p = snakeSegments.get(i);
g.fillOval((int) (p.x - p.rad), (int) (p.y - p.rad), (int) (2 * p.rad + 1), (int) (2 * p.rad + 1));
}
// 眼睛
if (displayCuteEyes) {
PhysicalCircle p = snakeSegments.get(0); // 头
double dist = p.rad / 2.3;
double size = p.rad / 3.5;
g.setColor(new Color(255, 255, 255, alpha));
g.fillOval((int) (p.x + p.vy * dist / p.getAbsoluteVelocity() - size), (int) (p.y - p.vx * dist / p.getAbsoluteVelocity() - size),
(int) (size * 2 + 1), (int) (size * 2 + 1));
g.fillOval((int) (p.x - p.vy * dist / p.getAbsoluteVelocity() - size), (int) (p.y + p.vx * dist / p.getAbsoluteVelocity() - size),
(int) (size * 2 + 1), (int) (size * 2 + 1));
size = p.rad / 6;
g.setColor(new Color(0, 0, 0, alpha));
g.fillOval((int) (p.x + p.vy * dist / p.getAbsoluteVelocity() - size), (int) (p.y - p.vx * dist / p.getAbsoluteVelocity() - size),
(int) (size * 2 + 1), (int) (size * 2 + 1));
g.fillOval((int) (p.x - p.vy * dist / p.getAbsoluteVelocity() - size), (int) (p.y + p.vx * dist / p.getAbsoluteVelocity() - size),
(int) (size * 2 + 1), (int) (size * 2 + 1));
}
}
}
package gameEngine;
import genetics.DNA;
import helpers.KeyboardListener;
import java.awt.Color;
import java.awt.Font;
import java.awt.Graphics;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import javax.swing.JComponent;
public class GameLoop extends JComponent {
// 更新频率
public static final long UPDATEPERIOD = 8;
public double per = UPDATEPERIOD;
// 常量
public static final int globalCircleRadius = 20;
public static final int numSnakes = 8;
public static final int numNibbles = 4;
// 基因参数初始化
public static double mutationrate = .02;
public double currentGeneration = 0;
// 蛇和世界的初始化
public World world = new World();
public LinkedList<Snake> snakes = new LinkedList<Snake>();
public LinkedList<Snake> backupSnakes = new LinkedList<Snake>();
// 当前最优值
public DNA bestDna = null;
public double bestscore = 0;
// 统计值
public LinkedList<Double> fitnessTimeline = new LinkedList<Double>();
public double currentMaxFitness = 0;
// 相关控制
public boolean singleSnakeModeActive = false;
public boolean displayStatisticsActive = false;
public boolean simulationPaused = false;
/**
* 核心控制
*/
public GameLoop(KeyboardListener keyb) {
world.height = 200;
world.width = 300;
new Thread(new Runnable() {
private long simulationLastMillis;
private long statisticsLastMillis;
public void run() {
simulationLastMillis = System.currentTimeMillis() + 100;
statisticsLastMillis = 0;
while (true) {
if (System.currentTimeMillis() - simulationLastMillis > UPDATEPERIOD) {
synchronized (snakes) {
long currentTime = System.currentTimeMillis();
// 键盘控制
char keyCode = (char) keyb.getKey();
switch (keyCode) {
case ' ':
if (!singleSnakeModeActive) {
singleSnakeModeActive = true;
displayStatisticsActive = false;
backupSnakes.clear();
backupSnakes.addAll(snakes);
snakes.clear();
snakes.add(new Snake(bestDna, world));
}
break;
case 'A': // a = 暂定
simulationPaused = true;
break;
case 'B': // b = 恢复
simulationPaused = false;
break;
case 'C': // c = 显示状态
displayStatisticsActive = true;
break;
case 'D': // d = 隐藏状态
displayStatisticsActive = false;
break;
}
// 初始化第一代蛇
if (snakes.isEmpty()) {
firstGeneration(numSnakes);
world.newNibble(numNibbles);
}
// 计算
if (!simulationPaused) {
int deadCount = 0;
world.update(getWidth(), getHeight());
synchronized (fitnessTimeline) {
if (world.clock - statisticsLastMillis > 1000 && !singleSnakeModeActive) {
fitnessTimeline.addLast(currentMaxFitness);
currentMaxFitness = 0;
if (fitnessTimeline.size() >= world.width / 2) {
fitnessTimeline.removeFirst();
}
statisticsLastMillis = world.clock;
}
}
for (Snake s : snakes) {
if (!s.update(world)) {
deadCount++;
}
if (s.getFitness() > currentMaxFitness)
currentMaxFitness = s.getFitness();
if (s.getFitness() > bestscore) {
bestscore = s.getFitness();
bestDna = s.dna;
}
}
if (deadCount > 0 && singleSnakeModeActive) {
singleSnakeModeActive = false;
snakes.clear();
snakes.addAll(backupSnakes);
} else {
// new snakes
for (int i = 0; i < deadCount; i++) {
newSnake();
currentGeneration += 1 / (double) numSnakes;
}
}
Iterator<Snake> it = snakes.iterator();
while (it.hasNext()) {
Snake s = it.next();
if (s.deathFade <= 0) {
it.remove();
}
}
} else {
// print status:
snakes.get(0).brain(world);
}
repaint();
per = System.currentTimeMillis() - currentTime;
simulationLastMillis += UPDATEPERIOD;
}
}
}
}
}).start();
}
/**
* 初始化n条蛇
*
* @param n
* 蛇的数量
*/
public void firstGeneration(int n) {
snakes.clear();
for (int i = 0; i < n; i++) {
snakes.add(new Snake(null, world));
}
world.reset();
}
/**
* 创建繁衍池
*
* @return 用于繁衍的蛇的列表
*/
public ArrayList<Snake> makeMatingpool() {
ArrayList<Snake> matingpool = new ArrayList<Snake>();
// 得到最大的生存度
double maxscore = 0;
for (Snake s : snakes) {
if (s.getFitness() > maxscore) {
maxscore = s.getFitness();
}
}
// 基于生存度添加不同数量的蛇的副本
for (Snake s : snakes) {
int amount = (int) (s.getFitness() * 100 / maxscore);
for (int i = 0; i < amount; i++) {
matingpool.add(s);
}
}
return matingpool;
}
/**
* 基于基因算法生成一条新蛇
*/
public void newSnake() {
mutationrate = 10 / currentMaxFitness;
ArrayList<Snake> matingpool = makeMatingpool();
int idx1 = (int) (Math.random() * matingpool.size());
int idx2 = (int) (Math.random() * matingpool.size());
DNA parentA = matingpool.get(idx1).dna;
DNA parentB = matingpool.get(idx2).dna;
snakes.add(new Snake(parentA.crossoverBytewise(parentB, mutationrate), world));
}
/**
* 画图
*/
protected void paintComponent(Graphics g) {
super.paintComponent(g);
// 背景
g.setColor(Color.black);
g.fillRect(0, 0, getWidth(), getHeight());
// 状态
if (displayStatisticsActive) {
g.setColor(Color.DARK_GRAY);
g.setFont(new Font("Arial", 0, 64));
g.drawString("t = " + Long.toString(world.clock / 1000), 20, 105);
g.drawString("g = " + Integer.toString((int) currentGeneration), 20, 205);
g.setFont(new Font("Arial", 0, 32));
g.drawString("Mut. Prob.: " + String.format("%1$,.3f", mutationrate), 20, 305);
g.drawString("Max fitness: " + Integer.toString((int) currentMaxFitness), 20, 355);
// 时间线
synchronized (fitnessTimeline) {
if (!fitnessTimeline.isEmpty()) {
double last = fitnessTimeline.getFirst();
int x = 0;
double limit = getHeight();
if (limit < bestscore)
limit = bestscore;
for (Double d : fitnessTimeline) {
g.setColor(new Color(0, 1, 0, .5f));
g.drawLine(x, (int) (getHeight() - getHeight() * last / limit), x + 2, (int) (getHeight() - getHeight() * d / limit));
last = d;
x += 2;
}
}
}
}
// 神经网络
if (singleSnakeModeActive) {
snakes.getFirst().brainNet.display(g, 0, world.width, world.height);
}
// 蛇
synchronized (snakes) {
for (Snake s : snakes)
s.draw(g);
world.draw(g);
}
}
}
package genetics;
import java.util.Arrays;
import java.util.Random;
public class DNA {
/**
* 实现DNA建模、交叉、变异
*/
public Random random = new Random();
public byte data[];
public DNA(boolean random, int size){
data = new byte[size];
for (int i = 0; i < data.length; i++){
data[i] = random?(byte)Math.floor(Math.random()*256d):0;
}
}
/**
* 基于两个DNA生成新的DNA,包括交叉和变异,byte级别
*/
public DNA crossoverNoise(DNA other, double mutationprob){ //按照byte变异
DNA newdna = new DNA(false, data.length);
int numswaps = data.length/10;
int swaps[] = new int[numswaps+1];
for (int i = 0; i < swaps.length-1; i++){
swaps[i] = (int)Math.floor(Math.random()*data.length);
}
swaps[numswaps] = data.length; //save last
Arrays.sort(swaps);
int swapidx = 0;
boolean that = true;
for (int i = 0; i < data.length; i++){
if (i >= swaps[swapidx]){
swapidx++;
that = !that;
}
byte d = 0;
if (that){
d = this.data[i];
}
else {
d = other.data[i];
}
d += (byte)(random.nextGaussian()*mutationprob*256);
newdna.data[i] = d;
}
return newdna;
}
/**
* Gaussian变异
*/
public void mutateNoise(double prob, double mag){
for (int i = 0; i < data.length; i++){
if (Math.random() < prob) data[i] += (byte)(random.nextGaussian()*mag*256);
}
}
/**
* 基于两个DNA生成新的DNA,包括交叉和变异,bit级别
*/
public DNA crossover(DNA other, double mutationprob){
DNA newdna = new DNA(false, data.length);
int numswaps = data.length/8;
int swaps[] = new int[numswaps+1];
for (int i = 0; i < swaps.length-1; i++){
swaps[i] = (int)Math.floor(Math.random()*8*data.length);
}
swaps[numswaps] = 8*data.length;
Arrays.sort(swaps);
int swapidx = 0;
boolean that = true;
for (int i = 0; i < 8*data.length; i++){
if (i >= swaps[swapidx]){
swapidx++;
that = !that;
}
int bit = 0;
if (that){
bit = ((this.data[i/8] >> (i%8)) & 1);
}
else {
bit = ((other.data[i/8] >> (i%8)) & 1);
}
if (Math.random() < mutationprob) bit = 1-bit;
newdna.data[i/8] |= (bit << (i%8));
}
return newdna;
}
/**
* 基于两个DNA生成新的DNA,包括交叉和变异;byte级别交叉,bit级别变异
*/
public DNA crossoverBytewise(DNA other, double mutationprob){
DNA newdna = new DNA(false, data.length);
int numswaps = data.length/8;
int swaps[] = new int[numswaps+1];
for (int i = 0; i < swaps.length-1; i++){
swaps[i] = 8*(int)Math.floor(Math.random()*data.length);
}
swaps[numswaps] = 8*data.length; //save last
Arrays.sort(swaps);
int swapidx = 0;
boolean that = true;
for (int i = 0; i < 8*data.length; i++){
if (i >= swaps[swapidx]){
swapidx++;
that = !that;
}
int bit = 0;
if (that){
bit = ((this.data[i/8] >> (i%8)) & 1);
}
else {
bit = ((other.data[i/8] >> (i%8)) & 1);
}
if (Math.random() < mutationprob) bit = 1-bit;
newdna.data[i/8] |= (bit << (i%8));
}
return newdna;
}
}
package neuralNetwork;
import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
public class NeuralNet {
public Stage stages[];
/**
* C'tor
*
* @param stageSizes
* 定义每一层的节点数,例如
* {48,16,16,2}
*/
public NeuralNet(int stageSizes[]) {
stages = new Stage[stageSizes.length];
Stage prev = null;
for (int i = 0; i < stageSizes.length; i++) {
stages[i] = new Stage(prev, stageSizes[i]);
prev = stages[i];
}
}
/**
* 导入权重
* @param coeffs 参数范围 -128 到 +127
*/
public void loadCoeffs(byte coeffs[]) {
int idx = 0;
for (int s = 1; s < stages.length; s++) {
for (int i = 0; i < stages[s].coeffs.length; i++) {
for (int j = 0; j < stages[s].coeffs[0].length; j++) {
stages[s].coeffs[i][j] = coeffs[idx++];
}
}
}
}
/**
* 导入权重,但是对称的载入
* @param coeffs 参数范围 -128 到 +127
*/
public void loadCoeffsSymmetrical(byte coeffs[]) {
int idx = 0;
for (int s = 1; s < stages.length; s++) {
if (stages[s].coeffs.length % 2 == 1) {
System.err.println("Symmetrical Net without even sized stages. Bad.");
return;
}
for (int i = 0; i < (stages[s].coeffs.length) / 2; i++) {
for (int j = 0; j < stages[s].coeffs[0].length; j++) {
stages[s].coeffs[i][j] = coeffs[idx];
stages[s].coeffs[stages[s].coeffs.length - 1 - i][stages[s].coeffs[0].length - 1 - j] = coeffs[idx++];
}
}
}
}
/**
* 基于输出计算神经网络输出
*
* @param input 输入的第一层值
* @return 输出的值
*/
public double[] calc(double input[]) {
for (int i = 0; i < input.length; i++) {
stages[0].output[i] = input[i];
}
for (int i = 1; i < stages.length; i++) {
stages[i].calc();
}
return stages[stages.length - 1].output;
}
/**
* 计算需要的权重个数,用来计算DNA的长度
*
* @param stageSizes 定义神经网络的结构,例如
* {48,16,16,2}
* @param symmetrical 网络是否对称
* @return 需要的权重个数
*/
public static int calcNumberOfCoeffs(int stageSizes[], boolean symmetrical) {
int sum = 0;
if (stageSizes.length < 2)
return 0;
for (int i = 1; i < stageSizes.length; i++) {
if (symmetrical)
sum += (stageSizes[i] * (stageSizes[i - 1] + 1) + 1) / 2;
else
sum += stageSizes[i] * (stageSizes[i - 1] + 1);
}
return sum;
}
public String toString() {
String k = "";
for (int s = 1; s < stages.length; s++) {
k += "\nStage " + s + ": \n" + stages[s].toString();
}
return k;
}
/**
* 画出当前的神经网络
*
* @param g 画的地方
* @param alpha 透明度 0 .. 1
* @param w 宽度
* @param h 高度
*/
public void display(Graphics g, float alpha, double w, double h) {
Graphics2D g2 = (Graphics2D) g;
int d = 20;
// 突触
for (int s = 1; s < stages.length; s++) {
int x1 = (s) * (int) (w / (stages.length + 1));
int x2 = (s + 1) * (int) (w / (stages.length + 1));
for (int i = 0; i < stages[s].coeffs.length; i++) {
for (int j = 0; j < stages[s].coeffs[0].length - 1; j++) {
int c = stages[s].coeffs[i][j];
if (Math.abs(c) < 48)
continue;
g2.setStroke(new BasicStroke(Math.abs(c) * 3 / 129));
int y1 = (j + 1) * (int) (h / (stages[s - 1].output.length + 1));
int y2 = (i + 1) * (int) (h / (stages[s].output.length + 1));
float b = (float) (stages[s - 1].output[j] / Stage.signalMultiplier);
if (c < 0)
g.setColor(new Color(b, 0, 0));
else
g.setColor(new Color(0, b, 0));
g2.drawLine(x1, y1, x2, y2);
}
}
}
// 神经元
for (int s = 0; s < stages.length; s++) {
int x = (s + 1) * (int) (w / (stages.length + 1));
d = (int) (h / (stages[s].output.length + 7));
for (int i = 0; i < stages[s].output.length; i++) {
int y = (i + 1) * (int) (h / (stages[s].output.length + 1));
float output = (float) (stages[s].output[i] / Stage.signalMultiplier * .8 + .2);
g.setColor(new Color(Color.HSBtoRGB(.6f, 1, output)));
g.fillOval(x - d / 2, y - d / 2, d, d);
}
}
}
}
package neuralNetwork;
public class Stage {
public static final double signalMultiplier = .1;
public Stage prev;
public double output[];
public byte coeffs[][];
public Stage(Stage prev, int size){
this.prev = prev;
output = new double[size];
if (prev != null)
coeffs = new byte[size][prev.output.length+1];
else
coeffs = new byte[0][0];
}
/**
* 基于输入,计算输出
*/
public void calc(){
if (prev == null) return;
for (int i = 0; i < coeffs.length; i++){
double sum = 0;
for (int j = 0; j < coeffs[0].length-1; j++){
sum += coeffs[i][j]*prev.output[j];
}
sum += coeffs[i][coeffs[0].length-1]*signalMultiplier;
output[i] = sigmoid(sum);
}
}
public static double sigmoid(double x) {
return signalMultiplier/(1+Math.exp(-x/2d));
}
public String toString(){
String k = "[";
for (int i = 0; i< coeffs.length; i++){
k += "[";
for (int j = 0; j < coeffs[0].length; j++){
k += Byte.toString(coeffs[i][j])+" ";
}
k += "]\n ";
}
k+= "]\n";
return k;
}
}
package helpers;
public class DoubleMath {
public static double doubleModulo(double a, double b){
int k = (int)(a/b);
if (a < 0) k--;
return a-b*k;
}
public static double signedDoubleModulo(double a, double b){
double c = doubleModulo(a, b);
if (c >= b/2) c-=b;
return c;
}
}
package helpers;
import java.awt.event.KeyEvent;
import java.awt.event.KeyListener;
public class KeyboardListener implements KeyListener {
private int code = 0;
@Override
public void keyPressed(KeyEvent e) {
code = e.getKeyCode();
}
@Override
public void keyReleased(KeyEvent e) {
code = 0;
}
@Override
public void keyTyped(KeyEvent arg0) {
}
public int getKey() {
return code;
}
}
综上所述,我们学习了一个基于遗传基因算法的人工智能的贪吃蛇的实现。让我们进一步学习更多的实现方法,来掌握不同的人工智能算法之间的异同和取舍。