Fix vectorized CountFilter (#2943)

This commit is contained in:
Hannes Greule 2024-10-25 19:59:53 +02:00 committed by GitHub
parent 65de3642d6
commit 7d6643b452
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 43 additions and 13 deletions

View File

@ -3,6 +3,7 @@ package com.fastasyncworldedit.core.extent.filter;
import com.fastasyncworldedit.core.extent.filter.block.FilterBlock;
import com.fastasyncworldedit.core.internal.simd.VectorizedFilter;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask;
public class CountFilter extends ForkedFilter<CountFilter> implements VectorizedFilter {
@ -36,8 +37,8 @@ public class CountFilter extends ForkedFilter<CountFilter> implements Vectorized
}
@Override
public ShortVector applyVector(final ShortVector get, final ShortVector set) {
total += set.length();
public ShortVector applyVector(final ShortVector get, final ShortVector set, VectorMask<Short> mask) {
total += mask.trueCount();
return set;
}

View File

@ -6,6 +6,7 @@ import com.fastasyncworldedit.core.queue.Filter;
import com.fastasyncworldedit.core.queue.IChunk;
import com.sk89q.worldedit.regions.Region;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask;
import org.jetbrains.annotations.Nullable;
/**
@ -77,9 +78,9 @@ public sealed class LinkedFilter<L extends Filter, R extends Filter> implements
}
@Override
public ShortVector applyVector(final ShortVector get, final ShortVector set) {
ShortVector res = getLeft().applyVector(get, set);
return getRight().applyVector(get, res);
public ShortVector applyVector(final ShortVector get, final ShortVector set, VectorMask<Short> mask) {
ShortVector res = getLeft().applyVector(get, set, mask);
return getRight().applyVector(get, res, mask);
}
@Override

View File

@ -82,11 +82,10 @@ public class MaskFilter<T extends Filter> extends DelegateFilter<T> {
}
@Override
public ShortVector applyVector(final ShortVector get, final ShortVector set) {
public ShortVector applyVector(final ShortVector get, final ShortVector set, VectorMask<Short> mask) {
final T parent = getParent();
VectorMask<Short> masked = vectorizedMask.compareVector(set, get);
ShortVector res = parent.applyVector(get, set);
res = set.blend(res, masked);
ShortVector res = parent.applyVector(get, set, mask.and(masked));
VectorMask<Short> changed = res.compare(VectorOperators.NE, set);
changes.getAndAdd(changed.trueCount());
return res;

View File

@ -14,6 +14,7 @@ import com.sk89q.worldedit.world.block.BaseBlock;
import com.sk89q.worldedit.world.block.BlockStateHolder;
import com.sk89q.worldedit.world.block.BlockTypesCache;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorOperators;
import javax.annotation.Nullable;
@ -101,8 +102,9 @@ public class SimdSupport {
}
@Override
public ShortVector applyVector(final ShortVector get, final ShortVector set) {
return ShortVector.broadcast(ShortVector.SPECIES_PREFERRED, ordinal);
public ShortVector applyVector(final ShortVector get, final ShortVector set, VectorMask<Short> mask) {
// only change the lanes the mask dictates us to change, keep the rest
return set.blend(ShortVector.broadcast(ShortVector.SPECIES_PREFERRED, ordinal), mask);
}
@Override

View File

@ -4,6 +4,7 @@ import com.fastasyncworldedit.core.extent.filter.block.CharFilterBlock;
import com.fastasyncworldedit.core.queue.Filter;
import com.sk89q.worldedit.extent.Extent;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorSpecies;
public class VectorizedCharFilterBlock extends CharFilterBlock {
@ -18,15 +19,17 @@ public class VectorizedCharFilterBlock extends CharFilterBlock {
throw new IllegalStateException("Unexpected VectorizedCharFilterBlock " + filter);
}
final VectorSpecies<Short> species = ShortVector.SPECIES_PREFERRED;
// TODO can we avoid eager initSet?
initSet(); // set array is null before
char[] setArr = this.setArr;
assert setArr != null;
char[] getArr = this.getArr;
// assume setArr.length == getArr.length == 4096
VectorMask<Short> affectAll = species.maskAll(true);
for (int i = 0; i < 4096; i += species.length()) {
ShortVector set = ShortVector.fromCharArray(species, setArr, i);
ShortVector get = ShortVector.fromCharArray(species, getArr, i);
ShortVector res = vecFilter.applyVector(get, set);
ShortVector res = vecFilter.applyVector(get, set, affectAll);
res.intoCharArray(setArr, i);
}
}

View File

@ -2,7 +2,18 @@ package com.fastasyncworldedit.core.internal.simd;
import com.fastasyncworldedit.core.queue.Filter;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask;
public interface VectorizedFilter extends Filter {
ShortVector applyVector(ShortVector get, ShortVector set);
/**
* Applies a filter to a vector pair of get and set.
*
* @param get the get vector
* @param set the set vector
* @param mask the mask with the lanes set to true which should be affected by the filter
* @return the resulting set vector.
*/
ShortVector applyVector(ShortVector get, ShortVector set, VectorMask<Short> mask);
}

View File

@ -3,6 +3,7 @@ package com.fastasyncworldedit.core.internal.simd;
import com.fastasyncworldedit.core.queue.IChunk;
import com.fastasyncworldedit.core.queue.IChunkGet;
import com.fastasyncworldedit.core.queue.IChunkSet;
import com.sk89q.worldedit.world.block.BlockTypesCache;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorSpecies;
@ -31,10 +32,22 @@ public interface VectorizedMask {
}
}
/**
* {@return the set vector with all lanes that do not match this mask set to 0}
*
* @param set the set vector
* @param get the get vector
*/
default ShortVector processVector(ShortVector set, ShortVector get) {
return set.blend(0, compareVector(set, get).not());
return set.blend(BlockTypesCache.ReservedIDs.__RESERVED__, compareVector(set, get).not());
}
/**
* {@return a mask with all lanes set that match this mask}
*
* @param set the set vector
* @param get the get vector
*/
VectorMask<Short> compareVector(ShortVector set, ShortVector get);
}