1. Binary classification
import numpy as np
import matplotlib.pyplot as plt
import mlpy
np.random.seed(0)
mean1, cov1, n1 = [1, 5], [[1,1],[1,2]], 200 # 200 samples of class 1
x1 = np.random.multivariate_normal(mean1, cov1, n1)
y1 = np.ones(n1, dtype=np.int)
mean2, cov2, n2 = [2.5, 2.5], [[1,0],[0,1]], 300 # 300 samples of class -1
x2 = np.random.multivariate_normal(mean2, cov2, n2)
y2 = -np.ones(n2, dtype=np.int)
x = np.concatenate((x1, x2), axis=0) # concatenate the samples
y = np.concatenate((y1, y2))
ldac = mlpy.LDAC()
ldac.learn(x, y)
w = ldac.w() # return the coefficient
b = ldac.bias() # return the bias
xx = np.arange(np.min(x[:,0]), np.max(x[:,0]), 0.01)
yy = - (w[0] * xx + b) / w[1] # separator line
fig = plt.figure(1) # plot
plot1 = plt.plot(x1[:, 0], x1[:, 1], 'ob', x2[:, 0], x2[:, 1], 'or') # 'o' means circle marker, 'b' means blue, and 'r' means red
plot2 = plt.plot(xx, yy, '--k')
plt.show()
2. Multi-class classification
mean1, cov1, n1 = [1, 25], [[1,1],[1,2]], 200 # 200 samples of class 0
x1 = np.random.multivariate_normal(mean1, cov1, n1)
y1 = np.zeros(n1, dtype=np.int)
mean2, cov2, n2 = [2.5, 22.5], [[1,0],[0,1]], 300 # 300 samples of class 1
x2 = np.random.multivariate_normal(mean2, cov2, n2)
y2 = np.ones(n2, dtype=np.int)
mean3, cov3, n3 = [5, 28], [[0.5,0],[0,0.5]], 200 # 200 samples of class 2
x3 = np.random.multivariate_normal(mean3, cov3, n3)
y3 = 2 * np.ones(n3, dtype=np.int)
x = np.concatenate((x1, x2, x3), axis=0) # concatenate the samples
y = np.concatenate((y1, y2, y3))
ldac = mlpy.LDAC()
ldac.learn(x, y)
w = ldac.w()
b = ldac.bias()
xx = np.arange(np.min(x[:,0]), np.max(x[:,0]), 0.01)
yy1 = (xx* (w[1][0]-w[0][0]) + b[1] - b[0]) / (w[0][1]-w[1][1])
yy2 = (xx* (w[2][0]-w[0][0]) + b[2] - b[0]) / (w[0][1]-w[2][1])
yy3 = (xx* (w[2][0]-w[1][0]) + b[2] - b[1]) / (w[1][1]-w[2][1])
fig = plt.figure(1) # plot
plot1 = plt.plot(x1[:, 0], x1[:, 1], 'ob', x2[:, 0], x2[:, 1], 'or', x3[:, 0], x3[:, 1], 'og')
plot2 = plt.plot(xx, yy1, '--k')
plot3 = plt.plot(xx, yy2, '--k')
plot4 = plt.plot(xx, yy3, '--k')
plt.show()

No comments:
Post a Comment