From d78957d1cfc11ed9fdcd780115067f735de91729 Mon Sep 17 00:00:00 2001
From: kd-11 <karokidii@gmail.com>
Date: Fri, 6 Jul 2018 14:50:29 +0300
Subject: [PATCH] rsx/vp: CodeGen improvements - Fix double destination writes
 on conditional write masking - Fix codegen to simplify simple scalar
 comparisons vs vector functions

---
 rpcs3/Emu/RSX/Common/GLSLCommon.h             | 50 +++++++++++++------
 rpcs3/Emu/RSX/Common/ShaderParam.h            | 15 +++++-
 .../RSX/Common/VertexProgramDecompiler.cpp    | 43 ++++++++++------
 .../Emu/RSX/Common/VertexProgramDecompiler.h  |  2 +-
 .../D3D12/D3D12VertexProgramDecompiler.cpp    |  2 +-
 .../RSX/D3D12/D3D12VertexProgramDecompiler.h  |  2 +-
 rpcs3/Emu/RSX/GL/GLVertexProgram.cpp          |  4 +-
 rpcs3/Emu/RSX/GL/GLVertexProgram.h            |  2 +-
 rpcs3/Emu/RSX/VK/VKVertexProgram.cpp          |  4 +-
 rpcs3/Emu/RSX/VK/VKVertexProgram.h            |  2 +-
 10 files changed, 87 insertions(+), 39 deletions(-)

diff --git a/rpcs3/Emu/RSX/Common/GLSLCommon.h b/rpcs3/Emu/RSX/Common/GLSLCommon.h
index de26e385de..3f526e42e6 100644
--- a/rpcs3/Emu/RSX/Common/GLSLCommon.h
+++ b/rpcs3/Emu/RSX/Common/GLSLCommon.h
@@ -110,23 +110,45 @@ namespace glsl
 		}
 	}
 
-	static std::string compareFunctionImpl(COMPARE f, const std::string &Op0, const std::string &Op1)
+	static std::string compareFunctionImpl(COMPARE f, const std::string &Op0, const std::string &Op1, bool scalar = false)
 	{
-		switch (f)
+		if (scalar)
 		{
-		case COMPARE::FUNCTION_SEQ:
-			return "equal(" + Op0 + ", " + Op1 + ")";
-		case COMPARE::FUNCTION_SGE:
-			return "greaterThanEqual(" + Op0 + ", " + Op1 + ")";
-		case COMPARE::FUNCTION_SGT:
-			return "greaterThan(" + Op0 + ", " + Op1 + ")";
-		case COMPARE::FUNCTION_SLE:
-			return "lessThanEqual(" + Op0 + ", " + Op1 + ")";
-		case COMPARE::FUNCTION_SLT:
-			return "lessThan(" + Op0 + ", " + Op1 + ")";
-		case COMPARE::FUNCTION_SNE:
-			return "notEqual(" + Op0 + ", " + Op1 + ")";
+			switch (f)
+			{
+			case COMPARE::FUNCTION_SEQ:
+				return Op0 + " == " + Op1;
+			case COMPARE::FUNCTION_SGE:
+				return Op0 + " >= " + Op1;
+			case COMPARE::FUNCTION_SGT:
+				return Op0 + " > " + Op1;
+			case COMPARE::FUNCTION_SLE:
+				return Op0 + " <= " + Op1;
+			case COMPARE::FUNCTION_SLT:
+				return Op0 + " < " + Op1;
+			case COMPARE::FUNCTION_SNE:
+				return Op0 + " != " + Op1;
+			}
 		}
+		else
+		{
+			switch (f)
+			{
+			case COMPARE::FUNCTION_SEQ:
+				return "equal(" + Op0 + ", " + Op1 + ")";
+			case COMPARE::FUNCTION_SGE:
+				return "greaterThanEqual(" + Op0 + ", " + Op1 + ")";
+			case COMPARE::FUNCTION_SGT:
+				return "greaterThan(" + Op0 + ", " + Op1 + ")";
+			case COMPARE::FUNCTION_SLE:
+				return "lessThanEqual(" + Op0 + ", " + Op1 + ")";
+			case COMPARE::FUNCTION_SLT:
+				return "lessThan(" + Op0 + ", " + Op1 + ")";
+			case COMPARE::FUNCTION_SNE:
+				return "notEqual(" + Op0 + ", " + Op1 + ")";
+			}
+		}
+
 		fmt::throw_exception("Unknown compare function" HERE);
 	}
 
diff --git a/rpcs3/Emu/RSX/Common/ShaderParam.h b/rpcs3/Emu/RSX/Common/ShaderParam.h
index 8d92acf307..837b677117 100644
--- a/rpcs3/Emu/RSX/Common/ShaderParam.h
+++ b/rpcs3/Emu/RSX/Common/ShaderParam.h
@@ -177,7 +177,20 @@ public:
 	ShaderVariable() = default;
 	ShaderVariable(const std::string& var)
 	{
-		auto var_blocks = fmt::split(var, { "." });
+		// Separate 'double destination' variables 'X=Y=SRC'
+		std::string simple_var;
+		const auto pos = var.find("=");
+
+		if (pos != std::string::npos)
+		{
+			simple_var = var.substr(0, pos - 1);
+		}
+		else
+		{
+			simple_var = var;
+		}
+
+		auto var_blocks = fmt::split(simple_var, { "." });
 
 		verify(HERE), (var_blocks.size() != 0);
 
diff --git a/rpcs3/Emu/RSX/Common/VertexProgramDecompiler.cpp b/rpcs3/Emu/RSX/Common/VertexProgramDecompiler.cpp
index 47e6657b79..bb28dc1b23 100644
--- a/rpcs3/Emu/RSX/Common/VertexProgramDecompiler.cpp
+++ b/rpcs3/Emu/RSX/Common/VertexProgramDecompiler.cpp
@@ -283,32 +283,45 @@ void VertexProgramDecompiler::AddCodeCond(const std::string& dst, const std::str
 		COMPARE::FUNCTION_SGE,
 	};
 
-	static const char f[4] = { 'x', 'y', 'z', 'w' };
-
-	std::string swizzle;
-	swizzle += f[d0.mask_x];
-	swizzle += f[d0.mask_y];
-	swizzle += f[d0.mask_z];
-	swizzle += f[d0.mask_w];
-
-	swizzle = swizzle == "xyzw" ? "" : "." + swizzle;
-
-	std::string cond = compareFunction(cond_string_table[d0.cond], AddCondReg() + swizzle, getFloatTypeName(4) + "(0., 0., 0., 0.)");
-
 	ShaderVariable dst_var(dst);
 	dst_var.simplify();
 
-	//const char *c_mask = f;
+	static const char f[4] = { 'x', 'y', 'z', 'w' };
+	const u32 mask_index[4] = { d0.mask_x, d0.mask_y, d0.mask_z, d0.mask_w };
+
+	auto get_masked_dst = [](const std::string& dest, const char mask)
+	{
+		const auto selector = std::string(".") + mask;
+		const auto pos = dest.find("=");
+
+		std::string result = dest + selector;
+
+		if (pos != std::string::npos)
+		{
+			result.insert(pos - 1, selector);
+		}
+
+		return result;
+	};
+
+	auto get_cond_func = [this, &mask_index](COMPARE op, int index)
+	{
+		// Condition reg check for single element (x,y,z,w)
+		const auto cond_mask = f[mask_index[index]];
+		return compareFunction(op, AddCondReg() + "." + cond_mask, "0.", true);
+	};
 
 	if (dst_var.swizzles[0].length() == 1)
 	{
-		AddCode("if (" + cond + ".x) " + dst + " = " + src + ";");
+		const std::string cond = get_cond_func(cond_string_table[d0.cond], 0);
+		AddCode("if (" + cond + ") " + dst + " = " + src + ";");
 	}
 	else
 	{
 		for (int i = 0; i < dst_var.swizzles[0].length(); ++i)
 		{
-			AddCode("if (" + cond + "." + f[i] + ") " + dst + "." + f[i] + " = " + src + "." + f[i] + ";");
+			const std::string cond = get_cond_func(cond_string_table[d0.cond], i);
+			AddCode("if (" + cond + ") " + get_masked_dst(dst, f[i]) + " = " + src + "." + f[i] + ";");
 		}
 	}
 }
diff --git a/rpcs3/Emu/RSX/Common/VertexProgramDecompiler.h b/rpcs3/Emu/RSX/Common/VertexProgramDecompiler.h
index 9d5600ed70..38e022e18c 100644
--- a/rpcs3/Emu/RSX/Common/VertexProgramDecompiler.h
+++ b/rpcs3/Emu/RSX/Common/VertexProgramDecompiler.h
@@ -98,7 +98,7 @@ protected:
 
 	/** returns string calling comparison function on 2 args passed as strings.
 	*/
-	virtual std::string compareFunction(COMPARE, const std::string &, const std::string &) = 0;
+	virtual std::string compareFunction(COMPARE, const std::string &, const std::string &, bool scalar = false) = 0;
 
 	/** Insert header of shader file (eg #version, "system constants"...)
 	*/
diff --git a/rpcs3/Emu/RSX/D3D12/D3D12VertexProgramDecompiler.cpp b/rpcs3/Emu/RSX/D3D12/D3D12VertexProgramDecompiler.cpp
index 6476501b8c..d9908c9408 100644
--- a/rpcs3/Emu/RSX/D3D12/D3D12VertexProgramDecompiler.cpp
+++ b/rpcs3/Emu/RSX/D3D12/D3D12VertexProgramDecompiler.cpp
@@ -21,7 +21,7 @@ std::string D3D12VertexProgramDecompiler::getFunction(enum class FUNCTION f)
 	return getFunctionImp(f);
 }
 
-std::string D3D12VertexProgramDecompiler::compareFunction(COMPARE f, const std::string &Op0, const std::string &Op1)
+std::string D3D12VertexProgramDecompiler::compareFunction(COMPARE f, const std::string &Op0, const std::string &Op1, bool /*scalar*/)
 {
 	return compareFunctionImp(f, Op0, Op1);
 }
diff --git a/rpcs3/Emu/RSX/D3D12/D3D12VertexProgramDecompiler.h b/rpcs3/Emu/RSX/D3D12/D3D12VertexProgramDecompiler.h
index 01161b37c8..8c0eee1278 100644
--- a/rpcs3/Emu/RSX/D3D12/D3D12VertexProgramDecompiler.h
+++ b/rpcs3/Emu/RSX/D3D12/D3D12VertexProgramDecompiler.h
@@ -10,7 +10,7 @@ protected:
 	virtual std::string getFloatTypeName(size_t elementCount) override;
 	std::string getIntTypeName(size_t elementCount) override;
 	virtual std::string getFunction(enum class FUNCTION) override;
-	virtual std::string compareFunction(enum class COMPARE, const std::string &, const std::string &) override;
+	virtual std::string compareFunction(enum class COMPARE, const std::string &, const std::string &, bool scalar) override;
 
 	virtual void insertHeader(std::stringstream &OS);
 	virtual void insertInputs(std::stringstream &OS, const std::vector<ParamType> &inputs);
diff --git a/rpcs3/Emu/RSX/GL/GLVertexProgram.cpp b/rpcs3/Emu/RSX/GL/GLVertexProgram.cpp
index 239e62366c..7801e2dd15 100644
--- a/rpcs3/Emu/RSX/GL/GLVertexProgram.cpp
+++ b/rpcs3/Emu/RSX/GL/GLVertexProgram.cpp
@@ -23,9 +23,9 @@ std::string GLVertexDecompilerThread::getFunction(FUNCTION f)
 	return glsl::getFunctionImpl(f);
 }
 
-std::string GLVertexDecompilerThread::compareFunction(COMPARE f, const std::string &Op0, const std::string &Op1)
+std::string GLVertexDecompilerThread::compareFunction(COMPARE f, const std::string &Op0, const std::string &Op1, bool scalar)
 {
-	return glsl::compareFunctionImpl(f, Op0, Op1);
+	return glsl::compareFunctionImpl(f, Op0, Op1, scalar);
 }
 
 void GLVertexDecompilerThread::insertHeader(std::stringstream &OS)
diff --git a/rpcs3/Emu/RSX/GL/GLVertexProgram.h b/rpcs3/Emu/RSX/GL/GLVertexProgram.h
index 25d01d2da6..f883be01ba 100644
--- a/rpcs3/Emu/RSX/GL/GLVertexProgram.h
+++ b/rpcs3/Emu/RSX/GL/GLVertexProgram.h
@@ -20,7 +20,7 @@ protected:
 	virtual std::string getFloatTypeName(size_t elementCount) override;
 	std::string getIntTypeName(size_t elementCount) override;
 	virtual std::string getFunction(FUNCTION) override;
-	virtual std::string compareFunction(COMPARE, const std::string&, const std::string&) override;
+	virtual std::string compareFunction(COMPARE, const std::string&, const std::string&, bool scalar) override;
 
 	virtual void insertHeader(std::stringstream &OS) override;
 	virtual void insertInputs(std::stringstream &OS, const std::vector<ParamType> &inputs) override;
diff --git a/rpcs3/Emu/RSX/VK/VKVertexProgram.cpp b/rpcs3/Emu/RSX/VK/VKVertexProgram.cpp
index ed49514dbc..f72970e29d 100644
--- a/rpcs3/Emu/RSX/VK/VKVertexProgram.cpp
+++ b/rpcs3/Emu/RSX/VK/VKVertexProgram.cpp
@@ -20,9 +20,9 @@ std::string VKVertexDecompilerThread::getFunction(FUNCTION f)
 	return glsl::getFunctionImpl(f);
 }
 
-std::string VKVertexDecompilerThread::compareFunction(COMPARE f, const std::string &Op0, const std::string &Op1)
+std::string VKVertexDecompilerThread::compareFunction(COMPARE f, const std::string &Op0, const std::string &Op1, bool scalar)
 {
-	return glsl::compareFunctionImpl(f, Op0, Op1);
+	return glsl::compareFunctionImpl(f, Op0, Op1, scalar);
 }
 
 void VKVertexDecompilerThread::insertHeader(std::stringstream &OS)
diff --git a/rpcs3/Emu/RSX/VK/VKVertexProgram.h b/rpcs3/Emu/RSX/VK/VKVertexProgram.h
index 744c8c3396..3c21b0ab24 100644
--- a/rpcs3/Emu/RSX/VK/VKVertexProgram.h
+++ b/rpcs3/Emu/RSX/VK/VKVertexProgram.h
@@ -14,7 +14,7 @@ protected:
 	virtual std::string getFloatTypeName(size_t elementCount) override;
 	std::string getIntTypeName(size_t elementCount) override;
 	virtual std::string getFunction(FUNCTION) override;
-	virtual std::string compareFunction(COMPARE, const std::string&, const std::string&) override;
+	virtual std::string compareFunction(COMPARE, const std::string&, const std::string&, bool scalar) override;
 
 	virtual void insertHeader(std::stringstream &OS) override;
 	virtual void insertInputs(std::stringstream &OS, const std::vector<ParamType> &inputs) override;