File indexing completed on 2024-04-09 02:22:22
0001 #include <cstddef>
0002
0003 #include <hip/hip_runtime.h>
0004
0005 #include "HeterogeneousTest/ROCmWrapper/interface/DeviceAdditionWrapper.h"
0006 #include "HeterogeneousTest/ROCmOpaque/interface/DeviceAdditionOpaque.h"
0007 #include "HeterogeneousCore/ROCmUtilities/interface/hipCheck.h"
0008
0009 namespace cms::rocmtest {
0010
0011 void opaque_add_vectors_f(const float* in1_h, const float* in2_h, float* out_h, size_t size) {
0012
0013 float* in1_d;
0014 float* in2_d;
0015 float* out_d;
0016 hipCheck(hipMalloc(&in1_d, size * sizeof(float)));
0017 hipCheck(hipMalloc(&in2_d, size * sizeof(float)));
0018 hipCheck(hipMalloc(&out_d, size * sizeof(float)));
0019
0020
0021 hipCheck(hipMemcpy(in1_d, in1_h, size * sizeof(float), hipMemcpyHostToDevice));
0022 hipCheck(hipMemcpy(in2_d, in2_h, size * sizeof(float), hipMemcpyHostToDevice));
0023
0024
0025 hipCheck(hipMemset(out_d, 0, size * sizeof(float)));
0026
0027
0028 wrapper_add_vectors_f(in1_d, in2_d, out_d, size);
0029
0030
0031 hipCheck(hipMemcpy(out_h, out_d, size * sizeof(float), hipMemcpyDeviceToHost));
0032
0033
0034 hipCheck(hipDeviceSynchronize());
0035
0036
0037 hipCheck(hipFree(in1_d));
0038 hipCheck(hipFree(in2_d));
0039 hipCheck(hipFree(out_d));
0040 }
0041
0042 void opaque_add_vectors_d(const double* in1_h, const double* in2_h, double* out_h, size_t size) {
0043
0044 double* in1_d;
0045 double* in2_d;
0046 double* out_d;
0047 hipCheck(hipMalloc(&in1_d, size * sizeof(double)));
0048 hipCheck(hipMalloc(&in2_d, size * sizeof(double)));
0049 hipCheck(hipMalloc(&out_d, size * sizeof(double)));
0050
0051
0052 hipCheck(hipMemcpy(in1_d, in1_h, size * sizeof(double), hipMemcpyHostToDevice));
0053 hipCheck(hipMemcpy(in2_d, in2_h, size * sizeof(double), hipMemcpyHostToDevice));
0054
0055
0056 hipCheck(hipMemset(out_d, 0, size * sizeof(double)));
0057
0058
0059 wrapper_add_vectors_d(in1_d, in2_d, out_d, size);
0060
0061
0062 hipCheck(hipMemcpy(out_h, out_d, size * sizeof(double), hipMemcpyDeviceToHost));
0063
0064
0065 hipCheck(hipDeviceSynchronize());
0066
0067
0068 hipCheck(hipFree(in1_d));
0069 hipCheck(hipFree(in2_d));
0070 hipCheck(hipFree(out_d));
0071 }
0072
0073 }