/*
 * JCuda - Java bindings for NVIDIA CUDA driver and runtime API
 * http://www.jcuda.org
 *
 * Copyright 2010 Marco Hutter - http://www.jcuda.org
 */

import static jcuda.driver.JCudaDriver.*;
import static jcuda.jcublas.JCublas.*;
import static jcuda.runtime.JCuda.*;

import java.io.IOException;
import java.util.*;

import jcuda.*;
import jcuda.driver.*;
import jcuda.jcublas.JCublas;
import jcuda.runtime.cudaMemcpyKind;

/**
 * This is a simple example that shows how the interoperability between
 * the CUDA runtime- and driver API may be used with JCuda. <br />
 * <br />
 * The example creates a vector on the device using the runtime API,
 * computes the norm of a vector using JCublas, then inverts all
 * elements of the vector using a kernel that is executed with the
 * driver API, computes the norm of the resulting vector with JCublas,
 * and finally copies the vector back using the driver API.
 */
public class JCudaRuntimeDriverMixSample
{
    public static void main(String args[]) throws IOException
    {
        JCudaDriver.setExceptionsEnabled(true);
        JCublas.setExceptionsEnabled(true);

        // Initialize the driver and create a context for the first device.
        cuInit(0);
        CUcontext context = new CUcontext();
        CUdevice device = new CUdevice();
        cuDeviceGet(device, 0);
        cuCtxCreate(context, 0, device);

        // Load the CUBIN file and obtain the "invertVectorElements" function.
        CUmodule module = new CUmodule();
        cuModuleLoad(module, "invertVectorElements.cubin");
        CUfunction function = new CUfunction();
        cuModuleGetFunction(function, module, "invertVectorElements");

        // Create the input data.
        int n = 5;
        Random random = new Random(0);
        float vector[] = new float[n];
        for (int i=0; i<n; i++)
        {
            vector[i] = random.nextFloat();
        }

        // Copy the vector to the device using the Runtime API
        CUdeviceptr vectorDevice = new CUdeviceptr();
        cudaMalloc(vectorDevice, n * Sizeof.FLOAT);
        cudaMemcpy(vectorDevice, Pointer.to(vector), n * Sizeof.FLOAT,
            cudaMemcpyKind.cudaMemcpyHostToDevice);

        // Use JCublas to compute the vector norm
        cublasInit();
        float norm = cublasSnrm2(n, vectorDevice, 1);

        System.out.println("Input vector    "+Arrays.toString(vector));
        System.out.println("Norm            "+norm);

        // Set up the execution parameters for the kernel
        cuFuncSetBlockShape(function, n, 1, 1);
        Pointer vectorDevicePointer = Pointer.to(vectorDevice);
        Pointer nPointer = Pointer.to(new int[]{n});
        int offset = 0;
        offset = align(offset, Sizeof.POINTER);
        cuParamSetv(function, offset, vectorDevicePointer, Sizeof.POINTER);
        offset += Sizeof.POINTER;
        offset = align(offset, Sizeof.INT);
        cuParamSetv(function, offset, nPointer, Sizeof.INT);
        offset += Sizeof.INT;
        cuParamSetSize(function, offset);

        // Call the kernel function.
        cuLaunch(function);
        cuCtxSynchronize();

        // Use JCublas to compute the norm of the vector that
        // was modified using the kernel
        float invNorm = cublasSnrm2(n, vectorDevice, 1);

        // Copy the vector back to the host using the Driver API
        cuMemcpyDtoH(Pointer.to(vector), vectorDevice, n * Sizeof.FLOAT);

        // Print the results
        System.out.println("Inverted vector "+Arrays.toString(vector));
        System.out.println("Norm            "+invNorm);

        // Clean up
        cuMemFree(vectorDevice);
    }
}


