File size: 5,712 Bytes
6eaa88c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
#include "common.cuh"
#include "dispatch_utils.h"
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace vllm {
template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);
scaled_fp8_conversion_vec<scalar_t, true>(
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
}
template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
scalar_t const* __restrict__ token_input = &input[offset];
FP8_TYPE* __restrict__ token_output = &out[offset];
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;
float absmax_val = 0.0f;
if (can_vectorize) {
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
float const x = static_cast<float>(token_input[i]);
absmax_val = max(absmax_val, fabs(x));
}
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
float const block_absmax_val_maybe =
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
__shared__ float token_scale;
if (tid == 0) {
if (scale_ub) {
token_scale = min(block_absmax_val_maybe, *scale_ub);
} else {
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
scale[token_idx] = token_scale;
}
__syncthreads();
// Note that we don't use inverted scales so we can match FBGemm impl.
if (can_vectorize) {
scaled_fp8_conversion_vec<scalar_t, false>(
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion<false>(
static_cast<float>(token_input[i]), token_scale);
}
}
}
} // namespace vllm
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
});
}
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
});
}
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size);
});
}
|