Fix vectorized CountFilter (#2943)

This commit is contained in:
Hannes Greule 2024-10-25 19:59:53 +02:00 committed by GitHub
parent 65de3642d6
commit 7d6643b452
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 43 additions and 13 deletions

View File

@ -3,6 +3,7 @@ package com.fastasyncworldedit.core.extent.filter;
import com.fastasyncworldedit.core.extent.filter.block.FilterBlock; import com.fastasyncworldedit.core.extent.filter.block.FilterBlock;
import com.fastasyncworldedit.core.internal.simd.VectorizedFilter; import com.fastasyncworldedit.core.internal.simd.VectorizedFilter;
import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask;
public class CountFilter extends ForkedFilter<CountFilter> implements VectorizedFilter { public class CountFilter extends ForkedFilter<CountFilter> implements VectorizedFilter {
@ -36,8 +37,8 @@ public class CountFilter extends ForkedFilter<CountFilter> implements Vectorized
} }
@Override @Override
public ShortVector applyVector(final ShortVector get, final ShortVector set) { public ShortVector applyVector(final ShortVector get, final ShortVector set, VectorMask<Short> mask) {
total += set.length(); total += mask.trueCount();
return set; return set;
} }

View File

@ -6,6 +6,7 @@ import com.fastasyncworldedit.core.queue.Filter;
import com.fastasyncworldedit.core.queue.IChunk; import com.fastasyncworldedit.core.queue.IChunk;
import com.sk89q.worldedit.regions.Region; import com.sk89q.worldedit.regions.Region;
import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
/** /**
@ -77,9 +78,9 @@ public sealed class LinkedFilter<L extends Filter, R extends Filter> implements
} }
@Override @Override
public ShortVector applyVector(final ShortVector get, final ShortVector set) { public ShortVector applyVector(final ShortVector get, final ShortVector set, VectorMask<Short> mask) {
ShortVector res = getLeft().applyVector(get, set); ShortVector res = getLeft().applyVector(get, set, mask);
return getRight().applyVector(get, res); return getRight().applyVector(get, res, mask);
} }
@Override @Override

View File

@ -82,11 +82,10 @@ public class MaskFilter<T extends Filter> extends DelegateFilter<T> {
} }
@Override @Override
public ShortVector applyVector(final ShortVector get, final ShortVector set) { public ShortVector applyVector(final ShortVector get, final ShortVector set, VectorMask<Short> mask) {
final T parent = getParent(); final T parent = getParent();
VectorMask<Short> masked = vectorizedMask.compareVector(set, get); VectorMask<Short> masked = vectorizedMask.compareVector(set, get);
ShortVector res = parent.applyVector(get, set); ShortVector res = parent.applyVector(get, set, mask.and(masked));
res = set.blend(res, masked);
VectorMask<Short> changed = res.compare(VectorOperators.NE, set); VectorMask<Short> changed = res.compare(VectorOperators.NE, set);
changes.getAndAdd(changed.trueCount()); changes.getAndAdd(changed.trueCount());
return res; return res;

View File

@ -14,6 +14,7 @@ import com.sk89q.worldedit.world.block.BaseBlock;
import com.sk89q.worldedit.world.block.BlockStateHolder; import com.sk89q.worldedit.world.block.BlockStateHolder;
import com.sk89q.worldedit.world.block.BlockTypesCache; import com.sk89q.worldedit.world.block.BlockTypesCache;
import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorOperators; import jdk.incubator.vector.VectorOperators;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -101,8 +102,9 @@ public class SimdSupport {
} }
@Override @Override
public ShortVector applyVector(final ShortVector get, final ShortVector set) { public ShortVector applyVector(final ShortVector get, final ShortVector set, VectorMask<Short> mask) {
return ShortVector.broadcast(ShortVector.SPECIES_PREFERRED, ordinal); // only change the lanes the mask dictates us to change, keep the rest
return set.blend(ShortVector.broadcast(ShortVector.SPECIES_PREFERRED, ordinal), mask);
} }
@Override @Override

View File

@ -4,6 +4,7 @@ import com.fastasyncworldedit.core.extent.filter.block.CharFilterBlock;
import com.fastasyncworldedit.core.queue.Filter; import com.fastasyncworldedit.core.queue.Filter;
import com.sk89q.worldedit.extent.Extent; import com.sk89q.worldedit.extent.Extent;
import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorSpecies; import jdk.incubator.vector.VectorSpecies;
public class VectorizedCharFilterBlock extends CharFilterBlock { public class VectorizedCharFilterBlock extends CharFilterBlock {
@ -18,15 +19,17 @@ public class VectorizedCharFilterBlock extends CharFilterBlock {
throw new IllegalStateException("Unexpected VectorizedCharFilterBlock " + filter); throw new IllegalStateException("Unexpected VectorizedCharFilterBlock " + filter);
} }
final VectorSpecies<Short> species = ShortVector.SPECIES_PREFERRED; final VectorSpecies<Short> species = ShortVector.SPECIES_PREFERRED;
// TODO can we avoid eager initSet?
initSet(); // set array is null before initSet(); // set array is null before
char[] setArr = this.setArr; char[] setArr = this.setArr;
assert setArr != null; assert setArr != null;
char[] getArr = this.getArr; char[] getArr = this.getArr;
// assume setArr.length == getArr.length == 4096 // assume setArr.length == getArr.length == 4096
VectorMask<Short> affectAll = species.maskAll(true);
for (int i = 0; i < 4096; i += species.length()) { for (int i = 0; i < 4096; i += species.length()) {
ShortVector set = ShortVector.fromCharArray(species, setArr, i); ShortVector set = ShortVector.fromCharArray(species, setArr, i);
ShortVector get = ShortVector.fromCharArray(species, getArr, i); ShortVector get = ShortVector.fromCharArray(species, getArr, i);
ShortVector res = vecFilter.applyVector(get, set); ShortVector res = vecFilter.applyVector(get, set, affectAll);
res.intoCharArray(setArr, i); res.intoCharArray(setArr, i);
} }
} }

View File

@ -2,7 +2,18 @@ package com.fastasyncworldedit.core.internal.simd;
import com.fastasyncworldedit.core.queue.Filter; import com.fastasyncworldedit.core.queue.Filter;
import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask;
public interface VectorizedFilter extends Filter { public interface VectorizedFilter extends Filter {
ShortVector applyVector(ShortVector get, ShortVector set);
/**
* Applies a filter to a vector pair of get and set.
*
* @param get the get vector
* @param set the set vector
* @param mask the mask with the lanes set to true which should be affected by the filter
* @return the resulting set vector.
*/
ShortVector applyVector(ShortVector get, ShortVector set, VectorMask<Short> mask);
} }

View File

@ -3,6 +3,7 @@ package com.fastasyncworldedit.core.internal.simd;
import com.fastasyncworldedit.core.queue.IChunk; import com.fastasyncworldedit.core.queue.IChunk;
import com.fastasyncworldedit.core.queue.IChunkGet; import com.fastasyncworldedit.core.queue.IChunkGet;
import com.fastasyncworldedit.core.queue.IChunkSet; import com.fastasyncworldedit.core.queue.IChunkSet;
import com.sk89q.worldedit.world.block.BlockTypesCache;
import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorMask; import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorSpecies; import jdk.incubator.vector.VectorSpecies;
@ -31,10 +32,22 @@ public interface VectorizedMask {
} }
} }
/**
* {@return the set vector with all lanes that do not match this mask set to 0}
*
* @param set the set vector
* @param get the get vector
*/
default ShortVector processVector(ShortVector set, ShortVector get) { default ShortVector processVector(ShortVector set, ShortVector get) {
return set.blend(0, compareVector(set, get).not()); return set.blend(BlockTypesCache.ReservedIDs.__RESERVED__, compareVector(set, get).not());
} }
/**
* {@return a mask with all lanes set that match this mask}
*
* @param set the set vector
* @param get the get vector
*/
VectorMask<Short> compareVector(ShortVector set, ShortVector get); VectorMask<Short> compareVector(ShortVector set, ShortVector get);
} }