Skip to content

reflexive.space

Speculating on the Apple M2: Local Prediction

Table of Contents

  1. Setting up the Stack
  2. About Local Prediction
  3. Test 1: A Simple Loop
  4. Test 2: Probing the Counter
  5. Test 3: Probing the Counter (Redux)
  6. Test 4: "Defeating" the Counter
  7. Breakpoint

This article explores the local branch prediction on Apple's "Avalanche" microarchitecture.

In the previous article, I wrote about running Rust binaries on-top of the "proxy kernel" mode exposed by m1n1 (the Asahi Linux bootloader). I originally wanted to do a bunch of microbenchmarking with tiny Rust programs, but I decided that I didn't want to write a crate for assembling ARMv8 during runtime on the target machine.

In short: m1n1 has very useful Python bindings that let you interact with the proxy kernel, and we're going to use them for running microbenchmarks on the bare-metal.

With any luck, we'll be able to learn something about how branch prediction is implemented on the "Avalanche" P-cores in the Apple M2. At the same time, you should probably take any conclusions made in this article with a grain of salt. There's no official documentation about any of this.

Setting up the Stack

The fragments of Python scattered around this article represent different experiments. Since I might not explain all of the content in detail, you should at least know that they all work in the following way:

  • We're running these experiments on bare metal (via m1n1), and we're confident that our test code is the only thing running on the target P-core
  • We have a template with ARMv8 assembly that represents the body of the code we want to measure with performance-monitoring counters (PMCs)
  • We do some experiment multiple times in a loop like this:
    • We fill in any variable parts of the template
    • We compile our code into a flat binary and write it to memory
    • We determine what inputs will be fed into the test
    • We run the code with some input
    • We read back results from the PMCs

I'm not going to spend time talking about the implementation details behind my test setup unless I really need to, so hopefully this is enough to help you make sense of things. The Python should be pretty easy to digest anyway.

My experiments are also on GitHub, see eigenform/m2e.


About Local Prediction

Many branches can be predicted very accurately by keeping track of their previous outcomes. In other words, the expected outcome of a particular branch can be correlated with the local history of its previous outcomes.

In order to get this behavior, an implementation doesn't need to save a huge string of bits corresponding to previous outcomes for a particular branch. This would incur a massive storage cost in hardware, since you ideally want the ability to track predictions for many [thousands of] branches.

Instead, a local predictor typically relies on the fact that some long pattern of outcomes can be represented by the output of a simple state machine that uses only a few bits of storage. This is usually accomplished with a table of saturating counters which are indexed by some part of the program counter. Each counter has the following parts:

  • A bit for the predicted direction ('taken' or 'not-taken')
  • One or more bits to represent hysteresis/confidence in the prediction

For example, consider the possible states of this 2-bit counter:

00 - Strongly not-taken 
01 - Weakly not-taken
10 - Weakly taken
11 - Strongly taken

When a prediction is determined to be correct, our confidence in the predicted direction is strengthened [up to some limit]. Otherwise, when mispredictions occur, our confidence is weakened [until the predicted direction changes].

Test 1: A Simple Loop

With that in mind, let's look at very simple branch behavior and try to make sense of what we see. This branch is mostly 'not-taken', but we're changing the branch outcome to 'taken' on every 16th test run. Here's some test code:

PHT_COUNTER_EXP = MyExperiment(
AssemblerTemplate(
"""
_test_brn:
    cbnz x0, #4
"""
))

def pht_counter_exp():
    # Compile this test (for recording mispredicted conditional branches)
    # and write it to memory. Our code above is placed at 0x09_0000_0000. 
    code = PHT_COUNTER_EXP.compile(0xc5, addr=0x09_0000_0000)
    tgt.write_payload(code.start, code.data)

    # The output value is 1 on every 16th test run
    # 00000000000000010000000000000001 ...
    sctr = SingleCounter(16)

    # Run the test 512 times
    results = []
    x0_vals = []
    for _ in range(0, 512):
        x0 = sctr.output()
        res = tgt.smp_call_sync(code.start, x0)
        results.append(res)
        x0_vals.append(x0)
        sctr.next()

    print("[*] mispredictions")
    print_samples_history(results)
    print("[*] x0 inputs")
    print_samples_history(x0_vals)

I should also mention that in this case, each test begins with a "newly-discovered" branch instruction. This cbnz instruction lives at 0x09_0000_0000, and I've guaranteed that no branch instruction has been observed at this address in the past.

In this case (and the following ones), we're using the PMCs to measure the number of mispredicted conditional branch instructions. Since we're only measuring one branch here, each test run must record either 1 or 0 mispredictions.

Here's the output of all 512 test runs, along with the corresponding input to x0 for each run:

[*] mispredictions
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
[*] x0 inputs
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001

So, we see a single misprediction each time we change the branch outcome to 'taken'. If we instead switch to testing cbz instead of cbnz, we can see that the first test is mispredicted:

[*] mispredictions
1000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
[*] x0 inputs
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001
0000000000000001000000000000000100000000000000010000000000000001

This is good evidence that the default prediction for newly-discovered branches is 'not-taken.' This also tells us that, when we're in the initial 'not-taken' state, we can tolerate a single misprediction without changing the predicted direction (since we continue to correctly predict 'not-taken' until the 16th iteration).

This must mean there's at least a single hysteresis bit. Otherwise, the predicted direction would always correspond to the previous outcome. That would cause us to mispredict on the first test run after changing the outcome - but that's not what we've observed here.

Our 16 consecutive 'not-taken' outcomes are enough to outweigh the effect of a single 'taken' outcome, and so the behavior of this branch has essentially been captured. In other words, there's no way for the predictor to escape the 'not-taken' state for this pattern of outcomes.

Test 2: Probing the Counter

In order for us to distinguish between other states of the predictor, we need to change the pattern of outcomes taken by our test branch. This time, instead of changing the outcome once per 16 iterations, what if we change it twice in a row?

In this test code, the PatternCounter object just takes an arbitrary list of inputs [corresponding to our branch outcomes] and repeats them until the tests are done.

def test_pht_counters_1():
    code = SINGLE_CBNZ.compile(0xc5, addr=0x09_0000_0000, tgt_align=2)
    tgt.write_payload(code.start, code.data)

    # 00000000000000110000000000000011 ...
    ctr = PatternCounter([
        0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1
    ])

    results = []
    x0_vals = []
    for _ in range(0, 512):
        x0 = ctr.output()
        res = tgt.smp_call_sync(code.start, x0)
        results.append(res)
        x0_vals.append(x0)
        ctr.next()

    print("[*] mispredictions")
    print_samples_history(results)
    print("[*] x0 inputs")
    print_samples_history(x0_vals)
    print()

When the outcome changes twice in a row, what should we expect? There are two possibilities:

  • If we see a misprediction on the next 'not-taken' outcome following our two 'taken' outcomes, it must mean there's a single hysteresis bit. After two mispredictions, the predicted outcome will flip.

  • If we see no misprediction following the two 'taken' outcomes, it must mean there's more than one hysteresis bit tracking our confidence in the 'not-taken' state.

I've manually annotated these results and split them into chunks of 16 tests to make this a little easier to read:

[*] mispredictions
       notice the mispredicted not-taken branches here
                 |                |
                 v                v
0000000000000011 1000000000000011 1000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011

[*] x0 inputs
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011
0000000000000011 0000000000000011 0000000000000011 0000000000000011

Interestingly, it seems like we see both possibilities over the course of our test runs!

Notice how, for the first two times we encounter our 'taken' outcomes, we see a misprediction for the subsequent 'not-taken' outcome. But for all the subsequent iterations, this seems to change! After the first two times, we never observe a misprediction in the 'not-taken' state.

I interpreted this to mean that, for "newly-discovered" branches, the predictor only uses a single hysteresis bit for the 'not-taken' state. After a branch has been "discovered" in this way, the strategy changes to some other predictor with more hysteresis bits!

Test 3: Probing the Counter (Redux)

Again, we need to change our pattern in order to distinguish between different states in the other predictor.

We'll do the same test as before, but instead of changing the outcome a single time on each 16th test run, we'll flip the outcome and repeat it 16 times.

def pht_counter_exp():
    ...

    # This output value flips between 0 and 1 every 16 iterations.
    # 00000000000000001111111111111111 ...
    fctr = FlipCounter(16)

    # Run the test 512 times
    results = []
    x0_vals = []
    for _ in range(0, 512):
        x0 = fctr.output()
        res = tgt.smp_call_sync(code.start, x0)
        results.append(res)
        x0_vals.append(x0)
        fctr.next()
    ...

Over a total of 512 test runs, this flips the outcome of our branch every 16 test runs. This is what the output looks like:

[*] mispredictions
0000000000000000110000000000000011000000000000001100000000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
[*] x0 inputs
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111

First, notice that we start by correctly predicting 'not-taken' 16 times. Since this branch has not been discovered before, we can probably conclude again that the default prediction is 'not-taken'.

Notice how, for the first three times when the outcome flips, we only see two mispredictions before the correct direction is learned. This is consistent with what we observed before.

Testing with cbz yields the same results, but the misprediction pattern switches earlier:

[*] mispredictions
1000000000000000110000000000000011000000000000001110000000000000
1111100000000000111000000000000011111000000000001110000000000000
1111100000000000111000000000000011111000000000001110000000000000
1111100000000000111000000000000011111000000000001110000000000000
1111100000000000111000000000000011111000000000001110000000000000
1111100000000000111000000000000011111000000000001110000000000000
1111100000000000111000000000000011111000000000001110000000000000
1111100000000000111000000000000011111000000000001110000000000000
[*] x0 inputs
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111

I think this indicates that the transitions from the 'not-taken' to the 'taken' state are what promotes a branch to the other predictor, and it looks like the promotion occurs after this happens twice.

If I had to guess, this initial predictor is intentionally simple because we're much more likely to run into branches that are only mispredicted once or twice (ie. in simple loops). Otherwise, our branch probably has more complicated behavior, and it's worth tracking with a more complicated predictor.

Anyhow, if we look at the input values for the remaining test runs, it's clear that this other behavior is different, and it shows us that the underlying saturating counters implementing the other predictor are indeed different:

  • There are five (5) confidence levels to the 'not-taken' state

    • We're fully confident in a 'not-taken' prediction after observing a branch is not-taken consecutively five times in a row
  • There are three (3) confidence levels to the 'taken' state

    • We're fully confident in a 'taken' prediction after observing a branch is taken consecutively three times in a row

Test 4: "Defeating" the Counter

We can use facts like this to do goofy things, like devise a test that effectively "defeats" this particular prediction scheme! We can do this by just tacking on a different pattern to the previous test:

...
def pht_counter_exp():
    code = PHT_COUNTER_EXP.compile(0xc5, addr=0x09_0000_c000)
    tgt.write_payload(code.start, code.data)

    # Like before, this set of tests ends by taking the branch 16 times. 
    # Afterwards, it should take three mispredictions to change the direction. 

    ctr = FlipCounter(16)
    results = []
    x0_vals = []
    for _ in range(0, 512):
        x0 = ctr.output()
        res = tgt.smp_call_sync(code.start, x0)
        results.append(res)
        x0_vals.append(x0)
        ctr.next()

    print("[*] mispredictions")
    print_samples_history(results)
    print("[*] x0 inputs")
    print_samples_history(x0_vals)
    print()

    # Set up the pattern of inputs for these tests.
    # Let the outcome be 'not-taken' three times, and then alternate 
    # between 'taken' and 'not-taken' for the rest of the test runs.

    ctr = PatternCounter([0,0,0] + [ 1,0 ] * 256)
    results = []
    x0_vals = []
    for _ in range(0, 512):
        x0 = ctr.output()
        res = tgt.smp_call_sync(code.start, x0)
        results.append(res)
        x0_vals.append(x0)
        ctr.next()

    print("[*] mispredictions")
    print_samples_history(results)
    print("[*] x0 inputs")
    print_samples_history(x0_vals)

The first set of tests leaves this predictor fully-confident in the 'taken' state. Since we know that the predictor has 3 levels in the 'taken' state, it's easy to undo this by forcing the 'not-taken' outcome 3 times. This should leave predictor in the weakest 'not-taken' state.

After setting that up, all we need to do is alternate between between 'taken' and 'not-taken' outcomes. This causes the predictor to remain in the weakest state for both directions, and we should observe that the branch is consistently mispredicted. Checking the output, we get:

[*] mispredictions
0000000000000000110000000000000011000000000000001100000000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
1110000000000000111110000000000011100000000000001111100000000000
[*] x0 inputs
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111
0000000000000000111111111111111100000000000000001111111111111111

[*] mispredictions
1111111111111111111111111111111111111111111111111111111111111111
1111111111111111111111111111111111111111111111111111111111111111
1111111111111111111111111111111111111111111111111111111111111111
1111111111111111111111111111111111111111111111111111111111111111
1111111111111111111111111111111111111111111111111111111111111111
1111111111111111111111111111111111111111111111111111111111111111
1111111111111111111111111111111111111111111111111111111111111111
1111111111111111111111111111111111111111111111111111111111111111
[*] x0 inputs
0001010101010101010101010101010101010101010101010101010101010101
0101010101010101010101010101010101010101010101010101010101010101
0101010101010101010101010101010101010101010101010101010101010101
0101010101010101010101010101010101010101010101010101010101010101
0101010101010101010101010101010101010101010101010101010101010101
0101010101010101010101010101010101010101010101010101010101010101
0101010101010101010101010101010101010101010101010101010101010101
0101010101010101010101010101010101010101010101010101010101010101

Breakpoint

Isn't that cool? We've learned something kinda non-trivial about how the local branch predictor is [probably!] implemented.

At some point I'd like to continue this series and try experimenting with other predictors in the machine, but I don't know when I'll get around to it.

Thanks for reading!