#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

/*
   Points sampling helper functions.
 */
template <typename scalar_t>
__global__ void infer_t_minmax_cuda_kernel(
        scalar_t* __restrict__ rays_o,
        scalar_t* __restrict__ rays_d,
        scalar_t* __restrict__ xyz_min,
        scalar_t* __restrict__ xyz_max,
        const float near, const float far, const int n_rays,
        scalar_t* __restrict__ t_min,
        scalar_t* __restrict__ t_max) {
  const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
  if(i_ray<n_rays) {
    const int offset = i_ray * 3;
    float vx = ((rays_d[offset  ]==0) ? 1e-6 : rays_d[offset  ]);
    float vy = ((rays_d[offset+1]==0) ? 1e-6 : rays_d[offset+1]);
    float vz = ((rays_d[offset+2]==0) ? 1e-6 : rays_d[offset+2]);
    float ax = (xyz_max[0] - rays_o[offset  ]) / vx;
    float ay = (xyz_max[1] - rays_o[offset+1]) / vy;
    float az = (xyz_max[2] - rays_o[offset+2]) / vz;
    float bx = (xyz_min[0] - rays_o[offset  ]) / vx;
    float by = (xyz_min[1] - rays_o[offset+1]) / vy;
    float bz = (xyz_min[2] - rays_o[offset+2]) / vz;
    t_min[i_ray] = max(min(max(max(min(ax, bx), min(ay, by)), min(az, bz)), far), near);
    t_max[i_ray] = max(min(min(min(max(ax, bx), max(ay, by)), max(az, bz)), far), near);
  }
}

template <typename scalar_t>
__global__ void infer_n_samples_cuda_kernel(
        scalar_t* __restrict__ rays_d,
        scalar_t* __restrict__ t_min,
        scalar_t* __restrict__ t_max,
        const float stepdist,
        const int n_rays,
        int64_t* __restrict__ n_samples) {
  const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
  if(i_ray<n_rays) {
    const int offset = i_ray * 3;
    const float rnorm = sqrt(
            rays_d[offset  ]*rays_d[offset  ] +\
            rays_d[offset+1]*rays_d[offset+1] +\
            rays_d[offset+2]*rays_d[offset+2]);
    // at least 1 point for easier implementation in the later sample_pts_on_rays_cuda
    n_samples[i_ray] = max(ceil((t_max[i_ray]-t_min[i_ray]) * rnorm / stepdist), 1.);
  }
}

template <typename scalar_t>
__global__ void infer_ray_start_dir_cuda_kernel(
        scalar_t* __restrict__ rays_o,
        scalar_t* __restrict__ rays_d,
        scalar_t* __restrict__ t_min,
        const int n_rays,
        scalar_t* __restrict__ rays_start,
        scalar_t* __restrict__ rays_dir) {
  const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
  if(i_ray<n_rays) {
    const int offset = i_ray * 3;
    const float rnorm = sqrt(
            rays_d[offset  ]*rays_d[offset  ] +\
            rays_d[offset+1]*rays_d[offset+1] +\
            rays_d[offset+2]*rays_d[offset+2]);
    rays_start[offset  ] = rays_o[offset  ] + rays_d[offset  ] * t_min[i_ray];
    rays_start[offset+1] = rays_o[offset+1] + rays_d[offset+1] * t_min[i_ray];
    rays_start[offset+2] = rays_o[offset+2] + rays_d[offset+2] * t_min[i_ray];
    rays_dir  [offset  ] = rays_d[offset  ] / rnorm;
    rays_dir  [offset+1] = rays_d[offset+1] / rnorm;
    rays_dir  [offset+2] = rays_d[offset+2] / rnorm;
  }
}


std::vector<torch::Tensor> infer_t_minmax_cuda(
        torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max,
        const float near, const float far) {
  const int n_rays = rays_o.size(0);
  auto t_min = torch::empty({n_rays}, rays_o.options());
  auto t_max = torch::empty({n_rays}, rays_o.options());

  const int threads = 256;
  const int blocks = (n_rays + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "infer_t_minmax_cuda", ([&] {
    infer_t_minmax_cuda_kernel<scalar_t><<<blocks, threads>>>(
        rays_o.data<scalar_t>(),
        rays_d.data<scalar_t>(),
        xyz_min.data<scalar_t>(),
        xyz_max.data<scalar_t>(),
        near, far, n_rays,
        t_min.data<scalar_t>(),
        t_max.data<scalar_t>());
  }));

  return {t_min, t_max};
}

torch::Tensor infer_n_samples_cuda(torch::Tensor rays_d, torch::Tensor t_min, torch::Tensor t_max, const float stepdist) {
  const int n_rays = t_min.size(0);
  auto n_samples = torch::empty({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
  const int threads = 256;
  const int blocks = (n_rays + threads - 1) / threads;
  AT_DISPATCH_FLOATING_TYPES(t_min.type(), "infer_n_samples_cuda", ([&] {
    infer_n_samples_cuda_kernel<scalar_t><<<blocks, threads>>>(
        rays_d.data<scalar_t>(),
        t_min.data<scalar_t>(),
        t_max.data<scalar_t>(),
        stepdist,
        n_rays,
        n_samples.data<int64_t>());
  }));
  return n_samples;
}

std::vector<torch::Tensor> infer_ray_start_dir_cuda(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min) {
  const int n_rays = rays_o.size(0);
  const int threads = 256;
  const int blocks = (n_rays + threads - 1) / threads;
  auto rays_start = torch::empty_like(rays_o);
  auto rays_dir = torch::empty_like(rays_o);
  AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "infer_ray_start_dir_cuda", ([&] {
    infer_ray_start_dir_cuda_kernel<scalar_t><<<blocks, threads>>>(
        rays_o.data<scalar_t>(),
        rays_d.data<scalar_t>(),
        t_min.data<scalar_t>(),
        n_rays,
        rays_start.data<scalar_t>(),
        rays_dir.data<scalar_t>());
  }));
  return {rays_start, rays_dir};
}

/*
   Sampling query points on rays.
 */
__global__ void __set_1_at_ray_seg_start(
        int64_t* __restrict__ ray_id,
        int64_t* __restrict__ N_steps_cumsum,
        const int n_rays) {
  const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if(0<idx && idx<n_rays) {
    ray_id[N_steps_cumsum[idx-1]] = 1;
  }
}

__global__ void __set_step_id(
        int64_t* __restrict__ step_id,
        int64_t* __restrict__ ray_id,
        int64_t* __restrict__ N_steps_cumsum,
        const int total_len) {
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if(idx<total_len) {
      const int rid = ray_id[idx];
      step_id[idx] = idx - ((rid!=0) ? N_steps_cumsum[rid-1] : 0);
    }
}

template <typename scalar_t>
__global__ void sample_pts_on_rays_cuda_kernel(
        scalar_t* __restrict__ rays_start,
        scalar_t* __restrict__ rays_dir,
        scalar_t* __restrict__ xyz_min,
        scalar_t* __restrict__ xyz_max,
        int64_t* __restrict__ ray_id,
        int64_t* __restrict__ step_id,
        const float stepdist, const int total_len,
        scalar_t* __restrict__ rays_pts,
        bool* __restrict__ mask_outbbox) {
  const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if(idx<total_len) {
    const int i_ray = ray_id[idx];
    const int i_step = step_id[idx];

    const int offset_p = idx * 3;
    const int offset_r = i_ray * 3;
    const float dist = stepdist * i_step;
    const float px = rays_start[offset_r  ] + rays_dir[offset_r  ] * dist;
    const float py = rays_start[offset_r+1] + rays_dir[offset_r+1] * dist;
    const float pz = rays_start[offset_r+2] + rays_dir[offset_r+2] * dist;
    rays_pts[offset_p  ] = px;
    rays_pts[offset_p+1] = py;
    rays_pts[offset_p+2] = pz;
    mask_outbbox[idx] = (xyz_min[0]>px) | (xyz_min[1]>py) | (xyz_min[2]>pz) | \
                        (xyz_max[0]<px) | (xyz_max[1]<py) | (xyz_max[2]<pz);
  }
}

std::vector<torch::Tensor> sample_pts_on_rays_cuda(
        torch::Tensor rays_o, torch::Tensor rays_d,
        torch::Tensor xyz_min, torch::Tensor xyz_max,
        const float near, const float far, const float stepdist) {
  const int threads = 256;
  const int n_rays = rays_o.size(0);

  // Compute ray-bbox intersection
  auto t_minmax = infer_t_minmax_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far);
  auto t_min = t_minmax[0];
  auto t_max = t_minmax[1];

  // Compute the number of points required.
  // Assign ray index and step index to each.
  auto N_steps = infer_n_samples_cuda(rays_d, t_min, t_max, stepdist);
  auto N_steps_cumsum = N_steps.cumsum(0);
  const int total_len = N_steps.sum().item<int>();
  auto ray_id = torch::zeros({total_len}, torch::dtype(torch::kInt64).device(torch::kCUDA));
  __set_1_at_ray_seg_start<<<(n_rays+threads-1)/threads, threads>>>(
        ray_id.data<int64_t>(), N_steps_cumsum.data<int64_t>(), n_rays);
  ray_id.cumsum_(0);
  auto step_id = torch::empty({total_len}, ray_id.options());
  __set_step_id<<<(total_len+threads-1)/threads, threads>>>(
        step_id.data<int64_t>(), ray_id.data<int64_t>(), N_steps_cumsum.data<int64_t>(), total_len);

  // Compute the global xyz of each point
  auto rays_start_dir = infer_ray_start_dir_cuda(rays_o, rays_d, t_min);
  auto rays_start = rays_start_dir[0];
  auto rays_dir = rays_start_dir[1];

  auto rays_pts = torch::empty({total_len, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));
  auto mask_outbbox = torch::empty({total_len}, torch::dtype(torch::kBool).device(torch::kCUDA));

  AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_pts_on_rays_cuda", ([&] {
    sample_pts_on_rays_cuda_kernel<scalar_t><<<(total_len+threads-1)/threads, threads>>>(
        rays_start.data<scalar_t>(),
        rays_dir.data<scalar_t>(),
        xyz_min.data<scalar_t>(),
        xyz_max.data<scalar_t>(),
        ray_id.data<int64_t>(),
        step_id.data<int64_t>(),
        stepdist, total_len,
        rays_pts.data<scalar_t>(),
        mask_outbbox.data<bool>());
  }));
  return {rays_pts, mask_outbbox, ray_id, step_id, N_steps, t_min, t_max};
}

template <typename scalar_t>
__global__ void sample_ndc_pts_on_rays_cuda_kernel(
        const scalar_t* __restrict__ rays_o,
        const scalar_t* __restrict__ rays_d,
        const scalar_t* __restrict__ xyz_min,
        const scalar_t* __restrict__ xyz_max,
        const int N_samples, const int n_rays,
        scalar_t* __restrict__ rays_pts,
        bool* __restrict__ mask_outbbox) {
  const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if(idx<N_samples*n_rays) {
    const int i_ray = idx / N_samples;
    const int i_step = idx % N_samples;

    const int offset_p = idx * 3;
    const int offset_r = i_ray * 3;
    const float dist = ((float)i_step) / (N_samples-1);
    const float px = rays_o[offset_r  ] + rays_d[offset_r  ] * dist;
    const float py = rays_o[offset_r+1] + rays_d[offset_r+1] * dist;
    const float pz = rays_o[offset_r+2] + rays_d[offset_r+2] * dist;
    rays_pts[offset_p  ] = px;
    rays_pts[offset_p+1] = py;
    rays_pts[offset_p+2] = pz;
    mask_outbbox[idx] = (xyz_min[0]>px) | (xyz_min[1]>py) | (xyz_min[2]>pz) | \
                        (xyz_max[0]<px) | (xyz_max[1]<py) | (xyz_max[2]<pz);
  }
}

std::vector<torch::Tensor> sample_ndc_pts_on_rays_cuda(
        torch::Tensor rays_o, torch::Tensor rays_d,
        torch::Tensor xyz_min, torch::Tensor xyz_max,
        const int N_samples) {
  const int threads = 256;
  const int n_rays = rays_o.size(0);

  auto rays_pts = torch::empty({n_rays, N_samples, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));
  auto mask_outbbox = torch::empty({n_rays, N_samples}, torch::dtype(torch::kBool).device(torch::kCUDA));

  AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_ndc_pts_on_rays_cuda", ([&] {
    sample_ndc_pts_on_rays_cuda_kernel<scalar_t><<<(n_rays*N_samples+threads-1)/threads, threads>>>(
        rays_o.data<scalar_t>(),
        rays_d.data<scalar_t>(),
        xyz_min.data<scalar_t>(),
        xyz_max.data<scalar_t>(),
        N_samples, n_rays,
        rays_pts.data<scalar_t>(),
        mask_outbbox.data<bool>());
  }));
  return {rays_pts, mask_outbbox};
}

template <typename scalar_t>
__device__ __forceinline__ scalar_t norm3(const scalar_t x, const scalar_t y, const scalar_t z) {
  return sqrt(x*x + y*y + z*z);
}

template <typename scalar_t>
__global__ void sample_bg_pts_on_rays_cuda_kernel(
        const scalar_t* __restrict__ rays_o,
        const scalar_t* __restrict__ rays_d,
        const scalar_t* __restrict__ t_max,
        const float bg_preserve,
        const int N_samples, const int n_rays,
        scalar_t* __restrict__ rays_pts) {
  const int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if(idx<N_samples*n_rays) {
    const int i_ray = idx / N_samples;
    const int i_step = idx % N_samples;

    const int offset_p = idx * 3;
    const int offset_r = i_ray * 3;
    /* Original pytorch implementation
    ori_t_outer = t_max[:,None] - 1 + 1 / torch.linspace(1, 0, N_outer+1)[:-1]
    ori_ray_pts_outer = (rays_o[:,None,:] + rays_d[:,None,:] * ori_t_outer[:,:,None]).reshape(-1,3)
    t_outer = ori_ray_pts_outer.norm(dim=-1)
    R_outer = t_outer / ori_ray_pts_outer.abs().amax(1)
    # r = R * R / t
    o2i_p = R_outer.pow(2) / t_outer.pow(2) * (1-self.bg_preserve) + R_outer / t_outer * self.bg_preserve
    ray_pts_outer = (ori_ray_pts_outer * o2i_p[:,None]).reshape(len(rays_o), -1, 3)
   */
    const float t_inner = t_max[i_ray];
    const float ori_t_outer = t_inner - 1. + 1. / (1. - ((float)i_step) / N_samples);
    const float ori_ray_pts_x =  rays_o[offset_r  ] + rays_d[offset_r  ] * ori_t_outer;
    const float ori_ray_pts_y =  rays_o[offset_r+1] + rays_d[offset_r+1] * ori_t_outer;
    const float ori_ray_pts_z =  rays_o[offset_r+2] + rays_d[offset_r+2] * ori_t_outer;
    const float t_outer = norm3(ori_ray_pts_x, ori_ray_pts_y, ori_ray_pts_z);
    const float ori_ray_pts_m = max(abs(ori_ray_pts_x), max(abs(ori_ray_pts_y), abs(ori_ray_pts_z)));
    const float R_outer = t_outer / ori_ray_pts_m;
    const float o2i_p = R_outer*R_outer / (t_outer*t_outer) * (1.-bg_preserve) + R_outer / t_outer * bg_preserve;
    const float px = ori_ray_pts_x * o2i_p;
    const float py = ori_ray_pts_y * o2i_p;
    const float pz = ori_ray_pts_z * o2i_p;
    rays_pts[offset_p  ] = px;
    rays_pts[offset_p+1] = py;
    rays_pts[offset_p+2] = pz;
  }
}

torch::Tensor sample_bg_pts_on_rays_cuda(
        torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_max,
        const float bg_preserve, const int N_samples) {
  const int threads = 256;
  const int n_rays = rays_o.size(0);

  auto rays_pts = torch::empty({n_rays, N_samples, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));

  AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_bg_pts_on_rays_cuda", ([&] {
    sample_bg_pts_on_rays_cuda_kernel<scalar_t><<<(n_rays*N_samples+threads-1)/threads, threads>>>(
        rays_o.data<scalar_t>(),
        rays_d.data<scalar_t>(),
        t_max.data<scalar_t>(),
        bg_preserve,
        N_samples, n_rays,
        rays_pts.data<scalar_t>());
  }));
  return rays_pts;
}


/*
   MaskCache lookup to skip known freespace.
 */

static __forceinline__ __device__
bool check_xyz(int i, int j, int k, int sz_i, int sz_j, int sz_k) {
  return (0 <= i) && (i < sz_i) && (0 <= j) && (j < sz_j) && (0 <= k) && (k < sz_k);
}


template <typename scalar_t>
__global__ void maskcache_lookup_cuda_kernel(
    bool* __restrict__ world,
    scalar_t* __restrict__ xyz,
    bool* __restrict__ out,
    scalar_t* __restrict__ xyz2ijk_scale,
    scalar_t* __restrict__ xyz2ijk_shift,
    const int sz_i, const int sz_j, const int sz_k, const int n_pts) {

  const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
  if(i_pt<n_pts) {
    const int offset = i_pt * 3;
    const int i = round(xyz[offset  ] * xyz2ijk_scale[0] + xyz2ijk_shift[0]);
    const int j = round(xyz[offset+1] * xyz2ijk_scale[1] + xyz2ijk_shift[1]);
    const int k = round(xyz[offset+2] * xyz2ijk_scale[2] + xyz2ijk_shift[2]);
    if(check_xyz(i, j, k, sz_i, sz_j, sz_k)) {
      out[i_pt] = world[i*sz_j*sz_k + j*sz_k + k];
    }
  }
}

torch::Tensor maskcache_lookup_cuda(
        torch::Tensor world,
        torch::Tensor xyz,
        torch::Tensor xyz2ijk_scale,
        torch::Tensor xyz2ijk_shift) {

  const int sz_i = world.size(0);
  const int sz_j = world.size(1);
  const int sz_k = world.size(2);
  const int n_pts = xyz.size(0);

  auto out = torch::zeros({n_pts}, torch::dtype(torch::kBool).device(torch::kCUDA));
  if(n_pts==0) {
    return out;
  }

  const int threads = 256;
  const int blocks = (n_pts + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES(xyz.type(), "maskcache_lookup_cuda", ([&] {
    maskcache_lookup_cuda_kernel<scalar_t><<<blocks, threads>>>(
        world.data<bool>(),
        xyz.data<scalar_t>(),
        out.data<bool>(),
        xyz2ijk_scale.data<scalar_t>(),
        xyz2ijk_shift.data<scalar_t>(),
        sz_i, sz_j, sz_k, n_pts);
  }));

  return out;
}


/*
    Ray marching helper function.
 */
template <typename scalar_t>
__global__ void raw2alpha_cuda_kernel(
    scalar_t* __restrict__ density,
    const float shift, const float interval, const int n_pts,
    scalar_t* __restrict__ exp_d,
    scalar_t* __restrict__ alpha) {

  const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
  if(i_pt<n_pts) {
    const scalar_t e = exp(density[i_pt] + shift); // can be inf
    exp_d[i_pt] = e;
    alpha[i_pt] = 1 - pow(1 + e, -interval);
  }
}

template <typename scalar_t>
__global__ void raw2alpha_nonuni_cuda_kernel(
    scalar_t* __restrict__ density,
    const float shift, scalar_t* __restrict__ interval, const int n_pts,
    scalar_t* __restrict__ exp_d,
    scalar_t* __restrict__ alpha) {

  const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
  if(i_pt<n_pts) {
    const scalar_t e = exp(density[i_pt] + shift); // can be inf
    exp_d[i_pt] = e;
    alpha[i_pt] = 1 - pow(1 + e, -interval[i_pt]);
  }
}

std::vector<torch::Tensor> raw2alpha_cuda(torch::Tensor density, const float shift, const float interval) {

  const int n_pts = density.size(0);
  auto exp_d = torch::empty_like(density);
  auto alpha = torch::empty_like(density);
  if(n_pts==0) {
    return {exp_d, alpha};
  }

  const int threads = 256;
  const int blocks = (n_pts + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES(density.type(), "raw2alpha_cuda", ([&] {
    raw2alpha_cuda_kernel<scalar_t><<<blocks, threads>>>(
        density.data<scalar_t>(),
        shift, interval, n_pts,
        exp_d.data<scalar_t>(),
        alpha.data<scalar_t>());
  }));

  return {exp_d, alpha};
}

std::vector<torch::Tensor> raw2alpha_nonuni_cuda(torch::Tensor density, const float shift, torch::Tensor interval) {

  const int n_pts = density.size(0);
  auto exp_d = torch::empty_like(density);
  auto alpha = torch::empty_like(density);
  if(n_pts==0) {
    return {exp_d, alpha};
  }

  const int threads = 256;
  const int blocks = (n_pts + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES(density.type(), "raw2alpha_cuda", ([&] {
    raw2alpha_nonuni_cuda_kernel<scalar_t><<<blocks, threads>>>(
        density.data<scalar_t>(),
        shift, interval.data<scalar_t>(), n_pts,
        exp_d.data<scalar_t>(),
        alpha.data<scalar_t>());
  }));

  return {exp_d, alpha};
}

template <typename scalar_t>
__global__ void raw2alpha_backward_cuda_kernel(
    scalar_t* __restrict__ exp_d,
    scalar_t* __restrict__ grad_back,
    const float interval, const int n_pts,
    scalar_t* __restrict__ grad) {

  const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
  if(i_pt<n_pts) {
    grad[i_pt] = min(exp_d[i_pt], 1e10) * pow(1+exp_d[i_pt], -interval-1) * interval * grad_back[i_pt];
  }
}

template <typename scalar_t>
__global__ void raw2alpha_nonuni_backward_cuda_kernel(
    scalar_t* __restrict__ exp_d,
    scalar_t* __restrict__ grad_back,
    scalar_t* __restrict__ interval, const int n_pts,
    scalar_t* __restrict__ grad) {

  const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
  if(i_pt<n_pts) {
    grad[i_pt] = min(exp_d[i_pt], 1e10) * pow(1+exp_d[i_pt], -interval[i_pt]-1) * interval[i_pt] * grad_back[i_pt];
  }
}

torch::Tensor raw2alpha_backward_cuda(torch::Tensor exp_d, torch::Tensor grad_back, const float interval) {

  const int n_pts = exp_d.size(0);
  auto grad = torch::empty_like(exp_d);
  if(n_pts==0) {
    return grad;
  }

  const int threads = 256;
  const int blocks = (n_pts + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES(exp_d.type(), "raw2alpha_backward_cuda", ([&] {
    raw2alpha_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
        exp_d.data<scalar_t>(),
        grad_back.data<scalar_t>(),
        interval, n_pts,
        grad.data<scalar_t>());
  }));

  return grad;
}

torch::Tensor raw2alpha_nonuni_backward_cuda(torch::Tensor exp_d, torch::Tensor grad_back, torch::Tensor interval) {

  const int n_pts = exp_d.size(0);
  auto grad = torch::empty_like(exp_d);
  if(n_pts==0) {
    return grad;
  }

  const int threads = 256;
  const int blocks = (n_pts + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES(exp_d.type(), "raw2alpha_backward_cuda", ([&] {
    raw2alpha_nonuni_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
        exp_d.data<scalar_t>(),
        grad_back.data<scalar_t>(),
        interval.data<scalar_t>(), n_pts,
        grad.data<scalar_t>());
  }));

  return grad;
}

template <typename scalar_t>
__global__ void alpha2weight_cuda_kernel(
    scalar_t* __restrict__ alpha,
    const int n_rays,
    scalar_t* __restrict__ weight,
    scalar_t* __restrict__ T,
    scalar_t* __restrict__ alphainv_last,
    int64_t* __restrict__ i_start,
    int64_t* __restrict__ i_end) {

  const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
  if(i_ray<n_rays) {
    const int i_s = i_start[i_ray];
    const int i_e_max = i_end[i_ray];

    float T_cum = 1.;
    int i;
    for(i=i_s; i<i_e_max; ++i) {
      T[i] = T_cum;
      weight[i] = T_cum * alpha[i];
      T_cum *= (1. - alpha[i]);
      if(T_cum<1e-3) {
        i+=1;
        break;
      }
    }
    i_end[i_ray] = i;
    alphainv_last[i_ray] = T_cum;
  }
}

__global__ void __set_i_for_segment_start_end(
        int64_t* __restrict__ ray_id,
        const int n_pts,
        int64_t* __restrict__ i_start,
        int64_t* __restrict__ i_end) {
  const int index = blockIdx.x * blockDim.x + threadIdx.x;
  if(0<index && index<n_pts && ray_id[index]!=ray_id[index-1]) {
    i_start[ray_id[index]] = index;
    i_end[ray_id[index-1]] = index;
  }
}

std::vector<torch::Tensor> alpha2weight_cuda(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays) {

  const int n_pts = alpha.size(0);
  const int threads = 256;

  auto weight = torch::zeros_like(alpha);
  auto T = torch::ones_like(alpha);
  auto alphainv_last = torch::ones({n_rays}, alpha.options());
  auto i_start = torch::zeros({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
  auto i_end = torch::zeros({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
  if(n_pts==0) {
    return {weight, T, alphainv_last, i_start, i_end};
  }

  __set_i_for_segment_start_end<<<(n_pts+threads-1)/threads, threads>>>(
          ray_id.data<int64_t>(), n_pts, i_start.data<int64_t>(), i_end.data<int64_t>());
  i_end[ray_id[n_pts-1]] = n_pts;

  const int blocks = (n_rays + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES(alpha.type(), "alpha2weight_cuda", ([&] {
    alpha2weight_cuda_kernel<scalar_t><<<blocks, threads>>>(
        alpha.data<scalar_t>(),
        n_rays,
        weight.data<scalar_t>(),
        T.data<scalar_t>(),
        alphainv_last.data<scalar_t>(),
        i_start.data<int64_t>(),
        i_end.data<int64_t>());
  }));

  return {weight, T, alphainv_last, i_start, i_end};
}

template <typename scalar_t>
__global__ void alpha2weight_backward_cuda_kernel(
    scalar_t* __restrict__ alpha,
    scalar_t* __restrict__ weight,
    scalar_t* __restrict__ T,
    scalar_t* __restrict__ alphainv_last,
    int64_t* __restrict__ i_start,
    int64_t* __restrict__ i_end,
    const int n_rays,
    scalar_t* __restrict__ grad_weights,
    scalar_t* __restrict__ grad_last,
    scalar_t* __restrict__ grad) {

  const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
  if(i_ray<n_rays) {
    const int i_s = i_start[i_ray];
    const int i_e = i_end[i_ray];

    float back_cum = grad_last[i_ray] * alphainv_last[i_ray];
    for(int i=i_e-1; i>=i_s; --i) {
      grad[i] = grad_weights[i] * T[i] - back_cum / (1-alpha[i] + 1e-10);
      back_cum += grad_weights[i] * weight[i];
    }
  }
}

torch::Tensor alpha2weight_backward_cuda(
        torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last,
        torch::Tensor i_start, torch::Tensor i_end, const int n_rays,
        torch::Tensor grad_weights, torch::Tensor grad_last) {

  auto grad = torch::zeros_like(alpha);
  if(n_rays==0) {
    return grad;
  }

  const int threads = 256;
  const int blocks = (n_rays + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES(alpha.type(), "alpha2weight_backward_cuda", ([&] {
    alpha2weight_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
        alpha.data<scalar_t>(),
        weight.data<scalar_t>(),
        T.data<scalar_t>(),
        alphainv_last.data<scalar_t>(),
        i_start.data<int64_t>(),
        i_end.data<int64_t>(),
        n_rays,
        grad_weights.data<scalar_t>(),
        grad_last.data<scalar_t>(),
        grad.data<scalar_t>());
  }));

  return grad;
}
