}
};
+template<int Size, typename Ret, typename Arg0, typename Arg1>
+struct ApplyReflect
+{
+ static ExprP<Ret> apply (ExpandContext& ctx,
+ const ExprP<Arg0>& i,
+ const ExprP<Arg1>& n)
+ {
+ const ExprP<float> dotNI = bindExpression("dotNI", ctx, dot(n, i));
+
+ return i - alternatives((n * dotNI) * constant(2.0f),
+ n * (dotNI * constant(2.0f)));
+ };
+};
+
+template<typename Ret, typename Arg0, typename Arg1>
+struct ApplyReflect<1, Ret, Arg0, Arg1>
+{
+ static ExprP<Ret> apply (ExpandContext& ctx,
+ const ExprP<Arg0>& i,
+ const ExprP<Arg1>& n)
+ {
+ return i - alternatives(alternatives((n * (n*i)) * constant(2.0f),
+ n * ((n*i) * constant(2.0f))),
+ (n * n) * (i * constant(2.0f)));
+ };
+};
+
template <int Size>
class Reflect : public DerivedFunc<
Signature<typename ContainerOf<float, Size>::Container,
{
const ExprP<Arg0>& i = args.a;
const ExprP<Arg1>& n = args.b;
- const ExprP<float> dotNI = bindExpression("dotNI", ctx, dot(n, i));
- return i - alternatives((n * dotNI) * constant(2.0f),
- n * (dotNI * constant(2.0f)));
+ return ApplyReflect<Size, Ret, Arg0, Arg1>::apply(ctx, i, n);
}
};