}
};
+template<int Size, typename Ret, typename Arg0, typename Arg1>
+struct ApplyRefract
+{
+ static ExprP<Ret> apply (ExpandContext& ctx,
+ const ExprP<Arg0>& i,
+ const ExprP<Arg1>& n,
+ const ExprP<float>& eta)
+ {
+ const ExprP<float> dotNI = bindExpression("dotNI", ctx, dot(n, i));
+ const ExprP<float> k = bindExpression("k", ctx, constant(1.0f) - eta * eta *
+ (constant(1.0f) - dotNI * dotNI));
+
+ return cond(k < constant(0.0f),
+ genXType<float, Size>(constant(0.0f)),
+ i * eta - n * (eta * dotNI + sqrt(k)));
+ };
+};
+
+template<typename Ret, typename Arg0, typename Arg1>
+struct ApplyRefract<1, Ret, Arg0, Arg1>
+{
+ static ExprP<Ret> apply (ExpandContext& ctx,
+ const ExprP<Arg0>& i,
+ const ExprP<Arg1>& n,
+ const ExprP<float>& eta)
+ {
+ const ExprP<float> dotNI = bindExpression("dotNI", ctx, dot(n, i));
+ const ExprP<float> k1 = bindExpression("k1", ctx, constant(1.0f) - eta * eta *
+ (constant(1.0f) - dotNI * dotNI));
+
+ const ExprP<float> k2 = bindExpression("k2", ctx,
+ (((dotNI * (-dotNI)) + constant(1.0f)) * eta)
+ * (-eta) + constant(1.0f));
+
+ return alternatives(cond(k1 < constant(0.0f),
+ genXType<float, 1>(constant(0.0f)),
+ i * eta - n * (eta * dotNI + sqrt(k1))),
+ cond(k2 < constant(0.0f),
+ genXType<float, 1>(constant(0.0f)),
+ i * eta - n * (eta * dotNI + sqrt(k2))));
+ };
+};
+
template <int Size>
class Refract : public DerivedFunc<
Signature<typename ContainerOf<float, Size>::Container,
const ExprP<Arg0>& i = args.a;
const ExprP<Arg1>& n = args.b;
const ExprP<float>& eta = args.c;
- const ExprP<float> dotNI = bindExpression("dotNI", ctx, dot(n, i));
- const ExprP<float> k = bindExpression("k", ctx, constant(1.0f) - eta * eta *
- (constant(1.0f) - dotNI * dotNI));
- return cond(k < constant(0.0f),
- genXType<float, Size>(constant(0.0f)),
- i * eta - n * (eta * dotNI + sqrt(k)));
+ return ApplyRefract<Size, Ret, Arg0, Arg1>::apply(ctx, i, n, eta);
}
};