1
0
Fork 0
QuaPy/distribution_matching/tmp/compile_test.py

28 lines
430 B
Python

from time import time
from sklearn.metrics.pairwise import rbf_kernel
import torch
import numpy as np
def rbf(X, Y):
return X @ Y.T
@torch.compile
def rbf_comp(X, Y):
return X @ Y.T
def measure_time(X, Y, func):
tinit = time()
func(X, Y)
tend = time()
print(f'took {tend-tinit:.3}s')
X = np.random.rand(1000, 100)
Y = np.random.rand(1000, 100)
measure_time(X, Y, rbf)
measure_time(X, Y, rbf_comp)