diff --git a/primitive/infiniteplane.cpp b/primitive/infiniteplane.cpp index 81bd973..a0851f2 100644 --- a/primitive/infiniteplane.cpp +++ b/primitive/infiniteplane.cpp @@ -28,9 +28,9 @@ bool InfinitePlane::intersect(Ray &ray) const { // Set the normal if (acos(dotProduct(ray.direction, this->normal)) > 0) { - ray.normal = this->normal; + ray.normal = normalized(this->normal); } else { - ray.normal = -1.0f * this->normal; + ray.normal = normalized((-1.0f * this->normal)); } diff --git a/primitive/triangle.cpp b/primitive/triangle.cpp index b3cb2bd..e0dd877 100644 --- a/primitive/triangle.cpp +++ b/primitive/triangle.cpp @@ -101,7 +101,7 @@ bool Triangle::intersect(Ray &ray) const { return false; // Calculate the normal - ray.normal = crossProduct(edge1, edge1); + ray.normal = normalized(crossProduct(edge1, edge1)); // Calculate the surface position ray.surface = u * this->surface[1] + v * this->surface[2] + (1 - u - v) * this->surface[0]; diff --git a/shader/mirrorshader.cpp b/shader/mirrorshader.cpp index 28780db..35b5567 100644 --- a/shader/mirrorshader.cpp +++ b/shader/mirrorshader.cpp @@ -5,12 +5,11 @@ MirrorShader::MirrorShader() {} Color MirrorShader::shade(Scene const &scene, Ray const &ray) const { if (ray.getRemainingBounces() <= 0) { - return Color(0, 0, 0); + return Color(0, 255, 0); } Vector3d newDirection = ray.direction - 2 * dotProduct(ray.normal, ray.direction) * ray.normal; - //TODO: should we reset the ray or use a new ray? Regarding the count of bounces Ray mirroredRay = Ray(ray.origin + ray.length * ray.direction, newDirection); mirroredRay.setRemainingBounces(ray.getRemainingBounces() - 1); diff --git a/shader/refractionshader.cpp b/shader/refractionshader.cpp index 8a942c2..41ef265 100644 --- a/shader/refractionshader.cpp +++ b/shader/refractionshader.cpp @@ -1,3 +1,4 @@ +#include #include "scene/scene.h" #include "shader/refractionshader.h" @@ -11,19 +12,72 @@ Color RefractionShader::shade(Scene const &scene, Ray const &ray) const { // Send out a new refracted ray into the scene; recursively call traceRay() if (ray.getRemainingBounces() < 1) { - return Color(0,0,0); + return Color(0,255,0); + } + + auto eta1 = indexOutside; + auto eta2 = indexInside; + + + Vector3d v = normalized(ray.direction); + Vector3d n = normalized(ray.normal); + + // Check whether we are entering or leaving the object + if (acos(dotProduct(v, n)) < PI/2.0f) { + eta1 = indexInside; + eta2 = indexOutside; } - Vector3d n = ray.normal; - double root = sqrt(1 - (indexOutside*indexOutside*(1 - dotProduct(ray.direction, n) - *dotProduct(ray.direction, n)))/(indexInside*indexInside)); + // Calculate refraction percentage + float denominator = 1.0f - + (pow(eta1, 2) * (1 - pow(dotProduct(v, n), 2)) / + pow(eta2, 2)); - Vector3d t = (indexOutside/indexInside)*(ray.direction - dotProduct(ray.direction, n) * n)- n * root; + // Check whether there is any refracted ray at all or whether we have total internal reflection + if (denominator <= 0) { + Vector3d newDirection = v - 2.0f * dotProduct(n, v) * n; - Ray reflectionRay = Ray(ray.origin + ray.length * ray.direction, t); - reflectionRay.setRemainingBounces(ray.getRemainingBounces()-1); + Ray mirroredRay = Ray(ray.origin + ray.length * v, newDirection); + mirroredRay.setRemainingBounces(ray.getRemainingBounces() - 1); - return scene.traceRay(reflectionRay); + return scene.traceRay(mirroredRay); + } + + // Calculate refracted ray direction t + Vector3d t = normalized( + eta1 / eta2 * + (v - dotProduct(v, n) * n) - + n * sqrt(denominator)); + + // Calculate reflected ray direction + Vector3d cosTheta1 = crossProduct(v, n); + Vector3d cosTheta2 = -1.0f * crossProduct(n, t); + + Vector3d polarizationParallel = (eta2 * cosTheta1 - eta1 * cosTheta2) / + (eta2 * cosTheta1 + eta1 * cosTheta2); + Vector3d polarizationOrthogonal = (eta1 * cosTheta1 - eta2 * cosTheta2) / + (eta1 * cosTheta1 + eta2 * cosTheta2); + + Vector3d reflectionVector = normalized( + 0.5f * (polarizationParallel * polarizationParallel + polarizationOrthogonal * polarizationOrthogonal)); + + // Calculate rays + Vector3d rayOrigin = ray.origin + ray.length * ray.direction; + Ray refractionRay = Ray(rayOrigin, t); + Ray reflectionRay = Ray(rayOrigin, reflectionVector); + + int remainingBounces = ray.getRemainingBounces() - 1; + + refractionRay.setRemainingBounces(remainingBounces); + reflectionRay.setRemainingBounces(remainingBounces); + + // Calculate Color + Color refractionColor = scene.traceRay(refractionRay); + Color reflectionColor = scene.traceRay(reflectionRay); + + Color resultingColor = denominator * refractionColor + (1-denominator) * reflectionColor; + + return resultingColor; } bool RefractionShader::isTransparent() const { return true; }