DEV Community

Discussion on: Help Me Understand This Vectorized Logic

Collapse
 
rpalo profile image
Ryan Palo

I've got an article in the works that walks through this in more detail, but I wanted to post my solution in case somebody ran into the same or similar issue.

The main problem I was having was using the reshape method is the wrong one. It would give me the right dimensions, but it jumbled up all of the individual numbers and didn't keep the "records" together.

After doing some experimenting in a REPL with simpler cases, I discovered that what I really wanted was swapaxis. This keeps the numbers in the correct order, but allows you to pivot an array into other dimensions (e.g. roll your 2D array into a plane in the third dimension).

So what I ended up with is:

def vectorized_distance(vec1, vec2):
    return np.sqrt(np.sum((vec2 - vec1)**2, axis=1))


def nearest_neighbor_classify(data, neighbors, labels):
    # Reshape data so that broadcasting works right
    data_rows, data_cols = data.shape
    data = data.reshape((data_rows, data_cols, 1))
    neighbor_rows, neighbor_cols = neighbors.shape
    flipped_neighbors = np.swapaxes(
        neighbors.reshape((neighbor_rows, neighbor_cols, 1)),
        0, 2)

    # Now data should be (n x m x 1) and flipped_neighbors (1 x m x n)
    # Broadcasting should produce an (n x m x n) array, but `np.sum` will 
    # squash axis 1 so we get a (n x n) point-to-point distance matrix

    distances = vectorized_distance(data, flipped_neighbors)

    # The index of the smallest value for each row is the index of the prediction
    closest_neighbor_indices = distances.argmin(axis=1)

    return labels[closest_neighbor_indices]
Collapse
 
evanoman profile image
Evan Oman

Ah, is this column-major vs row-major ordering issue?

Collapse
 
rpalo profile image
Ryan Palo

Yeah or maybe the multidimensional version of that, although I tried numpy’s different ordering strategies and none seemed to work quite right.