Skip to content

Machine learning examples #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions examples/common/idxio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/usr/bin/env python


#######################################################
# Copyright (c) 2024, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
# The complete license agreement can be obtained at:
# http://arrayfire.com/licenses/BSD-3-Clause
########################################################
def reverse_char(b: int) -> int:
b = (b & 0xF0) >> 4 | (b & 0x0F) << 4
b = (b & 0xCC) >> 2 | (b & 0x33) << 2
b = (b & 0xAA) >> 1 | (b & 0x55) << 1
return b


# http://stackoverflow.com/a/9144870/2192361
def reverse(x: int) -> int:
x = ((x >> 1) & 0x55555555) | ((x & 0x55555555) << 1)
x = ((x >> 2) & 0x33333333) | ((x & 0x33333333) << 2)
x = ((x >> 4) & 0x0F0F0F0F) | ((x & 0x0F0F0F0F) << 4)
x = ((x >> 8) & 0x00FF00FF) | ((x & 0x00FF00FF) << 8)
x = ((x >> 16) & 0xFFFF) | ((x & 0xFFFF) << 16)
return x


def read_idx(name: str) -> tuple[list[int], list[float]]:
with open(name, "rb") as f:
# In the C++ version, bytes the size of 4 chars are being read
# May not work properly in machines where a char is not 1 byte
bytes_read = f.read(4)
bytes_read = bytearray(bytes_read)

if bytes_read[2] != 8:
raise RuntimeError("Unsupported data type")

numdims = bytes_read[3]
elemsize = 1

# Read the dimensions
elem = 1
dims = [0] * numdims
for i in range(numdims):
bytes_read = bytearray(f.read(4))

# Big endian to little endian
for j in range(4):
bytes_read[j] = reverse_char(bytes_read[j])
bytes_read_int = int.from_bytes(bytes_read, "little")
dim = reverse(bytes_read_int)

elem = elem * dim
dims[i] = dim

# Read the data
cdata = f.read(elem * elemsize)
cdata_list = list(cdata)
data = [float(cdata_elem) for cdata_elem in cdata_list]

return (dims, data)


if __name__ == "__main__":
# Example usage of reverse_char
byte_value = 0b10101010
reversed_byte = reverse_char(byte_value)
print(f"Original byte: {byte_value:08b}, Reversed byte: {reversed_byte:08b}")

# Example usage of reverse
int_value = 0x12345678
reversed_int = reverse(int_value)
print(f"Original int: {int_value:032b}, Reversed int: {reversed_int:032b}")
216 changes: 216 additions & 0 deletions examples/machine_learning/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
#!/usr/bin/env python

#######################################################
# Copyright (c) 2024, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
# The complete license agreement can be obtained at:
# http://arrayfire.com/licenses/BSD-3-Clause
########################################################

import sys
import time

from mnist_common import display_results, setup_mnist

import arrayfire as af


def accuracy(predicted: af.Array, target: af.Array) -> float | complex:
"""Calculates the accuracy of the predictions compares to the actual target"""
_, tlabels = af.imax(target, axis=1)
_, plabels = af.imax(predicted, axis=1)
return 100 * af.count(plabels == tlabels) / tlabels.size


def abserr(predicted: af.Array, target: af.Array) -> float | complex:
"""Calculates the mean absolute error (MAE), scaled by 100"""
return 100 * af.sum(af.abs(predicted - target)) / predicted.size


def predict_prob(X: af.Array, Weights: af.Array) -> af.Array:
"""Predict (probability) based on given parameters"""
Z = af.matmul(X, Weights)
return af.sigmoid(Z)


def predict_log_prob(X: af.Array, Weights: af.Array) -> af.Array:
"""# Predict (log probability) based on given parameters"""
return af.log(predict_prob(X, Weights))


def predict_class(X: af.Array, Weights: af.Array) -> af.Array:
"""Give most likely class based on given parameters"""
probs = predict_prob(X, Weights)
_, classes = af.imax(probs, axis=1)
return classes


def cost(Weights: af.Array, X: af.Array, Y: af.Array, lambda_param: float = 1.0) -> tuple[af.Array, af.Array]:
"""Calculate the cost of predictions made with a given set of weights"""
# Number of samples
m = Y.shape[0]

dim0 = Weights.shape[0]
dim1 = Weights.shape[1] if len(Weights.shape) > 1 else 1
dim2 = Weights.shape[2] if len(Weights.shape) > 2 else 1
dim3 = Weights.shape[3] if len(Weights.shape) > 3 else 1
# Make the lambda corresponding to Weights(0) == 0
lambdat = af.constant(lambda_param, (dim0, dim1, dim2, dim3))

# No regularization for bias weights
lambdat[0, :] = 0

# Get the prediction
H = predict_prob(X, Weights)

# Cost of misprediction
Jerr = -1 * af.sum(Y * af.log(H) + (1 - Y) * af.log(1 - H), axis=0)

# Regularization cost
Jreg = 0.5 * af.sum(lambdat * Weights * Weights, axis=0)

# Total cost
J = (Jerr + Jreg) / m

# Find the gradient of cost
D = H - Y
dJ = (af.matmul(X, D, af.MatProp.TRANS) + lambdat * Weights) / m

return J, dJ


def train(
X: af.Array,
Y: af.Array,
alpha: float = 0.1,
lambda_param: float = 1.0,
maxerr: float = 0.01,
maxiter: int = 1000,
verbose: bool = False,
) -> af.Array: # noqa :E501
"""Train a machine learning model using gradient descent to minimize the cost function."""
# Initialize parameters to 0
Weights = af.constant(0, (X.shape[1], Y.shape[1]))

for i in range(maxiter):
# Get the cost and gradient
J, dJ = cost(Weights, X, Y, lambda_param)

err = af.max(af.abs(J))
if err < maxerr: # type: ignore[operator]
print("Iteration {0:4d} Err: {1:4f}".format(i + 1, err)) # type: ignore[str-format]
print("Training converged")
return Weights

if verbose and ((i + 1) % 10 == 0):
print("Iteration {0:4d} Err: {1:4f}".format(i + 1, err)) # type: ignore[str-format]

# Update the parameters via gradient descent
Weights = Weights - alpha * dJ

if verbose:
print("Training stopped after {0:d} iterations".format(maxiter))

return Weights


def benchmark_logistic_regression(train_feats: af.Array, train_targets: af.Array, test_feats: af.Array) -> None:
t0 = time.time()
Weights = train(train_feats, train_targets, 0.1, 1.0, 0.01, 1000)
af.eval(Weights)
af.sync()
t1 = time.time()
dt = t1 - t0
print("Training time: {0:4.4f} s".format(dt))

t0 = time.time()
iters = 100
for i in range(iters):
test_outputs = predict_prob(test_feats, Weights)
af.eval(test_outputs)
af.sync()
t1 = time.time()
dt = t1 - t0
print("Prediction time: {0:4.4f} s".format(dt / iters))


def logit_demo(console: bool, perc: int) -> None:
"""Demo of one vs all logistic regression"""
# Load mnist data
frac = float(perc) / 100.0
mnist_data = setup_mnist(frac, True)
num_classes = mnist_data[0] # noqa: F841
num_train = mnist_data[1]
num_test = mnist_data[2]
train_images = mnist_data[3]
test_images = mnist_data[4]
train_targets = mnist_data[5]
test_targets = mnist_data[6]

# Reshape images into feature vectors
feature_length = int(train_images.size / num_train)
train_feats = af.transpose(af.moddims(train_images, (feature_length, num_train)))

test_feats = af.transpose(af.moddims(test_images, (feature_length, num_test)))

train_targets = af.transpose(train_targets)
test_targets = af.transpose(test_targets)

num_train = train_feats.shape[0]
num_test = test_feats.shape[0]

# Add a bias that is always 1
train_bias = af.constant(1, (num_train, 1))
test_bias = af.constant(1, (num_test, 1))
train_feats = af.join(1, train_bias, train_feats)
test_feats = af.join(1, test_bias, test_feats)

# Train logistic regression parameters
Weights = train(
train_feats,
train_targets,
0.1, # learning rate
1.0, # regularization constant
0.01, # max error
1000, # max iters
True, # verbose mode
) # noqa: E124

af.eval(Weights)
af.sync()

# Predict the results
train_outputs = predict_prob(train_feats, Weights)
test_outputs = predict_prob(test_feats, Weights)

print("Accuracy on training data: {0:2.2f}".format(accuracy(train_outputs, train_targets))) # type: ignore[str-format] # noqa :E501
print("Accuracy on testing data: {0:2.2f}".format(accuracy(test_outputs, test_targets))) # type: ignore[str-format] # noqa :E501
print("Maximum error on testing data: {0:2.2f}".format(abserr(test_outputs, test_targets))) # type: ignore[str-format] # noqa :E501

benchmark_logistic_regression(train_feats, train_targets, test_feats)

if not console:
test_outputs = af.transpose(test_outputs)
# Get 20 random test images
display_results(test_images, test_outputs, af.transpose(test_targets), 20, True)


def main() -> None:
argc = len(sys.argv)

device = int(sys.argv[1]) if argc > 1 else 0
console = sys.argv[2][0] == "-" if argc > 2 else False
perc = int(sys.argv[3]) if argc > 3 else 60

try:
af.set_device(device)
af.info()
logit_demo(console, perc)
except Exception as e:
print("Error: ", str(e))


if __name__ == "__main__":
main()
99 changes: 99 additions & 0 deletions examples/machine_learning/mnist_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python

#######################################################
# Copyright (c) 2024, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
# The complete license agreement can be obtained at:
# http://arrayfire.com/licenses/BSD-3-Clause
########################################################

import os

import arrayfire as af

# sys.path.insert(0, '../common')
from examples.common.idxio import read_idx


def classify(arr: af.Array, k: int, expand_labels: bool) -> str:
ret_str = ""
if expand_labels:
vec = af.cast(arr[:, k], af.f32)
h_vec = vec.to_list()
data = []

for i in range(vec.size):
data.append((h_vec[i], i))

data = sorted(data, key=lambda pair: pair[0], reverse=True) # type: ignore[arg-type,return-value]

ret_str = str(data[0][1])

else:
ret_str = str(int(af.cast(arr[k], af.float32).scalar())) # type: ignore[arg-type]

return ret_str


def setup_mnist(frac: float, expand_labels: bool) -> tuple[int, int, int, af.Array, af.Array, af.Array, af.Array]:
root_path = os.path.dirname(os.path.abspath(__file__))
file_path = root_path + "/../../assets/examples/data/mnist/"
idims, idata = read_idx(file_path + "images-subset")
ldims, ldata = read_idx(file_path + "labels-subset")

idims.reverse()
images = af.Array(idata, af.float32, tuple(idims))

R = af.randu((10000, 1))
cond = R < min(frac, 0.8)
train_indices = af.where(cond)
test_indices = af.where(~cond)

train_images = af.lookup(images, train_indices, axis=2) / 255
test_images = af.lookup(images, test_indices, axis=2) / 255

num_classes = 10
num_train = train_images.shape[2]
num_test = test_images.shape[2]

if expand_labels:
train_labels = af.constant(0, (num_classes, num_train))
test_labels = af.constant(0, (num_classes, num_test))

h_train_idx = train_indices.copy()
h_test_idx = test_indices.copy()

ldata = list(map(int, ldata))

for i in range(num_train):
ldata_ind = ldata[h_train_idx[i].scalar()] # type: ignore[index]
train_labels[ldata_ind, i] = 1 # type: ignore[index]

for i in range(num_test):
ldata_ind = ldata[h_test_idx[i].scalar()] # type: ignore[index]
test_labels[ldata_ind, i] = 1 # type: ignore[index]

else:
labels = af.Array(idata, af.float32, tuple(idims))
train_labels = labels[train_indices]
test_labels = labels[test_indices]

return (num_classes, num_train, num_test, train_images, test_images, train_labels, test_labels)


def display_results(
test_images: af.Array, test_output: af.Array, test_actual: af.Array, num_display: int, expand_labels: bool
) -> None: # noqa: E501
for i in range(num_display):
print("Predicted: ", classify(test_output, i, expand_labels))
print("Actual: ", classify(test_actual, i, expand_labels))

img = af.cast((test_images[:, :, i] > 0.1), af.u8)
flattened_img = af.moddims(img, (img.size,)).to_list()
for j in range(28):
for k in range(28):
print("\u2588" if flattened_img[j * 28 + k] > 0 else " ", end="") # type: ignore[operator]
print()
input()