diff --git a/CI/float16-small-causal-latest.png b/CI/float16-small-causal-latest.png index 7c7f0d4..af05338 100644 Binary files a/CI/float16-small-causal-latest.png and b/CI/float16-small-causal-latest.png differ diff --git a/MetalFlashAttention.xcodeproj/project.pbxproj b/MetalFlashAttention.xcodeproj/project.pbxproj index 1435178..efd7d84 100644 --- a/MetalFlashAttention.xcodeproj/project.pbxproj +++ b/MetalFlashAttention.xcodeproj/project.pbxproj @@ -77,7 +77,6 @@ 984279D62A619008001BBD55 /* AttentionTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AttentionTest.swift; sourceTree = ""; }; 984F721A2A6EEB0E00C15D4A /* float16-small-sequences-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-small-sequences-latest.png"; sourceTree = ""; }; 984F721B2A6EEC4B00C15D4A /* float16-large-sequences-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-large-sequences-latest.png"; sourceTree = ""; }; - 984F721C2A6EEE0F00C15D4A /* float16-small-causal-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-small-causal-latest.png"; sourceTree = ""; }; 984F721D2A6EF09000C15D4A /* float16-head-sizes-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-head-sizes-latest.png"; sourceTree = ""; }; 987E35DB2A45E4F400ACACE3 /* MetalFlashAttention */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = MetalFlashAttention; sourceTree = BUILT_PRODUCTS_DIR; }; 987E35DE2A45E4F400ACACE3 /* main.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = main.swift; sourceTree = ""; }; @@ -104,6 +103,7 @@ 98C795312A4DC1F200DB688D /* GEMMSquareBenchmark.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GEMMSquareBenchmark.swift; sourceTree = ""; }; 98DFBD0C2A72F0EC002E4B47 /* float16-large-causal-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-large-causal-latest.png"; sourceTree = ""; }; 98DFBD0D2A72F242002E4B47 /* float32-large-causal-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float32-large-causal-latest.png"; sourceTree = ""; }; + 98DFBD0E2A72F92C002E4B47 /* float16-small-causal-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-small-causal-latest.png"; sourceTree = ""; }; 98F2F5DC2A60978C006216F4 /* GEMMTransposeTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GEMMTransposeTest.swift; sourceTree = ""; }; 98F7440E2A4A008C00B5E60A /* build.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = build.swift; sourceTree = ""; }; 98F7440F2A4A0CB200B5E60A /* API.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = API.md; sourceTree = ""; }; @@ -134,7 +134,7 @@ 984F721A2A6EEB0E00C15D4A /* float16-small-sequences-latest.png */, 984F721B2A6EEC4B00C15D4A /* float16-large-sequences-latest.png */, 98DFBD0D2A72F242002E4B47 /* float32-large-causal-latest.png */, - 984F721C2A6EEE0F00C15D4A /* float16-small-causal-latest.png */, + 98DFBD0E2A72F92C002E4B47 /* float16-small-causal-latest.png */, 98DFBD0C2A72F0EC002E4B47 /* float16-large-causal-latest.png */, 984F721D2A6EF09000C15D4A /* float16-head-sizes-latest.png */, 98FDDF002A5895CE0096BC27 /* float16-nt-batched-latest.png */, diff --git a/README.md b/README.md index 2b22e8f..a79a64e 100644 --- a/README.md +++ b/README.md @@ -189,16 +189,16 @@ Dense: Stable Diffusion 2 outermost attention layer @ 512x512 (sequence length = ![FlashAttention (F16, H=5, D=64)](./CI/float16-large-sequences-latest.png) -### Float32 Sequence Scaling (Causal Mask) - -![FlashAttention (F32, H=10, D=64)](./CI/float32-large-causal-latest.png) - ### Float16 Sequence Scaling (Causal Mask) ![FlashAttention (F16, H=10, D=64)](./CI/float16-small-causal-latest.png) ![FlashAttention (F16, H=10, D=64)](./CI/float16-large-causal-latest.png) +### Float32 Sequence Scaling (Causal Mask) + +![FlashAttention (F32, H=10, D=64)](./CI/float32-large-causal-latest.png) + ### Float16 Head Scaling Dense: Stable Diffusion 1 outermost attention layer @ 512x512 (head size = 40) diff --git a/Tests/Test Cases/AttentionPerfTests.swift b/Tests/Test Cases/AttentionPerfTests.swift index 5e6ef9d..2e777ea 100644 --- a/Tests/Test Cases/AttentionPerfTests.swift +++ b/Tests/Test Cases/AttentionPerfTests.swift @@ -38,23 +38,24 @@ class AttentionPerfTests: MFATestCase { // For heads scaling: // sequence length 4096 - let duration = Duration(granularity: -1, length: 2) + let duration = Duration(granularity: 2, length: 2) let (domain, ranges) = rangeSequenceScaling( duration: duration, type: .causal) var backends = SequenceType.causal.backends // let backends: [AttentionBackend] = [.mfa] - backends = backends.compactMap { - if $0.isMPS { return nil } - return $0 - } +// backends = backends.compactMap { +// if $0.isMPS { return nil } +// return $0 +// } // let duration = Duration(granularity: 1, length: 1) // let (domain, ranges) = rangeHeadScaling(duration: duration) // let backends = [AttentionBackend.mps, AttentionBackend.mfa] testAttention( - domain: domain, ranges: ranges, backends: backends, config: .triangular) + domain: domain, ranges: ranges, backends: backends, + config: .none) } enum GraphConfig { @@ -277,33 +278,33 @@ class AttentionPerfTests: MFATestCase { var domain: ClosedRange var parameters: [SIMD8] if type == .causal { - domain = 0...16384 +// domain = 0...16384 // domain = 512...1024 -// domain = 0...1024 + domain = 0...1024 parameters = [ - SIMD8( 1, 8, 256, 1, 0, 0, 0, 0), - SIMD8( 8, 192, 256, 8, 0, 0, 0, 0), - SIMD8( 192, 256, 128, 8, 0, 0, 0, 0), - SIMD8( 256, 384, 64, 8, 0, 0, 0, 0), - SIMD8( 384, 512, 32, 8, 0, 0, 0, 0), - SIMD8( 512, 768, 16, 16, 0, 0, 0, 0), - SIMD8( 768, 1024, 8, 16, 0, 0, 0, 0), - SIMD8(1024, 1536, 4, 32, 0, 0, 0, 0), - SIMD8(1536, 2048, 2, 32, 8, 0, 0, 0), - SIMD8(2048, 3072, 2, 64, 8, 0, 0, 0), - SIMD8(3072, 4096, 2, 128, 8, 0, 0, 0), - SIMD8(4096, 6144, 2, 256, 8, 0, 0, 0), - SIMD8( 6 * 1024, 8 * 1024, 2, 512, 7, 0, 0, 0), - SIMD8( 8 * 1024, 12 * 1024, 2, 1024, 6, 0, 0, 0), - SIMD8(12 * 1024, 16 * 1024 + 1, 2, 2048, 5, 0, 0, 0), +// SIMD8( 1, 8, 256, 1, 0, 0, 0, 0), +// SIMD8( 8, 192, 256, 8, 0, 0, 0, 0), +// SIMD8( 192, 256, 128, 8, 0, 0, 0, 0), +// SIMD8( 256, 384, 64, 8, 0, 0, 0, 0), +// SIMD8( 384, 512, 32, 8, 0, 0, 0, 0), +// SIMD8( 512, 768, 16, 16, 0, 0, 0, 0), +// SIMD8( 768, 1024, 8, 16, 0, 0, 0, 0), +// SIMD8(1024, 1536, 4, 32, 0, 0, 0, 0), +// SIMD8(1536, 2048, 2, 32, 8, 0, 0, 0), +// SIMD8(2048, 3072, 2, 64, 8, 0, 0, 0), +// SIMD8(3072, 4096, 2, 128, 8, 0, 0, 0), +// SIMD8(4096, 6144, 2, 256, 8, 0, 0, 0), +// SIMD8( 6 * 1024, 8 * 1024, 2, 512, 7, 0, 0, 0), +// SIMD8( 8 * 1024, 12 * 1024, 2, 1024, 6, 0, 0, 0), +// SIMD8(12 * 1024, 16 * 1024 + 1, 2, 2048, 5, 0, 0, 0), -// SIMD4(granularity, 192, 256, granularity, 0, 0, 0, 0), -// SIMD4( 192, 256, 128, granularity, 0, 0, 0, 0), -// SIMD4( 256, 384, 64, granularity, 0, 0, 0, 0), -// SIMD4( 384, 512, 32, granularity, 0, 0, 0, 0), + SIMD8(granularity, 192, 256, granularity, 0, 0, 0, 0), + SIMD8( 192, 256, 128, granularity, 0, 0, 0, 0), + SIMD8( 256, 384, 64, granularity, 0, 0, 0, 0), + SIMD8( 384, 512, 32, granularity, 0, 0, 0, 0), -// SIMD4( 512, 768, 16, granularity, 0, 0, 0, 0), -// SIMD4( 768, 1025, 8, granularity, 0, 0, 0, 0), + SIMD8( 512, 768, 16, granularity, 0, 0, 0, 0), + SIMD8( 768, 1025, 8, granularity, 0, 0, 0, 0), ] } else if type == .small { domain = 0...2048 diff --git a/Tests/Test Cases/MFATestCase.swift b/Tests/Test Cases/MFATestCase.swift index 4d98a24..df460a6 100644 --- a/Tests/Test Cases/MFATestCase.swift +++ b/Tests/Test Cases/MFATestCase.swift @@ -10,7 +10,7 @@ import Foundation class MFATestCase { // Global setting for the precision used in tests. #if arch(arm64) - typealias Real = Float32 + typealias Real = Float16 #else typealias Real = Float #endif