diff --git a/app/src/main/java/org/schabi/newpipe/util/StateSaver.java b/app/src/main/java/org/schabi/newpipe/util/StateSaver.java index 61fdb602f..b6877d375 100644 --- a/app/src/main/java/org/schabi/newpipe/util/StateSaver.java +++ b/app/src/main/java/org/schabi/newpipe/util/StateSaver.java @@ -35,8 +35,12 @@ import org.schabi.newpipe.MainActivity; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InvalidClassException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; import java.util.LinkedList; import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; @@ -63,13 +67,8 @@ public final class StateSaver { * @param context used to get the available cache dir */ public static void init(final Context context) { - final File externalCacheDir = context.getExternalCacheDir(); - if (externalCacheDir != null) { - cacheDirPath = externalCacheDir.getAbsolutePath(); - } - if (TextUtils.isEmpty(cacheDirPath)) { - cacheDirPath = context.getCacheDir().getAbsolutePath(); - } + // Use internal cache directory to prevent other apps from accessing/modifying the state + cacheDirPath = context.getCacheDir().getAbsolutePath(); } /** @@ -129,7 +128,7 @@ public final class StateSaver { } try (FileInputStream fileInputStream = new FileInputStream(file); - ObjectInputStream inputStream = new ObjectInputStream(fileInputStream)) { + ObjectInputStream inputStream = new ValidatingObjectInputStream(fileInputStream)) { //noinspection unchecked savedObjects = (Queue) inputStream.readObject(); } @@ -310,6 +309,34 @@ public final class StateSaver { } } + private static final class ValidatingObjectInputStream extends ObjectInputStream { + ValidatingObjectInputStream(final InputStream in) throws IOException { + super(in); + } + + @Override + protected Class resolveClass(final ObjectStreamClass desc) + throws IOException, ClassNotFoundException { + final String name = desc.getName(); + if (!isSafe(name)) { + throw new InvalidClassException("Unauthorized deserialization attempt", name); + } + return super.resolveClass(desc); + } + + private boolean isSafe(final String name) { + return name.startsWith("java.lang.") + || name.startsWith("java.util.") + || name.startsWith("org.schabi.newpipe.") + || name.startsWith("[Ljava.lang.") + || name.startsWith("[Ljava.util.") + || name.startsWith("[Lorg.schabi.newpipe.") + || name.equals("[Z") || name.equals("[B") || name.equals("[C") + || name.equals("[S") || name.equals("[I") || name.equals("[J") + || name.equals("[F") || name.equals("[D"); + } + } + /** * Used for describing how to save/read the objects. *