diff --git a/BUILD.bazel b/BUILD.bazel index a2d7803a..7697c16f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -141,6 +141,7 @@ cc_library( cc_test( name = "flash_attention_test", srcs = ["gemma/flash_attention_test.cc"], + linkstatic = True, deps = [ ":activations", ":attention", @@ -601,6 +602,7 @@ cc_test( name = "kv_transcoding_test", srcs = ["gemma/kv_transcoding_test.cc"], deps = [ + ":basics", ":configs", ":kv_transcoding", "//testing/base/public:gunit_main", @@ -680,6 +682,7 @@ cc_library( ], textual_hdrs = [ "gemma/gemma-inl.h", + "gemma/flash_attention_arm-inl.h", ], deps = [ ":activations", @@ -690,6 +693,7 @@ cc_library( ":mat", ":matmul", ":matmul_env", + ":matmul_static", ":ops", ":query", ":tensor_stats", @@ -748,6 +752,7 @@ cc_library( ":flash_structs", ":gemma_args", ":kv_cache", + ":kv_transcoding", ":mat", ":matmul_env", ":model_store", diff --git a/gemma/configs.cc b/gemma/configs.cc index 19c7c26e..df432557 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -736,6 +736,7 @@ constexpr std::pair kAttentionImplNameToEnum[] = { {"flash_transposed_qs", AttentionImpl::kFlashTransposedQs}, {"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16}, {"flash_transposed_qs_int16", AttentionImpl::kFlashTransposedQsInt16}, + {"flash_matrix_accumulation", AttentionImpl::kFlashMatrixAccumulation}, }; std::string GetAttentionImplName(AttentionImpl impl) { @@ -768,6 +769,8 @@ std::string KVEncodingToString(KVEncoding encoding) { return "Int8"; case KVEncoding::kInt8TwoTranspositions: return "Int8TwoTranspositions"; + case KVEncoding::kBF16MatrixAccumulation: + return "BF16MatrixAccumulation"; default: return "Unknown"; } diff --git a/gemma/configs.h b/gemma/configs.h index 89cc9906..474cdef1 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -90,6 +90,7 @@ enum class KVEncoding { kBF16TwoTranspositions = 4, kInt8 = 5, kInt8TwoTranspositions = 6, + kBF16MatrixAccumulation = 7, }; // Returns a string representation of the KVEncoding. @@ -104,6 +105,7 @@ enum class AttentionImpl { kFlashTransposedQs, kFlashTransposedQsBF16, kFlashTransposedQsInt16, + kFlashMatrixAccumulation, kSentinel, }; diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 9c4bd1b2..a925581b 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -17,11 +17,11 @@ #include #include -#include #include #include -#include +#include #include +#include #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS @@ -52,6 +52,8 @@ // After highway.h #include "compression/compress-inl.h" #include "gemma/attention.h" +#include "gemma/flash_attention.h" +#include "gemma/flash_attention_arm-inl.h" #include "ops/matmul-inl.h" #include "ops/ops-inl.h" #include "hwy/contrib/math/fast_math-inl.h" @@ -60,7 +62,6 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -static constexpr float kNegInf = -std::numeric_limits::max() / 64.0f; // Updates q in place for RMSNorm and positional encoding. void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, @@ -570,7 +571,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4( const DF4 df4; using VF4 = hn::Vec; static_assert(kNumQueries >= 1 && kNumQueries <= 4); - VF4 new_max = hn::Set(df4, kNegInf); + VF4 new_max = hn::Set(df4, kMaskedLogitVal); VF max_0 = hn::Zero(df), max_1 = hn::Zero(df), max_2 = hn::Zero(df), max_3 = hn::Zero(df); max_0 = hn::Max(x_0_p0, x_0_p1); @@ -595,10 +596,10 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4( new_max = hn::Mul(cap, hn::FastTanh(df4, hn::Mul(new_max, one_over_cap))); } VF4 local_max = new_max; - VF4 old_max_vf = hn::Set(df4, kNegInf); + VF4 old_max_vf = hn::Set(df4, kMaskedLogitVal); old_max_vf = hn::LoadU(df4, old_max); new_max = hn::Max(new_max, old_max_vf); - auto changed_max = hn::Gt(new_max, hn::Set(df4, kNegInf)); + auto changed_max = hn::Gt(new_max, hn::Set(df4, kMaskedLogitVal)); hn::StoreU(new_max, df4, old_max); auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR { const VF new_max_i = hn::Set(df, old_max[i]); @@ -699,7 +700,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8( const DF8 df8; using VF8 = hn::Vec; static_assert(kNumQueries >= 1 && kNumQueries <= 8); - VF8 new_max = hn::Set(df8, kNegInf); + VF8 new_max = hn::Set(df8, kMaskedLogitVal); VF max_0, max_1, max_2, max_3, max_4, max_5, max_6, max_7 = hn::Zero(df); max_0 = hn::Max(x_0_p0, x_0_p1); if constexpr (kNumQueries >= 2) { @@ -737,10 +738,10 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8( new_max = hn::Mul(cap, hn::FastTanh(df8, hn::Mul(new_max, one_over_cap))); } VF8 local_max = new_max; - VF8 old_max_vf = hn::Set(df8, kNegInf); + VF8 old_max_vf = hn::Set(df8, kMaskedLogitVal); old_max_vf = hn::LoadU(df8, old_max); new_max = hn::Max(new_max, old_max_vf); - auto changed_max = hn::Gt(new_max, hn::Set(df8, kNegInf)); + auto changed_max = hn::Gt(new_max, hn::Set(df8, kMaskedLogitVal)); hn::StoreU(new_max, df8, old_max); auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR { @@ -1235,182 +1236,6 @@ static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthBF16( #endif } -template > -static HWY_INLINE void ApplySoftCap(DF df, float att_cap, float one_over_cap, - VF& x0, VF& x1, VF& x2, VF& x3, VF& x4, - VF& x5, VF& x6, VF& x7) { - if (att_cap > 0.0f) { - VF cap = hn::Set(df, att_cap); - VF one_over_cap_vec = hn::Set(df, one_over_cap); - x0 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x0, one_over_cap_vec))); - if constexpr (kVTileSize >= 2) { - x1 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x1, one_over_cap_vec))); - } - if constexpr (kVTileSize >= 3) { - x2 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x2, one_over_cap_vec))); - } - if constexpr (kVTileSize >= 4) { - x3 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x3, one_over_cap_vec))); - } - if constexpr (kVTileSize >= 5) { - x4 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x4, one_over_cap_vec))); - } - if constexpr (kVTileSize >= 6) { - x5 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x5, one_over_cap_vec))); - } - if constexpr (kVTileSize >= 7) { - x6 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x6, one_over_cap_vec))); - } - if constexpr (kVTileSize >= 8) { - x7 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x7, one_over_cap_vec))); - } - } -} - -template , typename DU, - class VU = hn::Vec> -static HWY_NOINLINE void ApplyMasking( - DF df, DU du, size_t position, - const size_t* HWY_RESTRICT first_pos_per_query, - const size_t* HWY_RESTRICT last_pos_per_query, VF& x0_p0, VF& x0_p1, - VF& x1_p0, VF& x1_p1, VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, VF& x4_p0, - VF& x4_p1, VF& x5_p0, VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0, - VF& x7_p1) { - VU lane_indices = hn::Iota(du, 0); - HWY_LANES_CONSTEXPR size_t kTileSize = hn::Lanes(df); - auto per_lane_pos_p0 = hn::Add(hn::Set(du, position), lane_indices); - auto per_lane_pos_p1 = - hn::Add(hn::Set(du, position + kTileSize), lane_indices); - - VF neg_inf = hn::Set(df, kNegInf); - - auto apply_mask_for_query = [&](int query_idx, VF& x_p0, VF& x_p1) HWY_ATTR { - const size_t first_pos = first_pos_per_query[query_idx]; - const size_t last_pos = last_pos_per_query[query_idx]; - - auto valid_tokens_mask_p0 = hn::Ge(per_lane_pos_p0, hn::Set(du, first_pos)); - valid_tokens_mask_p0 = hn::And( - valid_tokens_mask_p0, hn::Le(per_lane_pos_p0, hn::Set(du, last_pos))); - x_p0 = - hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p0), x_p0, neg_inf); - - auto valid_tokens_mask_p1 = hn::Ge(per_lane_pos_p1, hn::Set(du, first_pos)); - valid_tokens_mask_p1 = hn::And( - valid_tokens_mask_p1, hn::Le(per_lane_pos_p1, hn::Set(du, last_pos))); - x_p1 = - hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p1), x_p1, neg_inf); - }; - - if constexpr (kNumQueries >= 1) { - apply_mask_for_query(0, x0_p0, x0_p1); - } - if constexpr (kNumQueries >= 2) { - apply_mask_for_query(1, x1_p0, x1_p1); - } - if constexpr (kNumQueries >= 3) { - apply_mask_for_query(2, x2_p0, x2_p1); - } - if constexpr (kNumQueries >= 4) { - apply_mask_for_query(3, x3_p0, x3_p1); - } - if constexpr (kNumQueries >= 5) { - apply_mask_for_query(4, x4_p0, x4_p1); - } - if constexpr (kNumQueries >= 6) { - apply_mask_for_query(5, x5_p0, x5_p1); - } - if constexpr (kNumQueries >= 7) { - apply_mask_for_query(6, x6_p0, x6_p1); - } - if constexpr (kNumQueries >= 8) { - apply_mask_for_query(7, x7_p0, x7_p1); - } -} - -template > -static HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0, - VF& x0_p1, VF& x1_p0, VF& x1_p1, - VF& x2_p0, VF& x2_p1, VF& x3_p0, - VF& x3_p1, VF& x4_p0, VF& x4_p1, - VF& x5_p0, VF& x5_p1, VF& x6_p0, - VF& x6_p1, VF& x7_p0, VF& x7_p1) { - const size_t kTileSize = hn::Lanes(df); - const PackedSpan scales_span = - MakeConstSpan(scales, 2 * kTileSize); - VF scales_p0, scales_p1; - Decompress2(df, scales_span, 0, scales_p0, scales_p1); - if constexpr (kNumQueries >= 1) { - x0_p0 = hn::Mul(x0_p0, scales_p0); - x0_p1 = hn::Mul(x0_p1, scales_p1); - } - if constexpr (kNumQueries >= 2) { - x1_p0 = hn::Mul(x1_p0, scales_p0); - x1_p1 = hn::Mul(x1_p1, scales_p1); - } - if constexpr (kNumQueries >= 3) { - x2_p0 = hn::Mul(x2_p0, scales_p0); - x2_p1 = hn::Mul(x2_p1, scales_p1); - } - if constexpr (kNumQueries >= 4) { - x3_p0 = hn::Mul(x3_p0, scales_p0); - x3_p1 = hn::Mul(x3_p1, scales_p1); - } - if constexpr (kNumQueries >= 5) { - x4_p0 = hn::Mul(x4_p0, scales_p0); - x4_p1 = hn::Mul(x4_p1, scales_p1); - } - if constexpr (kNumQueries >= 6) { - x5_p0 = hn::Mul(x5_p0, scales_p0); - x5_p1 = hn::Mul(x5_p1, scales_p1); - } - if constexpr (kNumQueries >= 7) { - x6_p0 = hn::Mul(x6_p0, scales_p0); - x6_p1 = hn::Mul(x6_p1, scales_p1); - } - if constexpr (kNumQueries >= 8) { - x7_p0 = hn::Mul(x7_p0, scales_p0); - x7_p1 = hn::Mul(x7_p1, scales_p1); - } -} - -template > -static HWY_INLINE void ApplyQuantizationScale( - DF df, const float* HWY_RESTRICT q_scales, size_t query_idx, VF& x0_p0, - VF& x0_p1, VF& x1_p0, VF& x1_p1, VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, - VF& x4_p0, VF& x4_p1, VF& x5_p0, VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0, - VF& x7_p1) { - auto apply_scale = [&](size_t i, VF& x_p0, VF& x_p1) HWY_ATTR { - size_t scale_idx = query_idx + i; - VF s = hn::Set(df, q_scales[scale_idx]); - x_p0 = hn::Mul(x_p0, s); - x_p1 = hn::Mul(x_p1, s); - }; - - if constexpr (kNumQueries >= 1) { - apply_scale(0, x0_p0, x0_p1); - } - if constexpr (kNumQueries >= 2) { - apply_scale(1, x1_p0, x1_p1); - } - if constexpr (kNumQueries >= 3) { - apply_scale(2, x2_p0, x2_p1); - } - if constexpr (kNumQueries >= 4) { - apply_scale(3, x3_p0, x3_p1); - } - if constexpr (kNumQueries >= 5) { - apply_scale(4, x4_p0, x4_p1); - } - if constexpr (kNumQueries >= 6) { - apply_scale(5, x5_p0, x5_p1); - } - if constexpr (kNumQueries >= 7) { - apply_scale(6, x6_p0, x6_p1); - } - if constexpr (kNumQueries >= 8) { - apply_scale(7, x7_p0, x7_p1); - } -} // Performs tiled flash attention for arbitrary number of queries // It depends on kv being tiled. @@ -1638,22 +1463,27 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); } + MatPtrT offset_out = att_out_per_query[loop_idx]; + const size_t group_offset = query_idx % kNumQueriesPerLoop; + if (group_offset > 0) { + offset_out.SetPtr(offset_out.Row(group_offset), offset_out.Stride()); + } + if constexpr (IsF32()) { MulByConstAndAddTileUpTo8( df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, - x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]); + x_6_p_1, x_7_p_0, x_7_p_1, v_tile, offset_out); } else if constexpr (IsInt16()) { MulByConstAndAddTileUpTo8_BF16_Int16( df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, - x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx], - q_scales_s); + x_6_p_1, x_7_p_0, x_7_p_1, v_tile, offset_out, q_scales_s); } else { MulByConstAndAddTileUpTo8_BF16( df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, - x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]); + x_6_p_1, x_7_p_0, x_7_p_1, v_tile, offset_out); } }; @@ -1740,6 +1570,20 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits); } +void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation( + hwy::Span kvs, size_t q_count, + const BF16* HWY_RESTRICT q_base, + hwy::Span start_pos_per_query, + hwy::Span last_pos_per_query, const float att_cap, + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, + float* HWY_RESTRICT max_logits) { + CallUpcastedKVs(kvs, [&](const auto& kv_t) { + TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( + kv_t, q_count, q_base, {}, start_pos_per_query, last_pos_per_query, + att_cap, att_out, exp_denominator_sums, max_logits); + }); +} + // Implements flash attention for a strip of tiles of size 1, 4 or 8 query // vectors by 2NF positions in K. // It iterates through tiles in K from `params.min_start_pos / 2NF * 2NF` up to diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 7e83a823..3c9f4fbe 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -36,46 +36,54 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void RMSNormAndPositionalEncoding( \ - size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ - const MatPtr& query_norm_scale, size_t layer_idx, \ - const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ - \ - size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ - size_t total_tasks, size_t target_parallelism); \ - \ - void FlashAttention(size_t num_tokens, size_t target_parallelism, \ - size_t layer_idx, const MatPtr& query_norm_scale, \ - AttentionActivationsPtrs& activations, QBatch& qbatch, \ - ThreadingContext& ctx, AttentionImpl attention_impl); \ - \ - void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( \ - hwy::Span kvs, size_t q_count, \ - const float* HWY_RESTRICT q_base, \ - hwy::Span start_pos_per_query, \ - hwy::Span last_pos_per_query, const float att_cap, \ - MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ - float* HWY_RESTRICT max_logits); \ - \ - void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( \ - hwy::Span kvs, size_t q_count, \ - const BF16* HWY_RESTRICT q_base, \ - hwy::Span start_pos_per_query, \ - hwy::Span last_pos_per_query, const float att_cap, \ - MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ - float* HWY_RESTRICT max_logits); \ - \ - void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( \ - hwy::Span kvs, size_t q_count, \ - const int16_t* HWY_RESTRICT q_base, hwy::Span q_scales, \ - hwy::Span start_pos_per_query, \ - hwy::Span last_pos_per_query, const float att_cap, \ - MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ - float* HWY_RESTRICT max_logits); \ - \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void RMSNormAndPositionalEncoding( \ + size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ + const MatPtr& query_norm_scale, size_t layer_idx, \ + const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ + \ + size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ + size_t total_tasks, size_t target_parallelism); \ + \ + void FlashAttention(size_t num_tokens, size_t target_parallelism, \ + size_t layer_idx, const MatPtr& query_norm_scale, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + ThreadingContext& ctx, AttentionImpl attention_impl); \ + \ + void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( \ + hwy::Span kvs, size_t q_count, \ + const float* HWY_RESTRICT q_base, \ + hwy::Span start_pos_per_query, \ + hwy::Span last_pos_per_query, const float att_cap, \ + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ + float* HWY_RESTRICT max_logits); \ + \ + void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( \ + hwy::Span kvs, size_t q_count, \ + const BF16* HWY_RESTRICT q_base, \ + hwy::Span start_pos_per_query, \ + hwy::Span last_pos_per_query, const float att_cap, \ + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ + float* HWY_RESTRICT max_logits); \ + \ + void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( \ + hwy::Span kvs, size_t q_count, \ + const int16_t* HWY_RESTRICT q_base, hwy::Span q_scales, \ + hwy::Span start_pos_per_query, \ + hwy::Span last_pos_per_query, const float att_cap, \ + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ + float* HWY_RESTRICT max_logits); \ + \ + void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation( \ + hwy::Span kvs, size_t q_count, \ + const BF16* HWY_RESTRICT q_base, \ + hwy::Span start_pos_per_query, \ + hwy::Span last_pos_per_query, const float att_cap, \ + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ + float* HWY_RESTRICT max_logits); \ + \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/gemma/flash_attention_arm-inl.h b/gemma/flash_attention_arm-inl.h new file mode 100644 index 00000000..96222a46 --- /dev/null +++ b/gemma/flash_attention_arm-inl.h @@ -0,0 +1,1010 @@ +// Include guard for non-SIMD code. +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_ARM_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_ARM_INL_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "gemma/flash_attention.h" +#include "gemma/kv_cache.h" +#include "util/basics.h" +#include "util/threading_context.h" +#include "hwy/base.h" + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_ARM_INL_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_ARM_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_ARM_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_ARM_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_ARM_TOGGLE +#endif + +#include "compression/compress-inl.h" +#include "ops/matmul-inl.h" +#include "ops/ops-inl.h" +#include "hwy/contrib/math/fast_math-inl.h" + +#ifndef BENCHMARK_BLOCK_SIZE +#define BENCHMARK_BLOCK_SIZE 128 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +struct TileAttentionGroupParams { + size_t smallest_start_pos; + size_t largest_last_pos; + size_t num_loops; + size_t position; + float one_over_cap; + float max_v_scale; + hwy::AlignedVector pos_data; + hwy::Span min_start_pos_per_group; + hwy::Span max_start_pos_per_group; + hwy::Span min_last_pos_per_group; + hwy::Span max_last_pos_per_group; + std::vector> att_out_per_query; + + TileAttentionGroupParams(size_t q_count, size_t kNumQueriesPerLoop, + size_t step_size, + hwy::Span start_pos_per_query, + hwy::Span last_pos_per_query, + float att_cap, MatPtrT& att_out, + bool create_att_out_per_query = false) { + smallest_start_pos = std::numeric_limits::max(); + largest_last_pos = std::numeric_limits::min(); + for (size_t i = 0; i < start_pos_per_query.size(); ++i) { + smallest_start_pos = std::min(smallest_start_pos, start_pos_per_query[i]); + largest_last_pos = std::max(largest_last_pos, last_pos_per_query[i]); + } + num_loops = hwy::DivCeil(q_count, kNumQueriesPerLoop); + + pos_data.resize(num_loops * 4); + min_start_pos_per_group = hwy::Span(pos_data.data(), num_loops); + max_start_pos_per_group = + hwy::Span(pos_data.data() + num_loops, num_loops); + min_last_pos_per_group = + hwy::Span(pos_data.data() + 2 * num_loops, num_loops); + max_last_pos_per_group = + hwy::Span(pos_data.data() + 3 * num_loops, num_loops); + + for (size_t i = 0; i < num_loops; ++i) { + size_t min_start = std::numeric_limits::max(); + size_t max_start = 0; + size_t min_last = std::numeric_limits::max(); + size_t max_last = 0; + for (size_t j = 0; j < kNumQueriesPerLoop; ++j) { + if (i * kNumQueriesPerLoop + j < q_count) { + size_t q_idx = i * kNumQueriesPerLoop + j; + min_start = std::min(min_start, start_pos_per_query[q_idx]); + max_start = std::max(max_start, start_pos_per_query[q_idx]); + min_last = std::min(min_last, last_pos_per_query[q_idx]); + max_last = std::max(max_last, last_pos_per_query[q_idx]); + } + } + min_start_pos_per_group[i] = min_start; + max_start_pos_per_group[i] = max_start; + min_last_pos_per_group[i] = min_last; + max_last_pos_per_group[i] = max_last; + } + + constexpr int kTileSize = gcpp::KVCache::kTileSize; + const size_t base_pos = + smallest_start_pos - (smallest_start_pos % kTileSize); + const size_t rem = smallest_start_pos % kTileSize; + const size_t num_skipped_sub_tiles = rem / step_size; + position = base_pos + num_skipped_sub_tiles * step_size; + one_over_cap = att_cap > 0.0f ? 1.0f / att_cap : 0.0f; + max_v_scale = 1.0f; + + if (create_att_out_per_query) { + const size_t qkv_dim = att_out.Cols(); + att_out_per_query.reserve(num_loops); + for (size_t i = 0; i < num_loops; ++i) { + att_out_per_query.emplace_back("att_out", + Extents2D(kNumQueriesPerLoop, qkv_dim)); + att_out_per_query.back().SetPtr(att_out.Row(i * kNumQueriesPerLoop), + att_out.Stride()); + } + } + } +}; + +template +HWY_INLINE hn::Vec LoadAndDuplicateQueries(D d, + const T* HWY_RESTRICT q_ptr) { +#if HWY_HAVE_CONSTEXPR_LANES + constexpr size_t N = hn::Lanes(d); + if constexpr (N <= 8) { + return hn::LoadU(d, q_ptr); + } else { + return hn::LoadDup128(d, q_ptr); + } +#else + return hn::LoadDup128(d, q_ptr); +#endif +} + +template > +HWY_INLINE void QDotKTilexUpTo8MatrixAccumulation( + DF df, const T_IN* HWY_RESTRICT q_group, const T_IN* tile_base, + size_t current_pos, size_t qkv_dim, VF& C00, VF& C01, VF& C02, VF& C03, + VF& C10, VF& C11, VF& C12, VF& C13, VF& C20, VF& C21, VF& C22, VF& C23, + VF& C30, VF& C31, VF& C32, VF& C33) { + using D_ACC = hwy::If(), hn::FixedTag, DF>; + const D_ACC d_acc; + using VecAcc = hn::Vec; + + VecAcc acc00 = hn::Zero(d_acc), acc01 = hn::Zero(d_acc), + acc02 = hn::Zero(d_acc), acc03 = hn::Zero(d_acc); + VecAcc acc10 = hn::Zero(d_acc), acc11 = hn::Zero(d_acc), + acc12 = hn::Zero(d_acc), acc13 = hn::Zero(d_acc); + VecAcc acc20 = hn::Zero(d_acc), acc21 = hn::Zero(d_acc), + acc22 = hn::Zero(d_acc), acc23 = hn::Zero(d_acc); + VecAcc acc30 = hn::Zero(d_acc), acc31 = hn::Zero(d_acc), + acc32 = hn::Zero(d_acc), acc33 = hn::Zero(d_acc); + + using D_INPUT = hn::Repartition; + const D_INPUT d_input; + using VecInput = hn::Vec; + + HWY_LANES_CONSTEXPR size_t step_size = hn::Lanes(d_input); + HWY_LANES_CONSTEXPR size_t ch_step = 4; // Always 4 for 8x4 layout + + size_t g0 = (current_pos / 8) % 4; + const T_IN* k_ptr0; + const T_IN* k_ptr1; + const T_IN* k_ptr2; + const T_IN* k_ptr3; + + if (step_size == 32) { + k_ptr0 = tile_base + 0; + k_ptr1 = tile_base + 32; + k_ptr2 = tile_base + 64; + k_ptr3 = tile_base + 96; + } else if (step_size == 16) { + k_ptr0 = tile_base + g0 * 32; + k_ptr1 = tile_base + (g0 + 1) * 32; + k_ptr2 = nullptr; + k_ptr3 = nullptr; + } else { // step_size == 8 + k_ptr0 = tile_base + g0 * 32; + k_ptr1 = nullptr; + k_ptr2 = nullptr; + k_ptr3 = nullptr; + } + + size_t ch_base_k = 0; + size_t ch_base_q = 0; + for (size_t ch_base = 0; ch_base < qkv_dim; ch_base += ch_step) { + VecInput B0, B1, B2, B3; + if (step_size == 32) { // 512-bit native path + B0 = hn::LoadU(d_input, k_ptr0 + ch_base_k); + B1 = hn::LoadU(d_input, k_ptr1 + ch_base_k); + B2 = hn::LoadU(d_input, k_ptr2 + ch_base_k); + B3 = hn::LoadU(d_input, k_ptr3 + ch_base_k); + } else if (step_size == 16) { // 256-bit fallback path + B0 = hn::LoadU(d_input, k_ptr0 + ch_base_k); + B1 = hn::LoadU(d_input, k_ptr0 + ch_base_k + 16); // Group 0, second half + B2 = hn::LoadU(d_input, k_ptr1 + ch_base_k); + B3 = hn::LoadU(d_input, k_ptr1 + ch_base_k + 16); // Group 1, second half + } else if (step_size == 8) { // 128-bit fallback path + B0 = hn::LoadU(d_input, k_ptr0 + ch_base_k); + B1 = hn::LoadU(d_input, + k_ptr0 + ch_base_k + 8); // Group 0, second quarter + B2 = hn::LoadU(d_input, + k_ptr0 + ch_base_k + 16); // Group 0, third quarter + B3 = hn::LoadU(d_input, + k_ptr0 + ch_base_k + 24); // Group 0, fourth quarter + } + + if constexpr (kNumQueries >= 1) { + const auto A0 = LoadAndDuplicateQueries(d_input, q_group + ch_base_q); + acc00 = PerBlock2x2MatMulMaybeEmulate(d_acc, A0, B0, acc00); + acc01 = PerBlock2x2MatMulMaybeEmulate(d_acc, A0, B1, acc01); + acc02 = PerBlock2x2MatMulMaybeEmulate(d_acc, A0, B2, acc02); + acc03 = PerBlock2x2MatMulMaybeEmulate(d_acc, A0, B3, acc03); + } + if constexpr (kNumQueries >= 3) { + const auto A1 = + LoadAndDuplicateQueries(d_input, q_group + qkv_dim * 2 + ch_base_q); + acc10 = PerBlock2x2MatMulMaybeEmulate(d_acc, A1, B0, acc10); + acc11 = PerBlock2x2MatMulMaybeEmulate(d_acc, A1, B1, acc11); + acc12 = PerBlock2x2MatMulMaybeEmulate(d_acc, A1, B2, acc12); + acc13 = PerBlock2x2MatMulMaybeEmulate(d_acc, A1, B3, acc13); + } + if constexpr (kNumQueries >= 5) { + const auto A2 = + LoadAndDuplicateQueries(d_input, q_group + qkv_dim * 4 + ch_base_q); + acc20 = PerBlock2x2MatMulMaybeEmulate(d_acc, A2, B0, acc20); + acc21 = PerBlock2x2MatMulMaybeEmulate(d_acc, A2, B1, acc21); + acc22 = PerBlock2x2MatMulMaybeEmulate(d_acc, A2, B2, acc22); + acc23 = PerBlock2x2MatMulMaybeEmulate(d_acc, A2, B3, acc23); + } + if constexpr (kNumQueries >= 7) { + const auto A3 = + LoadAndDuplicateQueries(d_input, q_group + qkv_dim * 6 + ch_base_q); + acc30 = PerBlock2x2MatMulMaybeEmulate(d_acc, A3, B0, acc30); + acc31 = PerBlock2x2MatMulMaybeEmulate(d_acc, A3, B1, acc31); + acc32 = PerBlock2x2MatMulMaybeEmulate(d_acc, A3, B2, acc32); + acc33 = PerBlock2x2MatMulMaybeEmulate(d_acc, A3, B3, acc33); + } + ch_base_k += ch_step * 32; // Stride 128 elements for 8x4 layout + ch_base_q += ch_step * 2; + } + + auto convert_and_reduce = [&](VF& C, VecAcc acc) HWY_ATTR { + if constexpr (!IsInt8()) { + C = acc; + } else { + C = hn::ConvertTo(df, acc); + } + }; + + convert_and_reduce(C00, acc00); + convert_and_reduce(C01, acc01); + convert_and_reduce(C02, acc02); + convert_and_reduce(C03, acc03); + convert_and_reduce(C10, acc10); + convert_and_reduce(C11, acc11); + convert_and_reduce(C12, acc12); + convert_and_reduce(C13, acc13); + convert_and_reduce(C20, acc20); + convert_and_reduce(C21, acc21); + convert_and_reduce(C22, acc22); + convert_and_reduce(C23, acc23); + convert_and_reduce(C30, acc30); + convert_and_reduce(C31, acc31); + convert_and_reduce(C32, acc32); + convert_and_reduce(C33, acc33); +} + +// Symmetrical VLA Implementation +template +HWY_INLINE V ConcatLowerLower_VLA(DF df, V q1, V q0) { + using D64 = hn::Repartition; + const D64 d64; + auto q0_64 = hn::BitCast(d64, q0); + auto q1_64 = hn::BitCast(d64, q1); + auto interleaved = hn::ConcatEven(d64, q1_64, q0_64); + return hn::BitCast(df, interleaved); +} + +template +HWY_INLINE V ConcatUpperUpper_VLA(DF df, V q1, V q0) { + using D64 = hn::Repartition; + const D64 d64; + auto q0_64 = hn::BitCast(d64, q0); + auto q1_64 = hn::BitCast(d64, q1); + auto interleaved = hn::ConcatOdd(d64, q1_64, q0_64); + return hn::BitCast(df, interleaved); +} + +template , + typename DBF_T = hn::ScalableTag, typename KV_T> +HWY_INLINE void TileFlashAttentionSVBlock( + size_t q_base_idx, size_t position, size_t current_kv_start_offset, + size_t current_kv_idx, size_t actual_steps, size_t qkv_dim, + const float* HWY_RESTRICT scales_old, + const float* HWY_RESTRICT q_scales_new, + const float* HWY_RESTRICT softmax_buf, + const hwy::Span>& kvs, + float* HWY_RESTRICT C_accumulators) { + namespace hn = hwy::HWY_NAMESPACE; + using BF16 = hwy::bfloat16_t; + const DF_T df; + const DBF_T dbf; + const hn::FixedTag df4; + using VBF = hn::Vec; + using VF = hn::Vec; + using VF4 = hn::Vec; + + constexpr size_t kBlockSize = BENCHMARK_BLOCK_SIZE; + + // Pre-pack Q into BF16 to avoid scaling and demoting in the inner loop. + const size_t kNumPackedVectors = ((kNumQueries + 1) / 2) * 2; + HWY_ALIGN BF16 q_packed[kBlockSize * 8 * hn::MaxLanes(dbf)]; + + for (size_t step_idx = 0; step_idx < actual_steps; ++step_idx) { + auto load_scaled_q = [&](size_t qp, VF& v0, VF& v1) HWY_ATTR { + const float* ptr = + softmax_buf + (q_base_idx + qp) * kBlockSize + step_idx * kStepSize; + VF qs = hn::Set(df, q_scales_new[qp]); + v0 = hn::Mul(hn::LoadU(df, ptr + 0), qs); + v1 = hn::Mul(hn::LoadU(df, ptr + hn::Lanes(df)), qs); + }; + + VF q0_l = hn::Zero(df), q0_h = hn::Zero(df), q1_l = hn::Zero(df), + q1_h = hn::Zero(df); + VF q2_l = hn::Zero(df), q2_h = hn::Zero(df), q3_l = hn::Zero(df), + q3_h = hn::Zero(df); + VF q4_l = hn::Zero(df), q4_h = hn::Zero(df), q5_l = hn::Zero(df), + q5_h = hn::Zero(df); + VF q6_l = hn::Zero(df), q6_h = hn::Zero(df), q7_l = hn::Zero(df), + q7_h = hn::Zero(df); + + if constexpr (kNumQueries >= 1) load_scaled_q(0, q0_l, q0_h); + if constexpr (kNumQueries >= 2) load_scaled_q(1, q1_l, q1_h); + if constexpr (kNumQueries >= 3) load_scaled_q(2, q2_l, q2_h); + if constexpr (kNumQueries >= 4) load_scaled_q(3, q3_l, q3_h); + if constexpr (kNumQueries >= 5) load_scaled_q(4, q4_l, q4_h); + if constexpr (kNumQueries >= 6) load_scaled_q(5, q5_l, q5_h); + if constexpr (kNumQueries >= 7) load_scaled_q(6, q6_l, q6_h); + if constexpr (kNumQueries >= 8) load_scaled_q(7, q7_l, q7_h); + + auto pack_and_store_pair = [&](size_t pair_idx, VF ql0, VF qh0, VF ql1, + VF qh1) HWY_ATTR { + BF16* dst = q_packed + step_idx * kNumPackedVectors * kStepSize + + pair_idx * 2 * kStepSize; + using D64 = hn::Repartition; + const D64 d64; + using D64_half = hn::Half; + const D64_half d64_half; + using dbf_half_t = hn::Half; + const dbf_half_t dbf_half; + + auto ql0_bf = hn::DemoteTo(dbf_half, ql0); + auto ql1_bf = hn::DemoteTo(dbf_half, ql1); + + auto qh0_bf = hn::DemoteTo(dbf_half, qh0); + auto qh1_bf = hn::DemoteTo(dbf_half, qh1); + + hn::Vec A0, A1; + if constexpr (kStepSize > 8) { + auto ql0_64 = hn::BitCast(d64_half, ql0_bf); + auto ql1_64 = hn::BitCast(d64_half, ql1_bf); + // This interleaves within 128-bit block so it's fast. + auto lo_l = hn::InterleaveLower(d64_half, ql0_64, ql1_64); + auto hi_l = hn::InterleaveUpper(d64_half, ql0_64, ql1_64); + A0 = hn::BitCast(dbf, hn::Combine(d64, hi_l, lo_l)); + + auto qh0_64 = hn::BitCast(d64_half, qh0_bf); + auto qh1_64 = hn::BitCast(d64_half, qh1_bf); + auto lo_h = hn::InterleaveLower(d64_half, qh0_64, qh1_64); + auto hi_h = hn::InterleaveUpper(d64_half, qh0_64, qh1_64); + A1 = hn::BitCast(dbf, hn::Combine(d64, hi_h, lo_h)); + } else { + A0 = hn::Combine(dbf, ql1_bf, ql0_bf); + A1 = hn::Combine(dbf, qh1_bf, qh0_bf); + } + + hn::StoreU(A0, dbf, dst + 0); + hn::StoreU(A1, dbf, dst + kStepSize); + }; + + if constexpr (kNumQueries >= 1) + pack_and_store_pair(0, q0_l, q0_h, q1_l, q1_h); + if constexpr (kNumQueries >= 3) + pack_and_store_pair(1, q2_l, q2_h, q3_l, q3_h); + if constexpr (kNumQueries >= 5) + pack_and_store_pair(2, q4_l, q4_h, q5_l, q5_h); + if constexpr (kNumQueries >= 7) + pack_and_store_pair(3, q6_l, q6_h, q7_l, q7_h); + } + + // Pre-compute V tile pointers to avoid row lookups in the inner loop. + const BF16* v_ptrs[8]; + size_t start_tile_idx = + (position - current_kv_start_offset) / KVCache::kTileSize; + + // Optimized tile pointer loading + size_t num_tiles_in_block = + (actual_steps * kStepSize + KVCache::kTileSize - 1) / KVCache::kTileSize; + for (size_t t = 0; t < num_tiles_in_block; ++t) { + const BF16* tile_base = reinterpret_cast( + kvs[current_kv_idx].RowBytes(start_tile_idx + t)); + v_ptrs[t] = tile_base + qkv_dim * 32; + } + + // Step-Dependent Pre-computation + const BF16* step_q_ptrs[BENCHMARK_BLOCK_SIZE / 8]; + const BF16* step_v_tiles[BENCHMARK_BLOCK_SIZE / 8]; + size_t step_offsets_even[BENCHMARK_BLOCK_SIZE / 8]; + size_t step_offsets_odd[BENCHMARK_BLOCK_SIZE / 8]; + + for (size_t step_idx = 0; step_idx < actual_steps; ++step_idx) { + step_q_ptrs[step_idx] = q_packed + step_idx * kNumPackedVectors * kStepSize; + size_t step_pos = position + step_idx * kStepSize; + size_t global_token_pos = step_pos - current_kv_start_offset; + size_t tile_idx = global_token_pos / 32 - start_tile_idx; + step_v_tiles[step_idx] = v_ptrs[tile_idx]; + size_t t = step_pos % 32; + size_t t_odd = t + kStepSize / 2; + step_offsets_even[step_idx] = (t / 16) * 64 + ((t % 16) / 4) * 8; + step_offsets_odd[step_idx] = (t_odd / 16) * 64 + ((t_odd % 16) / 4) * 8; + } + + for (size_t ch_base = 0; ch_base < qkv_dim; ch_base += 8) { + // Initialize accumulators to 0 + VF C00 = hn::Zero(df), C01 = hn::Zero(df), C02 = hn::Zero(df), + C03 = hn::Zero(df); + VF C10 = hn::Zero(df), C11 = hn::Zero(df), C12 = hn::Zero(df), + C13 = hn::Zero(df); + VF C20 = hn::Zero(df), C21 = hn::Zero(df), C22 = hn::Zero(df), + C23 = hn::Zero(df); + VF C30 = hn::Zero(df), C31 = hn::Zero(df), C32 = hn::Zero(df), + C33 = hn::Zero(df); + + size_t ch_g = ch_base / 4; + size_t v_offset = ch_g * 128; + size_t v_next_offset = v_offset + 128; + + for (size_t step_idx = 0; step_idx < actual_steps; ++step_idx) { + // Retrieve pre-computed values + const BF16* q_ptr = step_q_ptrs[step_idx]; + const BF16* v_tile_b_curr = step_v_tiles[step_idx]; + size_t offset_even = step_offsets_even[step_idx]; + size_t offset_odd = step_offsets_odd[step_idx]; + + const BF16* v_ptr = v_tile_b_curr + v_offset; + const BF16* v_ptr_next = v_tile_b_curr + v_next_offset; + + VBF B_even0, B_even1, B_even2, B_even3; + VBF B_odd0, B_odd1, B_odd2, B_odd3; + + B_even0 = hn::LoadU(dbf, v_ptr + offset_even); + B_odd0 = hn::LoadU(dbf, v_ptr + offset_odd); + B_even1 = hn::LoadU(dbf, v_ptr + offset_even + 32); + B_odd1 = hn::LoadU(dbf, v_ptr + offset_odd + 32); + + B_even2 = hn::LoadU(dbf, v_ptr_next + offset_even); + B_odd2 = hn::LoadU(dbf, v_ptr_next + offset_odd); + B_even3 = hn::LoadU(dbf, v_ptr_next + offset_even + 32); + B_odd3 = hn::LoadU(dbf, v_ptr_next + offset_odd + 32); + + // Even halves first (A0, A2, A4, A6) + if constexpr (kNumQueries >= 1) { + const auto A0 = hn::LoadU(dbf, q_ptr + 0); + C00 = PerBlock2x2MatMulMaybeEmulate(df, A0, B_even0, C00); + C01 = PerBlock2x2MatMulMaybeEmulate(df, A0, B_even1, C01); + C02 = PerBlock2x2MatMulMaybeEmulate(df, A0, B_even2, C02); + C03 = PerBlock2x2MatMulMaybeEmulate(df, A0, B_even3, C03); + } + if constexpr (kNumQueries >= 3) { + const auto A2 = hn::LoadU(dbf, q_ptr + 2 * kStepSize); + C10 = PerBlock2x2MatMulMaybeEmulate(df, A2, B_even0, C10); + C11 = PerBlock2x2MatMulMaybeEmulate(df, A2, B_even1, C11); + C12 = PerBlock2x2MatMulMaybeEmulate(df, A2, B_even2, C12); + C13 = PerBlock2x2MatMulMaybeEmulate(df, A2, B_even3, C13); + } + if constexpr (kNumQueries >= 5) { + const auto A4 = hn::LoadU(dbf, q_ptr + 4 * kStepSize); + C20 = PerBlock2x2MatMulMaybeEmulate(df, A4, B_even0, C20); + C21 = PerBlock2x2MatMulMaybeEmulate(df, A4, B_even1, C21); + C22 = PerBlock2x2MatMulMaybeEmulate(df, A4, B_even2, C22); + C23 = PerBlock2x2MatMulMaybeEmulate(df, A4, B_even3, C23); + } + if constexpr (kNumQueries >= 7) { + const auto A6 = hn::LoadU(dbf, q_ptr + 6 * kStepSize); + C30 = PerBlock2x2MatMulMaybeEmulate(df, A6, B_even0, C30); + C31 = PerBlock2x2MatMulMaybeEmulate(df, A6, B_even1, C31); + C32 = PerBlock2x2MatMulMaybeEmulate(df, A6, B_even2, C32); + C33 = PerBlock2x2MatMulMaybeEmulate(df, A6, B_even3, C33); + } + + // Odd halves second (A1, A3, A5, A7) + if constexpr (kNumQueries >= 1) { + const auto A1 = hn::LoadU(dbf, q_ptr + kStepSize); + C00 = PerBlock2x2MatMulMaybeEmulate(df, A1, B_odd0, C00); + C01 = PerBlock2x2MatMulMaybeEmulate(df, A1, B_odd1, C01); + C02 = PerBlock2x2MatMulMaybeEmulate(df, A1, B_odd2, C02); + C03 = PerBlock2x2MatMulMaybeEmulate(df, A1, B_odd3, C03); + } + if constexpr (kNumQueries >= 3) { + const auto A3 = hn::LoadU(dbf, q_ptr + 3 * kStepSize); + C10 = PerBlock2x2MatMulMaybeEmulate(df, A3, B_odd0, C10); + C11 = PerBlock2x2MatMulMaybeEmulate(df, A3, B_odd1, C11); + C12 = PerBlock2x2MatMulMaybeEmulate(df, A3, B_odd2, C12); + C13 = PerBlock2x2MatMulMaybeEmulate(df, A3, B_odd3, C13); + } + if constexpr (kNumQueries >= 5) { + const auto A5 = hn::LoadU(dbf, q_ptr + 5 * kStepSize); + C20 = PerBlock2x2MatMulMaybeEmulate(df, A5, B_odd0, C20); + C21 = PerBlock2x2MatMulMaybeEmulate(df, A5, B_odd1, C21); + C22 = PerBlock2x2MatMulMaybeEmulate(df, A5, B_odd2, C22); + C23 = PerBlock2x2MatMulMaybeEmulate(df, A5, B_odd3, C23); + } + if constexpr (kNumQueries >= 7) { + const auto A7 = hn::LoadU(dbf, q_ptr + 7 * kStepSize); + C30 = PerBlock2x2MatMulMaybeEmulate(df, A7, B_odd0, C30); + C31 = PerBlock2x2MatMulMaybeEmulate(df, A7, B_odd1, C31); + C32 = PerBlock2x2MatMulMaybeEmulate(df, A7, B_odd2, C32); + C33 = PerBlock2x2MatMulMaybeEmulate(df, A7, B_odd3, C33); + } + } + + // Reduce accumulators to 128-bit and add scaled old values + auto reduce_and_store_pair = [&](size_t r, size_t col_g, VF acc0, + VF acc1) HWY_ATTR { + acc0 = SumReduceSegments(df, acc0); + acc1 = SumReduceSegments(df, acc1); + + // Safe extraction of the lowest 128-bits (4 lanes) via aligned stack + // spill/load to ensure cross-compilation safety on x86/AVX targets. + HWY_ALIGN float temp0[hn::MaxLanes(df)]; + HWY_ALIGN float temp1[hn::MaxLanes(df)]; + hn::StoreU(acc0, df, temp0); + hn::StoreU(acc1, df, temp1); + VF4 acc0_red = hn::LoadU(df4, temp0); + VF4 acc1_red = hn::LoadU(df4, temp1); + + // Load and scale old values (128-bit) + float scale0 = scales_old[2 * r]; + float scale1 = (2 * r + 1 < kNumQueries) ? scales_old[2 * r + 1] : 1.0f; + VF4 s0 = hn::Set(df4, scale0); + VF4 s1 = hn::Set(df4, scale1); + + const float* row0 = + C_accumulators + (q_base_idx + 2 * r) * qkv_dim + ch_base + col_g * 4; + VF4 old0 = hn::Mul(hn::LoadU(df4, row0), s0); + + VF4 old1; + if (2 * r + 1 < kNumQueries) { + const float* row1 = C_accumulators + + (q_base_idx + 2 * r + 1) * qkv_dim + ch_base + + col_g * 4; + old1 = hn::Mul(hn::LoadU(df4, row1), s1); + } else { + old1 = hn::Zero(df4); + } + + // Interleave old values to match accumulator layout + VF4 old_acc0 = hn::ConcatLowerLower(df4, old1, old0); + VF4 old_acc1 = hn::ConcatUpperUpper(df4, old1, old0); + + // Add to reduced accumulators + VF4 final_acc0 = hn::Add(acc0_red, old_acc0); + VF4 final_acc1 = hn::Add(acc1_red, old_acc1); + + // Undo interleaving and store back + float* out = + C_accumulators + (q_base_idx + 2 * r) * qkv_dim + ch_base + col_g * 4; + VF4 q0 = ConcatLowerLower_VLA(df4, final_acc1, final_acc0); + hn::StoreU(q0, df4, out); + + VF4 q1; + if (2 * r + 1 < kNumQueries) { + float* out1 = C_accumulators + (q_base_idx + 2 * r + 1) * qkv_dim + + ch_base + col_g * 4; + q1 = ConcatUpperUpper_VLA(df4, final_acc1, final_acc0); + hn::StoreU(q1, df4, out1); + } else { + q1 = hn::Zero(df4); + } + }; + + if constexpr (kNumQueries >= 1) { + reduce_and_store_pair(0, 0, C00, C01); + reduce_and_store_pair(0, 1, C02, C03); + } + if constexpr (kNumQueries >= 3) { + reduce_and_store_pair(1, 0, C10, C11); + reduce_and_store_pair(1, 1, C12, C13); + } + if constexpr (kNumQueries >= 5) { + reduce_and_store_pair(2, 0, C20, C21); + reduce_and_store_pair(2, 1, C22, C23); + } + if constexpr (kNumQueries >= 7) { + reduce_and_store_pair(3, 0, C30, C31); + reduce_and_store_pair(3, 1, C32, C33); + } + } +} + +template > +HWY_INLINE void UpdateOnlineSoftmaxSingleQuery( + DF df, float* HWY_RESTRICT q_logits, size_t actual_block_size, + size_t q, // global query index + float* HWY_RESTRICT max_logits, float* HWY_RESTRICT exp_denominator_sums, + size_t q_offset, // local query index in the group (0..7) + float* HWY_RESTRICT scales_old, float* HWY_RESTRICT q_scales_new) { + float block_max = kMaskedLogitVal; + VF v_max0 = hn::Set(df, block_max); + VF v_max1 = v_max0; + VF v_max2 = v_max0; + VF v_max3 = v_max0; + + size_t t = 0; + const size_t L_f = hn::Lanes(df); + const size_t unroll_step = 4 * L_f; + for (; t + unroll_step <= actual_block_size; t += unroll_step) { + v_max0 = hn::Max(v_max0, hn::LoadU(df, q_logits + t)); + v_max1 = hn::Max(v_max1, hn::LoadU(df, q_logits + t + L_f)); + v_max2 = hn::Max(v_max2, hn::LoadU(df, q_logits + t + 2 * L_f)); + v_max3 = hn::Max(v_max3, hn::LoadU(df, q_logits + t + 3 * L_f)); + } + for (; t + L_f <= actual_block_size; t += L_f) { + v_max0 = hn::Max(v_max0, hn::LoadU(df, q_logits + t)); + } + v_max0 = hn::Max(hn::Max(v_max0, v_max1), hn::Max(v_max2, v_max3)); + if (t < actual_block_size) { + const size_t remaining = actual_block_size - t; + VF v_logits = hn::LoadN(df, q_logits + t, remaining); + auto mask = hn::FirstN(df, remaining); + VF masked_logits = + hn::IfThenElse(mask, v_logits, hn::Set(df, kMaskedLogitVal)); + v_max0 = hn::Max(v_max0, masked_logits); + } + block_max = std::max(block_max, hn::ReduceMax(df, v_max0)); + + float old_m = max_logits[q]; + float old_sum = exp_denominator_sums[q]; + float new_m = std::max(old_m, block_max); + + float block_sum = 0.0f; + VF v_sum0 = hn::Zero(df); + VF v_sum1 = hn::Zero(df); + VF v_sum2 = hn::Zero(df); + VF v_sum3 = hn::Zero(df); + VF v_new_m = hn::Set(df, new_m); + + t = 0; + for (; t + unroll_step <= actual_block_size; t += unroll_step) { + VF v_logits0 = hn::LoadU(df, q_logits + t); + VF v_logits1 = hn::LoadU(df, q_logits + t + L_f); + VF v_logits2 = hn::LoadU(df, q_logits + t + 2 * L_f); + VF v_logits3 = hn::LoadU(df, q_logits + t + 3 * L_f); + + VF v_exp0 = hn::FastExpMinusOrZero(df, hn::Sub(v_logits0, v_new_m)); + VF v_exp1 = hn::FastExpMinusOrZero(df, hn::Sub(v_logits1, v_new_m)); + VF v_exp2 = hn::FastExpMinusOrZero(df, hn::Sub(v_logits2, v_new_m)); + VF v_exp3 = hn::FastExpMinusOrZero(df, hn::Sub(v_logits3, v_new_m)); + + hn::StoreU(v_exp0, df, q_logits + t); + hn::StoreU(v_exp1, df, q_logits + t + L_f); + hn::StoreU(v_exp2, df, q_logits + t + 2 * L_f); + hn::StoreU(v_exp3, df, q_logits + t + 3 * L_f); + + v_sum0 = hn::Add(v_sum0, v_exp0); + v_sum1 = hn::Add(v_sum1, v_exp1); + v_sum2 = hn::Add(v_sum2, v_exp2); + v_sum3 = hn::Add(v_sum3, v_exp3); + } + for (; t + L_f <= actual_block_size; t += L_f) { + VF v_logits = hn::LoadU(df, q_logits + t); + VF v_exp = hn::FastExpMinusOrZero(df, hn::Sub(v_logits, v_new_m)); + hn::StoreU(v_exp, df, q_logits + t); + v_sum0 = hn::Add(v_sum0, v_exp); + } + v_sum0 = hn::Add(hn::Add(v_sum0, v_sum1), hn::Add(v_sum2, v_sum3)); + if (t < actual_block_size) { + const size_t remaining = actual_block_size - t; + auto mask = hn::FirstN(df, remaining); + VF v_logits = hn::LoadN(df, q_logits + t, remaining); + VF v_exp = hn::FastExpMinusOrZero(df, hn::Sub(v_logits, v_new_m)); + hn::StoreN(v_exp, df, q_logits + t, remaining); + v_sum0 = hn::Add(v_sum0, hn::IfThenElseZero(mask, v_exp)); + } + block_sum = hn::ReduceSum(df, v_sum0); + + float exp_diff = 1.0f; + if (old_m != new_m) { + const hn::CappedTag d1; + auto v_diff = hn::Set(d1, old_m - new_m); + auto v_exp = hn::FastExpMinusOrZero(d1, v_diff); + exp_diff = hn::GetLane(v_exp); + } + float new_sum = old_sum * exp_diff + block_sum; + + float scale_old = (new_sum > 0.0f) ? (old_sum * exp_diff) / new_sum : 1.0f; + float q_scale = (new_sum > 0.0f) ? (1.0f / new_sum) : 0.0f; + + scales_old[q_offset] = scale_old; + q_scales_new[q_offset] = q_scale; + max_logits[q] = new_m; + exp_denominator_sums[q] = new_sum; +} + +template +HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( + const hwy::Span> kvs, size_t q_count, + const Q_T* HWY_RESTRICT q_base, hwy::Span q_scales, + hwy::Span start_pos_per_query, + hwy::Span last_pos_per_query, float att_cap, + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, + float* HWY_RESTRICT max_logits) { + using BF16 = hwy::bfloat16_t; + + const size_t qkv_dim = att_out.Cols(); + using DF = hn::ScalableTag; + using DBF = hn::ScalableTag; + using DU = hn::ScalableTag; + const DF df; + const DBF dbf; + const DU du; + using VF = hn::Vec; + const size_t step_size = hn::Lanes(dbf); + if (step_size > 32) { + HWY_ABORT( + "Unsupported step size (vector width) %zu. Only up to 512-bit (32 " + "lanes) is supported.", + step_size); + } + const float one_over_cap = 1.0f / att_cap; + + constexpr int kNumQueriesPerLoop = 8; + constexpr size_t kTileSize = 32; + constexpr size_t kBlockSize = BENCHMARK_BLOCK_SIZE; + const size_t kStepsPerTile = + std::max(size_t(1), KVCache::kTileSize / step_size); + + TileAttentionGroupParams preamble(q_count, kNumQueriesPerLoop, step_size, + start_pos_per_query, last_pos_per_query, + att_cap, att_out); + + const size_t largest_last_pos = preamble.largest_last_pos; + + const auto min_start_pos_per_group = preamble.min_start_pos_per_group; + const auto max_start_pos_per_group = preamble.max_start_pos_per_group; + const auto min_last_pos_per_group = preamble.min_last_pos_per_group; + const auto max_last_pos_per_group = preamble.max_last_pos_per_group; + + hwy::AlignedVector C_accumulators(hwy::RoundUpTo(q_count, 8) * qkv_dim, + 0.0f); + hwy::AlignedVector softmax_buf(q_count * kBlockSize, kMaskedLogitVal); + + size_t current_kv_idx = 0; + size_t current_kv_start_offset = 0; + size_t position = 0; + + while (position <= largest_last_pos) { + std::fill(softmax_buf.begin(), softmax_buf.end(), kMaskedLogitVal); + size_t remaining_tokens = largest_last_pos - position + 1; + + while (position - current_kv_start_offset >= + kvs[current_kv_idx].Rows() * KVCache::kTileSize) { + current_kv_start_offset += + kvs[current_kv_idx].Rows() * KVCache::kTileSize; + current_kv_idx++; + } + size_t kv_rows = kvs[current_kv_idx].Rows() * KVCache::kTileSize; + size_t kv_remaining = kv_rows - (position - current_kv_start_offset); + + size_t actual_block_size = + std::min(kBlockSize, hwy::RoundUpTo(remaining_tokens, step_size)); + actual_block_size = std::min(actual_block_size, kv_remaining); + + size_t actual_steps = actual_block_size / step_size; + [[maybe_unused]] size_t actual_M = + hwy::DivCeil(actual_block_size, kTileSize); + + size_t macro_tile_start_pos = position; + + auto inner_loop_qk = [&](size_t query_idx, + size_t step_idx) HWY_ATTR { + size_t loop_idx = query_idx / kNumQueriesPerLoop; + size_t step_pos = position + step_idx * step_size; + if (step_pos + step_size <= min_start_pos_per_group[loop_idx] || + step_pos > max_last_pos_per_group[loop_idx]) { + float* softmax_buf_ptr = + softmax_buf.data() + query_idx * kBlockSize + step_idx * step_size; + for (int q = 0; q < kNumQueries; ++q) { + for (int t = 0; t < step_size; ++t) { + softmax_buf_ptr[q * kBlockSize + t] = kMaskedLogitVal; + } + } + return; + } + + size_t tile = step_idx / kStepsPerTile; + size_t s_idx = + (position - current_kv_start_offset) / KVCache::kTileSize + tile; + + const size_t current_pos = macro_tile_start_pos + step_idx * step_size; + const BF16* tile_base = + reinterpret_cast(kvs[current_kv_idx].RowBytes(s_idx)); + + const BF16* q_group = q_base + query_idx * qkv_dim; + + VF x_0_p_0 = hn::Zero(df), x_0_p_1 = hn::Zero(df), x_1_p_0 = hn::Zero(df), + x_1_p_1 = hn::Zero(df); + VF x_2_p_0 = hn::Zero(df), x_2_p_1 = hn::Zero(df), x_3_p_0 = hn::Zero(df), + x_3_p_1 = hn::Zero(df); + VF x_4_p_0 = hn::Zero(df), x_4_p_1 = hn::Zero(df), x_5_p_0 = hn::Zero(df), + x_5_p_1 = hn::Zero(df); + VF x_6_p_0 = hn::Zero(df), x_6_p_1 = hn::Zero(df), x_7_p_0 = hn::Zero(df), + x_7_p_1 = hn::Zero(df); + + VF C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, C31, + C32, C33; + QDotKTilexUpTo8MatrixAccumulation( + df, q_group, tile_base, current_pos, qkv_dim, C00, C01, C02, C03, C10, + C11, C12, C13, C20, C21, C22, C23, C30, C31, C32, C33); + auto pack_queries = [&](VF c_left, VF c_right, VF& x_even, + VF& x_odd) HWY_ATTR { + using D64 = hn::Repartition; + const D64 d64; + + auto c_left_64 = hn::BitCast(d64, c_left); + auto c_right_64 = hn::BitCast(d64, c_right); + + auto even_scrambled = hn::ConcatEven(d64, c_right_64, c_left_64); + auto odd_scrambled = hn::ConcatOdd(d64, c_right_64, c_left_64); + + x_even = hn::BitCast(df, even_scrambled); + x_odd = hn::BitCast(df, odd_scrambled); + }; + if constexpr (kNumQueries >= 1) pack_queries(C00, C01, x_0_p_0, x_1_p_0); + if constexpr (kNumQueries >= 1) pack_queries(C02, C03, x_0_p_1, x_1_p_1); + if constexpr (kNumQueries >= 3) pack_queries(C10, C11, x_2_p_0, x_3_p_0); + if constexpr (kNumQueries >= 3) pack_queries(C12, C13, x_2_p_1, x_3_p_1); + if constexpr (kNumQueries >= 5) pack_queries(C20, C21, x_4_p_0, x_5_p_0); + if constexpr (kNumQueries >= 5) pack_queries(C22, C23, x_4_p_1, x_5_p_1); + if constexpr (kNumQueries >= 7) pack_queries(C30, C31, x_6_p_0, x_7_p_0); + if constexpr (kNumQueries >= 7) pack_queries(C32, C33, x_6_p_1, x_7_p_1); + + constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); + constexpr int kSecondHalfAmountOfQueries = + kNumQueries - kFirstHalfAmountOfQueries; + ApplySoftCap( + df, att_cap, one_over_cap, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, + x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1); + if constexpr (kNumQueries > 4) { + ApplySoftCap( + df, att_cap, one_over_cap, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, + x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + } + + if (current_pos < max_start_pos_per_group[loop_idx] || + current_pos + step_size - 1 > min_last_pos_per_group[loop_idx]) { + ApplyMasking( + df, du, current_pos, start_pos_per_query.data() + query_idx, + last_pos_per_query.data() + query_idx, x_0_p_0, x_0_p_1, x_1_p_0, + x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, + x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + } + + float* softmax_buf_ptr = + softmax_buf.data() + query_idx * kBlockSize + step_idx * step_size; + auto store_logits = [&](const VF& x_p0, const VF& x_p1, + size_t q) HWY_ATTR { + hn::StoreU(x_p0, df, softmax_buf_ptr + q * kBlockSize + 0); + hn::StoreU(x_p1, df, softmax_buf_ptr + q * kBlockSize + hn::Lanes(df)); + }; + + if constexpr (kNumQueries >= 1) store_logits(x_0_p_0, x_0_p_1, 0); + if constexpr (kNumQueries >= 2) store_logits(x_1_p_0, x_1_p_1, 1); + if constexpr (kNumQueries >= 3) store_logits(x_2_p_0, x_2_p_1, 2); + if constexpr (kNumQueries >= 4) store_logits(x_3_p_0, x_3_p_1, 3); + if constexpr (kNumQueries >= 5) store_logits(x_4_p_0, x_4_p_1, 4); + if constexpr (kNumQueries >= 6) store_logits(x_5_p_0, x_5_p_1, 5); + if constexpr (kNumQueries >= 7) store_logits(x_6_p_0, x_6_p_1, 6); + if constexpr (kNumQueries >= 8) store_logits(x_7_p_0, x_7_p_1, 7); + }; + + for (size_t step_idx = 0; step_idx < actual_steps; ++step_idx) { + size_t query_idx = 0; + for (; query_idx + kNumQueriesPerLoop <= q_count; + query_idx += kNumQueriesPerLoop) { + inner_loop_qk.template operator()(query_idx, + step_idx); + } + if (query_idx < q_count) { + size_t rem = q_count - query_idx; + if (rem >= 8) { + inner_loop_qk.template operator()<8>(query_idx, step_idx); + query_idx += 8; + rem -= 8; + } + if (rem >= 4) { + inner_loop_qk.template operator()<4>(query_idx, step_idx); + query_idx += 4; + rem -= 4; + } + switch (rem) { + case 1: + inner_loop_qk.template operator()<1>(query_idx, step_idx); + break; + case 2: + inner_loop_qk.template operator()<2>(query_idx, step_idx); + break; + case 3: + inner_loop_qk.template operator()<3>(query_idx, step_idx); + break; + } + } + } + + for (size_t q_base_idx = 0; q_base_idx < q_count; q_base_idx += 8) { + size_t actual_q_count = std::min(size_t(8), q_count - q_base_idx); + + HWY_ALIGN float scales_old[8]; + HWY_ALIGN float q_scales_new[8]; + for (int i = 0; i < 8; ++i) { + scales_old[i] = 1.0f; + q_scales_new[i] = 0.0f; + } + + for (size_t q_offset = 0; q_offset < actual_q_count; ++q_offset) { + size_t q = q_base_idx + q_offset; + if (position + actual_block_size <= start_pos_per_query[q] || + position > last_pos_per_query[q]) { + // Skip update for completely masked query in this block. + // scales_old[q_offset] remains 1.0f, q_scales_new[q_offset] remains + // 0.0f. + continue; + } + float* q_logits = softmax_buf.data() + q * kBlockSize; + UpdateOnlineSoftmaxSingleQuery( + df, q_logits, actual_block_size, q, max_logits, + exp_denominator_sums, q_offset, scales_old, q_scales_new); + } + + auto call_sv_block = [&]() HWY_ATTR { + HWY_LANES_CONSTEXPR size_t step_size = hn::Lanes(dbf); + if constexpr (HWY_HAVE_CONSTEXPR_LANES) { + TileFlashAttentionSVBlock( + q_base_idx, position, current_kv_start_offset, current_kv_idx, + actual_steps, qkv_dim, scales_old, q_scales_new, + softmax_buf.data(), kvs, C_accumulators.data()); + } else { + if (step_size == 32) { + TileFlashAttentionSVBlock( + q_base_idx, position, current_kv_start_offset, current_kv_idx, + actual_steps, qkv_dim, scales_old, q_scales_new, + softmax_buf.data(), kvs, C_accumulators.data()); + } else if (step_size == 16) { + TileFlashAttentionSVBlock( + q_base_idx, position, current_kv_start_offset, current_kv_idx, + actual_steps, qkv_dim, scales_old, q_scales_new, + softmax_buf.data(), kvs, C_accumulators.data()); + } else { // step_size == 8 (guaranteed by top validation) + TileFlashAttentionSVBlock( + q_base_idx, position, current_kv_start_offset, current_kv_idx, + actual_steps, qkv_dim, scales_old, q_scales_new, + softmax_buf.data(), kvs, C_accumulators.data()); + } + } + }; + + if (actual_q_count >= 8) { + call_sv_block.template operator()<8>(); + } else if (actual_q_count >= 7) { + call_sv_block.template operator()<7>(); + } else if (actual_q_count >= 6) { + call_sv_block.template operator()<6>(); + } else if (actual_q_count >= 5) { + call_sv_block.template operator()<5>(); + } else if (actual_q_count >= 4) { + call_sv_block.template operator()<4>(); + } else if (actual_q_count >= 3) { + call_sv_block.template operator()<3>(); + } else if (actual_q_count >= 2) { + call_sv_block.template operator()<2>(); + } else if (actual_q_count >= 1) { + call_sv_block.template operator()<1>(); + } + } + + position += actual_block_size; + } + + for (size_t qi = 0; qi < q_count; ++qi) { + float* out = att_out.Row(qi); + const float* accum = C_accumulators.data() + qi * qkv_dim; + for (size_t d = 0; d < qkv_dim; d += hn::Lanes(df)) { + VF v = hn::LoadU(df, accum + d); + hn::StoreU(v, df, out + d); + } + } +} + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_ARM_TOGGLE diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 30534ec8..75092b63 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -13,9 +13,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS +#endif // HWY_DISABLED_TARGETS + +#include +#include + +#include // std::max +#include // std::abs #include #include #include +#include #include #include @@ -28,23 +38,12 @@ #include "gemma/weights.h" #include "ops/matmul.h" #include "ops/ops.h" -#include "util/test_util.h" -#include "hwy/nanobenchmark.h" -#ifndef HWY_DISABLED_TARGETS -#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS -#endif // HWY_DISABLED_TARGETS - -#include -#include - -#include // std::max -#include // std::abs -#include - #include "util/mat.h" +#include "util/test_util.h" #include "util/threading_context.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" +#include "hwy/nanobenchmark.h" // clang-format off #undef HWY_TARGET_INCLUDE @@ -248,6 +247,7 @@ std::unique_ptr> MakeCopyOfMat(const MatPtrT& mat, } void AssertClose(const MatPtrT& a, const MatPtrT& b) { + size_t failures = 0; // Avoid comparing the padding bytes, which are uninitialized. for (size_t r = 0; r < a.Rows(); ++r) { const float* HWY_RESTRICT a_row = a.Row(r); @@ -257,11 +257,20 @@ void AssertClose(const MatPtrT& a, const MatPtrT& b) { if (rel_abs_delta > 0.0f) { rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c])); } - EXPECT_LT(rel_abs_delta, 1e-3) - << "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << "," - << c << "]=" << b_row[c]; + if (rel_abs_delta >= 1e-3) { + if (failures < 5) { + EXPECT_LT(rel_abs_delta, 1e-3) + << "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << "," + << c << "]=" << b_row[c]; + } + failures++; + } } } + if (failures > 5) { + ADD_FAILURE() << "Truncated " << (failures - 5) + << " additional failures in AssertClose."; + } } template @@ -486,14 +495,14 @@ const std::vector att_out_gold = { template void RunTiledFlashAttentionTest(gcpp::KVEncoding kv_encoding, - AttentionImpl attention_impl, - float tol, float tol_exp, float tol_max) { + AttentionImpl attention_impl, float tol, + float tol_exp, float tol_max, + float att_cap = 10.0f) { size_t qkv_dim = 64; size_t kv_seq_len = 60; // number of tokens we will attend to. // Not divisible by tiles size to test the padding logic. size_t padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); - float att_cap = 10.0f; size_t num_queries = 8; size_t num_queries_per_timestep = 4; size_t num_tokens = num_queries / num_queries_per_timestep; @@ -517,9 +526,9 @@ void RunTiledFlashAttentionTest(gcpp::KVEncoding kv_encoding, using DF = hn::ScalableTag; const DF df; HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df); - size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes); - std::vector exp_denominator_sums(num_queries_rounded_to_laness); - std::vector max_logits(num_queries_rounded_to_laness); + size_t num_queries_rounded_to_lanes = hwy::RoundUpTo(num_queries, lanes); + std::vector exp_denominator_sums(num_queries_rounded_to_lanes); + std::vector max_logits(num_queries_rounded_to_lanes); for (size_t i = 0; i < num_queries; ++i) { hwy::ZeroBytes(att_out.Row(i), qkv_dim * sizeof(decltype(att_out.Row(i)[0]))); @@ -565,6 +574,17 @@ void RunTiledFlashAttentionTest(gcpp::KVEncoding kv_encoding, hwy::Span(start_pos_per_query), hwy::Span(last_pos_per_query), att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); + } else if (attention_impl == AttentionImpl::kFlashMatrixAccumulation) { + size_t num_queries_rounded = hwy::RoundUpTo(num_queries, 2); + hwy::AlignedVector bf16_queries(num_queries_rounded * qkv_dim); + CompressAndTransposeQueriesMatrixAccumulation( + q_all.data(), bf16_queries.data(), num_queries, qkv_dim); + + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation( + kvs, num_queries, bf16_queries.data(), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); } else { DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( kvs, num_queries, q_all.data(), @@ -573,27 +593,58 @@ void RunTiledFlashAttentionTest(gcpp::KVEncoding kv_encoding, exp_denominator_sums.data(), max_logits.data()); } - PrintMatPtr(att_out); + // PrintMatPtr(att_out); + size_t failures = 0; for (size_t i = 0; i < num_queries; ++i) { std::cerr << "exp_d: " << exp_denominator_sums[i] << " max_logit: " << max_logits[i] << std::endl; - EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], tol_exp) - << "i=" << i; - EXPECT_NEAR(max_logits[i], max_logits_gold[i], tol_max) << "i=" << i; + + float diff_exp = + std::abs(exp_denominator_sums[i] - exp_denominator_sums_gold[i]); + if (diff_exp >= tol_exp) { + if (failures < 5) { + EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], + tol_exp) + << "i=" << i; + } + failures++; + } + + float diff_max = std::abs(max_logits[i] - max_logits_gold[i]); + if (diff_max >= tol_max) { + if (failures < 5) { + EXPECT_NEAR(max_logits[i], max_logits_gold[i], tol_max) << "i=" << i; + } + failures++; + } + for (size_t j = 0; j < qkv_dim; ++j) { if (j == 0 && attention_impl == AttentionImpl::kFlashTransposedQsBF16 && kv_encoding == gcpp::KVEncoding::kBF16TwoTranspositions) { std::cerr << "att_out[0][" << j << "]=" << att_out.Row(i)[j] << " gold=" << att_out_gold[i * qkv_dim + j] << "\n"; } - EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], tol); + + float diff_out = + std::abs(att_out.Row(i)[j] - att_out_gold[i * qkv_dim + j]); + if (diff_out >= tol) { + if (failures < 5) { + EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], tol) + << "i=" << i << " j=" << j; + } + failures++; + } } } + if (failures > 5) { + ADD_FAILURE() << "Truncated " << (failures - 5) + << " additional failures in RunTiledFlashAttentionTest."; + } } void TestTiledFlashAttention() { RunTiledFlashAttentionTest(gcpp::KVEncoding::kF32, - AttentionImpl::kFlash, 1e-5f, 1e-3f, 1e-6f); + AttentionImpl::kFlash, 1e-5f, 1e-3f, 5e-6f); } void TestTiledFlashAttentionBF16() { @@ -620,6 +671,189 @@ void TestTiledFlashAttentionInt8Int16() { 5e-3f, 2e-2f, 1e-3f); } +void TestTiledFlashAttentionBF16MatrixAccumulation() { + const hn::ScalableTag dbf; + if (hn::Lanes(dbf) > 32) { + GTEST_SKIP() << "Skipping MatrixAccumulation test for target with register " + "size > 512-bit."; + return; + } + + const float tol = 2.0e-3f; + const float tol_exp = 4e-2f; + const float tol_max = 1e-3f; + + RunTiledFlashAttentionTest(gcpp::KVEncoding::kBF16MatrixAccumulation, + AttentionImpl::kFlashMatrixAccumulation, tol, + tol_exp, tol_max, 0.0f); +} + +void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { + const hn::ScalableTag dbf; + if (hn::Lanes(dbf) > 32) { + GTEST_SKIP() << "Skipping MatrixAccumulation test for target with register " + "size > 512-bit."; + return; + } + + size_t qkv_dim = 64; + size_t kv_seq_len = 2048; // number of tokens we will attend to. + size_t padded_kv_seq_len = + hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); + float att_cap = 0.0f; + size_t num_queries = 37; + size_t num_queries_per_timestep = 1; + size_t num_tokens = num_queries / num_queries_per_timestep; + size_t kv_seq_end = + kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + + // Set up reference BF16 + MatStorageT kv_ref( + "kv_ref", + Extents2D(padded_kv_seq_len / gcpp::KVCache::kTileSize, + 2 * qkv_dim * gcpp::KVCache::kTileSize), + ctx.allocator, MatPadding::kPacked); + PopulateTestKVCache(kv_ref, gcpp::KVEncoding::kBF16TwoTranspositions, + qkv_dim); + + AlignedFloatVector q_all_ref = PopulateTestQueries(num_queries, qkv_dim); + std::vector> bf16_queries_ref(num_queries * + qkv_dim); + CompressQueriesBF16Contiguous(q_all_ref.data(), qkv_dim, num_queries, + bf16_queries_ref.data()); + + MatStorageT att_out_ref("att_out_ref", Extents2D(num_queries, qkv_dim), + ctx.allocator, MatPadding::kPacked); + + HWY_LANES_CONSTEXPR size_t lanes = 4; + size_t num_queries_rounded_to_lanes = hwy::RoundUpTo(num_queries, lanes); + std::vector exp_denominator_sums_ref(num_queries_rounded_to_lanes); + std::vector max_logits_ref(num_queries_rounded_to_lanes); + for (size_t i = 0; i < num_queries; ++i) { + hwy::ZeroBytes(att_out_ref.Row(i), + qkv_dim * sizeof(decltype(att_out_ref.Row(i)[0]))); + exp_denominator_sums_ref[i] = 0.0f; + max_logits_ref[i] = -std::numeric_limits::max() / 2.0f; + } + + // Set up Matrix Accumulation + size_t tile_size_bytes = *gcpp::GetTileSizeBytes( + gcpp::KVEncoding::kBF16MatrixAccumulation, qkv_dim); + size_t tile_size_in_elements = tile_size_bytes / sizeof(BF16); + MatStorageT kv("kv", + Extents2D(padded_kv_seq_len / gcpp::KVCache::kTileSize, + tile_size_in_elements), + ctx.allocator, MatPadding::kPacked); + PopulateTestKVCache(kv, gcpp::KVEncoding::kBF16MatrixAccumulation, qkv_dim); + + AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); + size_t num_queries_rounded = hwy::RoundUpTo(num_queries, 2); + std::vector> bf16_queries( + num_queries_rounded * qkv_dim); + CompressAndTransposeQueriesMatrixAccumulation( + q_all.data(), bf16_queries.data(), num_queries, qkv_dim); + + MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), + ctx.allocator, MatPadding::kPacked); + + std::vector exp_denominator_sums(num_queries_rounded_to_lanes); + std::vector max_logits(num_queries_rounded_to_lanes); + for (size_t i = 0; i < num_queries; ++i) { + hwy::ZeroBytes(att_out.Row(i), + qkv_dim * sizeof(decltype(att_out.Row(i)[0]))); + exp_denominator_sums[i] = 0.0f; + max_logits[i] = -std::numeric_limits::max() / 2.0f; + } + + std::vector> start_pos_per_query; + std::vector> last_pos_per_query; + start_pos_per_query.reserve(num_queries); + last_pos_per_query.reserve(num_queries); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + ssize_t query_last_pos = kv_seq_end + token_idx; + ssize_t query_start_pos = + std::max(query_last_pos - 100000 + 1, static_cast(0)); + for (size_t q_head_idx = 0; q_head_idx < num_queries_per_timestep; + ++q_head_idx) { + start_pos_per_query.push_back(query_start_pos); + last_pos_per_query.push_back(query_last_pos); + } + } + + hwy::Span kvs_ref(&kv_ref, 1); + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( + kvs_ref, num_queries, bf16_queries_ref.data(), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out_ref, + exp_denominator_sums_ref.data(), max_logits_ref.data()); + + hwy::Span kvs(&kv, 1); + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation( + kvs, num_queries, bf16_queries.data(), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); + + float max_abs_err = 0.0f; + float mse_sum = 0.0f; + float dot_prod = 0.0f; + float norm_ref = 0.0f; + float norm_out = 0.0f; + + const float tol_exp = 1e-1f; + const float tol_max = 1e-4f; + + size_t failures = 0; + for (size_t i = 0; i < num_queries; ++i) { + float diff_exp = + std::abs(exp_denominator_sums[i] - exp_denominator_sums_ref[i]); + if (diff_exp >= tol_exp) { + if (failures < 5) { + EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_ref[i], + tol_exp) + << "i=" << i; + } + failures++; + } + + float diff_max = std::abs(max_logits[i] - max_logits_ref[i]); + if (diff_max >= tol_max) { + if (failures < 5) { + EXPECT_NEAR(max_logits[i], max_logits_ref[i], tol_max) << "i=" << i; + } + failures++; + } + + for (size_t j = 0; j < qkv_dim; ++j) { + float v_ref = att_out_ref.Row(i)[j]; + float v_out = att_out.Row(i)[j]; + float diff = std::abs(v_ref - v_out); + max_abs_err = std::max(max_abs_err, diff); + mse_sum += diff * diff; + dot_prod += v_ref * v_out; + norm_ref += v_ref * v_ref; + norm_out += v_out * v_out; + if (diff >= 1e-1f) { + if (failures < 5) { + EXPECT_NEAR(v_out, v_ref, 1e-1f) << "i=" << i << " j=" << j; + } + failures++; + } + } + } + float cosine_sim = dot_prod / (std::sqrt(norm_ref) * std::sqrt(norm_out)); + float mse = mse_sum / (num_queries * qkv_dim); + std::cerr << "=== Numerical Verification Results (Q:32, KV:2048) ===\n" + << " Cosine Similarity: " << cosine_sim << "\n" + << " Max Absolute Error: " << max_abs_err << "\n" + << " Mean Squared Error: " << mse << "\n"; + if (HWY_NATIVE_PER_BLOCK_2X2_MATMUL_BF16) { + std::cerr << " Using native PerBlock2x2MatMul\n"; + } +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp @@ -631,10 +865,18 @@ namespace gcpp { HWY_BEFORE_TEST(FlashAttentionTest); HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention); HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttention); + HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionBF16); +HWY_EXPORT_AND_TEST_P(FlashAttentionTest, + TestTiledFlashAttentionBF16MatrixAccumulation); + +HWY_EXPORT_AND_TEST_P( + FlashAttentionTest, + TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification); HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8); HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8BF16); HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8Int16); + HWY_AFTER_TEST(); } // namespace gcpp diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 3d58421f..e0f9bdc8 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -79,14 +79,19 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, // clang-format off if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs || runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsInt16 || - runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16 + runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16 || + runtime_config.attention_impl == AttentionImpl::kFlashMatrixAccumulation ) { // clang-format on const size_t num_tiles = hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize); tiled_seq_len = num_tiles * kTileSize; Type kv_cache_type; - if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16 + if (runtime_config.attention_impl == + AttentionImpl::kFlashMatrixAccumulation) { + kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kBF16); + } else if (runtime_config.attention_impl == + AttentionImpl::kFlashTransposedQsBF16 ) { kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kBF16); } else if (runtime_config.attention_impl == @@ -124,6 +129,10 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, } Extents2D extents(total_num_tiles, tile_length); compact_kv_cache_ptr = MatPtr("kv_tiled", kv_cache_type, extents); + if (runtime_config.attention_impl == + AttentionImpl::kFlashMatrixAccumulation) { + compact_kv_cache_ptr.SetLayout(MatPtr::Layout::kBF16MatrixAccumulation); + } compact_kv_cache.AllocateFor(compact_kv_cache_ptr, allocator, MatPadding::kPacked); total_num_tiles = 0; @@ -138,6 +147,10 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, Extents2D(num_tiles_per_kv_head, tile_length)); kv_ptr.SetPtr(compact_kv_cache_ptr.RowBytes(total_num_tiles), compact_kv_cache_ptr.Stride()); + if (runtime_config.attention_impl == + AttentionImpl::kFlashMatrixAccumulation) { + kv_ptr.SetLayout(MatPtr::Layout::kBF16MatrixAccumulation); + } kv_head_ptrs.emplace_back(std::move(kv_ptr)); total_num_tiles += num_tiles_per_kv_head; } diff --git a/gemma/kv_transcoding.cc b/gemma/kv_transcoding.cc index ca06f0df..da635e5d 100644 --- a/gemma/kv_transcoding.cc +++ b/gemma/kv_transcoding.cc @@ -1,10 +1,12 @@ #include "gemma/kv_transcoding.h" +#include + #include +#include #include #include #include -#include #include "compression/types.h" #include "gemma/activations.h" @@ -26,6 +28,7 @@ std::optional GetTileSizeBytes(gcpp::KVEncoding encoding, kTileSize * 2 * sizeof(gcpp::KV_microscale_t); case gcpp::KVEncoding::kBF16: case gcpp::KVEncoding::kBF16TwoTranspositions: + case gcpp::KVEncoding::kBF16MatrixAccumulation: return qkv_dim * kTileSize * 2 * sizeof(gcpp::BF16); case gcpp::KVEncoding::kF32: case gcpp::KVEncoding::kF32TwoTranspositions: @@ -47,7 +50,6 @@ inline size_t KOffset(bool transposed, size_t qkv_dim, size_t dim, inline size_t VOffset(bool transposed, size_t qkv_dim, size_t dim, size_t token) { - HWY_DASSERT(dim < qkv_dim && token < kTileSize); return transposed ? ((token / 2) * qkv_dim * 2 + dim * 2 + (token % 2)) : (token * qkv_dim + dim); } @@ -121,6 +123,52 @@ void EncodeTileBF16(bool transposed, size_t qkv_dim, const DecodedTile& decoded, }); } +void EncodeTileBF16MatrixAccumulation(size_t qkv_dim, + const DecodedTile& decoded, + hwy::Span out_encoded_tile_data) { + gcpp::BF16* data = + HWY_RCAST_ALIGNED(gcpp::BF16*, out_encoded_tile_data.data()); + const size_t tile_size = decoded.tile_size; + const size_t v_start = qkv_dim * tile_size; + const size_t num_groups = tile_size / 8; + const size_t num_ch_groups = qkv_dim / 4; + + for (size_t ch_g = 0; ch_g < num_ch_groups; ++ch_g) { + for (size_t g = 0; g < num_groups; ++g) { + size_t base_offset = ch_g * 128 + g * 32; + // Pack K (8x4 block, token-major) + for (size_t t_in_g = 0; t_in_g < 8; ++t_in_g) { + size_t token = g * 8 + t_in_g; + for (size_t ch_in_g = 0; ch_in_g < 4; ++ch_in_g) { + size_t dim = ch_g * 4 + ch_in_g; + float val = decoded.k_elem(token, dim); + data[base_offset + t_in_g * 4 + ch_in_g] = + hwy::ConvertScalarTo(val); + } + } + } + } + + // Pack V (Contiguous SV-Blocked Layout) + for (size_t t = 0; t < tile_size; ++t) { + for (size_t c = 0; c < qkv_dim; ++c) { + float val = decoded.v_elem(t, c); + size_t g_t = t / 16; + size_t g_c = (c % 4) / 2; + size_t sub_block = (c / 4) * 4 + g_t * 2 + g_c; + + size_t t_prime = t % 16; + size_t c_prime = c % 2; + size_t g_t4 = t_prime / 4; + size_t t_double_prime = t_prime % 4; + size_t block_offset = g_t4 * 8 + c_prime * 4 + t_double_prime; + + size_t v_offset = sub_block * 32 + block_offset; + data[v_start + v_offset] = hwy::ConvertScalarTo(val); + } + } +} + void EncodeTileInt8(bool transposed, size_t qkv_dim, const DecodedTile& decoded, hwy::Span out_encoded_tile_data) { int8_t* k_data = HWY_RCAST_ALIGNED(int8_t*, out_encoded_tile_data.data()); @@ -200,6 +248,50 @@ void DecodeTileBF16(bool transposed, size_t qkv_dim, }); } +void DecodeTileBF16MatrixAccumulation(size_t qkv_dim, + hwy::Span encoded_tile_data, + DecodedTile* out) { + const gcpp::BF16* data = + HWY_RCAST_ALIGNED(const gcpp::BF16*, encoded_tile_data.data()); + const size_t tile_size = out->tile_size; + const size_t v_start = qkv_dim * tile_size; + const size_t num_groups = tile_size / 8; + const size_t num_ch_groups = qkv_dim / 4; + + for (size_t ch_g = 0; ch_g < num_ch_groups; ++ch_g) { + for (size_t g = 0; g < num_groups; ++g) { + size_t base_offset = ch_g * 128 + g * 32; + // Unpack K (8x4 block, token-major) + for (size_t t_in_g = 0; t_in_g < 8; ++t_in_g) { + size_t token = g * 8 + t_in_g; + for (size_t ch_in_g = 0; ch_in_g < 4; ++ch_in_g) { + size_t dim = ch_g * 4 + ch_in_g; + out->k_elem(token, dim) = hwy::ConvertScalarTo( + data[base_offset + t_in_g * 4 + ch_in_g]); + } + } + } + } + + // Unpack V (Contiguous SV-Blocked Layout) + for (size_t t = 0; t < tile_size; ++t) { + for (size_t c = 0; c < qkv_dim; ++c) { + size_t g_t = t / 16; + size_t g_c = (c % 4) / 2; + size_t sub_block = (c / 4) * 4 + g_t * 2 + g_c; + + size_t t_prime = t % 16; + size_t c_prime = c % 2; + size_t g_t4 = t_prime / 4; + size_t t_double_prime = t_prime % 4; + size_t block_offset = g_t4 * 8 + c_prime * 4 + t_double_prime; + + size_t v_offset = sub_block * 32 + block_offset; + out->v_elem(t, c) = hwy::ConvertScalarTo(data[v_start + v_offset]); + } + } +} + void DecodeTileInt8(bool transposed, size_t qkv_dim, hwy::Span encoded_tile_data, DecodedTile* out) { const int8_t* k_data = @@ -262,6 +354,10 @@ bool DecodeTile(KVEncoding encoding, hwy::Span encoded_tile_data, DecodeTileBF16(transposed, qkv_dim, encoded_tile_data, out); return true; } + case gcpp::KVEncoding::kBF16MatrixAccumulation: { + DecodeTileBF16MatrixAccumulation(qkv_dim, encoded_tile_data, out); + return true; + } case gcpp::KVEncoding::kInt8: case gcpp::KVEncoding::kInt8TwoTranspositions: { DecodeTileInt8(transposed, qkv_dim, encoded_tile_data, out); @@ -292,6 +388,10 @@ bool EncodeTile(gcpp::KVEncoding encoding, const DecodedTile& decoded, EncodeTileBF16(transposed, qkv_dim, decoded, out_encoded_tile_data); return true; } + case gcpp::KVEncoding::kBF16MatrixAccumulation: { + EncodeTileBF16MatrixAccumulation(qkv_dim, decoded, out_encoded_tile_data); + return true; + } case gcpp::KVEncoding::kInt8: case gcpp::KVEncoding::kInt8TwoTranspositions: { EncodeTileInt8(transposed, qkv_dim, decoded, out_encoded_tile_data); diff --git a/gemma/kv_transcoding.h b/gemma/kv_transcoding.h index 67610d47..7abe4065 100644 --- a/gemma/kv_transcoding.h +++ b/gemma/kv_transcoding.h @@ -19,8 +19,8 @@ std::optional GetTileSizeBytes(gcpp::KVEncoding encoding, // Layout: K is [tile_size, qkv_dim] contiguous, V is [tile_size, qkv_dim] // contiguous. struct DecodedTile { - std::vector> k; - std::vector> v; + hwy::AlignedVector k; + hwy::AlignedVector v; size_t qkv_dim = 0; size_t tile_size = 0; @@ -65,6 +65,31 @@ bool TranscodeTile(gcpp::KVEncoding src_encoding, gcpp::KVEncoding dst_encoding, hwy::Span dst_data, size_t qkv_dim); +inline size_t KMatrixAccumulationOffset_BF16(size_t qkv_dim, size_t dim, + size_t token) { + size_t g = token / 8; + size_t t_in_g = token % 8; + size_t ch_g = dim / 4; + size_t ch_in_g = dim % 4; + + return ch_g * 128 + g * 32 + t_in_g * 4 + ch_in_g; +} + +inline size_t VMatrixAccumulationOffset_BF16(size_t qkv_dim, size_t token, + size_t dim) { + size_t g_t = token / 16; + size_t g_c = (dim % 4) / 2; + size_t sub_block = (dim / 4) * 4 + g_t * 2 + g_c; + + size_t t_prime = token % 16; + size_t c_prime = dim % 2; + size_t g_t4 = t_prime / 4; + size_t t_double_prime = t_prime % 4; + size_t block_offset = g_t4 * 8 + c_prime * 4 + t_double_prime; + + return sub_block * 32 + block_offset; +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_TRANSCODING_H_ diff --git a/gemma/kv_transcoding_test.cc b/gemma/kv_transcoding_test.cc index 5f8bb556..72f114e8 100644 --- a/gemma/kv_transcoding_test.cc +++ b/gemma/kv_transcoding_test.cc @@ -9,6 +9,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "gemma/configs.h" +#include "util/basics.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" // For hwy::Span @@ -90,6 +91,7 @@ INSTANTIATE_TEST_SUITE_P( EncodingTestCase{gcpp::KVEncoding::kF32TwoTranspositions, 1e-6f}, EncodingTestCase{gcpp::KVEncoding::kBF16, 0.05f}, EncodingTestCase{gcpp::KVEncoding::kBF16TwoTranspositions, 0.05f}, + EncodingTestCase{gcpp::KVEncoding::kBF16MatrixAccumulation, 0.05f}, EncodingTestCase{gcpp::KVEncoding::kInt8, 0.1f}, EncodingTestCase{gcpp::KVEncoding::kInt8TwoTranspositions, 0.1f})); @@ -356,5 +358,75 @@ TEST(KVEncodingTest, LayoutValidationInt8TwoTranspositions) { EXPECT_EQ(data[v_start + 7], 127); } +TEST(KVEncodingTest, LayoutValidationBF16MatrixAccumulation) { + constexpr size_t kTileSize = 32; + constexpr size_t qkv_dim = 4; + gcpp::KVEncoding encoding = gcpp::KVEncoding::kBF16MatrixAccumulation; + + DecodedTile original(qkv_dim, kTileSize); + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + original.k_elem(token, dim) = dim * kTileSize + token + 1; + original.v_elem(token, dim) = + token * qkv_dim + dim + 1 + qkv_dim * kTileSize; + } + } + + size_t size = GetTileSizeBytes(encoding, qkv_dim).value(); + std::vector encoded(size); + + ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim, + hwy::Span(encoded.data(), encoded.size()))); + + const gcpp::BF16* data = reinterpret_cast(encoded.data()); + + // K Layout (8x4 block, token-major) + // base_offset = ch_g * 128 + g * 32. + // For qkv_dim = 4, ch_g = 0 is the only channel group. + // For g = 0 (tokens 0-7), base_offset = 0. + // K[t, c] is at offset t * 4 + c. + // original.k_elem(t, c) = c * 32 + t + 1. + // For t=0, c=0: original.k_elem(0,0) = 1. Offset = 0. + // For t=0, c=1: original.k_elem(0,1) = 33. Offset = 1. + // For t=1, c=0: original.k_elem(1,0) = 2. Offset = 4. + // For t=7, c=3: original.k_elem(7,3) = 3*32 + 7 + 1 = 104. Offset = 7 * 4 + 3 + // = 31. + EXPECT_NEAR(hwy::ConvertScalarTo(data[0]), 1.0f, 0.05f); + EXPECT_NEAR(hwy::ConvertScalarTo(data[1]), 33.0f, 0.05f); + EXPECT_NEAR(hwy::ConvertScalarTo(data[4]), 2.0f, 0.05f); + EXPECT_NEAR(hwy::ConvertScalarTo(data[31]), 104.0f, 0.05f); + + // For g = 1 (tokens 8-15), base_offset = 32. + // K[t_in_g, c] is at 32 + t_in_g * 4 + c. + // t=8 (t_in_g=0), c=0: original.k_elem(8,0) = 9. Offset = 32. + EXPECT_NEAR(hwy::ConvertScalarTo(data[32]), 9.0f, 0.05f); + + // V Layout (Contiguous SV-Blocked Layout) + // For qkv_dim = 4, ch_g = 0. + // V[t, c] is at v_start + sub_block * 32 + block_offset + // where: + // g_t = t / 16, g_c = c / 2 + // sub_block = g_t * 2 + g_c + // t' = t % 16, c' = c % 2 + // g_t4 = t' / 4, t'' = t' % 4 + // block_offset = g_t4 * 8 + c' * 4 + t'' + // original.v_elem(t, c) = t * 4 + c + 1 + 128. + // v_start = 4 * 32 = 128 elements of BF16. + size_t v_start = qkv_dim * kTileSize; + // For t=0, c=0: original.v_elem(0,0) = 129. Offset = v_start + 0. + // For t=1, c=0: original.v_elem(1,0) = 133. Offset = v_start + 1. + // For t=0, c=1: original.v_elem(0,1) = 130. Offset = v_start + 4. + // For t=1, c=1: original.v_elem(1,1) = 134. Offset = v_start + 5. + // For t=2, c=0: original.v_elem(2,0) = 137. Offset = v_start + 2. + // For t=7, c=3: original.v_elem(7,3) = 7*4 + 3 + 1 + 128 = 160. + // Offset = v_start + 47. + EXPECT_NEAR(hwy::ConvertScalarTo(data[v_start + 0]), 129.0f, 0.05f); + EXPECT_NEAR(hwy::ConvertScalarTo(data[v_start + 1]), 133.0f, 0.05f); + EXPECT_NEAR(hwy::ConvertScalarTo(data[v_start + 4]), 130.0f, 0.05f); + EXPECT_NEAR(hwy::ConvertScalarTo(data[v_start + 5]), 134.0f, 0.05f); + EXPECT_NEAR(hwy::ConvertScalarTo(data[v_start + 2]), 137.0f, 0.05f); + EXPECT_NEAR(hwy::ConvertScalarTo(data[v_start + 47]), 160.0f, 0.05f); +} + } // namespace } // namespace gcpp diff --git a/gemma/tiled_attention.cc b/gemma/tiled_attention.cc index cd942661..d3b39174 100644 --- a/gemma/tiled_attention.cc +++ b/gemma/tiled_attention.cc @@ -14,6 +14,7 @@ #include "gemma/configs.h" #include "gemma/gemma.h" #include "gemma/kv_cache.h" +#include "gemma/kv_transcoding.h" #include "ops/matmul.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" @@ -152,6 +153,9 @@ static HWY_INLINE void ComputeQKVTransposedTile( std::vector kv_ptrs = qbatch.KV(query_idx).cache->GetPointers( layer_idx, kv_head, kv_heads, start_pos, is_global_layer); + const size_t v_offset = qkv_dim * KVCache::kTileSize; + const size_t tile_span_size = 2 * qkv_dim * KVCache::kTileSize; + const size_t k_size = qkv_dim * KVCache::kTileSize; size_t tile_offset = 0; if (!is_global_layer) { tile_offset = start_pos / KVCache::kTileSize; @@ -173,13 +177,10 @@ static HWY_INLINE void ComputeQKVTransposedTile( tile_ptr = HWY_RCAST_ALIGNED( KV_T*, kv_ptrs[kv_ptr_idx].RowBytes(relative_tile_idx - absolute_rows)); - PackedSpan tile_packed_span{tile_ptr, - 2 * qkv_dim * KVCache::kTileSize}; + PackedSpan tile_packed_span{tile_ptr, tile_span_size}; - DecompressAndZeroPad(df, tile_packed_span, 0, k_tile_vec, - qkv_dim * KVCache::kTileSize); - DecompressAndZeroPad(df, tile_packed_span, - qkv_dim * KVCache::kTileSize, v_tile_vec, + DecompressAndZeroPad(df, tile_packed_span, 0, k_tile_vec, k_size); + DecompressAndZeroPad(df, tile_packed_span, v_offset, v_tile_vec, qkv_dim * KVCache::kTileSize); size_t token_in_tile_idx = current_token_idx; @@ -255,7 +256,21 @@ static HWY_INLINE void ComputeQKVTransposedTile( v_cache_values = v_buf; } - if (is_transposed_qs) { + const MatPtr& compact_kv_cache_ptr = + qbatch.KV(query_idx).cache->compact_kv_cache_ptr; + if (compact_kv_cache_ptr.GetType() == Type::kBF16 && + compact_kv_cache_ptr.GetLayout() == + MatPtr::Layout::kBF16MatrixAccumulation) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + size_t k_offset = gcpp::KMatrixAccumulationOffset_BF16( + qkv_dim, dim, in_tile_idx); + k_tile_vec[k_offset] = k_f32[dim]; + + size_t v_offset = gcpp::VMatrixAccumulationOffset_BF16( + qkv_dim, in_tile_idx, dim); + v_tile_vec[v_offset] = v_cache_values[dim]; + } + } else if (is_transposed_qs) { const int in_tile_idx_mod_2 = in_tile_idx % 2; for (int dim = 0; dim < qkv_dim; dim += 2) { const int dim_mod_2 = dim % 2; @@ -284,11 +299,11 @@ static HWY_INLINE void ComputeQKVTransposedTile( token_in_tile_idx++; } - Compress(k_tile_vec, qkv_dim * KVCache::kTileSize, tls, - tile_packed_span, 0); - if (is_transposed_qs) { + Compress(k_tile_vec, k_size, tls, tile_packed_span, 0); + if (is_transposed_qs || + attention_impl == AttentionImpl::kFlashMatrixAccumulation) { Compress(v_tile_vec, qkv_dim * KVCache::kTileSize, tls, - tile_packed_span, qkv_dim * KVCache::kTileSize); + tile_packed_span, v_offset); } current_token_idx = token_in_tile_idx; } @@ -399,6 +414,60 @@ static HWY_INLINE void MaybeResizeMatStorage(MatStorageT& mat_storage, } } +template +HWY_INLINE void CompressAndTransposeQueriesMatrixAccumulationImpl( + QueryProvider query_provider, BF16* packed_queries, size_t num_queries, + size_t qkv_dim) { + HWY_DASSERT(qkv_dim % 4 == 0); + + namespace hn = hwy::HWY_NAMESPACE; + const hn::FixedTag df; + const hn::FixedTag dbf16; + constexpr size_t kL = 4; + + size_t p = 0; + for (; p < num_queries / 2; ++p) { + const float* q0 = query_provider(2 * p); + const float* q1 = query_provider(2 * p + 1); + BF16* out = packed_queries + 2 * p * qkv_dim; + + for (size_t d = 0; d < qkv_dim; d += kL) { + auto v0 = hn::LoadU(df, q0 + d); + auto v1 = hn::LoadU(df, q1 + d); + auto A = hn::OrderedDemote2To(dbf16, v0, v1); + hn::StoreU(A, dbf16, out + d * 2); + } + } + + if (num_queries % 2 != 0) { + const float* q0 = query_provider(2 * p); + BF16* out = packed_queries + 2 * p * qkv_dim; + auto zero = hn::Zero(df); + + for (size_t d = 0; d < qkv_dim; d += kL) { + auto v0 = hn::LoadU(df, q0 + d); + auto A = hn::OrderedDemote2To(dbf16, v0, zero); + hn::StoreU(A, dbf16, out + d * 2); + } + } +} + +void CompressAndTransposeQueriesMatrixAccumulation(const float* raw_queries, + BF16* packed_queries, + size_t num_queries, + size_t qkv_dim) { + CompressAndTransposeQueriesMatrixAccumulationImpl( + [&](size_t idx) { return raw_queries + idx * qkv_dim; }, packed_queries, + num_queries, qkv_dim); +} + +void CompressAndTransposeQueriesMatrixAccumulationNonContiguous( + hwy::Span input, BF16* packed_queries, size_t qkv_dim) { + CompressAndTransposeQueriesMatrixAccumulationImpl( + [&](size_t idx) { return input[idx]; }, packed_queries, input.size(), + qkv_dim); +} + // clang-format off // Schedules TiledFlashAttention for all heads, tokens and batch. // Returns partial results in the same order as queries in `activations.q`. @@ -616,6 +685,7 @@ void LocalAttentionForAllHeadsTokensAndBatch( hwy::Span(last_pos_per_query), activations.config.att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); + } else if (attention_impl == AttentionImpl::kFlashTransposedQsInt16) { HWY_DASSERT(activations.int16_queries != nullptr); HWY_DASSERT(activations.q_scales != nullptr); @@ -632,6 +702,18 @@ void LocalAttentionForAllHeadsTokensAndBatch( hwy::Span(last_pos_per_query), activations.config.att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); + } else if (attention_impl == AttentionImpl::kFlashMatrixAccumulation) { + HWY_DASSERT(activations.bf16_queries != nullptr); + BF16* bf16_queries_ptr = activations.bf16_queries->data() + + task_idx * num_queries * qkv_dim; + CompressAndTransposeQueriesMatrixAccumulationNonContiguous( + queries_ptrs_span, bf16_queries_ptr, qkv_dim); + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation( + kv_ptrs, num_queries, bf16_queries_ptr, + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), + activations.config.att_cap, att_out, exp_denominator_sums.data(), + max_logits.data()); } else { HWY_DASSERT(activations.float_queries != nullptr); float* contiguous_queries_ptr = activations.float_queries->data() + @@ -725,10 +807,11 @@ void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, "query heads must be a multiple of key-value heads"); (void)layer_config; // only used in HWY_DASSERT - if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kBF16) { + const Type kv_type = qbatch.KV(0).cache->compact_kv_cache_ptr.GetType(); + if (kv_type == Type::kBF16) { ComputeQKVTransposedTile(num_tokens, layer_idx, layer, attention_impl, activations, qbatch, flags, env); - } else if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kF32) { + } else if (kv_type == Type::kF32) { ComputeQKVTransposedTile(num_tokens, layer_idx, layer, attention_impl, activations, qbatch, flags, env); diff --git a/gemma/tiled_attention.h b/gemma/tiled_attention.h index bc06bf6c..e8524a72 100644 --- a/gemma/tiled_attention.h +++ b/gemma/tiled_attention.h @@ -16,33 +16,38 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - AttentionActivationsPtrs& activations, QBatch& qbatch, \ - MatMulEnv& env, int flags); \ - void LocalAttentionForAllHeadsTokensAndBatch( \ - AttentionImpl attention_impl, const size_t num_tokens, \ - const size_t layer_idx, const LayerWeightsPtrs& layer, \ - AttentionActivationsPtrs& activations, QBatch& qbatch, \ - ThreadingContext& ctx); \ - \ - void CompressQueriesBF16(hwy::Span input, int qkv_dim, \ - BF16* HWY_RESTRICT output); \ - void CompressQueriesBF16Contiguous(const float* HWY_RESTRICT input, \ - int qkv_dim, size_t num_queries, \ - BF16* HWY_RESTRICT output); \ - \ - void CompressQueriesInt16(hwy::Span input, int qkv_dim, \ - int16_t* HWY_RESTRICT output, \ - float* HWY_RESTRICT scale); \ - \ - void CompressQueriesInt16Contiguous(const float* HWY_RESTRICT input, \ - int qkv_dim, size_t num_queries, \ - int16_t* HWY_RESTRICT output, \ - float* HWY_RESTRICT scale); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + MatMulEnv& env, int flags); \ + void LocalAttentionForAllHeadsTokensAndBatch( \ + AttentionImpl attention_impl, const size_t num_tokens, \ + const size_t layer_idx, const LayerWeightsPtrs& layer, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + ThreadingContext& ctx); \ + \ + void CompressQueriesBF16(hwy::Span input, int qkv_dim, \ + BF16* HWY_RESTRICT output); \ + void CompressQueriesBF16Contiguous(const float* HWY_RESTRICT input, \ + int qkv_dim, size_t num_queries, \ + BF16* HWY_RESTRICT output); \ + \ + void CompressQueriesInt16(hwy::Span input, int qkv_dim, \ + int16_t* HWY_RESTRICT output, \ + float* HWY_RESTRICT scale); \ + \ + void CompressQueriesInt16Contiguous(const float* HWY_RESTRICT input, \ + int qkv_dim, size_t num_queries, \ + int16_t* HWY_RESTRICT output, \ + float* HWY_RESTRICT scale); \ + \ + void CompressAndTransposeQueriesMatrixAccumulation(const float* raw_queries, \ + BF16* packed_queries, \ + size_t num_queries, \ + size_t qkv_dim); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/gemma/tiled_attention_test.cc b/gemma/tiled_attention_test.cc index 5e09c2a3..bca0f192 100644 --- a/gemma/tiled_attention_test.cc +++ b/gemma/tiled_attention_test.cc @@ -97,13 +97,19 @@ struct AttentionTestEnv { bool transposed = attention_impl == AttentionImpl::kFlashTransposedQsBF16; gcpp::KVEncoding encoding; - const Type type = kv_caches.back().compact_kv_cache_ptr.GetType(); + const MatPtr& compact_kv = kv_caches.back().compact_kv_cache_ptr; + const Type type = compact_kv.GetType(); + const MatPtr::Layout layout = compact_kv.GetLayout(); if (type == Type::kInt8) { encoding = transposed ? gcpp::KVEncoding::kInt8TwoTranspositions : gcpp::KVEncoding::kInt8; } else if (type == Type::kBF16) { - encoding = transposed ? gcpp::KVEncoding::kBF16TwoTranspositions - : gcpp::KVEncoding::kBF16; + if (layout == MatPtr::Layout::kBF16MatrixAccumulation) { + encoding = gcpp::KVEncoding::kBF16MatrixAccumulation; + } else { + encoding = transposed ? gcpp::KVEncoding::kBF16TwoTranspositions + : gcpp::KVEncoding::kBF16; + } } else { encoding = transposed ? gcpp::KVEncoding::kF32TwoTranspositions : gcpp::KVEncoding::kF32; @@ -123,12 +129,15 @@ struct AttentionTestEnv { } else { FillMatPtrT(kv_caches.back().kv_cache); } + } + + for (size_t q = 0; q < qbatch_size; ++q) { all_queries.Append({ .prompt = PromptTokens({1, 2, 3}), .mutable_pos = static_cast(last_pos), .initial_pos = 0, .prefix_end = 0, - .kv_cache = kv_caches.back().ToPtr(), + .kv_cache = kv_caches[q].ToPtr(), }); } @@ -730,6 +739,43 @@ void TestAttentionMultipleTokensBF16() { } } +void TestAttentionMultipleTokensBF16MatrixAccumulation() { + int qkv_dim = 64; + int kv_seq_len = 64; + int num_kv_heads = 2; + int num_heads = 4; + int num_tokens = 2; + int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1 + // will have 64 tokens to attend to. + float att_cap = 10.0f; + int layer_idx = 0; + int layers_total = 1; + int qbatch_size = 2; + AttentionImpl attention_impl = AttentionImpl::kFlashMatrixAccumulation; + AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, + num_heads, num_tokens, last_pos, att_cap, layer_idx, + layers_total, qbatch_size, attention_impl); + test_env.SetupWeights(); + FillMatPtrT(test_env.activations->attention.pre_att_rms_out); + FillMatPtrT(test_env.activations->attention.q); + FillMatPtrT(test_env.activations->attention.att_out); + FillMatPtrT(test_env.activations->attention.softmax_max); + FillMatPtrT(test_env.activations->attention.softmax_d); + + TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, + test_env.activations->attention, *test_env.qbatch, + test_env.env, kTiledFlags); + for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { + EXPECT_TRUE(hwy::CompareArraySimilar( + AttentionMultipleTokensAttentionGoldens.data() + + i * test_env.activations->attention.att_out.Cols(), + test_env.activations->attention.att_out.Row(i), + test_env.activations->attention.att_out.Cols(), 1e-1, + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "att_out mismatch for query: " << i; + } +} + void TestAttentionMultipleTokensInt8() { int qkv_dim = 64; int kv_seq_len = 64; @@ -783,6 +829,9 @@ HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestCompressQueries); // TestLocalAttentionForAllHeadsTokensAndBatch); HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokens); HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokensBF16); +HWY_EXPORT_AND_TEST_P(TiledAttentionTest, + TestAttentionMultipleTokensBF16MatrixAccumulation); + // HWY_EXPORT_AND_TEST_P(TiledAttentionTest, // TestAttentionMultipleTokensAttentionWindowSizeEdgeCase); diff --git a/ops/ops-inl.h b/ops/ops-inl.h index d5d1b49e..753570ee 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -23,6 +23,7 @@ #include #include +#include #include #include // std::enable_if_t #include @@ -1835,12 +1836,207 @@ HWY_API VF PerBlock2x2MatMulMaybeEmulate(DN dn, VBF a, VBF b, VF c) { const auto b2 = hn::BitCast(dbf, hn::Per4LaneBlockShuffle<3, 1, 3, 1>(b_f)); VF sum1 = hn::Zero(dn); - VF sum0 = hn::ReorderWidenMulAccumulate(dn, a1, b1, c, sum1); + VF sum0 = hn::ReorderWidenMulAccumulate(dn, a1, b1, hn::Zero(dn), sum1); sum0 = hn::ReorderWidenMulAccumulate(dn, a2, b2, sum0, sum1); - return hn::RearrangeToOddPlusEven(sum0, sum1); + return hn::Add(hn::RearrangeToOddPlusEven(sum0, sum1), c); #endif } +static constexpr float kMaskedLogitVal = + -std::numeric_limits::max() / 64.0f; + +template +HWY_INLINE void ApplySoftCap(DF df, float att_cap, float one_over_cap, VF& x0, + VF& x1, VF& x2, VF& x3, VF& x4, VF& x5, VF& x6, + VF& x7) { + if (att_cap > 0.0f) { + VF cap = hn::Set(df, att_cap); + VF one_over_cap_vec = hn::Set(df, one_over_cap); + x0 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x0, one_over_cap_vec))); + if constexpr (kVTileSize >= 2) { + x1 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x1, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 3) { + x2 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x2, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 4) { + x3 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x3, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 5) { + x4 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x4, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 6) { + x5 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x5, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 7) { + x6 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x6, one_over_cap_vec))); + } + if constexpr (kVTileSize >= 8) { + x7 = hn::Mul(cap, hn::CallFastTanh(df, hn::Mul(x7, one_over_cap_vec))); + } + } +} + +template , + typename DU = hn::ScalableTag, class VU = hn::Vec> +HWY_NOINLINE void ApplyMasking(DF df, DU du, size_t position, + const size_t* HWY_RESTRICT first_pos_per_query, + const size_t* HWY_RESTRICT last_pos_per_query, + VF& x0_p0, VF& x0_p1, VF& x1_p0, VF& x1_p1, + VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, + VF& x4_p0, VF& x4_p1, VF& x5_p0, VF& x5_p1, + VF& x6_p0, VF& x6_p1, VF& x7_p0, VF& x7_p1) { + VU lane_indices = hn::Iota(du, 0); + HWY_LANES_CONSTEXPR size_t kTileSize = hn::Lanes(df); + auto per_lane_pos_p0 = hn::Add(hn::Set(du, position), lane_indices); + auto per_lane_pos_p1 = + hn::Add(hn::Set(du, position + kTileSize), lane_indices); + + VF neg_inf = hn::Set(df, kMaskedLogitVal); + + auto apply_mask_for_query = [&](int query_idx, VF& x_p0, VF& x_p1) HWY_ATTR { + const size_t first_pos = first_pos_per_query[query_idx]; + const size_t last_pos = last_pos_per_query[query_idx]; + + auto valid_tokens_mask_p0 = hn::Ge(per_lane_pos_p0, hn::Set(du, first_pos)); + valid_tokens_mask_p0 = hn::And( + valid_tokens_mask_p0, hn::Le(per_lane_pos_p0, hn::Set(du, last_pos))); + x_p0 = + hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p0), x_p0, neg_inf); + + auto valid_tokens_mask_p1 = hn::Ge(per_lane_pos_p1, hn::Set(du, first_pos)); + valid_tokens_mask_p1 = hn::And( + valid_tokens_mask_p1, hn::Le(per_lane_pos_p1, hn::Set(du, last_pos))); + x_p1 = + hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p1), x_p1, neg_inf); + }; + + if constexpr (kNumQueries >= 1) { + apply_mask_for_query(0, x0_p0, x0_p1); + } + if constexpr (kNumQueries >= 2) { + apply_mask_for_query(1, x1_p0, x1_p1); + } + if constexpr (kNumQueries >= 3) { + apply_mask_for_query(2, x2_p0, x2_p1); + } + if constexpr (kNumQueries >= 4) { + apply_mask_for_query(3, x3_p0, x3_p1); + } + if constexpr (kNumQueries >= 5) { + apply_mask_for_query(4, x4_p0, x4_p1); + } + if constexpr (kNumQueries >= 6) { + apply_mask_for_query(5, x5_p0, x5_p1); + } + if constexpr (kNumQueries >= 7) { + apply_mask_for_query(6, x6_p0, x6_p1); + } + if constexpr (kNumQueries >= 8) { + apply_mask_for_query(7, x7_p0, x7_p1); + } +} + +template +HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0, VF& x0_p1, + VF& x1_p0, VF& x1_p1, VF& x2_p0, VF& x2_p1, + VF& x3_p0, VF& x3_p1, VF& x4_p0, VF& x4_p1, + VF& x5_p0, VF& x5_p1, VF& x6_p0, VF& x6_p1, + VF& x7_p0, VF& x7_p1) { + const size_t kTileSize = hn::Lanes(df); + const PackedSpan scales_span = + MakeConstSpan(scales, 2 * kTileSize); + VF scales_p0, scales_p1; + Decompress2(df, scales_span, 0, scales_p0, scales_p1); + if constexpr (kNumQueries >= 1) { + x0_p0 = hn::Mul(x0_p0, scales_p0); + x0_p1 = hn::Mul(x0_p1, scales_p1); + } + if constexpr (kNumQueries >= 2) { + x1_p0 = hn::Mul(x1_p0, scales_p0); + x1_p1 = hn::Mul(x1_p1, scales_p1); + } + if constexpr (kNumQueries >= 3) { + x2_p0 = hn::Mul(x2_p0, scales_p0); + x2_p1 = hn::Mul(x2_p1, scales_p1); + } + if constexpr (kNumQueries >= 4) { + x3_p0 = hn::Mul(x3_p0, scales_p0); + x3_p1 = hn::Mul(x3_p1, scales_p1); + } + if constexpr (kNumQueries >= 5) { + x4_p0 = hn::Mul(x4_p0, scales_p0); + x4_p1 = hn::Mul(x4_p1, scales_p1); + } + if constexpr (kNumQueries >= 6) { + x5_p0 = hn::Mul(x5_p0, scales_p0); + x5_p1 = hn::Mul(x5_p1, scales_p1); + } + if constexpr (kNumQueries >= 7) { + x6_p0 = hn::Mul(x6_p0, scales_p0); + x6_p1 = hn::Mul(x6_p1, scales_p1); + } + if constexpr (kNumQueries >= 8) { + x7_p0 = hn::Mul(x7_p0, scales_p0); + x7_p1 = hn::Mul(x7_p1, scales_p1); + } +} + +template +HWY_INLINE void ApplyQuantizationScale( + DF df, const float* HWY_RESTRICT q_scales, size_t query_idx, VF& x0_p0, + VF& x0_p1, VF& x1_p0, VF& x1_p1, VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, + VF& x4_p0, VF& x4_p1, VF& x5_p0, VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0, + VF& x7_p1) { + auto apply_scale = [&](size_t i, VF& x_p0, VF& x_p1) HWY_ATTR { + size_t scale_idx = query_idx + i; + VF s = hn::Set(df, q_scales[scale_idx]); + x_p0 = hn::Mul(x_p0, s); + x_p1 = hn::Mul(x_p1, s); + }; + + if constexpr (kNumQueries >= 1) { + apply_scale(0, x0_p0, x0_p1); + } + if constexpr (kNumQueries >= 2) { + apply_scale(1, x1_p0, x1_p1); + } + if constexpr (kNumQueries >= 3) { + apply_scale(2, x2_p0, x2_p1); + } + if constexpr (kNumQueries >= 4) { + apply_scale(3, x3_p0, x3_p1); + } + if constexpr (kNumQueries >= 5) { + apply_scale(4, x4_p0, x4_p1); + } + if constexpr (kNumQueries >= 6) { + apply_scale(5, x5_p0, x5_p1); + } + if constexpr (kNumQueries >= 7) { + apply_scale(6, x6_p0, x6_p1); + } + if constexpr (kNumQueries >= 8) { + apply_scale(7, x7_p0, x7_p1); + } +} + +template +HWY_INLINE V SumReduceSegments(D d, V v) { + constexpr size_t L = HWY_MAX_LANES_D(D); + if constexpr (L <= kTargetLanes) { + return v; + } else { + using D_half = hn::Half; + const D_half d_half; + auto lo = hn::LowerHalf(d_half, v); + auto hi = hn::UpperHalf(d_half, v); + auto sum = hn::Add(lo, hi); + auto reduced_half = SumReduceSegments(d_half, sum); + return hn::ZeroExtendVector(d, reduced_half); + } +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/util/mat.h b/util/mat.h index 256dfad0..b9bfb72c 100644 --- a/util/mat.h +++ b/util/mat.h @@ -67,6 +67,14 @@ using RowPtrsBF = RowPtrs; // Copyable, (de)serializable via `fields.h` for `model_store.h`. class MatPtr : public IFields { public: + enum class Layout { + kFlat, + kBF16MatrixAccumulation, + }; + + Layout GetLayout() const { return layout_; } + void SetLayout(Layout layout) { layout_ = layout; } + MatPtr() = default; // `name`: see `SetName`. Note that `stride` is initially `cols` and only // differs after deserializing, or calling `SetPtr`. @@ -297,6 +305,9 @@ class MatPtr : public IFields { uint32_t stride_; float scale_ = 1.0f; // multiplier for each value, for MatMul. + + private: + Layout layout_ = Layout::kFlat; }; // Non-type erased version of `MatPtr`: provides type-safe `Row()` and ensures