/*
	This program is free software; you can redistribute it and/or modify
	it under the terms of the GNU General Public License version 2 
	as published by the Free Software Foundation.

	This program is distributed in the hope that it will be useful,
	but WITHOUT ANY WARRANTY; without even the implied warranty of
	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
	GNU General Public License for more details.

	You should have received a copy of the GNU General Public License
	along with this program; if not, write to the Free Software
	Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA


	Copyright (C) 2006  Thierry Berger-Perrin <tbptbp@gmail.com>
*/
#ifndef SSE_H
#define SSE_H


#include <xmmintrin.h>	// __m128
#include <emmintrin.h>	// __m128i

namespace sse {
	// a shuffle that makes sense.
	// indices are meant in memory order: say float m[4]; x = m[0], ... w = m[3]
	// shuffleps mixes 2 vector together with half the element from first vector and half from the other
	// that's exactly what shuffle2ps does, and the way you think.
	// shufps -> 4 cycles, unpckhps -> 3 cycles, movlhps -> 2 cycles

	#define shuffle2ps(ps1,x,y, ps2,z,w)	_mm_shuffle_ps((ps1),(ps2),shuffleps_mask(x,y,z,w))
	#define shuffle1ps(ps1,x,y,z,w)		_mm_shuffle_ps((ps1),(ps1),shuffleps_mask(x,y,z,w))
	#define shuffleps_mask(x,y,z,w)		((x)|((y)<<2)|((z)<<4)|((w)<<6))

	// broacasts a scalar
#ifdef NOTANISSUENOWANDFUCKSICC
	#define broadcastps(ps)		_mm_shuffle_ps((ps),(ps), 0)
	//#define broadcastss(ss)		broadcastps(loadss((ss)))
#else
	// _mm_load1_ps?
	//#define broadcastps(ps)		_mm_set1_ps(ps)
	#define set1ps(f)			_mm_set1_ps((f))
	#define load1ps(mem)		_mm_load1_ps((mem))
#endif

	// special case (gcc writes a temp otherwise). loads & broacasts an (u)int32_t
	/*
	#ifdef __GNUC__
		// __builtin_ia32_loadd: nuked in recent gcc it seems.
		#define broadcastpi(mem)	_mm_shuffle_epi32(__builtin_ia32_loadd((const int*const)(mem)), 0)
	#else
		#define broadcastpi(mem)	_mm_set1_epi32(*(mem))
	#endif
	*/
	#define broadcastpi(mem)	_mm_set1_epi32(*(mem))

	// broadcasts one component
	#define splatps(ps, axis)	_mm_shuffle_ps((ps),(ps),	(axis<<6) | (axis<<4) | (axis<<2) | axis)
	#define splatpi(pi, axis)	_mm_shuffle_epi32((pi),		(axis<<6) | (axis<<4) | (axis<<2) | axis)

	
	#define rotaterps(ps)		shuffle1ps((ps), 3,0,1,2)			// a,b,c,d -> d,a,b,c
	#define rotatelps(ps)		shuffle1ps((ps), 1,2,3,0)			// a,b,c,d -> b,c,d,a
	#define swapps(ps)			shuffle1ps((ps), 2,3,0,1)			// a,b,c,d -> c,d,a,b

	// _mm_movelh_ps!
	//#define muxps(low,high)		shuffle2ps((low),0,1, (high),0,1)	// low{a,b,c,d}|high{e,f,g,h} = {a,b,e,f}
	#define muxps(low,high)			_mm_movelh_ps((low),(high))		// low{a,b,c,d}|high{e,f,g,h} = {a,b,e,f}
	#define muxhps(low,high)		_mm_movehl_ps((low),(high))		// low{a,b,c,d}|high{e,f,g,h} = {c,d,g,h}

	// exchange one value with the first one (to produce a scalar), i is 0 based
	// exchangessps({1,2,3,4}, 2) = {3,2,1,4}
	#define exchangessps(ps,i)	shuffle1ps((ps), (i), ((i)==1?0:1), ((i)==2?0:2), ((i)==3?0:3))


	// shuffleps 4 cycles
	#define loadps(mem)			_mm_load_ps((const float*)(mem))
	#define loadpsu(mem)		_mm_loadu_ps((const float*)(mem))
	#define loadss(mem)			_mm_load_ss((const float*)(mem))	// 3 cycles
	#define storeps(ps, mem)	_mm_store_ps((float*)(mem), (ps))
	#define storepsu(ps, mem)	_mm_storeu_ps((float*)(mem), (ps))
	#define storess(ss, mem)	_mm_store_ss((float*)(mem), (ss))

	// scalar=x, ps={a,b,c,d} -> {x,b,c,d}
	#define movess(scalar, ps)	_mm_move_ss((ps), (scalar))


	// loads a scalar into a specific slot
	// ie: loadss_into(9.0f, 2) = { 0, 9, 0, 0 }
	#define loadss_into_helper(ss, axis)	_mm_shuffle_ps((ss),(ss), ((axis==3?0:1)<<6) | ((axis==2?0:1)<<4) | ((axis==1?0:1)<<2) | ((axis==0?0:1)))
	#define movess_into(ss, axis)			loadss_into_helper((ss),(axis))
	#define loadss_into(mem, axis)			loadss_into_helper(loadss(mem), (axis))

	// returns { *mem0, *mem1, *mem2, *mem3 }
	#define composeps(mem0, mem1, mem2, mem3) \
		_mm_movelh_ps( \
			_mm_unpacklo_ps(loadss(&(mem0)), loadss(&(mem1))), \
			_mm_unpacklo_ps(loadss(&(mem2)), loadss(&(mem3))))


	// MOVDQA
	#define loadpi(mem)		_mm_load_si128((const __m128i * const)(mem))
	#define storepi(ps, mem)	_mm_store_si128((__m128i * const)(mem), (ps))
	 
	 

	#define addps(ps1,ps2)		_mm_add_ps((ps1),(ps2))		// 5 cycles
	#define subps(ps1,ps2)		_mm_sub_ps((ps1),(ps2))		// 5 cycles
	#define mulps(ps1,ps2)		_mm_mul_ps((ps1),(ps2))		// 5 cycles
	#define divps(ps1,ps2)		_mm_div_ps((ps1),(ps2))		// 33 cycles
	#define minps(ps1,ps2)		_mm_min_ps((ps1),(ps2))		// 3 cycles
	#define maxps(ps1,ps2)		_mm_max_ps((ps1),(ps2))		// 3 cycles
	#define andps(ps1,ps2)		_mm_and_ps((ps1),(ps2))		// 3 cycles
	#define andnps(ps1,ps2)		_mm_andnot_ps((ps1),(ps2))	// 3 cycles
	#define orps(ps1,ps2)		_mm_or_ps((ps1),(ps2))		// 3 cycles
	#define xorps(ps1,ps2)		_mm_xor_ps((ps1),(ps2))

	// x * 1/sqrt(x) = sqrt(x)
	#define rcpps(ps)		_mm_rcp_ps((ps))		// 4 cycles
	#define rsqrtps(ps)		_mm_rsqrt_ps((ps))		// 4 cycles
	#define sqrtps(ps)		_mm_sqrt_ps((ps))		// 39 cycles, 19 for ss version
	#define rsqrtss(ss)		_mm_rsqrt_ss((ss))		// ? cycles
	#define rcpss(ss)		_mm_rcp_ss((ss))		// ? cycles

	#define subss(ps1,ps2)		_mm_sub_ss((ps1),(ps2))
	#define addss(ps1,ps2)		_mm_add_ss((ps1),(ps2))		// 4 cycles
	#define mulss(ps1,ps2)		_mm_mul_ss((ps1),(ps2))
	#define divss(ps1,ps2)		_mm_div_ss((ps1),(ps2))		// 16 cycles
	#define minss(ps1,ps2)		_mm_min_ss((ps1),(ps2))		// 2 cycles
	#define maxss(ps1,ps2)		_mm_max_ss((ps1),(ps2))		// 2 cycles
	#define andss			andps
	#define andnss			andnps
	#define orss			orps


	#define maskw(ps)		andps((ps), loadps((const float *)rt::ps_cst_maskw))


	// comparisons
	#define cmplt(ps1,ps2)		_mm_cmplt_ps((ps1),(ps2))
	#define cmpgt(ps1,ps2)		_mm_cmpgt_ps((ps1),(ps2))
	#define cmple(ps1,ps2)		_mm_cmple_ps((ps1),(ps2))
	#define cmpge(ps1,ps2)		_mm_cmpge_ps((ps1),(ps2))
	#define cmpeq(ps1,ps2)		_mm_cmpeq_ps((ps1),(ps2))
	#define cmpneq(ps1,ps2)		_mm_cmpneq_ps((ps1),(ps2))

	#define cmpord(ps1,ps2)		_mm_cmpord_ps((ps1),(ps2))
	#define cmpunord(ps1,ps2)	_mm_cmpunord_ps((ps1),(ps2))

	#define movemask(ps)		_mm_movemask_ps((ps))
	#define mask_any(ps)		(movemask(ps) != 0)
	#define mask_all(ps)		(movemask(ps) == 15)

	// integer
	// PCMPEQD __m128i _mm_cmpeq_epi32 ( __m128i a, __m128i b)
	// PMOVMSKB int _mm_movemask_epi8 ( __m128i a) -> 16 bits returned
	// produce a mask with all 1: pxor xmm0; pcmeq xmm0,xmm0;

	// constants & constants generation.
	//FIXCST:
	#define all_zero()			_mm_setzero_ps()
	#define all_one()			cst::section.one		// { 1.0f, 1.0f, 1.0f, 1.0f }
	#define all_inf()			cst::section.plus_inf	// { +inf, +inf, +inf, +inf }
	#define all_minus_inf()		cst::section.minus_inf	// { -inf, -inf, -inf, -inf }
	
	
	//#define bits_all_one()		loadpsi(rt::pi_bits_all_one)
	// hmm. could use any register to do that, but can't really express that.
	//#define bits_all_oneps()	cmpeq(all_zero(),all_zero())
	//#define bits_all_onepi()	_mm_cmpeq_epi32(_mm_setzero_si128(), _mm_setzero_si128())
	#define bits_all_set()		cst::section.all_set
	
	//#define bits_sign_mask()

/*
; change sign of four single-precision floats in XMM0
CMPEQD XMM1,XMM1 ; generate all 1's
PSLLD XMM1,31 ; 1 in the leftmost bit of each DWORD only
XORPS XMM0,XMM1 ; change sign of XMM0

; absolute value of four single-precision floats in XMM0
CMPEQD XMM1,XMM1 ; generate all 1's
PSRLD XMM1,1 ; 1 in all but the leftmost bit of each DWORD
ANDPS XMM0,XMM1 ; set sign bits to 0
*/
//	#define shiftllps(ps, count)	
// hmm. ask for m128 <-> m128i transitions. bad for k8.

	// __m128i support
	#define orpi(pi1,pi2)			_mm_or_si128((pi1),(pi2))		// 3 cycles
	#define andpi(pi1,pi2)			_mm_and_si128((pi1),(pi2))		// 3 cycles
	#define andnpi(pi1,pi2)			_mm_andnot_si128((pi1),(pi2))	// 3 cycles

	#define shiftlpi(pi, octet)		_mm_slli_si128(pi, octet)		// shift the whole register left by 8xoctet bits.
	#define shiftrpi(pi, octet)		_mm_srli_si128(pi, octet)		// shift the whole register right by 8xoctet bits.
	//? #define srl _mm_srli_si128


	// conditionnal moves
	// update_mask with 1s where you want the update_val, 0s where you want previous_val
	// invert previous_val/update_val if you have a negated mask
	/*
	#define cond_moveps(update_mask, previous_val, update_val)	orps(andnps((update_mask), (previous_val)), andps((update_mask), (update_val)))
	#define cond_movepi(update_mask, previous_val, update_val)	orpi(andnpi((update_mask), (previous_val)), andpi((update_mask), (update_val)))
	#define cond_movess(update_mask, previous_val, update_val)	orps(andnps((update_mask), (previous_val)), andps((update_mask), (update_val)))
	*/

	#define cond_moveps(mask, original, update)	orps(andnps((mask), (original)), andps((mask), (update)))
	#define cond_movepi(mask, original, update)	orpi(andnpi((mask), (original)), andpi((mask), (update)))
	#define cond_movess(mask, original, update)	orps(andnps((mask), (original)), andps((mask), (update)))

	// composite
	//FIXCST: #define negateps(ps)	xorps((ps), loadps(&rt::ps_abs_mask))
	//#define absps(ps)		andps((ps), loadps(&rt::ps_abs_mask))
	#define negateps(op)	xorps(rt::cst::section.sign_mask.v, (op))
	#define absps(op)		andnps(rt::cst::section.sign_mask.v, (op))

	// wrapper for __m128 -> __m128i casting
	#ifdef __MSVC__
		// __m128i / __m128 bridge (useful for masks etc)
		// only needed for msvc now (icc has cast*, and gcc knows better).
		union ps_pi_t {
			__m128i pi;
			__m128	ps;
			ps_pi_t(const __m128 v) : ps(v) {}
			ps_pi_t(const __m128i v) : pi(v) {}

			static __m128i from_ps(const __m128 v)  { return ps_pi_t(v).pi; }
			static __m128  from_pi(const __m128i v) { return ps_pi_t(v).ps; }
	
			operator __m128()	const	{ return ps; }
			operator __m128i()	const	{ return pi; }
			operator __m128()			{ return ps; }
			operator __m128i()			{ return pi; }
		};

		/*
		namespace ps_pi_helper {
			static FINLINE const __m128i ps_pi_caster(const __m128 v) {
				const rt::ps_pi_t conv(v);
				return conv.pi;
			}
			static FINLINE const __m128 pi_ps_caster(const __m128i i) {
				const rt::ps_pi_t conv(i);
				return conv.ps;
			}
		}
		*/
		//#define castps2pi(ps) rt::ps_pi_helper::ps_pi_caster((ps))
		//#define castpi2ps(pi) rt::ps_pi_helper::pi_ps_caster((pi))
		#define castps2pi(v)	sse::ps_pi_t::from_ps(v)
		#define castpi2ps(v)	sse::ps_pi_t::from_pi(v)
	#elif defined(__GCC__)
		#define castps2pi(ps) ((const __m128i)(ps))
		#define castpi2ps(pi) ((const __m128)(pi))
	#else	// ICC
		#define castps2pi(ps) (_mm_castps_si128((ps)))
		#define castpi2ps(pi) (_mm_castsi128_ps((pi)))
	#endif
}

#endif
