#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

template <typename scalar_t>
__global__ void adam_upd_cuda_kernel(
    scalar_t* __restrict__ param,
    const scalar_t* __restrict__ grad,
    scalar_t* __restrict__ exp_avg,
    scalar_t* __restrict__ exp_avg_sq,
    const size_t N,
    const float step_size, const float beta1, const float beta2, const float eps) {

  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  if(index<N) {
    exp_avg[index] = beta1 * exp_avg[index] + (1-beta1) * grad[index];
    exp_avg_sq[index] = beta2 * exp_avg_sq[index] + (1-beta2) * grad[index] * grad[index];
    param[index] -= step_size * exp_avg[index] / (sqrt(exp_avg_sq[index]) + eps);
  }
}

template <typename scalar_t>
__global__ void masked_adam_upd_cuda_kernel(
    scalar_t* __restrict__ param,
    const scalar_t* __restrict__ grad,
    scalar_t* __restrict__ exp_avg,
    scalar_t* __restrict__ exp_avg_sq,
    const size_t N,
    const float step_size, const float beta1, const float beta2, const float eps) {

  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  if(index<N && grad[index]!=0) {
    exp_avg[index] = beta1 * exp_avg[index] + (1-beta1) * grad[index];
    exp_avg_sq[index] = beta2 * exp_avg_sq[index] + (1-beta2) * grad[index] * grad[index];
    param[index] -= step_size * exp_avg[index] / (sqrt(exp_avg_sq[index]) + eps);
  }
}

template <typename scalar_t>
__global__ void adam_upd_with_perlr_cuda_kernel(
    scalar_t* __restrict__ param,
    const scalar_t* __restrict__ grad,
    scalar_t* __restrict__ exp_avg,
    scalar_t* __restrict__ exp_avg_sq,
    scalar_t* __restrict__ perlr,
    const size_t N,
    const float step_size, const float beta1, const float beta2, const float eps) {

  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  if(index<N) {
    exp_avg[index] = beta1 * exp_avg[index] + (1-beta1) * grad[index];
    exp_avg_sq[index] = beta2 * exp_avg_sq[index] + (1-beta2) * grad[index] * grad[index];
    param[index] -= step_size * perlr[index] * exp_avg[index] / (sqrt(exp_avg_sq[index]) + eps);
  }
}

void adam_upd_cuda(
    torch::Tensor param,
    torch::Tensor grad,
    torch::Tensor exp_avg,
    torch::Tensor exp_avg_sq,
    const int step, const float beta1, const float beta2, const float lr, const float eps) {

  const size_t N = param.numel();

  const int threads = 256;
  const int blocks = (N + threads - 1) / threads;

  const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step));

  AT_DISPATCH_FLOATING_TYPES(param.type(), "adam_upd_cuda", ([&] {
    adam_upd_cuda_kernel<scalar_t><<<blocks, threads>>>(
        param.data<scalar_t>(),
        grad.data<scalar_t>(),
        exp_avg.data<scalar_t>(),
        exp_avg_sq.data<scalar_t>(),
        N, step_size, beta1, beta2, eps);
  }));
}

void masked_adam_upd_cuda(
    torch::Tensor param,
    torch::Tensor grad,
    torch::Tensor exp_avg,
    torch::Tensor exp_avg_sq,
    const int step, const float beta1, const float beta2, const float lr, const float eps) {

  const size_t N = param.numel();

  const int threads = 256;
  const int blocks = (N + threads - 1) / threads;

  const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step));

  AT_DISPATCH_FLOATING_TYPES(param.type(), "masked_adam_upd_cuda", ([&] {
    masked_adam_upd_cuda_kernel<scalar_t><<<blocks, threads>>>(
        param.data<scalar_t>(),
        grad.data<scalar_t>(),
        exp_avg.data<scalar_t>(),
        exp_avg_sq.data<scalar_t>(),
        N, step_size, beta1, beta2, eps);
  }));
}

void adam_upd_with_perlr_cuda(
    torch::Tensor param,
    torch::Tensor grad,
    torch::Tensor exp_avg,
    torch::Tensor exp_avg_sq,
    torch::Tensor perlr,
    const int step, const float beta1, const float beta2, const float lr, const float eps) {

  const size_t N = param.numel();

  const int threads = 256;
  const int blocks = (N + threads - 1) / threads;

  const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step));

  AT_DISPATCH_FLOATING_TYPES(param.type(), "adam_upd_with_perlr_cuda", ([&] {
    adam_upd_with_perlr_cuda_kernel<scalar_t><<<blocks, threads>>>(
        param.data<scalar_t>(),
        grad.data<scalar_t>(),
        exp_avg.data<scalar_t>(),
        exp_avg_sq.data<scalar_t>(),
        perlr.data<scalar_t>(),
        N, step_size, beta1, beta2, eps);
  }));
}
