import pennylane as qml
import numpy as np
from time import time
import os

n_wires = 28
n_layers = 200

# Measure time for lightning.gpu

start_timing = time() * 1000

dev = qml.device("lightning.gpu", wires=n_wires)
@qml.qnode(dev, diff_method="adjoint")
def circuit(weights):
    qml.StronglyEntanglingLayers(weights, wires=range(n_wires))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_wires)]
param_shape = qml.StronglyEntanglingLayers.shape(n_layers=n_layers, n_wires=n_wires)
params = np.random.random(param_shape)
jac = qml.jacobian(circuit)(params)

print("lighting.gpu:")
print(f"{round(time() * 1000 - start_timing, 3)} ms")

# Measure time for lightning.qubit with 1 thread

os.environ["OMP_NUM_THREADS"] = "1"
start_timing = time() * 1000

dev = qml.device("lightning.qubit", wires=n_wires)
@qml.qnode(dev, diff_method="adjoint")
def circuit(weights):
    qml.StronglyEntanglingLayers(weights, wires=list(range(n_wires)))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_wires)]
param_shape = qml.StronglyEntanglingLayers.shape(n_layers=n_layers, n_wires=n_wires)
params = np.random.random(param_shape)
jac = qml.jacobian(circuit)(params)

print("lighting.qubit with 1 thread:")
print(f"{round(time() * 1000 - start_timing, 3)} ms")

# Measure time for lightning.qubit with 32 threads

start_timing = time() * 1000
os.environ["OMP_NUM_THREADS"] = "32"

dev = qml.device("lightning.qubit", wires=n_wires)
@qml.qnode(dev, diff_method="adjoint")
def circuit(weights):
    qml.StronglyEntanglingLayers(weights, wires=list(range(n_wires)))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_wires)]
param_shape = qml.StronglyEntanglingLayers.shape(n_layers=n_layers, n_wires=n_wires)
params = np.random.random(param_shape)
jac = qml.jacobian(circuit)(params)

print("lighting.qubit 32 threads:")
print(f"{round(time() * 1000 - start_timing, 3)} ms")