Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ cc_library(
cc_test(
name = "flash_attention_test",
srcs = ["gemma/flash_attention_test.cc"],
linkstatic = True,
deps = [
":activations",
":attention",
Expand Down Expand Up @@ -601,6 +602,7 @@ cc_test(
name = "kv_transcoding_test",
srcs = ["gemma/kv_transcoding_test.cc"],
deps = [
":basics",
":configs",
":kv_transcoding",
"//testing/base/public:gunit_main",
Expand Down Expand Up @@ -680,6 +682,7 @@ cc_library(
],
textual_hdrs = [
"gemma/gemma-inl.h",
"gemma/flash_attention_arm-inl.h",
],
deps = [
":activations",
Expand All @@ -690,6 +693,7 @@ cc_library(
":mat",
":matmul",
":matmul_env",
":matmul_static",
":ops",
":query",
":tensor_stats",
Expand Down Expand Up @@ -748,6 +752,7 @@ cc_library(
":flash_structs",
":gemma_args",
":kv_cache",
":kv_transcoding",
":mat",
":matmul_env",
":model_store",
Expand Down
3 changes: 3 additions & 0 deletions gemma/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ constexpr std::pair<const char*, AttentionImpl> kAttentionImplNameToEnum[] = {
{"flash_transposed_qs", AttentionImpl::kFlashTransposedQs},
{"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16},
{"flash_transposed_qs_int16", AttentionImpl::kFlashTransposedQsInt16},
{"flash_matrix_accumulation", AttentionImpl::kFlashMatrixAccumulation},
};

std::string GetAttentionImplName(AttentionImpl impl) {
Expand Down Expand Up @@ -768,6 +769,8 @@ std::string KVEncodingToString(KVEncoding encoding) {
return "Int8";
case KVEncoding::kInt8TwoTranspositions:
return "Int8TwoTranspositions";
case KVEncoding::kBF16MatrixAccumulation:
return "BF16MatrixAccumulation";
default:
return "Unknown";
}
Expand Down
2 changes: 2 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ enum class KVEncoding {
kBF16TwoTranspositions = 4,
kInt8 = 5,
kInt8TwoTranspositions = 6,
kBF16MatrixAccumulation = 7,
};

// Returns a string representation of the KVEncoding.
Expand All @@ -104,6 +105,7 @@ enum class AttentionImpl {
kFlashTransposedQs,
kFlashTransposedQsBF16,
kFlashTransposedQsInt16,
kFlashMatrixAccumulation,
kSentinel,
};

Expand Down
222 changes: 33 additions & 189 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
#include <stdint.h>

#include <algorithm>
#include <array>
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <cstring>
#include <limits>
#include <type_traits>
#include <vector>

#include "compression/types.h" // GEMMA_DISABLED_TARGETS
Expand Down Expand Up @@ -52,6 +52,8 @@
// After highway.h
#include "compression/compress-inl.h"
#include "gemma/attention.h"
#include "gemma/flash_attention.h"
#include "gemma/flash_attention_arm-inl.h"
#include "ops/matmul-inl.h"
#include "ops/ops-inl.h"
#include "hwy/contrib/math/fast_math-inl.h"
Expand All @@ -60,7 +62,6 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {

static constexpr float kNegInf = -std::numeric_limits<float>::max() / 64.0f;

// Updates q in place for RMSNorm and positional encoding.
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
Expand Down Expand Up @@ -570,7 +571,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
const DF4 df4;
using VF4 = hn::Vec<DF4>;
static_assert(kNumQueries >= 1 && kNumQueries <= 4);
VF4 new_max = hn::Set(df4, kNegInf);
VF4 new_max = hn::Set(df4, kMaskedLogitVal);
VF max_0 = hn::Zero(df), max_1 = hn::Zero(df), max_2 = hn::Zero(df),
max_3 = hn::Zero(df);
max_0 = hn::Max(x_0_p0, x_0_p1);
Expand All @@ -595,10 +596,10 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
new_max = hn::Mul(cap, hn::FastTanh(df4, hn::Mul(new_max, one_over_cap)));
}
VF4 local_max = new_max;
VF4 old_max_vf = hn::Set(df4, kNegInf);
VF4 old_max_vf = hn::Set(df4, kMaskedLogitVal);
old_max_vf = hn::LoadU(df4, old_max);
new_max = hn::Max(new_max, old_max_vf);
auto changed_max = hn::Gt(new_max, hn::Set(df4, kNegInf));
auto changed_max = hn::Gt(new_max, hn::Set(df4, kMaskedLogitVal));
hn::StoreU(new_max, df4, old_max);
auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR {
const VF new_max_i = hn::Set(df, old_max[i]);
Expand Down Expand Up @@ -699,7 +700,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
const DF8 df8;
using VF8 = hn::Vec<DF8>;
static_assert(kNumQueries >= 1 && kNumQueries <= 8);
VF8 new_max = hn::Set(df8, kNegInf);
VF8 new_max = hn::Set(df8, kMaskedLogitVal);
VF max_0, max_1, max_2, max_3, max_4, max_5, max_6, max_7 = hn::Zero(df);
max_0 = hn::Max(x_0_p0, x_0_p1);
if constexpr (kNumQueries >= 2) {
Expand Down Expand Up @@ -737,10 +738,10 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
new_max = hn::Mul(cap, hn::FastTanh(df8, hn::Mul(new_max, one_over_cap)));
}
VF8 local_max = new_max;
VF8 old_max_vf = hn::Set(df8, kNegInf);
VF8 old_max_vf = hn::Set(df8, kMaskedLogitVal);
old_max_vf = hn::LoadU(df8, old_max);
new_max = hn::Max(new_max, old_max_vf);
auto changed_max = hn::Gt(new_max, hn::Set(df8, kNegInf));
auto changed_max = hn::Gt(new_max, hn::Set(df8, kMaskedLogitVal));
hn::StoreU(new_max, df8, old_max);

auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR {
Expand Down Expand Up @@ -1235,182 +1236,6 @@ static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthBF16(
#endif
}

template <int kVTileSize, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void ApplySoftCap(DF df, float att_cap, float one_over_cap,
VF& x0, VF& x1, VF& x2, VF& x3, VF& x4,
VF& x5, VF& x6, VF& x7) {
if (att_cap > 0.0f) {
VF cap = hn::Set(df, att_cap);
VF one_over_cap_vec = hn::Set(df, one_over_cap);
x0 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x0, one_over_cap_vec)));
if constexpr (kVTileSize >= 2) {
x1 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x1, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 3) {
x2 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x2, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 4) {
x3 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x3, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 5) {
x4 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x4, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 6) {
x5 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x5, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 7) {
x6 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x6, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 8) {
x7 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x7, one_over_cap_vec)));
}
}
}

template <int kNumQueries, class DF, class VF = hn::Vec<DF>, typename DU,
class VU = hn::Vec<DU>>
static HWY_NOINLINE void ApplyMasking(
DF df, DU du, size_t position,
const size_t* HWY_RESTRICT first_pos_per_query,
const size_t* HWY_RESTRICT last_pos_per_query, VF& x0_p0, VF& x0_p1,
VF& x1_p0, VF& x1_p1, VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, VF& x4_p0,
VF& x4_p1, VF& x5_p0, VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0,
VF& x7_p1) {
VU lane_indices = hn::Iota(du, 0);
HWY_LANES_CONSTEXPR size_t kTileSize = hn::Lanes(df);
auto per_lane_pos_p0 = hn::Add(hn::Set(du, position), lane_indices);
auto per_lane_pos_p1 =
hn::Add(hn::Set(du, position + kTileSize), lane_indices);

VF neg_inf = hn::Set(df, kNegInf);

auto apply_mask_for_query = [&](int query_idx, VF& x_p0, VF& x_p1) HWY_ATTR {
const size_t first_pos = first_pos_per_query[query_idx];
const size_t last_pos = last_pos_per_query[query_idx];

auto valid_tokens_mask_p0 = hn::Ge(per_lane_pos_p0, hn::Set(du, first_pos));
valid_tokens_mask_p0 = hn::And(
valid_tokens_mask_p0, hn::Le(per_lane_pos_p0, hn::Set(du, last_pos)));
x_p0 =
hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p0), x_p0, neg_inf);

auto valid_tokens_mask_p1 = hn::Ge(per_lane_pos_p1, hn::Set(du, first_pos));
valid_tokens_mask_p1 = hn::And(
valid_tokens_mask_p1, hn::Le(per_lane_pos_p1, hn::Set(du, last_pos)));
x_p1 =
hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p1), x_p1, neg_inf);
};

if constexpr (kNumQueries >= 1) {
apply_mask_for_query(0, x0_p0, x0_p1);
}
if constexpr (kNumQueries >= 2) {
apply_mask_for_query(1, x1_p0, x1_p1);
}
if constexpr (kNumQueries >= 3) {
apply_mask_for_query(2, x2_p0, x2_p1);
}
if constexpr (kNumQueries >= 4) {
apply_mask_for_query(3, x3_p0, x3_p1);
}
if constexpr (kNumQueries >= 5) {
apply_mask_for_query(4, x4_p0, x4_p1);
}
if constexpr (kNumQueries >= 6) {
apply_mask_for_query(5, x5_p0, x5_p1);
}
if constexpr (kNumQueries >= 7) {
apply_mask_for_query(6, x6_p0, x6_p1);
}
if constexpr (kNumQueries >= 8) {
apply_mask_for_query(7, x7_p0, x7_p1);
}
}

template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0,
VF& x0_p1, VF& x1_p0, VF& x1_p1,
VF& x2_p0, VF& x2_p1, VF& x3_p0,
VF& x3_p1, VF& x4_p0, VF& x4_p1,
VF& x5_p0, VF& x5_p1, VF& x6_p0,
VF& x6_p1, VF& x7_p0, VF& x7_p1) {
const size_t kTileSize = hn::Lanes(df);
const PackedSpan<const BF16> scales_span =
MakeConstSpan(scales, 2 * kTileSize);
VF scales_p0, scales_p1;
Decompress2(df, scales_span, 0, scales_p0, scales_p1);
if constexpr (kNumQueries >= 1) {
x0_p0 = hn::Mul(x0_p0, scales_p0);
x0_p1 = hn::Mul(x0_p1, scales_p1);
}
if constexpr (kNumQueries >= 2) {
x1_p0 = hn::Mul(x1_p0, scales_p0);
x1_p1 = hn::Mul(x1_p1, scales_p1);
}
if constexpr (kNumQueries >= 3) {
x2_p0 = hn::Mul(x2_p0, scales_p0);
x2_p1 = hn::Mul(x2_p1, scales_p1);
}
if constexpr (kNumQueries >= 4) {
x3_p0 = hn::Mul(x3_p0, scales_p0);
x3_p1 = hn::Mul(x3_p1, scales_p1);
}
if constexpr (kNumQueries >= 5) {
x4_p0 = hn::Mul(x4_p0, scales_p0);
x4_p1 = hn::Mul(x4_p1, scales_p1);
}
if constexpr (kNumQueries >= 6) {
x5_p0 = hn::Mul(x5_p0, scales_p0);
x5_p1 = hn::Mul(x5_p1, scales_p1);
}
if constexpr (kNumQueries >= 7) {
x6_p0 = hn::Mul(x6_p0, scales_p0);
x6_p1 = hn::Mul(x6_p1, scales_p1);
}
if constexpr (kNumQueries >= 8) {
x7_p0 = hn::Mul(x7_p0, scales_p0);
x7_p1 = hn::Mul(x7_p1, scales_p1);
}
}

template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void ApplyQuantizationScale(
DF df, const float* HWY_RESTRICT q_scales, size_t query_idx, VF& x0_p0,
VF& x0_p1, VF& x1_p0, VF& x1_p1, VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1,
VF& x4_p0, VF& x4_p1, VF& x5_p0, VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0,
VF& x7_p1) {
auto apply_scale = [&](size_t i, VF& x_p0, VF& x_p1) HWY_ATTR {
size_t scale_idx = query_idx + i;
VF s = hn::Set(df, q_scales[scale_idx]);
x_p0 = hn::Mul(x_p0, s);
x_p1 = hn::Mul(x_p1, s);
};

if constexpr (kNumQueries >= 1) {
apply_scale(0, x0_p0, x0_p1);
}
if constexpr (kNumQueries >= 2) {
apply_scale(1, x1_p0, x1_p1);
}
if constexpr (kNumQueries >= 3) {
apply_scale(2, x2_p0, x2_p1);
}
if constexpr (kNumQueries >= 4) {
apply_scale(3, x3_p0, x3_p1);
}
if constexpr (kNumQueries >= 5) {
apply_scale(4, x4_p0, x4_p1);
}
if constexpr (kNumQueries >= 6) {
apply_scale(5, x5_p0, x5_p1);
}
if constexpr (kNumQueries >= 7) {
apply_scale(6, x6_p0, x6_p1);
}
if constexpr (kNumQueries >= 8) {
apply_scale(7, x7_p0, x7_p1);
}
}

// Performs tiled flash attention for arbitrary number of queries
// It depends on kv being tiled.
Expand Down Expand Up @@ -1638,22 +1463,27 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
}
MatPtrT<float> offset_out = att_out_per_query[loop_idx];
const size_t group_offset = query_idx % kNumQueriesPerLoop;
if (group_offset > 0) {
offset_out.SetPtr(offset_out.Row(group_offset), offset_out.Stride());
}

if constexpr (IsF32<Q_T>()) {
MulByConstAndAddTileUpTo8<kNumQueries>(
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]);
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, offset_out);
} else if constexpr (IsInt16<Q_T>()) {
MulByConstAndAddTileUpTo8_BF16_Int16<kNumQueries>(
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx],
q_scales_s);
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, offset_out, q_scales_s);
} else {
MulByConstAndAddTileUpTo8_BF16<kNumQueries>(
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]);
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, offset_out);
}
};

Expand Down Expand Up @@ -1740,6 +1570,20 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16(
last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits);
}

void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation(
hwy::Span<const MatPtr> kvs, size_t q_count,
const BF16* HWY_RESTRICT q_base,
hwy::Span<const size_t> start_pos_per_query,
hwy::Span<const size_t> last_pos_per_query, const float att_cap,
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
float* HWY_RESTRICT max_logits) {
CallUpcastedKVs(kvs, [&](const auto& kv_t) {
TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro(
kv_t, q_count, q_base, {}, start_pos_per_query, last_pos_per_query,
att_cap, att_out, exp_denominator_sums, max_logits);
});
}

// Implements flash attention for a strip of tiles of size 1, 4 or 8 query
// vectors by 2NF positions in K.
// It iterates through tiles in K from `params.min_start_pos / 2NF * 2NF` up to
Expand Down
Loading
Loading